diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs index 944283ccd8..9e2019d9e5 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -11,6 +11,22 @@ namespace Microsoft.Extensions.AI; /// Provides a collection of static methods for extending instances. public static class ChatClientExtensions { + /// Asks the for an object of type . + /// The type of the object to be retrieved. + /// The client. + /// An optional key that may be used to help identify the target service. + /// The found object, otherwise . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + public static TService? GetService(this IChatClient client, object? serviceKey = null) + { + _ = Throw.IfNull(client); + + return (TService?)client.GetService(typeof(TService), serviceKey); + } + /// Sends a user chat text message to the model and returns the response messages. /// The chat client. /// The text content for the chat message to send. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs index a6fb40b355..d92590bad9 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs @@ -63,12 +63,13 @@ public class DelegatingChatClient : IChatClient } /// - public virtual TService? GetService(object? key = null) - where TService : class + public virtual object? GetService(Type serviceType, object? serviceKey = null) { -#pragma warning disable S3060 // "is" should not be used with "this" - // If the key is non-null, we don't know what it means so pass through to the inner service - return key is null && this is TService service ? service : InnerClient.GetService(key); -#pragma warning restore S3060 + _ = Throw.IfNull(serviceType); + + // If the key is non-null, we don't know what it means so pass through to the inner service. + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + InnerClient.GetService(serviceType, serviceKey); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 8cbfa1314f..4e3fd126b3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -56,14 +56,13 @@ public interface IChatClient : IDisposable /// Gets metadata that describes the . ChatClientMetadata Metadata { get; } - /// Asks the for an object of type . - /// The type of the object to be retrieved. - /// An optional key that may be used to help identify the target service. + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that may be used to help identify the target service. /// The found object, otherwise . /// /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the , /// including itself or any services it might be wrapping. /// - TService? GetService(object? key = null) - where TService : class; + object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs index 6b06d32d6d..590817d4e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs @@ -59,12 +59,13 @@ public class DelegatingEmbeddingGenerator : IEmbeddingGenera InnerGenerator.GenerateAsync(values, options, cancellationToken); /// - public virtual TService? GetService(object? key = null) - where TService : class + public virtual object? GetService(Type serviceType, object? serviceKey = null) { -#pragma warning disable S3060 // "is" should not be used with "this" - // If the key is non-null, we don't know what it means so pass through to the inner service - return key is null && this is TService service ? service : InnerGenerator.GetService(key); -#pragma warning restore S3060 + _ = Throw.IfNull(serviceType); + + // If the key is non-null, we don't know what it means so pass through to the inner service. + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + InnerGenerator.GetService(serviceType, serviceKey); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index efa804fd0e..8a388d361b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -15,6 +15,43 @@ namespace Microsoft.Extensions.AI; /// Provides a collection of static methods for extending instances. public static class EmbeddingGeneratorExtensions { + /// Asks the for an object of type . + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that may be used to help identify the target service. + /// The found object, otherwise . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService? GetService(this IEmbeddingGenerator generator, object? serviceKey = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + + return (TService?)generator.GetService(typeof(TService), serviceKey); + } + + // The following overload exists purely to work around the lack of partial generic type inference. + // Given an IEmbeddingGenerator generator, to call GetService with TService, you still need + // to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>. + // The case of string/Embedding is by far the most common case today, so this overload exists as an + // accelerator to allow it to be written simply as generator.GetService. + + /// Asks the for an object of type . + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that may be used to help identify the target service. + /// The found object, otherwise . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) => + GetService, TService>(generator, serviceKey); + /// Generates an embedding vector from the specified . /// The type from which embeddings will be generated. /// The numeric type of the embedding data. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 5cc289fbb5..9f9c9f1325 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -40,14 +40,13 @@ public interface IEmbeddingGenerator : IDisposable /// Gets metadata that describes the . EmbeddingGeneratorMetadata Metadata { get; } - /// Asks the for an object of type . - /// The type of the object to be retrieved. - /// An optional key that may be used to help identify the target service. + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that may be used to help identify the target service. /// The found object, otherwise . /// - /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the , - /// including itself or any services it might be wrapping. + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the + /// , including itself or any services it might be wrapping. /// - TService? GetService(object? key = null) - where TService : class; + object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index ba76f5c3c9..143d592810 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -57,10 +57,16 @@ public sealed class AzureAIInferenceChatClient : IChatClient public ChatClientMetadata Metadata { get; } /// - public TService? GetService(object? key = null) - where TService : class => - typeof(TService) == typeof(ChatCompletionsClient) ? (TService?)(object?)_chatCompletionsClient : - this as TService; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(ChatCompletionsClient) ? _chatCompletionsClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } /// public async Task CompleteAsync( diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index 866e55ad87..3f8f2adb3f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -70,10 +70,16 @@ public sealed class AzureAIInferenceEmbeddingGenerator : public EmbeddingGeneratorMetadata Metadata { get; } /// - public TService? GetService(object? key = null) - where TService : class => - typeof(TService) == typeof(EmbeddingsClient) ? (TService)(object)_embeddingsClient : - this as TService; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(EmbeddingsClient) ? _embeddingsClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } /// public async Task>> GenerateAsync( diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index 18ff5d50b7..e6084e94ab 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -166,9 +166,14 @@ public sealed class OllamaChatClient : IChatClient } /// - public TService? GetService(object? key = null) - where TService : class - => key is null ? this as TService : null; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + null; + } /// public void Dispose() diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 5779b60cbc..ea273c31b4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -57,9 +57,14 @@ public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator - public TService? GetService(object? key = null) - where TService : class - => key is null ? this as TService : null; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + null; + } /// public void Dispose() diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index 985060256f..5490466b66 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -16,6 +16,7 @@ using Microsoft.Shared.Diagnostics; using OpenAI; using OpenAI.Chat; +#pragma warning disable S1067 // Expressions should not be too complex #pragma warning disable S1135 // Track uses of "TODO" tags #pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields #pragma warning disable SA1204 // Static elements should appear before instance elements @@ -85,11 +86,17 @@ public sealed partial class OpenAIChatClient : IChatClient public ChatClientMetadata Metadata { get; } /// - public TService? GetService(object? key = null) - where TService : class => - typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : - typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient : - this as TService; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(OpenAIClient) ? _openAIClient : + serviceType == typeof(ChatClient) ? _chatClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } /// public async Task CompleteAsync( diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 155e047279..5c34a8028a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -11,6 +11,7 @@ using Microsoft.Shared.Diagnostics; using OpenAI; using OpenAI.Embeddings; +#pragma warning disable S1067 // Expressions should not be too complex #pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields namespace Microsoft.Extensions.AI; @@ -95,12 +96,17 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator - public TService? GetService(object? key = null) - where TService : class - => - typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : - typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient : - this as TService; + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(OpenAIClient) ? _openAIClient : + serviceType == typeof(EmbeddingClient) ? _embeddingClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } /// public async Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 68f5ad1224..3732e80503 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -11,6 +11,12 @@ namespace Microsoft.Extensions.AI; public class ChatClientExtensionsTests { + [Fact] + public void GetService_InvalidArgs_Throws() + { + Assert.Throws("client", () => ChatClientExtensions.GetService(null!)); + } + [Fact] public void CompleteAsync_InvalidArgs_Throws() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs index 51c82c7dcb..35027bb71f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -96,6 +96,14 @@ public class DelegatingChatClientTests Assert.False(await enumerator.MoveNextAsync()); } + [Fact] + public void GetServiceThrowsForNullType() + { + using var inner = new TestChatClient(); + using var delegating = new NoOpDelegatingChatClient(inner); + Assert.Throws("serviceType", () => delegating.GetService(null!)); + } + [Fact] public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs index 91640e62f4..3f6732a410 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs @@ -57,6 +57,14 @@ public class DelegatingEmbeddingGeneratorTests Assert.Same(expectedEmbedding, await resultTask); } + [Fact] + public void GetServiceThrowsForNullType() + { + using var inner = new TestEmbeddingGenerator(); + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + Assert.Throws("serviceType", () => delegating.GetService(null!)); + } + [Fact] public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index b6deb1ccd0..4466dd85d1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -10,6 +10,13 @@ namespace Microsoft.Extensions.AI; public class EmbeddingGeneratorExtensionsTests { + [Fact] + public void GetService_InvalidArgs_Throws() + { + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService(null!)); + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); + } + [Fact] public async Task GenerateAsync_InvalidArgs_ThrowsAsync() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs index 55f4c48648..5eacced35b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -26,9 +26,8 @@ public sealed class TestChatClient : IChatClient public IAsyncEnumerable CompleteStreamingAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken); - public TService? GetService(object? key = null) - where TService : class - => (TService?)GetServiceCallback!(typeof(TService), key); + public object? GetService(Type serviceType, object? serviceKey = null) + => GetServiceCallback!(serviceType, serviceKey); void IDisposable.Dispose() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs index 83680a2be1..5b79b1908d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs @@ -19,9 +19,8 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) => GenerateAsyncCallback!.Invoke(values, options, cancellationToken); - public TService? GetService(object? key = null) - where TService : class - => (TService?)GetServiceCallback!(typeof(TService), key); + public object? GetService(Type serviceType, object? serviceKey = null) + => GetServiceCallback!(serviceType, serviceKey); void IDisposable.Dispose() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs index 90032f1643..c48dc2e23e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs @@ -29,10 +29,9 @@ internal sealed class QuantizationEmbeddingGenerator : void IDisposable.Dispose() => _floatService.Dispose(); - public TService? GetService(object? key = null) - where TService : class => - key is null && this is TService ? (TService?)(object)this : - _floatService.GetService(key); + public object? GetService(Type serviceType, object? serviceKey = null) => + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + _floatService.GetService(serviceType, serviceKey); async Task> IEmbeddingGenerator.GenerateAsync( IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)