Change ChatClientBuilder to register singletons and support lambda-less chaining (#5642)

* Change ChatClientBuilder to register singletons and support lambda-less chaining

* Add generic keyed version

* Improve XML doc

* Update README files

* Remove generic DI registration methods
This commit is contained in:
Steve Sanderson 2024-11-14 05:36:08 -08:00 коммит произвёл Stephen Toub
Родитель 2de8e89f9e
Коммит 1f8ae147e5
21 изменённых файлов: 256 добавлений и 204 удалений

Просмотреть файл

@ -150,9 +150,9 @@ using Microsoft.Extensions.AI;
[Description("Gets the current weather")]
string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining";
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseFunctionInvocation()
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();
var response = client.CompleteStreamingAsync(
"Should I wear a rain coat?",
@ -174,9 +174,9 @@ using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();
string[] prompts = ["What is AI?", "What is .NET?", "What is AI?"];
@ -205,9 +205,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
.AddConsoleExporter()
.Build();
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();
Console.WriteLine((await client.CompleteAsync("What is AI?")).Message);
```
@ -220,9 +220,9 @@ Options may also be baked into an `IChatClient` via the `ConfigureOptions` exten
```csharp
using Microsoft.Extensions.AI;
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434")))
.ConfigureOptions(options => options.ModelId ??= "phi3")
.Use(new OllamaChatClient(new Uri("http://localhost:11434")));
.Build();
Console.WriteLine(await client.CompleteAsync("What is AI?")); // will request "phi3"
Console.WriteLine(await client.CompleteAsync("What is AI?", new() { ModelId = "llama3.1" })); // will request "llama3.1"
@ -248,11 +248,11 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
// Explore changing the order of the intermediate "Use" calls to see that impact
// that has on what gets cached, traced, etc.
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();
ChatOptions options = new()
{
@ -341,9 +341,8 @@ using Microsoft.Extensions.Hosting;
// App Setup
var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddChatClient(b => b
.UseDistributedCache()
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")));
builder.Services.AddChatClient(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache();
var host = builder.Build();
// Elsewhere in the app

Просмотреть файл

@ -85,9 +85,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseFunctionInvocation()
.Use(azureClient);
.Build();
ChatOptions chatOptions = new()
{
@ -120,9 +120,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.Use(azureClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -156,9 +156,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();
Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
@ -196,11 +196,11 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -236,10 +236,9 @@ builder.Services.AddSingleton(
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));
builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();
var app = builder.Build();
@ -261,8 +260,8 @@ builder.Services.AddSingleton(new ChatCompletionsClient(
new("https://models.inference.ai.azure.com"),
new AzureKeyCredential(builder.Configuration["GH_TOKEN"]!)));
builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"));
var app = builder.Build();

Просмотреть файл

@ -70,9 +70,9 @@ using Microsoft.Extensions.AI;
IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseFunctionInvocation()
.Use(ollamaClient);
.Build();
ChatOptions chatOptions = new()
{
@ -97,9 +97,9 @@ IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDi
IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.Use(ollamaClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -128,9 +128,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();
Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
@ -163,11 +163,11 @@ var chatOptions = new ChatOptions
IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -235,10 +235,9 @@ var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));
builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"))
.UseDistributedCache()
.UseLogging()
.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
.UseLogging();
var app = builder.Build();
@ -254,8 +253,8 @@ using Microsoft.Extensions.AI;
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddChatClient(c =>
c.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
builder.Services.AddChatClient(
new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"));
builder.Services.AddEmbeddingGenerator<string,Embedding<float>>(g =>
g.Use(new OllamaEmbeddingGenerator(endpoint, "all-minilm")));

Просмотреть файл

@ -77,9 +77,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseFunctionInvocation()
.Use(openaiClient);
.Build();
ChatOptions chatOptions = new()
{
@ -110,9 +110,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.Use(openaiClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -144,9 +144,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();
Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
@ -182,11 +182,11 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();
for (int i = 0; i < 3; i++)
{
@ -260,10 +260,9 @@ builder.Services.AddSingleton(new OpenAIClient(Environment.GetEnvironmentVariabl
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));
builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();
var app = builder.Build();
@ -282,8 +281,8 @@ var builder = WebApplication.CreateBuilder(args);
builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API_KEY"]));
builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"));
builder.Services.AddEmbeddingGenerator<string, Embedding<float>>(g =>
g.Use(g.Services.GetRequiredService<OpenAIClient>().AsEmbeddingGenerator("text-embedding-3-small")));

Просмотреть файл

@ -10,32 +10,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>A builder for creating pipelines of <see cref="IChatClient"/>.</summary>
public sealed class ChatClientBuilder
{
private Func<IServiceProvider, IChatClient> _innerClientFactory;
/// <summary>The registered client factory instances.</summary>
private List<Func<IServiceProvider, IChatClient, IChatClient>>? _clientFactories;
/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="services">The service provider to use for dependency injection.</param>
public ChatClientBuilder(IServiceProvider? services = null)
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(IChatClient innerClient)
{
Services = services ?? EmptyServiceProvider.Instance;
_ = Throw.IfNull(innerClient);
_innerClientFactory = _ => innerClient;
}
/// <summary>Gets the <see cref="IServiceProvider"/> associated with the builder instance.</summary>
public IServiceProvider Services { get; }
/// <summary>Completes the pipeline by adding a final <see cref="IChatClient"/> that represents the underlying backend. This is typically a client for an LLM service.</summary>
/// <param name="innerClient">The inner client to use.</param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</returns>
public IChatClient Use(IChatClient innerClient)
/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(Func<IServiceProvider, IChatClient> innerClientFactory)
{
var chatClient = Throw.IfNull(innerClient);
_innerClientFactory = Throw.IfNull(innerClientFactory);
}
/// <summary>Returns an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
/// <param name="services">
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IChatClient"/> instances.
/// If null, an empty <see cref="IServiceProvider"/> will be used.
/// </param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline.</returns>
public IChatClient Build(IServiceProvider? services = null)
{
services ??= EmptyServiceProvider.Instance;
var chatClient = _innerClientFactory(services);
// To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost.
if (_clientFactories is not null)
{
for (var i = _clientFactories.Count - 1; i >= 0; i--)
{
chatClient = _clientFactories[i](Services, chatClient) ??
chatClient = _clientFactories[i](services, chatClient) ??
throw new InvalidOperationException(
$"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " +
$"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances.");

Просмотреть файл

@ -10,38 +10,62 @@ namespace Microsoft.Extensions.DependencyInjection;
/// <summary>Provides extension methods for registering <see cref="IChatClient"/> with a <see cref="IServiceCollection"/>.</summary>
public static class ChatClientBuilderServiceCollectionExtensions
{
/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="services">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="clientFactory">The factory to use to construct the <see cref="IChatClient"/> instance.</param>
/// <returns>The <paramref name="services"/> collection.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static IServiceCollection AddChatClient(
this IServiceCollection services,
Func<ChatClientBuilder, IChatClient> clientFactory)
{
_ = Throw.IfNull(services);
_ = Throw.IfNull(clientFactory);
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a singleton service.</remarks>
public static ChatClientBuilder AddChatClient(
this IServiceCollection serviceCollection,
IChatClient innerClient)
=> AddChatClient(serviceCollection, _ => innerClient);
return services.AddScoped(services =>
clientFactory(new ChatClientBuilder(services)));
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a singleton service.</remarks>
public static ChatClientBuilder AddChatClient(
this IServiceCollection serviceCollection,
Func<IServiceProvider, IChatClient> innerClientFactory)
{
_ = Throw.IfNull(serviceCollection);
_ = Throw.IfNull(innerClientFactory);
var builder = new ChatClientBuilder(innerClientFactory);
_ = serviceCollection.AddSingleton(builder.Build);
return builder;
}
/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="services">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="serviceKey">The key with which to associate the client.</param>
/// <param name="clientFactory">The factory to use to construct the <see cref="IChatClient"/> instance.</param>
/// <returns>The <paramref name="services"/> collection.</returns>
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static IServiceCollection AddKeyedChatClient(
this IServiceCollection services,
public static ChatClientBuilder AddKeyedChatClient(
this IServiceCollection serviceCollection,
object serviceKey,
Func<ChatClientBuilder, IChatClient> clientFactory)
{
_ = Throw.IfNull(services);
_ = Throw.IfNull(serviceKey);
_ = Throw.IfNull(clientFactory);
IChatClient innerClient)
=> AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient);
return services.AddKeyedScoped(serviceKey, (services, _) =>
clientFactory(new ChatClientBuilder(services)));
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="serviceKey">The key with which to associate the client.</param>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static ChatClientBuilder AddKeyedChatClient(
this IServiceCollection serviceCollection,
object serviceKey,
Func<IServiceProvider, IChatClient> innerClientFactory)
{
_ = Throw.IfNull(serviceCollection);
_ = Throw.IfNull(serviceKey);
_ = Throw.IfNull(innerClientFactory);
var builder = new ChatClientBuilder(innerClientFactory);
_ = serviceCollection.AddKeyedSingleton(serviceKey, (services, _) => builder.Build(services));
return builder;
}
}

Просмотреть файл

@ -77,11 +77,11 @@ public class AzureAIInferenceChatClientTests
Assert.Same(client, chatClient.GetService<ChatCompletionsClient>());
using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());

Просмотреть файл

@ -377,12 +377,12 @@ public abstract class ChatClientIntegrationTests : IDisposable
}, "GetTemperature");
// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();
var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
@ -415,12 +415,12 @@ public abstract class ChatClientIntegrationTests : IDisposable
}, "GetTemperature");
// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();
var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
@ -454,12 +454,12 @@ public abstract class ChatClientIntegrationTests : IDisposable
}, "GetTemperature");
// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();
var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
@ -573,9 +573,9 @@ public abstract class ChatClientIntegrationTests : IDisposable
.AddInMemoryExporter(activities)
.Build();
var chatClient = new ChatClientBuilder()
var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseOpenTelemetry(sourceName: sourceName)
.Use(CreateChatClient()!);
.Build();
var response = await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]);

Просмотреть файл

@ -37,9 +37,9 @@ public class ReducingChatClientTests
}
};
using var client = new ChatClientBuilder()
using var client = new ChatClientBuilder(innerClient)
.UseChatReducer(new TokenCountingChatReducer(_gpt4oTokenizer, 40))
.Use(innerClient);
.Build();
List<ChatMessage> messages =
[

Просмотреть файл

@ -37,11 +37,11 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
{
SkipIfNotEnabled();
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
.Use(CreateChatClient()!);
.Build();
var secretNumber = 42;
var response = await chatClient.CompleteAsync("What is the current secret number? Answer with digits only.", new ChatOptions
@ -61,11 +61,11 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
{
SkipIfNotEnabled();
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
.Use(CreateChatClient()!);
.Build();
var stockPriceTool = AIFunctionFactory.Create([Description("Returns the stock price for a given ticker symbol")] (
[Description("The ticker symbol")] string symbol,

Просмотреть файл

@ -48,11 +48,11 @@ public class OllamaChatClientTests
Assert.Same(client, client.GetService<OllamaChatClient>());
Assert.Same(client, client.GetService<IChatClient>());
using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(client)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(client);
.Build();
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());

Просмотреть файл

@ -95,11 +95,11 @@ public class OpenAIChatClientTests
Assert.NotNull(chatClient.GetService<ChatClient>());
using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
@ -119,11 +119,11 @@ public class OpenAIChatClientTests
Assert.Same(chatClient, chatClient.GetService<IChatClient>());
Assert.Same(openAIClient, chatClient.GetService<ChatClient>());
using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());

Просмотреть файл

@ -13,17 +13,23 @@ public class ChatClientBuilderTest
public void PassesServiceProviderToFactories()
{
var expectedServiceProvider = new ServiceCollection().BuildServiceProvider();
using TestChatClient expectedResult = new();
var builder = new ChatClientBuilder(expectedServiceProvider);
using TestChatClient expectedInnerClient = new();
using TestChatClient expectedOuterClient = new();
var builder = new ChatClientBuilder(services =>
{
Assert.Same(expectedServiceProvider, services);
return expectedInnerClient;
});
builder.Use((serviceProvider, innerClient) =>
{
Assert.Same(expectedServiceProvider, serviceProvider);
return expectedResult;
Assert.Same(expectedInnerClient, innerClient);
return expectedOuterClient;
});
using TestChatClient innerClient = new();
Assert.Equal(expectedResult, builder.Use(innerClient: innerClient));
Assert.Same(expectedOuterClient, builder.Build(expectedServiceProvider));
}
[Fact]
@ -31,14 +37,14 @@ public class ChatClientBuilderTest
{
// Arrange
using TestChatClient expectedInnerClient = new();
var builder = new ChatClientBuilder();
var builder = new ChatClientBuilder(expectedInnerClient);
builder.Use(next => new InnerClientCapturingChatClient("First", next));
builder.Use(next => new InnerClientCapturingChatClient("Second", next));
builder.Use(next => new InnerClientCapturingChatClient("Third", next));
// Act
var first = (InnerClientCapturingChatClient)builder.Use(expectedInnerClient);
var first = (InnerClientCapturingChatClient)builder.Build();
// Assert
Assert.Equal("First", first.Name);
@ -52,23 +58,22 @@ public class ChatClientBuilderTest
[Fact]
public void DoesNotAcceptNullInnerService()
{
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder().Use((IChatClient)null!));
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder((IChatClient)null!));
}
[Fact]
public void DoesNotAcceptNullFactories()
{
ChatClientBuilder builder = new();
Assert.Throws<ArgumentNullException>(() => builder.Use((Func<IChatClient, IChatClient>)null!));
Assert.Throws<ArgumentNullException>(() => builder.Use((Func<IServiceProvider, IChatClient, IChatClient>)null!));
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder((Func<IServiceProvider, IChatClient>)null!));
}
[Fact]
public void DoesNotAllowFactoriesToReturnNull()
{
ChatClientBuilder builder = new();
using var innerClient = new TestChatClient();
ChatClientBuilder builder = new(innerClient);
builder.Use(_ => null!);
var ex = Assert.Throws<InvalidOperationException>(() => builder.Use(new TestChatClient()));
var ex = Assert.Throws<InvalidOperationException>(() => builder.Build());
Assert.Contains("entry at index 0", ex.Message);
}

Просмотреть файл

@ -22,7 +22,8 @@ public class ConfigureOptionsChatClientTests
[Fact]
public void ConfigureOptions_InvalidArgs_Throws()
{
var builder = new ChatClientBuilder();
using var innerClient = new TestChatClient();
var builder = new ChatClientBuilder(innerClient);
Assert.Throws<ArgumentNullException>("configure", () => builder.ConfigureOptions(null!));
}
@ -54,7 +55,7 @@ public class ConfigureOptionsChatClientTests
},
};
using var client = new ChatClientBuilder()
using var client = new ChatClientBuilder(innerClient)
.ConfigureOptions(options =>
{
Assert.NotSame(providedOptions, options);
@ -69,7 +70,7 @@ public class ConfigureOptionsChatClientTests
returnedOptions = options;
})
.Use(innerClient);
.Build();
var completion = await client.CompleteAsync(Array.Empty<ChatMessage>(), providedOptions, cts.Token);
Assert.Same(expectedCompletion, completion);

Просмотреть файл

@ -12,12 +12,11 @@ public class DependencyInjectionPatterns
private IServiceCollection ServiceCollection { get; } = new ServiceCollection();
[Fact]
public void CanRegisterScopedUsingGenericType()
public void CanRegisterSingletonUsingFactory()
{
// Arrange/Act
ServiceCollection.AddChatClient(builder => builder
.UseScopedMiddleware()
.Use(new TestChatClient()));
ServiceCollection.AddChatClient(services => new TestChatClient { Services = services })
.UseSingletonMiddleware();
// Assert
var services = ServiceCollection.BuildServiceProvider();
@ -28,73 +27,89 @@ public class DependencyInjectionPatterns
var instance1Copy = scope1.ServiceProvider.GetRequiredService<IChatClient>();
var instance2 = scope2.ServiceProvider.GetRequiredService<IChatClient>();
// Each scope gets a distinct outer *AND* inner client
var outer1 = Assert.IsType<ScopedChatClient>(instance1);
var outer2 = Assert.IsType<ScopedChatClient>(instance2);
var inner1 = Assert.IsType<TestChatClient>(((ScopedChatClient)instance1).InnerClient);
var inner2 = Assert.IsType<TestChatClient>(((ScopedChatClient)instance2).InnerClient);
Assert.NotSame(outer1.Services, outer2.Services);
Assert.NotSame(instance1, instance2);
Assert.NotSame(inner1, inner2);
Assert.Same(instance1, instance1Copy); // From the same scope
// Each scope gets the same instance, because it's singleton
var instance = Assert.IsType<SingletonMiddleware>(instance1);
Assert.Same(instance, instance1Copy);
Assert.Same(instance, instance2);
Assert.IsType<TestChatClient>(instance.InnerClient);
}
[Fact]
public void CanRegisterScopedUsingFactory()
{
// Arrange/Act
ServiceCollection.AddChatClient(builder =>
{
builder.UseScopedMiddleware();
return builder.Use(new TestChatClient { Services = builder.Services });
});
// Assert
var services = ServiceCollection.BuildServiceProvider();
using var scope1 = services.CreateScope();
using var scope2 = services.CreateScope();
var instance1 = scope1.ServiceProvider.GetRequiredService<IChatClient>();
var instance2 = scope2.ServiceProvider.GetRequiredService<IChatClient>();
// Each scope gets a distinct outer *AND* inner client
var outer1 = Assert.IsType<ScopedChatClient>(instance1);
var outer2 = Assert.IsType<ScopedChatClient>(instance2);
var inner1 = Assert.IsType<TestChatClient>(((ScopedChatClient)instance1).InnerClient);
var inner2 = Assert.IsType<TestChatClient>(((ScopedChatClient)instance2).InnerClient);
Assert.Same(outer1.Services, inner1.Services);
Assert.Same(outer2.Services, inner2.Services);
Assert.NotSame(outer1.Services, outer2.Services);
}
[Fact]
public void CanRegisterScopedUsingSharedInstance()
public void CanRegisterSingletonUsingSharedInstance()
{
// Arrange/Act
using var singleton = new TestChatClient();
ServiceCollection.AddChatClient(builder =>
{
builder.UseScopedMiddleware();
return builder.Use(singleton);
});
ServiceCollection.AddChatClient(singleton)
.UseSingletonMiddleware();
// Assert
var services = ServiceCollection.BuildServiceProvider();
using var scope1 = services.CreateScope();
using var scope2 = services.CreateScope();
var instance1 = scope1.ServiceProvider.GetRequiredService<IChatClient>();
var instance1Copy = scope1.ServiceProvider.GetRequiredService<IChatClient>();
var instance2 = scope2.ServiceProvider.GetRequiredService<IChatClient>();
// Each scope gets a distinct outer instance, but the same inner client
Assert.IsType<ScopedChatClient>(instance1);
Assert.IsType<ScopedChatClient>(instance2);
Assert.Same(singleton, ((ScopedChatClient)instance1).InnerClient);
Assert.Same(singleton, ((ScopedChatClient)instance2).InnerClient);
// Each scope gets the same instance, because it's singleton
var instance = Assert.IsType<SingletonMiddleware>(instance1);
Assert.Same(instance, instance1Copy);
Assert.Same(instance, instance2);
Assert.IsType<TestChatClient>(instance.InnerClient);
}
public class ScopedChatClient(IServiceProvider services, IChatClient inner) : DelegatingChatClient(inner)
[Fact]
public void CanRegisterKeyedSingletonUsingFactory()
{
// Arrange/Act
ServiceCollection.AddKeyedChatClient("mykey", services => new TestChatClient { Services = services })
.UseSingletonMiddleware();
// Assert
var services = ServiceCollection.BuildServiceProvider();
using var scope1 = services.CreateScope();
using var scope2 = services.CreateScope();
Assert.Null(services.GetService<IChatClient>());
var instance1 = scope1.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
var instance1Copy = scope1.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
var instance2 = scope2.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
// Each scope gets the same instance, because it's singleton
var instance = Assert.IsType<SingletonMiddleware>(instance1);
Assert.Same(instance, instance1Copy);
Assert.Same(instance, instance2);
Assert.IsType<TestChatClient>(instance.InnerClient);
}
[Fact]
public void CanRegisterKeyedSingletonUsingSharedInstance()
{
// Arrange/Act
using var singleton = new TestChatClient();
ServiceCollection.AddKeyedChatClient("mykey", singleton)
.UseSingletonMiddleware();
// Assert
var services = ServiceCollection.BuildServiceProvider();
using var scope1 = services.CreateScope();
using var scope2 = services.CreateScope();
Assert.Null(services.GetService<IChatClient>());
var instance1 = scope1.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
var instance1Copy = scope1.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
var instance2 = scope2.ServiceProvider.GetRequiredKeyedService<IChatClient>("mykey");
// Each scope gets the same instance, because it's singleton
var instance = Assert.IsType<SingletonMiddleware>(instance1);
Assert.Same(instance, instance1Copy);
Assert.Same(instance, instance2);
Assert.IsType<TestChatClient>(instance.InnerClient);
}
public class SingletonMiddleware(IServiceProvider services, IChatClient inner) : DelegatingChatClient(inner)
{
public new IChatClient InnerClient => base.InnerClient;
public IServiceProvider Services => services;

Просмотреть файл

@ -681,12 +681,12 @@ public class DistributedCachingChatClientTest
new(ChatRole.Assistant, [new TextContent("Hey")])]));
}
};
using var outer = new ChatClientBuilder(services)
using var outer = new ChatClientBuilder(testClient)
.UseDistributedCache(configure: options =>
{
options.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;
})
.Use(testClient);
.Build(services);
// Act: Make a request that should populate the cache
Assert.Empty(_storage.Keys);

