зеркало из https://github.com/dotnet/extensions.git
Make IChatClient/IEmbeddingGenerator.GetService non-generic (#5608)
This commit is contained in:
Родитель
4f775a0b0e
Коммит
ca7d3f28fb
|
@ -11,6 +11,22 @@ namespace Microsoft.Extensions.AI;
|
|||
/// <summary>Provides a collection of static methods for extending <see cref="IChatClient"/> instances.</summary>
|
||||
public static class ChatClientExtensions
|
||||
{
|
||||
/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
|
||||
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
|
||||
/// <param name="client">The client.</param>
|
||||
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
|
||||
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
|
||||
/// <remarks>
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
|
||||
/// including itself or any services it might be wrapping.
|
||||
/// </remarks>
|
||||
public static TService? GetService<TService>(this IChatClient client, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(client);
|
||||
|
||||
return (TService?)client.GetService(typeof(TService), serviceKey);
|
||||
}
|
||||
|
||||
/// <summary>Sends a user chat text message to the model and returns the response messages.</summary>
|
||||
/// <param name="client">The chat client.</param>
|
||||
/// <param name="chatMessage">The text content for the chat message to send.</param>
|
||||
|
|
|
@ -63,12 +63,13 @@ public class DelegatingChatClient : IChatClient
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public virtual TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
public virtual object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
#pragma warning disable S3060 // "is" should not be used with "this"
|
||||
// If the key is non-null, we don't know what it means so pass through to the inner service
|
||||
return key is null && this is TService service ? service : InnerClient.GetService<TService>(key);
|
||||
#pragma warning restore S3060
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
// If the key is non-null, we don't know what it means so pass through to the inner service.
|
||||
return
|
||||
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
|
||||
InnerClient.GetService(serviceType, serviceKey);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,14 +56,13 @@ public interface IChatClient : IDisposable
|
|||
/// <summary>Gets metadata that describes the <see cref="IChatClient"/>.</summary>
|
||||
ChatClientMetadata Metadata { get; }
|
||||
|
||||
/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
|
||||
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
|
||||
/// <param name="key">An optional key that may be used to help identify the target service.</param>
|
||||
/// <summary>Asks the <see cref="IChatClient"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
|
||||
/// <param name="serviceType">The type of object being requested.</param>
|
||||
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
|
||||
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
|
||||
/// <remarks>
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
|
||||
/// including itself or any services it might be wrapping.
|
||||
/// </remarks>
|
||||
TService? GetService<TService>(object? key = null)
|
||||
where TService : class;
|
||||
object? GetService(Type serviceType, object? serviceKey = null);
|
||||
}
|
||||
|
|
|
@ -59,12 +59,13 @@ public class DelegatingEmbeddingGenerator<TInput, TEmbedding> : IEmbeddingGenera
|
|||
InnerGenerator.GenerateAsync(values, options, cancellationToken);
|
||||
|
||||
/// <inheritdoc />
|
||||
public virtual TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
public virtual object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
#pragma warning disable S3060 // "is" should not be used with "this"
|
||||
// If the key is non-null, we don't know what it means so pass through to the inner service
|
||||
return key is null && this is TService service ? service : InnerGenerator.GetService<TService>(key);
|
||||
#pragma warning restore S3060
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
// If the key is non-null, we don't know what it means so pass through to the inner service.
|
||||
return
|
||||
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
|
||||
InnerGenerator.GetService(serviceType, serviceKey);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,43 @@ namespace Microsoft.Extensions.AI;
|
|||
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
|
||||
public static class EmbeddingGeneratorExtensions
|
||||
{
|
||||
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
|
||||
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
|
||||
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
|
||||
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
|
||||
/// <param name="generator">The generator.</param>
|
||||
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
|
||||
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
|
||||
/// <remarks>
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
|
||||
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
|
||||
/// </remarks>
|
||||
public static TService? GetService<TInput, TEmbedding, TService>(this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
|
||||
where TEmbedding : Embedding
|
||||
{
|
||||
_ = Throw.IfNull(generator);
|
||||
|
||||
return (TService?)generator.GetService(typeof(TService), serviceKey);
|
||||
}
|
||||
|
||||
// The following overload exists purely to work around the lack of partial generic type inference.
|
||||
// Given an IEmbeddingGenerator<TInput, TEmbedding> generator, to call GetService with TService, you still need
|
||||
// to re-specify both TInput and TEmbedding, e.g. generator.GetService<string, Embedding<float>, TService>.
|
||||
// The case of string/Embedding<float> is by far the most common case today, so this overload exists as an
|
||||
// accelerator to allow it to be written simply as generator.GetService<TService>.
|
||||
|
||||
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
|
||||
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
|
||||
/// <param name="generator">The generator.</param>
|
||||
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
|
||||
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
|
||||
/// <remarks>
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
|
||||
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
|
||||
/// </remarks>
|
||||
public static TService? GetService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
|
||||
GetService<string, Embedding<float>, TService>(generator, serviceKey);
|
||||
|
||||
/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
|
||||
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
|
||||
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
|
||||
|
|
|
@ -40,14 +40,13 @@ public interface IEmbeddingGenerator<TInput, TEmbedding> : IDisposable
|
|||
/// <summary>Gets metadata that describes the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</summary>
|
||||
EmbeddingGeneratorMetadata Metadata { get; }
|
||||
|
||||
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
|
||||
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
|
||||
/// <param name="key">An optional key that may be used to help identify the target service.</param>
|
||||
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
|
||||
/// <param name="serviceType">The type of object being requested.</param>
|
||||
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
|
||||
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
|
||||
/// <remarks>
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>,
|
||||
/// including itself or any services it might be wrapping.
|
||||
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
|
||||
/// <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>, including itself or any services it might be wrapping.
|
||||
/// </remarks>
|
||||
TService? GetService<TService>(object? key = null)
|
||||
where TService : class;
|
||||
object? GetService(Type serviceType, object? serviceKey = null);
|
||||
}
|
||||
|
|
|
@ -57,10 +57,16 @@ public sealed class AzureAIInferenceChatClient : IChatClient
|
|||
public ChatClientMetadata Metadata { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class =>
|
||||
typeof(TService) == typeof(ChatCompletionsClient) ? (TService?)(object?)_chatCompletionsClient :
|
||||
this as TService;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is not null ? null :
|
||||
serviceType == typeof(ChatCompletionsClient) ? _chatCompletionsClient :
|
||||
serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<ChatCompletion> CompleteAsync(
|
||||
|
|
|
@ -70,10 +70,16 @@ public sealed class AzureAIInferenceEmbeddingGenerator :
|
|||
public EmbeddingGeneratorMetadata Metadata { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class =>
|
||||
typeof(TService) == typeof(EmbeddingsClient) ? (TService)(object)_embeddingsClient :
|
||||
this as TService;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is not null ? null :
|
||||
serviceType == typeof(EmbeddingsClient) ? _embeddingsClient :
|
||||
serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
|
||||
|
|
|
@ -166,9 +166,14 @@ public sealed class OllamaChatClient : IChatClient
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
=> key is null ? this as TService : null;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
|
|
|
@ -57,9 +57,14 @@ public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator<string, Embed
|
|||
public EmbeddingGeneratorMetadata Metadata { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
=> key is null ? this as TService : null;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
|
|
|
@ -16,6 +16,7 @@ using Microsoft.Shared.Diagnostics;
|
|||
using OpenAI;
|
||||
using OpenAI.Chat;
|
||||
|
||||
#pragma warning disable S1067 // Expressions should not be too complex
|
||||
#pragma warning disable S1135 // Track uses of "TODO" tags
|
||||
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
|
||||
#pragma warning disable SA1204 // Static elements should appear before instance elements
|
||||
|
@ -85,11 +86,17 @@ public sealed partial class OpenAIChatClient : IChatClient
|
|||
public ChatClientMetadata Metadata { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class =>
|
||||
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
|
||||
typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient :
|
||||
this as TService;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is not null ? null :
|
||||
serviceType == typeof(OpenAIClient) ? _openAIClient :
|
||||
serviceType == typeof(ChatClient) ? _chatClient :
|
||||
serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<ChatCompletion> CompleteAsync(
|
||||
|
|
|
@ -11,6 +11,7 @@ using Microsoft.Shared.Diagnostics;
|
|||
using OpenAI;
|
||||
using OpenAI.Embeddings;
|
||||
|
||||
#pragma warning disable S1067 // Expressions should not be too complex
|
||||
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
|
||||
|
||||
namespace Microsoft.Extensions.AI;
|
||||
|
@ -95,12 +96,17 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator<string, Embed
|
|||
public EmbeddingGeneratorMetadata Metadata { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
=>
|
||||
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
|
||||
typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient :
|
||||
this as TService;
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
{
|
||||
_ = Throw.IfNull(serviceType);
|
||||
|
||||
return
|
||||
serviceKey is not null ? null :
|
||||
serviceType == typeof(OpenAIClient) ? _openAIClient :
|
||||
serviceType == typeof(EmbeddingClient) ? _embeddingClient :
|
||||
serviceType.IsInstanceOfType(this) ? this :
|
||||
null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
|
||||
|
|
|
@ -11,6 +11,12 @@ namespace Microsoft.Extensions.AI;
|
|||
|
||||
public class ChatClientExtensionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void GetService_InvalidArgs_Throws()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetService<object>(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CompleteAsync_InvalidArgs_Throws()
|
||||
{
|
||||
|
|
|
@ -96,6 +96,14 @@ public class DelegatingChatClientTests
|
|||
Assert.False(await enumerator.MoveNextAsync());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetServiceThrowsForNullType()
|
||||
{
|
||||
using var inner = new TestChatClient();
|
||||
using var delegating = new NoOpDelegatingChatClient(inner);
|
||||
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
|
||||
{
|
||||
|
|
|
@ -57,6 +57,14 @@ public class DelegatingEmbeddingGeneratorTests
|
|||
Assert.Same(expectedEmbedding, await resultTask);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetServiceThrowsForNullType()
|
||||
{
|
||||
using var inner = new TestEmbeddingGenerator();
|
||||
using var delegating = new NoOpDelegatingEmbeddingGenerator(inner);
|
||||
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
|
||||
{
|
||||
|
|
|
@ -10,6 +10,13 @@ namespace Microsoft.Extensions.AI;
|
|||
|
||||
public class EmbeddingGeneratorExtensionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void GetService_InvalidArgs_Throws()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<object>(null!));
|
||||
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<string, Embedding<double>, object>(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
|
||||
{
|
||||
|
|
|
@ -26,9 +26,8 @@ public sealed class TestChatClient : IChatClient
|
|||
public IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
|
||||
=> CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken);
|
||||
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
=> (TService?)GetServiceCallback!(typeof(TService), key);
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
=> GetServiceCallback!(serviceType, serviceKey);
|
||||
|
||||
void IDisposable.Dispose()
|
||||
{
|
||||
|
|
|
@ -19,9 +19,8 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator<string, Embeddi
|
|||
public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
|
||||
=> GenerateAsyncCallback!.Invoke(values, options, cancellationToken);
|
||||
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class
|
||||
=> (TService?)GetServiceCallback!(typeof(TService), key);
|
||||
public object? GetService(Type serviceType, object? serviceKey = null)
|
||||
=> GetServiceCallback!(serviceType, serviceKey);
|
||||
|
||||
void IDisposable.Dispose()
|
||||
{
|
||||
|
|
|
@ -29,10 +29,9 @@ internal sealed class QuantizationEmbeddingGenerator :
|
|||
|
||||
void IDisposable.Dispose() => _floatService.Dispose();
|
||||
|
||||
public TService? GetService<TService>(object? key = null)
|
||||
where TService : class =>
|
||||
key is null && this is TService ? (TService?)(object)this :
|
||||
_floatService.GetService<TService>(key);
|
||||
public object? GetService(Type serviceType, object? serviceKey = null) =>
|
||||
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
|
||||
_floatService.GetService(serviceType, serviceKey);
|
||||
|
||||
async Task<GeneratedEmbeddings<BinaryEmbedding>> IEmbeddingGenerator<string, BinaryEmbedding>.GenerateAsync(
|
||||
IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
|
||||
|
|
Загрузка…
Ссылка в новой задаче