From 43b91dabd1b1174d6cf14e207effeb385e87abac Mon Sep 17 00:00:00 2001 From: Carol Wang Date: Fri, 15 Nov 2024 13:14:01 +0800 Subject: [PATCH] Stop splitting the UPN credential into username and domain parts unless specified (#5663) * Stop splitting the UPN credential into username and domain parts unless specified and add unit test. --- .../ServiceModel/Security/SecurityUtils.cs | 9 ++- .../tests/Security/SecurityUtilsTest.cs | 57 +++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 src/System.ServiceModel.Primitives/tests/Security/SecurityUtilsTest.cs diff --git a/src/System.ServiceModel.Primitives/src/System/ServiceModel/Security/SecurityUtils.cs b/src/System.ServiceModel.Primitives/src/System/ServiceModel/Security/SecurityUtils.cs index 941f9c546..41fc68f66 100644 --- a/src/System.ServiceModel.Primitives/src/System/ServiceModel/Security/SecurityUtils.cs +++ b/src/System.ServiceModel.Primitives/src/System/ServiceModel/Security/SecurityUtils.cs @@ -201,6 +201,8 @@ namespace System.ServiceModel.Security public const string Principal = "Principal"; private static IIdentity s_anonymousIdentity; private static X509SecurityTokenAuthenticator s_nonValidatingX509Authenticator; + internal const string EnableLegacyUpnUsernameFixString = "Switch.System.ServiceModel.EnableLegacyUpnUsernameFix"; + internal static bool s_enableLegacyUpnUsernameFix = AppContext.TryGetSwitch(EnableLegacyUpnUsernameFixString, out bool enabled) && enabled; internal static string GetSpnFromIdentity(EndpointIdentity identity, EndpointAddress target) { @@ -932,6 +934,11 @@ namespace System.ServiceModel.Security } internal static void FixNetworkCredential(ref NetworkCredential credential) + { + FixNetworkCredential(ref credential, s_enableLegacyUpnUsernameFix); + } + + internal static void FixNetworkCredential(ref NetworkCredential credential, bool enableLegacyUpnUsernameFix) { if (credential == null) { @@ -952,7 +959,7 @@ namespace System.ServiceModel.Security credential = new NetworkCredential(partsWithSlashDelimiter[1], credential.Password, partsWithSlashDelimiter[0]); } } - else if (partsWithSlashDelimiter.Length == 1 && partsWithAtDelimiter.Length == 2) + else if (enableLegacyUpnUsernameFix && partsWithSlashDelimiter.Length == 1 && partsWithAtDelimiter.Length == 2) { if (!string.IsNullOrEmpty(partsWithAtDelimiter[0]) && !string.IsNullOrEmpty(partsWithAtDelimiter[1])) { diff --git a/src/System.ServiceModel.Primitives/tests/Security/SecurityUtilsTest.cs b/src/System.ServiceModel.Primitives/tests/Security/SecurityUtilsTest.cs new file mode 100644 index 000000000..59fb48661 --- /dev/null +++ b/src/System.ServiceModel.Primitives/tests/Security/SecurityUtilsTest.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Net; +using System.Reflection; +using System.ServiceModel.Security; +using Infrastructure.Common; +using Xunit; + +public static class SecurityUtilsTest +{ + [WcfFact] + public static void FixNetworkCredential_AppContext_EnableLegacyUpnUsernameFix() + { + Type t = Assembly.GetAssembly(typeof(WindowsClientCredential)) + .GetType(typeof(WindowsClientCredential).Namespace + ".SecurityUtils"); + + MethodInfo method = t.GetMethod("FixNetworkCredential", BindingFlags.NonPublic | BindingFlags.Static, + null, new[] { typeof(NetworkCredential).MakeByRefType() }, null); + + FieldInfo f = t.GetField("s_enableLegacyUpnUsernameFix", BindingFlags.Static | BindingFlags.NonPublic); + + //default + var credential = new NetworkCredential("user@domain.com", "password"); + var parameters = new object[] { credential }; + method.Invoke(null, parameters); + credential = (NetworkCredential)parameters[0]; + Assert.NotNull(credential); + Assert.Equal("user@domain.com", credential.UserName); + Assert.Equal("password", credential.Password); + Assert.Equal(string.Empty, credential.Domain); + + //switch on + f.SetValue(t, true); + credential = new NetworkCredential("user@domain.com", "password"); + parameters = new object[] { credential }; + method.Invoke(null, parameters); + credential = (NetworkCredential)parameters[0]; + Assert.NotNull(credential); + Assert.Equal("user", credential.UserName); + Assert.Equal("password", credential.Password); + Assert.Equal("domain.com", credential.Domain); + + //switch off + f.SetValue(t, false); + credential = new NetworkCredential("user@domain.com", "password"); + parameters = new object[] { credential }; + method.Invoke(null, parameters); + credential = (NetworkCredential)parameters[0]; + Assert.NotNull(credential); + Assert.Equal("user@domain.com", credential.UserName); + Assert.Equal("password", credential.Password); + Assert.Equal(string.Empty, credential.Domain); + } +}