Просмотреть файл

@ -295,7 +295,7 @@ public class FunctionInvokingChatClientTests
}
};
IChatClient service = new ChatClientBuilder().UseFunctionInvocation().Use(innerClient);
IChatClient service = new ChatClientBuilder(innerClient).UseFunctionInvocation().Build();
List<ChatMessage> chat = [new ChatMessage(ChatRole.User, "hello")];
var ex = await Assert.ThrowsAsync<InvalidOperationException>(
@ -415,7 +415,7 @@ public class FunctionInvokingChatClientTests
}
};
IChatClient service = configurePipeline(new ChatClientBuilder()).Use(innerClient);
IChatClient service = configurePipeline(new ChatClientBuilder(innerClient)).Build();
var result = await service.CompleteAsync(chat, options, cts.Token);
chat.Add(result.Message);

Просмотреть файл

@ -40,9 +40,9 @@ public class LoggingChatClientTests
},
};
using IChatClient client = new ChatClientBuilder(services)
using IChatClient client = new ChatClientBuilder(innerClient)
.UseLogging()
.Use(innerClient);
.Build(services);
await client.CompleteAsync(
[new(ChatRole.User, "What's the biggest animal?")],
@ -86,9 +86,9 @@ public class LoggingChatClientTests
yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "whale" };
}
using IChatClient client = new ChatClientBuilder()
using IChatClient client = new ChatClientBuilder(innerClient)
.UseLogging(logger)
.Use(innerClient);
.Build();
await foreach (var update in client.CompleteStreamingAsync(
[new(ChatRole.User, "What's the biggest animal?")],

Просмотреть файл

@ -86,13 +86,13 @@ public class OpenTelemetryChatClientTests
};
}
var chatClient = new ChatClientBuilder()
var chatClient = new ChatClientBuilder(innerClient)
.UseOpenTelemetry(loggerFactory, sourceName, configure: instance =>
{
instance.EnableSensitiveData = enableSensitiveData;
instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;
})
.Use(innerClient);
.Build();
List<ChatMessage> chatMessages =
[

Просмотреть файл

@ -1,11 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
namespace Microsoft.Extensions.AI;
public static class ScopedChatClientExtensions
{
public static ChatClientBuilder UseScopedMiddleware(this ChatClientBuilder builder)
=> builder.Use((services, inner)
=> new DependencyInjectionPatterns.ScopedChatClient(services, inner));
}

Просмотреть файл

@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
namespace Microsoft.Extensions.AI;
public static class SingletonChatClientExtensions
{
public static ChatClientBuilder UseSingletonMiddleware(this ChatClientBuilder builder)
=> builder.Use((services, inner)
=> new DependencyInjectionPatterns.SingletonMiddleware(services, inner));
}