Expose a schema transformer on AIJsonSchemaCreateOptions. (#5677)

* Expose a schema transformer on AIJsonSchemaCreateOptions.

* Address feedback

* Disable caching if a transformer is specified.

* Remove `FilterDisallowedKeywords`.

* Document caching.

* Apply suggestions from code review
This commit is contained in:
Eirik Tsarpalis 2024-11-20 21:46:16 +00:00 коммит произвёл Stephen Toub
Родитель 1f47a84f52
Коммит 7f2d9003f8
6 изменённых файлов: 228 добавлений и 120 удалений

Просмотреть файл

@ -25,6 +25,7 @@
<InjectSharedEmptyCollections>true</InjectSharedEmptyCollections>
<InjectStringHashOnLegacy>true</InjectStringHashOnLegacy>
<InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy>
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
</PropertyGroup>
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">

Просмотреть файл

@ -0,0 +1,105 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Linq;
using System.Reflection;
using System.Text.Json.Schema;
using System.Text.Json.Serialization.Metadata;
#pragma warning disable CA1815 // Override equals and operator equals on value types
namespace Microsoft.Extensions.AI;
/// <summary>
/// Defines the context in which a JSON schema within a type graph is being generated.
/// </summary>
/// <remarks>
/// This struct is being passed to the user-provided <see cref="AIJsonSchemaCreateOptions.TransformSchemaNode"/>
/// callback by the <see cref="AIJsonUtilities.CreateJsonSchema"/> method and cannot be instantiated directly.
/// </remarks>
public readonly struct AIJsonSchemaCreateContext
{
private readonly JsonSchemaExporterContext _exporterContext;
internal AIJsonSchemaCreateContext(JsonSchemaExporterContext exporterContext)
{
_exporterContext = exporterContext;
}
/// <summary>
/// Gets the path to the schema document currently being generated.
/// </summary>
public ReadOnlySpan<string> Path => _exporterContext.Path;
/// <summary>
/// Gets the <see cref="JsonTypeInfo"/> for the type being processed.
/// </summary>
public JsonTypeInfo TypeInfo => _exporterContext.TypeInfo;
/// <summary>
/// Gets the type info for the polymorphic base type if generated as a derived type.
/// </summary>
public JsonTypeInfo? BaseTypeInfo => _exporterContext.BaseTypeInfo;
/// <summary>
/// Gets the <see cref="JsonPropertyInfo"/> if the schema is being generated for a property.
/// </summary>
public JsonPropertyInfo? PropertyInfo => _exporterContext.PropertyInfo;
/// <summary>
/// Gets the declaring type of the property or parameter being processed.
/// </summary>
public Type? DeclaringType =>
#if NET9_0_OR_GREATER
_exporterContext.PropertyInfo?.DeclaringType;
#else
_exporterContext.DeclaringType;
#endif
/// <summary>
/// Gets the <see cref="ICustomAttributeProvider"/> corresponding to the property or field being processed.
/// </summary>
public ICustomAttributeProvider? PropertyAttributeProvider =>
#if NET9_0_OR_GREATER
_exporterContext.PropertyInfo?.AttributeProvider;
#else
_exporterContext.PropertyAttributeProvider;
#endif
/// <summary>
/// Gets the <see cref="System.Reflection.ICustomAttributeProvider"/> of the
/// constructor parameter associated with the accompanying <see cref="PropertyInfo"/>.
/// </summary>
public ICustomAttributeProvider? ParameterAttributeProvider =>
#if NET9_0_OR_GREATER
_exporterContext.PropertyInfo?.AssociatedParameter?.AttributeProvider;
#else
_exporterContext.ParameterInfo;
#endif
/// <summary>
/// Retrieves a custom attribute of a specified type that is applied to the specified schema node context.
/// </summary>
/// <typeparam name="TAttribute">The type of attribute to search for.</typeparam>
/// <param name="inherit">If <see langword="true"/>, specifies to also search the ancestors of the context members for custom attributes.</param>
/// <returns>The first occurrence of <typeparamref name="TAttribute"/> if found, or <see langword="null"/> otherwise.</returns>
/// <remarks>
/// This helper method resolves attributes from context locations in the following order:
/// <list type="number">
/// <item>Attributes specified on the property of the context, if specified.</item>
/// <item>Attributes specified on the constructor parameter of the context, if specified.</item>
/// <item>Attributes specified on the type of the context.</item>
/// </list>
/// </remarks>
public TAttribute? GetCustomAttribute<TAttribute>(bool inherit = false)
where TAttribute : Attribute
{
return GetCustomAttr(PropertyAttributeProvider) ??
GetCustomAttr(ParameterAttributeProvider) ??
GetCustomAttr(TypeInfo.Type);
TAttribute? GetCustomAttr(ICustomAttributeProvider? provider) =>
(TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit).FirstOrDefault();
}
}

Просмотреть файл

@ -1,6 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Text.Json.Nodes;
namespace Microsoft.Extensions.AI;
/// <summary>
@ -13,6 +16,11 @@ public sealed class AIJsonSchemaCreateOptions
/// </summary>
public static AIJsonSchemaCreateOptions Default { get; } = new AIJsonSchemaCreateOptions();
/// <summary>
/// Gets a callback that is invoked for every schema that is generated within the type graph.
/// </summary>
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }
/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
@ -32,20 +40,4 @@ public sealed class AIJsonSchemaCreateOptions
/// Gets a value indicating whether to mark all properties as required in the schema.
/// </summary>
public bool RequireAllProperties { get; init; } = true;
/// <summary>
/// Gets a value indicating whether to filter keywords that are disallowed by certain AI vendors.
/// </summary>
/// <remarks>
/// Filters a number of non-essential schema keywords that are not yet supported by some AI vendors.
/// These include:
/// <list type="bullet">
/// <item>The "minLength", "maxLength", "pattern", and "format" keywords.</item>
/// <item>The "minimum", "maximum", and "multipleOf" keywords.</item>
/// <item>The "patternProperties", "unevaluatedProperties", "propertyNames", "minProperties", and "maxProperties" keywords.</item>
/// <item>The "unevaluatedItems", "contains", "minContains", "maxContains", "minItems", "maxItems", and "uniqueItems" keywords.</item>
/// </list>
/// See also https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported.
/// </remarks>
public bool FilterDisallowedKeywords { get; init; } = true;
}

