Make IChatClient/IEmbeddingGenerator.GetService non-generic (#5608)

This commit is contained in:
Stephen Toub 2024-11-07 14:08:34 -05:00
Родитель 4f775a0b0e
Коммит ca7d3f28fb
19 изменённых файлов: 173 добавлений и 59 удалений

Просмотреть файл

@ -11,6 +11,22 @@ namespace Microsoft.Extensions.AI;
/// <summary>Provides a collection of static methods for extending <see cref="IChatClient"/> instances.</summary>
public static class ChatClientExtensions
{
/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="client">The client.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TService>(this IChatClient client, object? serviceKey = null)
{
_ = Throw.IfNull(client);
return (TService?)client.GetService(typeof(TService), serviceKey);
}
/// <summary>Sends a user chat text message to the model and returns the response messages.</summary>
/// <param name="client">The chat client.</param>
/// <param name="chatMessage">The text content for the chat message to send.</param>

Просмотреть файл

@ -63,12 +63,13 @@ public class DelegatingChatClient : IChatClient
}
/// <inheritdoc />
public virtual TService? GetService<TService>(object? key = null)
where TService : class
public virtual object? GetService(Type serviceType, object? serviceKey = null)
{
#pragma warning disable S3060 // "is" should not be used with "this"
// If the key is non-null, we don't know what it means so pass through to the inner service
return key is null && this is TService service ? service : InnerClient.GetService<TService>(key);
#pragma warning restore S3060
_ = Throw.IfNull(serviceType);
// If the key is non-null, we don't know what it means so pass through to the inner service.
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
InnerClient.GetService(serviceType, serviceKey);
}
}

Просмотреть файл

@ -56,14 +56,13 @@ public interface IChatClient : IDisposable
/// <summary>Gets metadata that describes the <see cref="IChatClient"/>.</summary>
ChatClientMetadata Metadata { get; }
/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="key">An optional key that may be used to help identify the target service.</param>
/// <summary>Asks the <see cref="IChatClient"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
TService? GetService<TService>(object? key = null)
where TService : class;
object? GetService(Type serviceType, object? serviceKey = null);
}

Просмотреть файл

@ -59,12 +59,13 @@ public class DelegatingEmbeddingGenerator<TInput, TEmbedding> : IEmbeddingGenera
InnerGenerator.GenerateAsync(values, options, cancellationToken);
/// <inheritdoc />
public virtual TService? GetService<TService>(object? key = null)
where TService : class
public virtual object? GetService(Type serviceType, object? serviceKey = null)
{
#pragma warning disable S3060 // "is" should not be used with "this"
// If the key is non-null, we don't know what it means so pass through to the inner service
return key is null && this is TService service ? service : InnerGenerator.GetService<TService>(key);
#pragma warning restore S3060
_ = Throw.IfNull(serviceType);
// If the key is non-null, we don't know what it means so pass through to the inner service.
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
InnerGenerator.GetService(serviceType, serviceKey);
}
}

Просмотреть файл

@ -15,6 +15,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
public static class EmbeddingGeneratorExtensions
{
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TInput, TEmbedding, TService>(this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
return (TService?)generator.GetService(typeof(TService), serviceKey);
}
// The following overload exists purely to work around the lack of partial generic type inference.
// Given an IEmbeddingGenerator<TInput, TEmbedding> generator, to call GetService with TService, you still need
// to re-specify both TInput and TEmbedding, e.g. generator.GetService<string, Embedding<float>, TService>.
// The case of string/Embedding<float> is by far the most common case today, so this overload exists as an
// accelerator to allow it to be written simply as generator.GetService<TService>.
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
GetService<string, Embedding<float>, TService>(generator, serviceKey);
/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>

Просмотреть файл

@ -40,14 +40,13 @@ public interface IEmbeddingGenerator<TInput, TEmbedding> : IDisposable
/// <summary>Gets metadata that describes the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</summary>
EmbeddingGeneratorMetadata Metadata { get; }
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="key">An optional key that may be used to help identify the target service.</param>
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>,
/// including itself or any services it might be wrapping.
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
TService? GetService<TService>(object? key = null)
where TService : class;
object? GetService(Type serviceType, object? serviceKey = null);
}

Просмотреть файл

