From c1318b3d038fff9cb49c4a40642947307fb41e2b Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Mon, 4 Nov 2013 18:32:05 -0800 Subject: [PATCH 1/4] Fix for codeplex-1357 This change allows users to configure the casing behavior of simple membership provider. The simple membership provider will by default generate a database query that normalizes the case of usernames on the database side. This comes with the side effect of obviating any index that the user may have configured for the user name column. The fix is to make this behavior configurable. With the new option, it will be possible to turn off casing normalization, and allow the database to handle it specific to its collation. --- .../SimpleMembershipProvider.cs | 67 ++++++++++---- .../SimpleMembershipProviderCasingBehavior.cs | 32 +++++++ src/WebMatrix.WebData/SimpleRoleProvider.cs | 6 +- .../WebMatrix.WebData.csproj | 1 + src/WebMatrix.WebData/WebSecurity.cs | 87 ++++++++++++++++--- .../SimpleMembershipProviderTest.cs | 42 +++++++++ 6 files changed, 204 insertions(+), 31 deletions(-) create mode 100644 src/WebMatrix.WebData/SimpleMembershipProviderCasingBehavior.cs diff --git a/src/WebMatrix.WebData/SimpleMembershipProvider.cs b/src/WebMatrix.WebData/SimpleMembershipProvider.cs index 87583867..a4b54cd7 100644 --- a/src/WebMatrix.WebData/SimpleMembershipProvider.cs +++ b/src/WebMatrix.WebData/SimpleMembershipProvider.cs @@ -157,7 +157,7 @@ namespace WebMatrix.WebData get { return "webpages_OAuthMembership"; } } - internal static string OAuthTokenTableName + internal static string OAuthTokenTableName { get { return "webpages_OAuthToken"; } } @@ -187,6 +187,8 @@ namespace WebMatrix.WebData // REVIEW: we could get this from the primary key of UserTable in the future public string UserIdColumn { get; set; } + public SimpleMembershipProviderCasingBehavior CasingBehavior { get; set; } + internal DatabaseConnectionInfo ConnectionInfo { get; set; } internal bool InitializeCalled { get; set; } @@ -297,15 +299,42 @@ namespace WebMatrix.WebData VerifyInitialized(); using (var db = ConnectToDatabase()) { - return GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + return GetUserId(db, userName); } } - internal static int GetUserId(IDatabase db, string userTableName, string userNameColumn, string userIdColumn, string userName) + private int GetUserId(IDatabase db, string userName) { - // Casing is normalized in Sql to allow the database to normalize username according to its collation. The common issue - // that can occur here is the 'Turkish i problem', where the uppercase of 'i' is not 'I' in Turkish. - var result = db.QueryValue(@"SELECT " + userIdColumn + " FROM " + userTableName + " WHERE (UPPER(" + userNameColumn + ") = UPPER(@0))", userName); + return GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, CasingBehavior, userName); + } + + internal static int GetUserId( + IDatabase db, + string userTableName, + string userNameColumn, + string userIdColumn, + SimpleMembershipProviderCasingBehavior casingBehavior, + string userName) + { + dynamic result; + if (casingBehavior == SimpleMembershipProviderCasingBehavior.NormalizeCasing) + { + // Casing is normalized in Sql to allow the database to normalize username according to its collation. The common issue + // that can occur here is the 'Turkish i problem', where the uppercase of 'i' is not 'I' in Turkish. + result = db.QueryValue(@"SELECT " + userIdColumn + " FROM " + userTableName + " WHERE (UPPER(" + userNameColumn + ") = UPPER(@0))", userName); + } + else if (casingBehavior == SimpleMembershipProviderCasingBehavior.RelyOnDatabaseCollation) + { + // When this option is supplied we assume the database has been configured with an appropriate casing, and don't normalize + // the user name. This is performant but requires appropriate configuration on the database. + result = db.QueryValue(@"SELECT " + userIdColumn + " FROM " + userTableName + " WHERE (" + userNameColumn + " = @0)", userName); + } + else + { + Debug.Fail("Unexpected enum value"); + return -1; + } + if (result != null) { return (int)result; @@ -429,7 +458,7 @@ namespace WebMatrix.WebData using (var db = ConnectToDatabase()) { // Step 1: Check if the user exists in the Users table - int uid = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int uid = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, CasingBehavior, userName); if (uid == -1) { // User not found @@ -476,7 +505,7 @@ namespace WebMatrix.WebData private void CreateUserRow(IDatabase db, string userName, IDictionary values) { // Make sure user doesn't exist - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId != -1) { throw new MembershipCreateUserException(MembershipCreateStatus.DuplicateUserName); @@ -575,7 +604,7 @@ namespace WebMatrix.WebData using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int userId = GetUserId(db, username); if (userId == -1) { return false; // User not found @@ -622,7 +651,7 @@ namespace WebMatrix.WebData // Due to a bug in v1, GetUser allows passing null / empty values. using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int userId = GetUserId(db, username); if (userId == -1) { return null; // User not found @@ -649,7 +678,7 @@ namespace WebMatrix.WebData using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId == -1) { return false; // User not found @@ -670,7 +699,7 @@ namespace WebMatrix.WebData using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int userId = GetUserId(db, username); if (userId == -1) { return false; // User not found @@ -746,7 +775,7 @@ namespace WebMatrix.WebData { using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, userName)); @@ -761,7 +790,7 @@ namespace WebMatrix.WebData { using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, userName)); @@ -781,7 +810,7 @@ namespace WebMatrix.WebData { using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, userName)); @@ -801,7 +830,7 @@ namespace WebMatrix.WebData { using (var db = ConnectToDatabase()) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName); + int userId = GetUserId(db, userName); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, userName)); @@ -852,7 +881,7 @@ namespace WebMatrix.WebData // Ensures the user exists in the accounts table private int VerifyUserNameHasConfirmedAccount(IDatabase db, string username, bool throwException) { - int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int userId = GetUserId(db, username); if (userId == -1) { if (throwException) @@ -1004,7 +1033,7 @@ namespace WebMatrix.WebData // GetUser will fail with an exception if the user table isn't set up properly try { - GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, "z"); + GetUserId(db, "z"); } catch (Exception e) { @@ -1255,7 +1284,7 @@ namespace WebMatrix.WebData { dynamic id = db.QueryValue(@"SELECT UserId FROM [" + MembershipTableName + "] WHERE UserId=@0", userId); return id != null; - } + } } } } \ No newline at end of file diff --git a/src/WebMatrix.WebData/SimpleMembershipProviderCasingBehavior.cs b/src/WebMatrix.WebData/SimpleMembershipProviderCasingBehavior.cs new file mode 100644 index 00000000..3d6bf05c --- /dev/null +++ b/src/WebMatrix.WebData/SimpleMembershipProviderCasingBehavior.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +namespace WebMatrix.WebData +{ + /// + /// Configures the behavior of SimpleMembershipProvider for the casing of user name queries. + /// + public enum SimpleMembershipProviderCasingBehavior + { + /// + /// Uses the SQL Upper function to normalize the casing of user names for a case-insensitive comparion. + /// This is the default value. + /// + /// + /// This option uses the SQL Upper function to perform case-normalization. This guarantees that the + /// the user name is searched case-insensitively, but can have a performance impact when a large number + /// of users exist. + /// + NormalizeCasing, + + /// + /// Relies on the database's configured collation to normalize casing for the comparison of user names. User + /// names are provided to the database exactly as entered by the user. + /// + /// + /// This option relies on the configured collection of database table for user names to perform a correct comparison. + /// This is guaranteed to be correct for the chosen collation and performant. Only choose this option if the table storing + /// user names is configured with the desired collation. + /// + RelyOnDatabaseCollation, + } +} diff --git a/src/WebMatrix.WebData/SimpleRoleProvider.cs b/src/WebMatrix.WebData/SimpleRoleProvider.cs index 0e03b6fb..c151621b 100644 --- a/src/WebMatrix.WebData/SimpleRoleProvider.cs +++ b/src/WebMatrix.WebData/SimpleRoleProvider.cs @@ -74,6 +74,8 @@ namespace WebMatrix.WebData // REVIEW: we could get this from the primary key of UserTable in the future public string UserIdColumn { get; set; } + public SimpleMembershipProviderCasingBehavior CasingBehavior { get; set; } + internal DatabaseConnectionInfo ConnectionInfo { get; set; } internal bool InitializeCalled { get; set; } @@ -142,7 +144,7 @@ namespace WebMatrix.WebData List userIds = new List(usernames.Length); foreach (string username in usernames) { - int id = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int id = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, CasingBehavior, username); if (id == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username)); @@ -307,7 +309,7 @@ namespace WebMatrix.WebData } using (var db = ConnectToDatabase()) { - int userId = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); + int userId = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, CasingBehavior, username); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username)); diff --git a/src/WebMatrix.WebData/WebMatrix.WebData.csproj b/src/WebMatrix.WebData/WebMatrix.WebData.csproj index e203b482..003f3aee 100644 --- a/src/WebMatrix.WebData/WebMatrix.WebData.csproj +++ b/src/WebMatrix.WebData/WebMatrix.WebData.csproj @@ -50,6 +50,7 @@ WebDataResources.resx + diff --git a/src/WebMatrix.WebData/WebSecurity.cs b/src/WebMatrix.WebData/WebSecurity.cs index d481b4ad..0fb1ee37 100644 --- a/src/WebMatrix.WebData/WebSecurity.cs +++ b/src/WebMatrix.WebData/WebSecurity.cs @@ -102,42 +102,99 @@ namespace WebMatrix.WebData public static void InitializeDatabaseConnection(string connectionStringName, string userTableName, string userIdColumn, string userNameColumn, bool autoCreateTables) { - DatabaseConnectionInfo connect = new DatabaseConnectionInfo(); - connect.ConnectionStringName = connectionStringName; - InitializeProviders(connect, userTableName, userIdColumn, userNameColumn, autoCreateTables); + InitializeDatabaseConnection( + connectionStringName, + userTableName, + userIdColumn, + userNameColumn, + autoCreateTables, + SimpleMembershipProviderCasingBehavior.NormalizeCasing); } - public static void InitializeDatabaseConnection(string connectionString, string providerName, string userTableName, string userIdColumn, string userNameColumn, bool autoCreateTables) + public static void InitializeDatabaseConnection( + string connectionStringName, + string userTableName, + string userIdColumn, + string userNameColumn, + bool autoCreateTables, + SimpleMembershipProviderCasingBehavior casingBehavior) + { + DatabaseConnectionInfo connect = new DatabaseConnectionInfo(); + connect.ConnectionStringName = connectionStringName; + InitializeProviders(connect, userTableName, userIdColumn, userNameColumn, autoCreateTables, casingBehavior); + } + + public static void InitializeDatabaseConnection( + string connectionString, + string providerName, + string userTableName, + string userIdColumn, + string userNameColumn, + bool autoCreateTables) + { + InitializeDatabaseConnection( + connectionString, + providerName, + userTableName, + userIdColumn, + userNameColumn, + autoCreateTables, + SimpleMembershipProviderCasingBehavior.NormalizeCasing); + } + + public static void InitializeDatabaseConnection( + string connectionString, + string providerName, + string userTableName, + string userIdColumn, + string userNameColumn, + bool autoCreateTables, + SimpleMembershipProviderCasingBehavior casingBehavior) { DatabaseConnectionInfo connect = new DatabaseConnectionInfo(); connect.ConnectionString = connectionString; connect.ProviderName = providerName; - InitializeProviders(connect, userTableName, userIdColumn, userNameColumn, autoCreateTables); + InitializeProviders(connect, userTableName, userIdColumn, userNameColumn, autoCreateTables, casingBehavior); } - private static void InitializeProviders(DatabaseConnectionInfo connect, string userTableName, string userIdColumn, string userNameColumn, bool autoCreateTables) + private static void InitializeProviders( + DatabaseConnectionInfo connect, + string userTableName, + string userIdColumn, + string userNameColumn, + bool autoCreateTables, + SimpleMembershipProviderCasingBehavior casingBehavior) { SimpleMembershipProvider simpleMembership = Membership.Provider as SimpleMembershipProvider; if (simpleMembership != null) { - InitializeMembershipProvider(simpleMembership, connect, userTableName, userIdColumn, userNameColumn, autoCreateTables); + InitializeMembershipProvider(simpleMembership, connect, userTableName, userIdColumn, userNameColumn, autoCreateTables, casingBehavior); } SimpleRoleProvider simpleRoles = Roles.Provider as SimpleRoleProvider; if (simpleRoles != null) { - InitializeRoleProvider(simpleRoles, connect, userTableName, userIdColumn, userNameColumn, autoCreateTables); + InitializeRoleProvider(simpleRoles, connect, userTableName, userIdColumn, userNameColumn, autoCreateTables, casingBehavior); } Initialized = true; } - internal static void InitializeMembershipProvider(SimpleMembershipProvider simpleMembership, DatabaseConnectionInfo connect, string userTableName, string userIdColumn, string userNameColumn, bool createTables) + internal static void InitializeMembershipProvider( + SimpleMembershipProvider simpleMembership, + DatabaseConnectionInfo connect, + string userTableName, + string userIdColumn, + string userNameColumn, + bool createTables, + SimpleMembershipProviderCasingBehavior casingBehavior) { if (simpleMembership.InitializeCalled) { throw new InvalidOperationException(WebDataResources.Security_InitializeAlreadyCalled); } + + simpleMembership.CasingBehavior = casingBehavior; simpleMembership.ConnectionInfo = connect; simpleMembership.UserIdColumn = userIdColumn; simpleMembership.UserNameColumn = userNameColumn; @@ -154,16 +211,26 @@ namespace WebMatrix.WebData simpleMembership.InitializeCalled = true; } - internal static void InitializeRoleProvider(SimpleRoleProvider simpleRoles, DatabaseConnectionInfo connect, string userTableName, string userIdColumn, string userNameColumn, bool createTables) + internal static void InitializeRoleProvider( + SimpleRoleProvider simpleRoles, + DatabaseConnectionInfo connect, + string userTableName, + string userIdColumn, + string userNameColumn, + bool createTables, + SimpleMembershipProviderCasingBehavior casingBehavior) { if (simpleRoles.InitializeCalled) { throw new InvalidOperationException(WebDataResources.Security_InitializeAlreadyCalled); } + + simpleRoles.CasingBehavior = casingBehavior; simpleRoles.ConnectionInfo = connect; simpleRoles.UserTableName = userTableName; simpleRoles.UserIdColumn = userIdColumn; simpleRoles.UserNameColumn = userNameColumn; + if (createTables) { simpleRoles.CreateTablesIfNeeded(); diff --git a/test/WebMatrix.WebData.Test/SimpleMembershipProviderTest.cs b/test/WebMatrix.WebData.Test/SimpleMembershipProviderTest.cs index f94f9d15..d2b7fb1b 100644 --- a/test/WebMatrix.WebData.Test/SimpleMembershipProviderTest.cs +++ b/test/WebMatrix.WebData.Test/SimpleMembershipProviderTest.cs @@ -170,6 +170,48 @@ namespace WebMatrix.WebData.Test Assert.Equal("fGH_eKcjvW__P-5BOEW1AA2", result); } + [Fact] + public void GetUserId_WithCaseNormalization() + { + // Arrange + var database = new Mock(MockBehavior.Strict); + var expectedQuery = @"SELECT userId FROM users WHERE (UPPER(userName) = UPPER(@0))"; + database.Setup(d => d.QueryValue(expectedQuery, "zeke")).Returns(999); + + // Act + var result = SimpleMembershipProvider.GetUserId( + database.Object, + "users", + "userName", + "userId", + SimpleMembershipProviderCasingBehavior.NormalizeCasing, + "zeke"); + + // Assert + Assert.Equal(999, result); + } + + [Fact] + public void GetUserId_WithoutCaseNormalization() + { + // Arrange + var database = new Mock(MockBehavior.Strict); + var expectedQuery = @"SELECT userId FROM users WHERE (userName = @0)"; + database.Setup(d => d.QueryValue(expectedQuery, "zeke")).Returns(999); + + // Act + var result = SimpleMembershipProvider.GetUserId( + database.Object, + "users", + "userName", + "userId", + SimpleMembershipProviderCasingBehavior.RelyOnDatabaseCollation, + "zeke"); + + // Assert + Assert.Equal(999, result); + } + private static DynamicRecord GetRecord(int userId, string confirmationToken) { var data = new Mock(MockBehavior.Strict); From 6f15e848550f9f8b9e32648d53684163ef2436b6 Mon Sep 17 00:00:00 2001 From: dougbu Date: Mon, 4 Nov 2013 23:52:07 -0800 Subject: [PATCH 2/4] Address FxCop warnings that showed up in CI builds - For e.g. http://wsr-teamcity/viewLog.html?buildId=10081&tab=buildResultsDiv&buildTypeId=bt4 --- src/CodeAnalysisDictionary.xml | 1 + .../Formatting/BaseJsonMediaTypeFormatter.cs | 6 +++--- .../Formatting/BsonMediaTypeFormatter.cs | 17 +++++++++-------- .../Formatting/JsonMediaTypeFormatter.cs | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/CodeAnalysisDictionary.xml b/src/CodeAnalysisDictionary.xml index 5338c8f7..ddbf56bf 100644 --- a/src/CodeAnalysisDictionary.xml +++ b/src/CodeAnalysisDictionary.xml @@ -53,6 +53,7 @@ Rfc Realtime ModelName + BSON WebPage diff --git a/src/System.Net.Http.Formatting/Formatting/BaseJsonMediaTypeFormatter.cs b/src/System.Net.Http.Formatting/Formatting/BaseJsonMediaTypeFormatter.cs index e7e2033f..2cb1e103 100644 --- a/src/System.Net.Http.Formatting/Formatting/BaseJsonMediaTypeFormatter.cs +++ b/src/System.Net.Http.Formatting/Formatting/BaseJsonMediaTypeFormatter.cs @@ -31,7 +31,7 @@ namespace System.Net.Http.Formatting /// /// Initializes a new instance of the class. /// - public BaseJsonMediaTypeFormatter() + protected BaseJsonMediaTypeFormatter() { // Initialize serializer settings #if !NETFX_CORE // DataContractResolver is not supported in portable library @@ -54,8 +54,8 @@ namespace System.Net.Http.Formatting Contract.Assert(formatter != null); SerializerSettings = formatter.SerializerSettings; -#if !NETFX_CORE // MaxDepth is not supported in portable library - MaxDepth = formatter.MaxDepth; +#if !NETFX_CORE // MaxDepth is not supported in portable library and so _maxDepth never changes there + _maxDepth = formatter._maxDepth; #endif } diff --git a/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs b/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs index 0a0a7eb9..9605a518 100644 --- a/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs +++ b/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs @@ -147,13 +147,14 @@ namespace System.Net.Http.Formatting throw Error.ArgumentNull("effectiveEncoding"); } - return new BsonReader(new BinaryReader(readStream, effectiveEncoding)) - { - // Special case discussed at http://stackoverflow.com/questions/16910369/bson-array-deserialization-with-json-net - // Dispensed with string (aka IEnumerable) case above in ReadFromStream() - ReadRootValueAsArray = - typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IDictionary).IsAssignableFrom(type), - }; + BsonReader reader = new BsonReader(new BinaryReader(readStream, effectiveEncoding)); + + // Special case discussed at http://stackoverflow.com/questions/16910369/bson-array-deserialization-with-json-net + // Dispensed with string (aka IEnumerable) case above in ReadFromStream() + reader.ReadRootValueAsArray = + typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IDictionary).IsAssignableFrom(type); + + return reader; } /// @@ -218,7 +219,7 @@ namespace System.Net.Http.Formatting return new BsonWriter(new BinaryWriter(writeStream, effectiveEncoding)); } - private bool IsSimpleType(Type type) + private static bool IsSimpleType(Type type) { bool isSimpleType; #if NETFX_CORE // TypeDescriptor is not supported in portable library diff --git a/src/System.Net.Http.Formatting/Formatting/JsonMediaTypeFormatter.cs b/src/System.Net.Http.Formatting/Formatting/JsonMediaTypeFormatter.cs index 61c49f46..e4f51a94 100644 --- a/src/System.Net.Http.Formatting/Formatting/JsonMediaTypeFormatter.cs +++ b/src/System.Net.Http.Formatting/Formatting/JsonMediaTypeFormatter.cs @@ -59,7 +59,7 @@ namespace System.Net.Http.Formatting Contract.Assert(formatter != null); #if !NETFX_CORE // MaxDepth and UseDataContractJsonSerializer are not supported in portable library - MaxDepth = formatter.MaxDepth; + _readerQuotas.MaxDepth = formatter._readerQuotas.MaxDepth; UseDataContractJsonSerializer = formatter.UseDataContractJsonSerializer; #endif From 7cdd9084081476fbe7944c4afc4a1826d2f53790 Mon Sep 17 00:00:00 2001 From: dougbu Date: Tue, 5 Nov 2013 09:06:18 -0800 Subject: [PATCH 3/4] BSON support fixup round 3: One last FxCop warning - Ensure JSON reader gets cleaned up --- .../Formatting/BsonMediaTypeFormatter.cs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs b/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs index 9605a518..fa7feaae 100644 --- a/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs +++ b/src/System.Net.Http.Formatting/Formatting/BsonMediaTypeFormatter.cs @@ -149,10 +149,19 @@ namespace System.Net.Http.Formatting BsonReader reader = new BsonReader(new BinaryReader(readStream, effectiveEncoding)); - // Special case discussed at http://stackoverflow.com/questions/16910369/bson-array-deserialization-with-json-net - // Dispensed with string (aka IEnumerable) case above in ReadFromStream() - reader.ReadRootValueAsArray = + try + { + // Special case discussed at http://stackoverflow.com/questions/16910369/bson-array-deserialization-with-json-net + // Dispensed with string (aka IEnumerable) case above in ReadFromStream() + reader.ReadRootValueAsArray = typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IDictionary).IsAssignableFrom(type); + } + catch + { + // Ensure instance is cleaned up in case of an issue + ((IDisposable)reader).Dispose(); + throw; + } return reader; } From aed4ac8667b28bae0db7825562cbcc07d53a1e2d Mon Sep 17 00:00:00 2001 From: davidmatson Date: Mon, 4 Nov 2013 11:40:48 -0800 Subject: [PATCH 4/4] Preserve exception stack trace in ActionFilterAttribute (fixes #1316). --- .../Filters/ActionFilterAttribute.cs | 52 ++++++++--- .../Filters/ActionFilterAttributeTest.cs | 92 +++++++++++++++++++ 2 files changed, 131 insertions(+), 13 deletions(-) diff --git a/src/System.Web.Http/Filters/ActionFilterAttribute.cs b/src/System.Web.Http/Filters/ActionFilterAttribute.cs index bafd204c..7128c2f1 100644 --- a/src/System.Web.Http/Filters/ActionFilterAttribute.cs +++ b/src/System.Web.Http/Filters/ActionFilterAttribute.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Net.Http; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; using System.Web.Http.Controllers; @@ -83,38 +84,63 @@ namespace System.Web.Http.Filters cancellationToken.ThrowIfCancellationRequested(); HttpResponseMessage response = null; - Exception exception = null; + ExceptionDispatchInfo exceptionInfo = null; try { response = await continuation(); } catch (Exception e) { - exception = e; + exceptionInfo = ExceptionDispatchInfo.Capture(e); } + Exception exception; + + if (exceptionInfo == null) + { + exception = null; + } + else + { + exception = exceptionInfo.SourceException; + } + + HttpActionExecutedContext executedContext = new HttpActionExecutedContext(actionContext, exception) + { + Response = response + }; + try { - HttpActionExecutedContext executedContext = new HttpActionExecutedContext(actionContext, exception) { Response = response }; await OnActionExecutedAsync(executedContext, cancellationToken); - - if (executedContext.Response != null) - { - return executedContext.Response; - } - if (executedContext.Exception != null) - { - throw executedContext.Exception; - } } catch { - // Catch is running because OnActionExecuted threw an exception, so we just want to re-throw the exception. + // Catch is running because OnActionExecuted threw an exception, so we just want to re-throw. // We also need to reset the response to forget about it since a filter threw an exception. actionContext.Response = null; throw; } + if (executedContext.Response != null) + { + return executedContext.Response; + } + + Exception newException = executedContext.Exception; + + if (newException != null) + { + if (newException == exception) + { + exceptionInfo.Throw(); + } + else + { + throw newException; + } + } + throw Error.InvalidOperation(SRResources.ActionFilterAttribute_MustSupplyResponseOrException, GetType().Name); } } diff --git a/test/System.Web.Http.Test/Filters/ActionFilterAttributeTest.cs b/test/System.Web.Http.Test/Filters/ActionFilterAttributeTest.cs index 0e5ab70d..9af566cc 100644 --- a/test/System.Web.Http.Test/Filters/ActionFilterAttributeTest.cs +++ b/test/System.Web.Http.Test/Filters/ActionFilterAttributeTest.cs @@ -472,6 +472,98 @@ namespace System.Web.Http.Filters ); } + [Fact] + public void ExecuteActionFilterAsync_IfOnActionExecutedReplacesException_ThrowsNewException() + { + // Arrange + Exception expectedReplacementException = CreateException(); + + using (HttpRequestMessage request = new HttpRequestMessage()) + { + Mock mock = new Mock(); + mock.CallBase = true; + mock + .Setup(f => f.OnActionExecuted(It.IsAny())) + .Callback((c) => c.Exception = expectedReplacementException); + IActionFilter product = mock.Object; + + HttpActionContext context = ContextUtil.CreateActionContext(); + Func> continuation = () => + CreateFaultedTask(CreateException()); + + // Act + Task task = product.ExecuteActionFilterAsync(context, CancellationToken.None, + continuation); + + // Assert + Assert.NotNull(task); + task.WaitUntilCompleted(); + Assert.Equal(TaskStatus.Faulted, task.Status); + Assert.NotNull(task.Exception); + Exception exception = task.Exception.GetBaseException(); + Assert.Same(expectedReplacementException, exception); + } + } + + [Fact] + public void ExecuteActionFilterAsync_IfFaultedTaskExceptionIsUnhandled_PreservesExceptionStackTrace() + { + // Arrange + Exception originalException = CreateExceptionWithStackTrace(); + string expectedStackTrace = originalException.StackTrace; + + using (HttpRequestMessage request = new HttpRequestMessage()) + { + IActionFilter product = new TestableActionFilter(); + HttpActionContext context = ContextUtil.CreateActionContext(); + Func> continuation = () => CreateFaultedTask( + originalException); + + // Act + Task task = product.ExecuteActionFilterAsync(context, CancellationToken.None, + continuation); + + // Assert + Assert.NotNull(task); + task.WaitUntilCompleted(); + Assert.Equal(TaskStatus.Faulted, task.Status); + Assert.NotNull(task.Exception); + Exception exception = task.Exception.GetBaseException(); + Assert.NotNull(expectedStackTrace); + Assert.NotNull(exception); + Assert.NotNull(exception.StackTrace); + Assert.True(exception.StackTrace.StartsWith(expectedStackTrace)); + } + } + + private static Exception CreateException() + { + return new InvalidOperationException(); + } + + private static Exception CreateExceptionWithStackTrace() + { + Exception exception; + + try + { + throw CreateException(); + } + catch (Exception ex) + { + exception = ex; + } + + return exception; + } + + private static Task CreateFaultedTask(Exception exception) + { + TaskCompletionSource source = new TaskCompletionSource(); + source.SetException(exception); + return source.Task; + } + public class TestableActionFilter : ActionFilterAttribute { }