diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs index 8f484f1cbc..9b4fe7e9c8 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs @@ -20,7 +20,9 @@ internal sealed class RequestMessageSnapshotStrategy : ResilienceStrategy ResilienceContext context, TState state) { - if (!context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out var request) || request is null) + HttpRequestMessage? request = context.GetRequestMessage(); + + if (request is null) { Throw.InvalidOperationException("The HTTP request message was not found in the resilience context."); } diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/ResilienceHttpClientBuilderExtensions.Hedging.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/ResilienceHttpClientBuilderExtensions.Hedging.cs index 3f7b82ba0d..7d46a6efe5 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/ResilienceHttpClientBuilderExtensions.Hedging.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/ResilienceHttpClientBuilderExtensions.Hedging.cs @@ -94,7 +94,7 @@ public static partial class ResilienceHttpClientBuilderExtensions requestMessage.SetResilienceContext(args.ActionContext); // replace the request message - args.ActionContext.Properties.Set(ResilienceKeys.RequestMessage, requestMessage); + args.ActionContext.SetRequestMessage(requestMessage); if (args.PrimaryContext.Properties.TryGetValue(ResilienceKeys.RoutingStrategy, out var routingPipeline)) { diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs index 6ca28d8d7f..4112666f37 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.Http.Resilience.Internal; internal static class ResilienceKeys { - public static readonly ResiliencePropertyKey RequestMessage = new("Resilience.Http.RequestMessage"); + public static readonly ResiliencePropertyKey RequestMessage = new("Resilience.Http.RequestMessage"); public static readonly ResiliencePropertyKey RoutingStrategy = new("Resilience.Http.RequestRoutingStrategy"); diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs new file mode 100644 index 0000000000..b04bd37a2a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http; +using Microsoft.Extensions.Http.Resilience.Internal; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; +using Polly; + +namespace Polly; + +/// +/// Provides utility methods for working with . +/// +[Experimental(diagnosticId: DiagnosticIds.Experiments.Resilience, UrlFormat = DiagnosticIds.UrlFormat)] +public static class HttpResilienceContextExtensions +{ + /// + /// Gets the request message from the . + /// + /// The resilience context. + /// + /// The request message. + /// If the request message is not present in the the method returns . + /// + /// is . + public static HttpRequestMessage? GetRequestMessage(this ResilienceContext context) + { + _ = Throw.IfNull(context); + return context.Properties.GetValue(ResilienceKeys.RequestMessage, default); + } + + /// + /// Sets the request message on the . + /// + /// The resilience context. + /// The request message. + /// is . + public static void SetRequestMessage(this ResilienceContext context, HttpRequestMessage? requestMessage) + { + _ = Throw.IfNull(context); + context.Properties.Set(ResilienceKeys.RequestMessage, requestMessage); + } +} diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/ResilienceHandler.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/ResilienceHandler.cs index 41a408fa71..aff8226036 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/ResilienceHandler.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/ResilienceHandler.cs @@ -58,7 +58,7 @@ public class ResilienceHandler : DelegatingHandler ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created); TrySetRequestMetadata(context, request); - SetRequestMessage(context, request); + context.SetRequestMessage(request); try { @@ -117,7 +117,7 @@ public class ResilienceHandler : DelegatingHandler ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created); TrySetRequestMetadata(context, request); - SetRequestMessage(context, request); + context.SetRequestMessage(request); try { @@ -165,11 +165,8 @@ public class ResilienceHandler : DelegatingHandler } } - private static void SetRequestMessage(ResilienceContext context, HttpRequestMessage request) - => context.Properties.Set(ResilienceKeys.RequestMessage, request); - private static HttpRequestMessage GetRequestMessage(ResilienceContext context, HttpRequestMessage request) - => context.Properties.GetValue(ResilienceKeys.RequestMessage, request); + => context.GetRequestMessage() ?? request; private static void RestoreResilienceContext(ResilienceContext context, HttpRequestMessage request, bool created) { diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs index 0aefc6cf17..2054cbcccb 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Shared.Diagnostics; @@ -26,7 +27,9 @@ internal sealed class RoutingResilienceStrategy : ResilienceStrategy ResilienceContext context, TState state) { - if (!context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out var request)) + HttpRequestMessage? request = context.GetRequestMessage(); + + if (request is null) { Throw.InvalidOperationException("The HTTP request message was not found in the resilience context."); } diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Internal/RequestMessageSnapshotStrategyTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Internal/RequestMessageSnapshotStrategyTests.cs index 6d1713150d..34bb4df92c 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Internal/RequestMessageSnapshotStrategyTests.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Internal/RequestMessageSnapshotStrategyTests.cs @@ -20,7 +20,7 @@ public class RequestMessageSnapshotStrategyTests var strategy = Create(); var context = ResilienceContextPool.Shared.Get(); using var request = new HttpRequestMessage(); - context.Properties.Set(ResilienceKeys.RequestMessage, request); + context.SetRequestMessage(request); using var response = await strategy.ExecuteAsync( context => @@ -39,5 +39,15 @@ public class RequestMessageSnapshotStrategyTests strategy.Invoking(s => s.Execute(() => { })).Should().Throw(); } + [Fact] + public void ExecuteAsync_RequestMessageIsNull_Throws() + { + var strategy = Create(); + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw(); + } + private static ResiliencePipeline Create() => new ResiliencePipelineBuilder().AddStrategy(_ => new RequestMessageSnapshotStrategy(), Mock.Of()).Build(); } diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs new file mode 100644 index 0000000000..9cd1255c99 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using Microsoft.Extensions.Http.Resilience.Internal; +using Polly; +using Xunit; + +namespace Microsoft.Extensions.Http.Resilience.Test.Resilience; + +public class HttpResilienceContextExtensionsTests +{ + [Fact] + public void GetRequestMessage_ResilienceContextIsNull_Throws() + { + ResilienceContext context = null!; + Assert.Throws(context.GetRequestMessage); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsMissing_ReturnsNull() + { + var context = ResilienceContextPool.Shared.Get(); + + Assert.Null(context.GetRequestMessage()); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsNull_ReturnsNull() + { + var context = ResilienceContextPool.Shared.Get(); + context.Properties.Set(ResilienceKeys.RequestMessage, null); + + Assert.Null(context.GetRequestMessage()); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsPresent_ReturnsRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + using var request = new HttpRequestMessage(); + context.Properties.Set(ResilienceKeys.RequestMessage, request); + + Assert.Same(request, context.GetRequestMessage()); + } + + [Fact] + public void SetRequestMessage_ResilienceContextIsNull_Throws() + { + ResilienceContext context = null!; + using var request = new HttpRequestMessage(); + + Assert.Throws(() => context.SetRequestMessage(request)); + } + + [Fact] + public void SetRequestMessage_RequestMessageIsNull_SetsNullRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? request)); + Assert.Null(request); + } + + [Fact] + public void SetRequestMessage_RequestMessageIsNotNull_SetsRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + using var request = new HttpRequestMessage(); + context.SetRequestMessage(request); + + Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? actualRequest)); + Assert.Same(request, actualRequest); + } +} diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs index 8a1831d67b..9de7255540 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs @@ -8,7 +8,6 @@ using System.Threading; using System.Threading.Tasks; using FluentAssertions; using Microsoft.Extensions.Http.Diagnostics; -using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Extensions.Http.Resilience.Test.Helpers; using Polly; using Xunit; @@ -108,7 +107,7 @@ public class ResilienceHandlerTest handler.InnerHandler = new TestHandlerStub((r, _) => { r.GetResilienceContext().Should().NotBeNull(); - r.GetResilienceContext()!.Properties.GetValue(ResilienceKeys.RequestMessage, null!).Should().BeSameAs(r); + r.GetResilienceContext()!.GetRequestMessage().Should().BeSameAs(r); return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.Created }); }); diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs index 341d32087f..ed1a6d5c6b 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs @@ -4,7 +4,6 @@ using System; using System.Net.Http; using FluentAssertions; -using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Extensions.Http.Resilience.Routing.Internal; using Moq; using Polly; @@ -22,6 +21,16 @@ public class RoutingResilienceStrategyTests strategy.Invoking(s => s.Execute(() => { })).Should().Throw().WithMessage("The HTTP request message was not found in the resilience context."); } + [Fact] + public void RequestMessageIsNull_Throws() + { + var strategy = Create(() => Mock.Of()); + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw().WithMessage("The HTTP request message was not found in the resilience context."); + } + [Fact] public void NoRoutingProvider_Ok() { @@ -29,7 +38,7 @@ public class RoutingResilienceStrategyTests var strategy = Create(null); var context = ResilienceContextPool.Shared.Get(); - context.Properties.Set(ResilienceKeys.RequestMessage, request); + context.SetRequestMessage(request); strategy.Invoking(s => s.Execute(_ => { }, context)).Should().NotThrow(); }