@ -57,10 +57,16 @@ public sealed class AzureAIInferenceChatClient : IChatClient
public ChatClientMetadata Metadata { get; }
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(ChatCompletionsClient) ? (TService?)(object?)_chatCompletionsClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is not null ? null :
serviceType == typeof(ChatCompletionsClient) ? _chatCompletionsClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(

Просмотреть файл

@ -70,10 +70,16 @@ public sealed class AzureAIInferenceEmbeddingGenerator :
public EmbeddingGeneratorMetadata Metadata { get; }
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(EmbeddingsClient) ? (TService)(object)_embeddingsClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is not null ? null :
serviceType == typeof(EmbeddingsClient) ? _embeddingsClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(

Просмотреть файл

@ -166,9 +166,14 @@ public sealed class OllamaChatClient : IChatClient
}
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=> key is null ? this as TService : null;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public void Dispose()

Просмотреть файл

@ -57,9 +57,14 @@ public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator<string, Embed
public EmbeddingGeneratorMetadata Metadata { get; }
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=> key is null ? this as TService : null;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public void Dispose()

Просмотреть файл

@ -16,6 +16,7 @@ using Microsoft.Shared.Diagnostics;
using OpenAI;
using OpenAI.Chat;
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S1135 // Track uses of "TODO" tags
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1204 // Static elements should appear before instance elements
@ -85,11 +86,17 @@ public sealed partial class OpenAIChatClient : IChatClient
public ChatClientMetadata Metadata { get; }
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is not null ? null :
serviceType == typeof(OpenAIClient) ? _openAIClient :
serviceType == typeof(ChatClient) ? _chatClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(

Просмотреть файл

@ -11,6 +11,7 @@ using Microsoft.Shared.Diagnostics;
using OpenAI;
using OpenAI.Embeddings;
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
namespace Microsoft.Extensions.AI;
@ -95,12 +96,17 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator<string, Embed
public EmbeddingGeneratorMetadata Metadata { get; }
/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=>
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);
return
serviceKey is not null ? null :
serviceType == typeof(OpenAIClient) ? _openAIClient :
serviceType == typeof(EmbeddingClient) ? _embeddingClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}
/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)

Просмотреть файл

@ -11,6 +11,12 @@ namespace Microsoft.Extensions.AI;
public class ChatClientExtensionsTests
{
[Fact]
public void GetService_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetService<object>(null!));
}
[Fact]
public void CompleteAsync_InvalidArgs_Throws()
{

Просмотреть файл

@ -96,6 +96,14 @@ public class DelegatingChatClientTests
Assert.False(await enumerator.MoveNextAsync());
}
[Fact]
public void GetServiceThrowsForNullType()
{
using var inner = new TestChatClient();
using var delegating = new NoOpDelegatingChatClient(inner);
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
}
[Fact]
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
{

Просмотреть файл

@ -57,6 +57,14 @@ public class DelegatingEmbeddingGeneratorTests
Assert.Same(expectedEmbedding, await resultTask);
}
[Fact]
public void GetServiceThrowsForNullType()
{
using var inner = new TestEmbeddingGenerator();
using var delegating = new NoOpDelegatingEmbeddingGenerator(inner);
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
}
[Fact]
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
{

Просмотреть файл

@ -10,6 +10,13 @@ namespace Microsoft.Extensions.AI;
public class EmbeddingGeneratorExtensionsTests
{
[Fact]
public void GetService_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<object>(null!));
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<string, Embedding<double>, object>(null!));
}
[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{

Просмотреть файл

@ -26,9 +26,8 @@ public sealed class TestChatClient : IChatClient
public IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken);
public TService? GetService<TService>(object? key = null)
where TService : class
=> (TService?)GetServiceCallback!(typeof(TService), key);
public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);
void IDisposable.Dispose()
{

Просмотреть файл

@ -19,9 +19,8 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator<string, Embeddi
public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
=> GenerateAsyncCallback!.Invoke(values, options, cancellationToken);
public TService? GetService<TService>(object? key = null)
where TService : class
=> (TService?)GetServiceCallback!(typeof(TService), key);
public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);
void IDisposable.Dispose()
{

Просмотреть файл

@ -29,10 +29,9 @@ internal sealed class QuantizationEmbeddingGenerator :
void IDisposable.Dispose() => _floatService.Dispose();
public TService? GetService<TService>(object? key = null)
where TService : class =>
key is null && this is TService ? (TService?)(object)this :
_floatService.GetService<TService>(key);
public object? GetService(Type serviceType, object? serviceKey = null) =>
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
_floatService.GetService(serviceType, serviceKey);
async Task<GeneratedEmbeddings<BinaryEmbedding>> IEmbeddingGenerator<string, BinaryEmbedding>.GenerateAsync(
IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)