Просмотреть файл

@ -10,7 +10,6 @@ using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
#endif
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
@ -23,18 +22,6 @@ using Microsoft.Shared.Diagnostics;
#pragma warning disable S1075 // URIs should not be hardcoded
#pragma warning disable SA1118 // Parameter should not span multiple lines
using FunctionParameterKey = (
System.Type? Type,
string? ParameterName,
string? Description,
bool HasDefaultValue,
object? DefaultValue,
bool IncludeSchemaUri,
bool DisallowAdditionalProperties,
bool IncludeTypeInEnumSchemas,
bool RequireAllProperties,
bool FilterDisallowedKeywords);
namespace Microsoft.Extensions.AI;
/// <summary>Provides a collection of utility methods for marshalling JSON data.</summary>
@ -47,7 +34,7 @@ public static partial class AIJsonUtilities
private const int CacheSoftLimit = 4096;
/// <summary>Caches of generated schemas for each <see cref="JsonSerializerOptions"/> that's employed.</summary>
private static readonly ConditionalWeakTable<JsonSerializerOptions, ConcurrentDictionary<FunctionParameterKey, JsonElement>> _schemaCaches = new();
private static readonly ConditionalWeakTable<JsonSerializerOptions, ConcurrentDictionary<SchemaGenerationKey, JsonElement>> _schemaCaches = new();
/// <summary>Gets a JSON schema accepting all values.</summary>
private static readonly JsonElement _trueJsonSchema = ParseJsonElement("true"u8);
@ -107,6 +94,10 @@ public static partial class AIJsonUtilities
/// <param name="serializerOptions">The options used to extract the schema from the specified type.</param>
/// <param name="inferenceOptions">The options controlling schema inference.</param>
/// <returns>A JSON schema document encoded as a <see cref="JsonElement"/>.</returns>
/// <remarks>
/// Uses a cache keyed on the <paramref name="serializerOptions"/> to store schema result,
/// unless a <see cref="AIJsonSchemaCreateOptions.TransformSchemaNode" /> delegate has been specified.
/// </remarks>
public static JsonElement CreateParameterJsonSchema(
Type? type,
string parameterName,
@ -121,17 +112,13 @@ public static partial class AIJsonUtilities
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
FunctionParameterKey key = (
SchemaGenerationKey key = new(
type,
parameterName,
description,
hasDefaultValue,
defaultValue,
IncludeSchemaUri: false,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);
inferenceOptions);
return GetJsonSchemaCached(serializerOptions, key);
}
@ -144,6 +131,10 @@ public static partial class AIJsonUtilities
/// <param name="serializerOptions">The options used to extract the schema from the specified type.</param>
/// <param name="inferenceOptions">The options controlling schema inference.</param>
/// <returns>A <see cref="JsonElement"/> representing the schema.</returns>
/// <remarks>
/// Uses a cache keyed on the <paramref name="serializerOptions"/> to store schema result,
/// unless a <see cref="AIJsonSchemaCreateOptions.TransformSchemaNode" /> delegate has been specified.
/// </remarks>
public static JsonElement CreateJsonSchema(
Type? type,
string? description = null,
@ -155,27 +146,23 @@ public static partial class AIJsonUtilities
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
FunctionParameterKey key = (
SchemaGenerationKey key = new(
type,
ParameterName: null,
parameterName: null,
description,
hasDefaultValue,
defaultValue,
inferenceOptions.IncludeSchemaKeyword,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);
inferenceOptions);
return GetJsonSchemaCached(serializerOptions, key);
}
private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, FunctionParameterKey key)
private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, SchemaGenerationKey key)
{
options.MakeReadOnly();
ConcurrentDictionary<FunctionParameterKey, JsonElement> cache = _schemaCaches.GetOrCreateValue(options);
ConcurrentDictionary<SchemaGenerationKey, JsonElement> cache = _schemaCaches.GetOrCreateValue(options);
if (cache.Count >= CacheSoftLimit)
if (key.TransformSchemaNode is not null || cache.Count >= CacheSoftLimit)
{
return GetJsonSchemaCore(options, key);
}
@ -195,7 +182,7 @@ public static partial class AIJsonUtilities
Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " +
"The exception message will guide users to turn off 'IlcTrimMetadata' which resolves all issues.")]
#endif
private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key)
private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, SchemaGenerationKey key)
{
_ = Throw.IfNull(options);
options.MakeReadOnly();
@ -206,7 +193,7 @@ public static partial class AIJsonUtilities
JsonObject? schemaObj = null;
if (key.IncludeSchemaUri)
if (key.IncludeSchemaKeyword)
{
(schemaObj = [])["$schema"] = SchemaKeywordUri;
}
@ -244,7 +231,7 @@ public static partial class AIJsonUtilities
JsonNode node = options.GetJsonSchemaAsNode(key.Type, exporterOptions);
return JsonSerializer.SerializeToElement(node, JsonContext.Default.JsonNode);
JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, JsonNode schema)
{
const string SchemaPropertyName = "$schema";
const string DescriptionPropertyName = "description";
@ -258,7 +245,9 @@ public static partial class AIJsonUtilities
const string DefaultPropertyName = "default";
const string RefPropertyName = "$ref";
if (ctx.ResolveAttribute<DescriptionAttribute>() is { } attr)
AIJsonSchemaCreateContext ctx = new(schemaExporterContext);
if (ctx.GetCustomAttribute<DescriptionAttribute>() is { } attr)
{
ConvertSchemaToObject(ref schema).InsertAtStart(DescriptionPropertyName, (JsonNode)attr.Description);
}
@ -308,12 +297,9 @@ public static partial class AIJsonUtilities
}
// Filter potentially disallowed keywords.
if (key.FilterDisallowedKeywords)
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
{
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
{
_ = objSchema.Remove(keyword);
}
_ = objSchema.Remove(keyword);
}
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
@ -357,13 +343,19 @@ public static partial class AIJsonUtilities
ConvertSchemaToObject(ref schema)[DefaultPropertyName] = defaultValue;
}
if (key.IncludeSchemaUri)
if (key.IncludeSchemaKeyword)
{
// The $schema property must be the first keyword in the object
ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri);
}
}
// Finally, apply any user-defined transformations if specified.
if (key.TransformSchemaNode is { } transformer)
{
schema = transformer(ctx, schema);
}
return schema;
static JsonObject ConvertSchemaToObject(ref JsonNode schema)
@ -388,7 +380,7 @@ public static partial class AIJsonUtilities
}
}
private static bool TypeIsIntegerWithStringNumberHandling(JsonSchemaExporterContext ctx, JsonObject schema)
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema)
{
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{
@ -443,30 +435,44 @@ public static partial class AIJsonUtilities
}
#endif
private static TAttribute? ResolveAttribute<TAttribute>(this JsonSchemaExporterContext ctx)
where TAttribute : Attribute
{
// Resolve attributes from locations in the following order:
// 1. Property-level attributes
// 2. Parameter-level attributes and
// 3. Type-level attributes.
return
#if NET9_0_OR_GREATER
GetAttrs(ctx.PropertyInfo?.AttributeProvider) ??
GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ??
#else
GetAttrs(ctx.PropertyAttributeProvider) ??
GetAttrs(ctx.ParameterInfo) ??
#endif
GetAttrs(ctx.TypeInfo.Type);
static TAttribute? GetAttrs(ICustomAttributeProvider? provider) =>
(TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault();
}
private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
return JsonElement.ParseValue(ref reader);
}
/// <summary>The equatable key used to look up cached schemas.</summary>
private readonly record struct SchemaGenerationKey
{
public SchemaGenerationKey(
Type? type,
string? parameterName,
string? description,
bool hasDefaultValue,
object? defaultValue,
AIJsonSchemaCreateOptions options)
{
Type = type;
ParameterName = parameterName;
Description = description;
HasDefaultValue = hasDefaultValue;
DefaultValue = defaultValue;
IncludeSchemaKeyword = options.IncludeSchemaKeyword;
DisallowAdditionalProperties = options.DisallowAdditionalProperties;
IncludeTypeInEnumSchemas = options.IncludeTypeInEnumSchemas;
RequireAllProperties = options.RequireAllProperties;
TransformSchemaNode = options.TransformSchemaNode;
}
public Type? Type { get; }
public string? ParameterName { get; }
public string? Description { get; }
public bool HasDefaultValue { get; }
public object? DefaultValue { get; }
public bool IncludeSchemaKeyword { get; }
public bool DisallowAdditionalProperties { get; }
public bool IncludeTypeInEnumSchemas { get; }
public bool RequireAllProperties { get; }
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; }
}
}

