Replace STJ boilerplate in the leaf clients with AIJsonUtilities calls. (#5630)

* Replace STJ boilerplate in the leaf clients with AIJsonUtilities calls.

* Address feedback.

* Address feedback.

* Remove redundant using
This commit is contained in:
Eirik Tsarpalis 2024-11-13 18:35:39 +00:00 коммит произвёл GitHub
Родитель ad3b5d0338
Коммит 73962c60f1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 84 добавлений и 122 удалений

Просмотреть файл

@ -7,6 +7,7 @@ using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;
@ -27,6 +28,9 @@ public sealed class AzureAIInferenceChatClient : IChatClient
/// <summary>The underlying <see cref="ChatCompletionsClient" />.</summary>
private readonly ChatCompletionsClient _chatCompletionsClient;
/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;
/// <summary>Initializes a new instance of the <see cref="AzureAIInferenceChatClient"/> class for the specified <see cref="ChatCompletionsClient"/>.</summary>
/// <param name="chatCompletionsClient">The underlying client.</param>
/// <param name="modelId">The ID of the model to use. If null, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
@ -51,7 +55,11 @@ public sealed class AzureAIInferenceChatClient : IChatClient
}
/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}
/// <inheritdoc />
public ChatClientMetadata Metadata { get; }
@ -304,7 +312,7 @@ public sealed class AzureAIInferenceChatClient : IChatClient
// These properties are strongly typed on ChatOptions but not on ChatCompletionsOptions.
if (options.TopK is int topK)
{
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, JsonContext.Default.Int32));
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(int))));
}
if (options.AdditionalProperties is { } props)
@ -317,7 +325,7 @@ public sealed class AzureAIInferenceChatClient : IChatClient
default:
if (prop.Value is not null)
{
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, JsonContext.GetTypeInfo(prop.Value.GetType(), ToolCallJsonSerializerOptions));
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}
@ -419,7 +427,7 @@ public sealed class AzureAIInferenceChatClient : IChatClient
{
try
{
result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
@ -449,7 +457,7 @@ public sealed class AzureAIInferenceChatClient : IChatClient
callRequest.CallId,
new FunctionCall(
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, JsonContext.GetTypeInfo(typeof(IDictionary<string, object>), ToolCallJsonSerializerOptions)))));
JsonSerializer.Serialize(callRequest.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object>))))));
}
}
@ -490,5 +498,6 @@ public sealed class AzureAIInferenceChatClient : IChatClient
private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
}

Просмотреть файл

@ -173,7 +173,7 @@ public sealed class AzureAIInferenceEmbeddingGenerator :
{
if (prop.Value is not null)
{
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, JsonContext.GetTypeInfo(prop.Value.GetType(), null));
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}
}

Просмотреть файл

@ -1,12 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Extensions.AI;
@ -16,55 +12,4 @@ namespace Microsoft.Extensions.AI;
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(AzureAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(long))]
[JsonSerializable(typeof(float))]
[JsonSerializable(typeof(double))]
[JsonSerializable(typeof(bool))]
[JsonSerializable(typeof(float[]))]
[JsonSerializable(typeof(byte[]))]
[JsonSerializable(typeof(sbyte[]))]
internal sealed partial class JsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();
/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
_defaultToolJsonOptions.GetTypeInfo(type);
/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};
options.MakeReadOnly();
return options;
}
return Default.Options;
}
}
internal sealed partial class JsonContext : JsonSerializerContext;

Просмотреть файл