Просмотреть файл

@ -45,7 +45,7 @@ public static class AIJsonUtilitiesTests
Assert.True(options.DisallowAdditionalProperties);
Assert.False(options.IncludeSchemaKeyword);
Assert.True(options.RequireAllProperties);
Assert.True(options.FilterDisallowedKeywords);
Assert.Null(options.TransformSchemaNode);
}
[Fact]
@ -124,6 +124,51 @@ public static class AIJsonUtilitiesTests
Assert.True(JsonElement.DeepEquals(expected, actual));
}
[Fact]
public static void CreateJsonSchema_UserDefinedTransformer()
{
JsonElement expected = JsonDocument.Parse("""
{
"description": "The type",
"type": "object",
"properties": {
"Key": {
"$comment": "Contains a DescriptionAttribute declaration with the text 'The parameter'.",
"type": "integer"
},
"EnumValue": {
"type": "string",
"enum": ["A", "B"]
},
"Value": {
"type": ["string", "null"],
"default": null
}
},
"required": ["Key", "EnumValue", "Value"],
"additionalProperties": false
}
""").RootElement;
AIJsonSchemaCreateOptions inferenceOptions = new()
{
TransformSchemaNode = static (context, schema) =>
{
return context.TypeInfo.Type == typeof(int) && context.GetCustomAttribute<DescriptionAttribute>() is DescriptionAttribute attr
? new JsonObject
{
["$comment"] = $"Contains a DescriptionAttribute declaration with the text '{attr.Description}'.",
["type"] = "integer",
}
: schema;
}
};
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default, inferenceOptions: inferenceOptions);
Assert.True(JsonElement.DeepEquals(expected, actual));
}
[Fact]
public static void CreateJsonSchema_FiltersDisallowedKeywords()
{
@ -152,46 +197,6 @@ public static class AIJsonUtilitiesTests
Assert.True(JsonElement.DeepEquals(expected, actual));
}
[Fact]
public static void CreateJsonSchema_FilterDisallowedKeywords_Disabled()
{
JsonElement expected = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"Date": {
"type": "string",
"format": "date-time"
},
"TimeSpan": {
"$comment": "Represents a System.TimeSpan value.",
"type": "string",
"pattern": "^-?(\\d+\\.)?\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,7})?$"
},
"Char" : {
"type": "string",
"minLength": 1,
"maxLength": 1
}
},
"required": ["Date","TimeSpan","Char"],
"additionalProperties": false
}
""").RootElement;
AIJsonSchemaCreateOptions inferenceOptions = new()
{
FilterDisallowedKeywords = false
};
JsonElement actual = AIJsonUtilities.CreateJsonSchema(
typeof(PocoWithTypesWithOpenAIUnsupportedKeywords),
serializerOptions: JsonSerializerOptions.Default,
inferenceOptions: inferenceOptions);
Assert.True(JsonElement.DeepEquals(expected, actual));
}
public class PocoWithTypesWithOpenAIUnsupportedKeywords
{
// Uses the unsupported "format" keyword

Просмотреть файл

@ -191,7 +191,6 @@ public class AIFunctionFactoryTest
Assert.NotNull(schemaOptions);
Assert.True(schemaOptions.IncludeTypeInEnumSchemas);
Assert.True(schemaOptions.FilterDisallowedKeywords);
Assert.True(schemaOptions.RequireAllProperties);
Assert.True(schemaOptions.DisallowAdditionalProperties);
}