@ -1,8 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI;
@ -23,6 +21,4 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(OllamaToolCall))]
[JsonSerializable(typeof(OllamaEmbeddingRequest))]
[JsonSerializable(typeof(OllamaEmbeddingResponse))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
internal sealed partial class JsonContext : JsonSerializerContext;

Просмотреть файл

@ -30,6 +30,9 @@ public sealed class OllamaChatClient : IChatClient
/// <summary>The <see cref="HttpClient"/> to use for sending requests.</summary>
private readonly HttpClient _httpClient;
/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;
/// <summary>Initializes a new instance of the <see cref="OllamaChatClient"/> class.</summary>
/// <param name="endpoint">The endpoint URI where Ollama is hosted.</param>
/// <param name="modelId">
@ -66,7 +69,11 @@ public sealed class OllamaChatClient : IChatClient
public ChatClientMetadata Metadata { get; }
/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}
/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
@ -388,7 +395,6 @@ public sealed class OllamaChatClient : IChatClient
case FunctionCallContent fcc:
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
yield return new OllamaChatRequestMessage
{
Role = "assistant",
@ -396,7 +402,7 @@ public sealed class OllamaChatClient : IChatClient
{
CallId = fcc.CallId,
Name = fcc.Name,
Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
}, JsonContext.Default.OllamaFunctionCallContent)
};
break;
@ -404,8 +410,7 @@ public sealed class OllamaChatClient : IChatClient
case FunctionResultContent frc:
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, serializerOptions.GetTypeInfo(typeof(object)));
JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
yield return new OllamaChatRequestMessage
{
Role = "tool",

Просмотреть файл

@ -3,7 +3,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
@ -38,6 +37,9 @@ public sealed partial class OpenAIChatClient : IChatClient
/// <summary>The underlying <see cref="ChatClient" />.</summary>
private readonly ChatClient _chatClient;
/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;
/// <summary>Initializes a new instance of the <see cref="OpenAIChatClient"/> class for the specified <see cref="OpenAIClient"/>.</summary>
/// <param name="openAIClient">The underlying client.</param>
/// <param name="modelId">The model to use.</param>
@ -80,7 +82,11 @@ public sealed partial class OpenAIChatClient : IChatClient
}
/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}
/// <inheritdoc />
public ChatClientMetadata Metadata { get; }
@ -593,7 +599,7 @@ public sealed partial class OpenAIChatClient : IChatClient
{
try
{
result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
@ -622,7 +628,7 @@ public sealed partial class OpenAIChatClient : IChatClient
callRequest.Name,
new(JsonSerializer.SerializeToUtf8Bytes(
callRequest.Arguments,
JsonContext.GetTypeInfo(typeof(IDictionary<string, object?>), ToolCallJsonSerializerOptions)))));
ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))))));
}
}
@ -668,11 +674,13 @@ public sealed partial class OpenAIChatClient : IChatClient
private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8Json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
/// <summary>Source-generated JSON type information.</summary>
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
@ -680,48 +688,5 @@ public sealed partial class OpenAIChatClient : IChatClient
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
private sealed partial class JsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();
/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
_defaultToolJsonOptions.GetTypeInfo(type);
/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};
options.MakeReadOnly();
return options;
}
return Default.Options;
}
}
private sealed partial class JsonContext : JsonSerializerContext;
}

Просмотреть файл

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Net.Http;
using System.Text.Json;
using System.Threading.Tasks;
using Azure;
using Azure.AI.Inference;
@ -29,6 +30,19 @@ public class AzureAIInferenceChatClientTests
Assert.Throws<ArgumentException>("modelId", () => new AzureAIInferenceChatClient(client, " "));
}
[Fact]
public void ToolCallJsonSerializerOptions_HasExpectedValue()
{
using AzureAIInferenceChatClient client = new(new(new("http://somewhere"), new AzureKeyCredential("key")), "mode");
Assert.Same(client.ToolCallJsonSerializerOptions, AIJsonUtilities.DefaultOptions);
Assert.Throws<ArgumentNullException>("value", () => client.ToolCallJsonSerializerOptions = null!);
JsonSerializerOptions options = new();
client.ToolCallJsonSerializerOptions = options;
Assert.Same(options, client.ToolCallJsonSerializerOptions);
}
[Fact]
public void AsChatClient_InvalidArgs_Throws()
{

Просмотреть файл

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Net.Http;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
@ -26,6 +27,19 @@ public class OllamaChatClientTests
Assert.Throws<ArgumentException>("modelId", () => new OllamaChatClient("http://localhost", " "));
}
[Fact]
public void ToolCallJsonSerializerOptions_HasExpectedValue()
{
using OllamaChatClient client = new("http://localhost", "model");
Assert.Same(client.ToolCallJsonSerializerOptions, AIJsonUtilities.DefaultOptions);
Assert.Throws<ArgumentNullException>("value", () => client.ToolCallJsonSerializerOptions = null!);
JsonSerializerOptions options = new();
client.ToolCallJsonSerializerOptions = options;
Assert.Same(options, client.ToolCallJsonSerializerOptions);
}
[Fact]
public void GetService_SuccessfullyReturnsUnderlyingClient()
{

Просмотреть файл

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Net.Http;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Caching.Distributed;
@ -34,6 +35,19 @@ public class OpenAIChatClientTests
Assert.Throws<ArgumentException>("modelId", () => new OpenAIChatClient(openAIClient, " "));
}
[Fact]
public void ToolCallJsonSerializerOptions_HasExpectedValue()
{
using OpenAIChatClient client = new(new("key"), "model");
Assert.Same(client.ToolCallJsonSerializerOptions, AIJsonUtilities.DefaultOptions);
Assert.Throws<ArgumentNullException>("value", () => client.ToolCallJsonSerializerOptions = null!);
JsonSerializerOptions options = new();
client.ToolCallJsonSerializerOptions = options;
Assert.Same(options, client.ToolCallJsonSerializerOptions);
}
[Fact]
public void AsChatClient_InvalidArgs_Throws()
{