From 7f2d9003f870bfff4a4bcf68c31604079b202662 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Wed, 20 Nov 2024 21:46:16 +0000 Subject: [PATCH] Expose a schema transformer on AIJsonSchemaCreateOptions. (#5677) * Expose a schema transformer on AIJsonSchemaCreateOptions. * Address feedback * Disable caching if a transformer is specified. * Remove `FilterDisallowedKeywords`. * Document caching. * Apply suggestions from code review --- ...icrosoft.Extensions.AI.Abstractions.csproj | 1 + .../Utilities/AIJsonSchemaCreateContext.cs | 105 ++++++++++++++ .../Utilities/AIJsonSchemaCreateOptions.cs | 24 ++-- .../Utilities/AIJsonUtilities.Schema.cs | 130 +++++++++--------- .../Utilities/AIJsonUtilitiesTests.cs | 87 ++++++------ .../Functions/AIFunctionFactoryTest.cs | 1 - 6 files changed, 228 insertions(+), 120 deletions(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateContext.cs diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj index b96b4dca92..8b8541688e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -25,6 +25,7 @@ true true true + true diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateContext.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateContext.cs new file mode 100644 index 0000000000..22e3bc6066 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateContext.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Reflection; +using System.Text.Json.Schema; +using System.Text.Json.Serialization.Metadata; + +#pragma warning disable CA1815 // Override equals and operator equals on value types + +namespace Microsoft.Extensions.AI; + +/// +/// Defines the context in which a JSON schema within a type graph is being generated. +/// +/// +/// This struct is being passed to the user-provided +/// callback by the method and cannot be instantiated directly. +/// +public readonly struct AIJsonSchemaCreateContext +{ + private readonly JsonSchemaExporterContext _exporterContext; + + internal AIJsonSchemaCreateContext(JsonSchemaExporterContext exporterContext) + { + _exporterContext = exporterContext; + } + + /// + /// Gets the path to the schema document currently being generated. + /// + public ReadOnlySpan Path => _exporterContext.Path; + + /// + /// Gets the for the type being processed. + /// + public JsonTypeInfo TypeInfo => _exporterContext.TypeInfo; + + /// + /// Gets the type info for the polymorphic base type if generated as a derived type. + /// + public JsonTypeInfo? BaseTypeInfo => _exporterContext.BaseTypeInfo; + + /// + /// Gets the if the schema is being generated for a property. + /// + public JsonPropertyInfo? PropertyInfo => _exporterContext.PropertyInfo; + + /// + /// Gets the declaring type of the property or parameter being processed. + /// + public Type? DeclaringType => +#if NET9_0_OR_GREATER + _exporterContext.PropertyInfo?.DeclaringType; +#else + _exporterContext.DeclaringType; +#endif + + /// + /// Gets the corresponding to the property or field being processed. + /// + public ICustomAttributeProvider? PropertyAttributeProvider => +#if NET9_0_OR_GREATER + _exporterContext.PropertyInfo?.AttributeProvider; +#else + _exporterContext.PropertyAttributeProvider; +#endif + + /// + /// Gets the of the + /// constructor parameter associated with the accompanying . + /// + public ICustomAttributeProvider? ParameterAttributeProvider => +#if NET9_0_OR_GREATER + _exporterContext.PropertyInfo?.AssociatedParameter?.AttributeProvider; +#else + _exporterContext.ParameterInfo; +#endif + + /// + /// Retrieves a custom attribute of a specified type that is applied to the specified schema node context. + /// + /// The type of attribute to search for. + /// If , specifies to also search the ancestors of the context members for custom attributes. + /// The first occurrence of if found, or otherwise. + /// + /// This helper method resolves attributes from context locations in the following order: + /// + /// Attributes specified on the property of the context, if specified. + /// Attributes specified on the constructor parameter of the context, if specified. + /// Attributes specified on the type of the context. + /// + /// + public TAttribute? GetCustomAttribute(bool inherit = false) + where TAttribute : Attribute + { + return GetCustomAttr(PropertyAttributeProvider) ?? + GetCustomAttr(ParameterAttributeProvider) ?? + GetCustomAttr(TypeInfo.Type); + + TAttribute? GetCustomAttr(ICustomAttributeProvider? provider) => + (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit).FirstOrDefault(); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs index 2ce42c3e61..ea1f393f7e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Text.Json.Nodes; + namespace Microsoft.Extensions.AI; /// @@ -13,6 +16,11 @@ public sealed class AIJsonSchemaCreateOptions /// public static AIJsonSchemaCreateOptions Default { get; } = new AIJsonSchemaCreateOptions(); + /// + /// Gets a callback that is invoked for every schema that is generated within the type graph. + /// + public Func? TransformSchemaNode { get; init; } + /// /// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums. /// @@ -32,20 +40,4 @@ public sealed class AIJsonSchemaCreateOptions /// Gets a value indicating whether to mark all properties as required in the schema. /// public bool RequireAllProperties { get; init; } = true; - - /// - /// Gets a value indicating whether to filter keywords that are disallowed by certain AI vendors. - /// - /// - /// Filters a number of non-essential schema keywords that are not yet supported by some AI vendors. - /// These include: - /// - /// The "minLength", "maxLength", "pattern", and "format" keywords. - /// The "minimum", "maximum", and "multipleOf" keywords. - /// The "patternProperties", "unevaluatedProperties", "propertyNames", "minProperties", and "maxProperties" keywords. - /// The "unevaluatedItems", "contains", "minContains", "maxContains", "minItems", "maxItems", and "uniqueItems" keywords. - /// - /// See also https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported. - /// - public bool FilterDisallowedKeywords { get; init; } = true; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index 4e3f90aa47..01f3d23e4c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -10,7 +10,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; #endif using System.Linq; -using System.Reflection; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -23,18 +22,6 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S1075 // URIs should not be hardcoded #pragma warning disable SA1118 // Parameter should not span multiple lines -using FunctionParameterKey = ( - System.Type? Type, - string? ParameterName, - string? Description, - bool HasDefaultValue, - object? DefaultValue, - bool IncludeSchemaUri, - bool DisallowAdditionalProperties, - bool IncludeTypeInEnumSchemas, - bool RequireAllProperties, - bool FilterDisallowedKeywords); - namespace Microsoft.Extensions.AI; /// Provides a collection of utility methods for marshalling JSON data. @@ -47,7 +34,7 @@ public static partial class AIJsonUtilities private const int CacheSoftLimit = 4096; /// Caches of generated schemas for each that's employed. - private static readonly ConditionalWeakTable> _schemaCaches = new(); + private static readonly ConditionalWeakTable> _schemaCaches = new(); /// Gets a JSON schema accepting all values. private static readonly JsonElement _trueJsonSchema = ParseJsonElement("true"u8); @@ -107,6 +94,10 @@ public static partial class AIJsonUtilities /// The options used to extract the schema from the specified type. /// The options controlling schema inference. /// A JSON schema document encoded as a . + /// + /// Uses a cache keyed on the to store schema result, + /// unless a delegate has been specified. + /// public static JsonElement CreateParameterJsonSchema( Type? type, string parameterName, @@ -121,17 +112,13 @@ public static partial class AIJsonUtilities serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; - FunctionParameterKey key = ( + SchemaGenerationKey key = new( type, parameterName, description, hasDefaultValue, defaultValue, - IncludeSchemaUri: false, - inferenceOptions.DisallowAdditionalProperties, - inferenceOptions.IncludeTypeInEnumSchemas, - inferenceOptions.RequireAllProperties, - inferenceOptions.FilterDisallowedKeywords); + inferenceOptions); return GetJsonSchemaCached(serializerOptions, key); } @@ -144,6 +131,10 @@ public static partial class AIJsonUtilities /// The options used to extract the schema from the specified type. /// The options controlling schema inference. /// A representing the schema. + /// + /// Uses a cache keyed on the to store schema result, + /// unless a delegate has been specified. + /// public static JsonElement CreateJsonSchema( Type? type, string? description = null, @@ -155,27 +146,23 @@ public static partial class AIJsonUtilities serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; - FunctionParameterKey key = ( + SchemaGenerationKey key = new( type, - ParameterName: null, + parameterName: null, description, hasDefaultValue, defaultValue, - inferenceOptions.IncludeSchemaKeyword, - inferenceOptions.DisallowAdditionalProperties, - inferenceOptions.IncludeTypeInEnumSchemas, - inferenceOptions.RequireAllProperties, - inferenceOptions.FilterDisallowedKeywords); + inferenceOptions); return GetJsonSchemaCached(serializerOptions, key); } - private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, FunctionParameterKey key) + private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, SchemaGenerationKey key) { options.MakeReadOnly(); - ConcurrentDictionary cache = _schemaCaches.GetOrCreateValue(options); + ConcurrentDictionary cache = _schemaCaches.GetOrCreateValue(options); - if (cache.Count >= CacheSoftLimit) + if (key.TransformSchemaNode is not null || cache.Count >= CacheSoftLimit) { return GetJsonSchemaCore(options, key); } @@ -195,7 +182,7 @@ public static partial class AIJsonUtilities Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " + "The exception message will guide users to turn off 'IlcTrimMetadata' which resolves all issues.")] #endif - private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key) + private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, SchemaGenerationKey key) { _ = Throw.IfNull(options); options.MakeReadOnly(); @@ -206,7 +193,7 @@ public static partial class AIJsonUtilities JsonObject? schemaObj = null; - if (key.IncludeSchemaUri) + if (key.IncludeSchemaKeyword) { (schemaObj = [])["$schema"] = SchemaKeywordUri; } @@ -244,7 +231,7 @@ public static partial class AIJsonUtilities JsonNode node = options.GetJsonSchemaAsNode(key.Type, exporterOptions); return JsonSerializer.SerializeToElement(node, JsonContext.Default.JsonNode); - JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) + JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, JsonNode schema) { const string SchemaPropertyName = "$schema"; const string DescriptionPropertyName = "description"; @@ -258,7 +245,9 @@ public static partial class AIJsonUtilities const string DefaultPropertyName = "default"; const string RefPropertyName = "$ref"; - if (ctx.ResolveAttribute() is { } attr) + AIJsonSchemaCreateContext ctx = new(schemaExporterContext); + + if (ctx.GetCustomAttribute() is { } attr) { ConvertSchemaToObject(ref schema).InsertAtStart(DescriptionPropertyName, (JsonNode)attr.Description); } @@ -308,12 +297,9 @@ public static partial class AIJsonUtilities } // Filter potentially disallowed keywords. - if (key.FilterDisallowedKeywords) + foreach (string keyword in _schemaKeywordsDisallowedByAIVendors) { - foreach (string keyword in _schemaKeywordsDisallowedByAIVendors) - { - _ = objSchema.Remove(keyword); - } + _ = objSchema.Remove(keyword); } // Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand @@ -357,13 +343,19 @@ public static partial class AIJsonUtilities ConvertSchemaToObject(ref schema)[DefaultPropertyName] = defaultValue; } - if (key.IncludeSchemaUri) + if (key.IncludeSchemaKeyword) { // The $schema property must be the first keyword in the object ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri); } } + // Finally, apply any user-defined transformations if specified. + if (key.TransformSchemaNode is { } transformer) + { + schema = transformer(ctx, schema); + } + return schema; static JsonObject ConvertSchemaToObject(ref JsonNode schema) @@ -388,7 +380,7 @@ public static partial class AIJsonUtilities } } - private static bool TypeIsIntegerWithStringNumberHandling(JsonSchemaExporterContext ctx, JsonObject schema) + private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema) { if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray) { @@ -443,30 +435,44 @@ public static partial class AIJsonUtilities } #endif - private static TAttribute? ResolveAttribute(this JsonSchemaExporterContext ctx) - where TAttribute : Attribute - { - // Resolve attributes from locations in the following order: - // 1. Property-level attributes - // 2. Parameter-level attributes and - // 3. Type-level attributes. - return -#if NET9_0_OR_GREATER - GetAttrs(ctx.PropertyInfo?.AttributeProvider) ?? - GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ?? -#else - GetAttrs(ctx.PropertyAttributeProvider) ?? - GetAttrs(ctx.ParameterInfo) ?? -#endif - GetAttrs(ctx.TypeInfo.Type); - - static TAttribute? GetAttrs(ICustomAttributeProvider? provider) => - (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault(); - } - private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) { Utf8JsonReader reader = new(utf8Json); return JsonElement.ParseValue(ref reader); } + + /// The equatable key used to look up cached schemas. + private readonly record struct SchemaGenerationKey + { + public SchemaGenerationKey( + Type? type, + string? parameterName, + string? description, + bool hasDefaultValue, + object? defaultValue, + AIJsonSchemaCreateOptions options) + { + Type = type; + ParameterName = parameterName; + Description = description; + HasDefaultValue = hasDefaultValue; + DefaultValue = defaultValue; + IncludeSchemaKeyword = options.IncludeSchemaKeyword; + DisallowAdditionalProperties = options.DisallowAdditionalProperties; + IncludeTypeInEnumSchemas = options.IncludeTypeInEnumSchemas; + RequireAllProperties = options.RequireAllProperties; + TransformSchemaNode = options.TransformSchemaNode; + } + + public Type? Type { get; } + public string? ParameterName { get; } + public string? Description { get; } + public bool HasDefaultValue { get; } + public object? DefaultValue { get; } + public bool IncludeSchemaKeyword { get; } + public bool DisallowAdditionalProperties { get; } + public bool IncludeTypeInEnumSchemas { get; } + public bool RequireAllProperties { get; } + public Func? TransformSchemaNode { get; } + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index 4107618d85..fb8501909c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -45,7 +45,7 @@ public static class AIJsonUtilitiesTests Assert.True(options.DisallowAdditionalProperties); Assert.False(options.IncludeSchemaKeyword); Assert.True(options.RequireAllProperties); - Assert.True(options.FilterDisallowedKeywords); + Assert.Null(options.TransformSchemaNode); } [Fact] @@ -124,6 +124,51 @@ public static class AIJsonUtilitiesTests Assert.True(JsonElement.DeepEquals(expected, actual)); } + [Fact] + public static void CreateJsonSchema_UserDefinedTransformer() + { + JsonElement expected = JsonDocument.Parse(""" + { + "description": "The type", + "type": "object", + "properties": { + "Key": { + "$comment": "Contains a DescriptionAttribute declaration with the text 'The parameter'.", + "type": "integer" + }, + "EnumValue": { + "type": "string", + "enum": ["A", "B"] + }, + "Value": { + "type": ["string", "null"], + "default": null + } + }, + "required": ["Key", "EnumValue", "Value"], + "additionalProperties": false + } + """).RootElement; + + AIJsonSchemaCreateOptions inferenceOptions = new() + { + TransformSchemaNode = static (context, schema) => + { + return context.TypeInfo.Type == typeof(int) && context.GetCustomAttribute() is DescriptionAttribute attr + ? new JsonObject + { + ["$comment"] = $"Contains a DescriptionAttribute declaration with the text '{attr.Description}'.", + ["type"] = "integer", + } + : schema; + } + }; + + JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default, inferenceOptions: inferenceOptions); + + Assert.True(JsonElement.DeepEquals(expected, actual)); + } + [Fact] public static void CreateJsonSchema_FiltersDisallowedKeywords() { @@ -152,46 +197,6 @@ public static class AIJsonUtilitiesTests Assert.True(JsonElement.DeepEquals(expected, actual)); } - [Fact] - public static void CreateJsonSchema_FilterDisallowedKeywords_Disabled() - { - JsonElement expected = JsonDocument.Parse(""" - { - "type": "object", - "properties": { - "Date": { - "type": "string", - "format": "date-time" - }, - "TimeSpan": { - "$comment": "Represents a System.TimeSpan value.", - "type": "string", - "pattern": "^-?(\\d+\\.)?\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,7})?$" - }, - "Char" : { - "type": "string", - "minLength": 1, - "maxLength": 1 - } - }, - "required": ["Date","TimeSpan","Char"], - "additionalProperties": false - } - """).RootElement; - - AIJsonSchemaCreateOptions inferenceOptions = new() - { - FilterDisallowedKeywords = false - }; - - JsonElement actual = AIJsonUtilities.CreateJsonSchema( - typeof(PocoWithTypesWithOpenAIUnsupportedKeywords), - serializerOptions: JsonSerializerOptions.Default, - inferenceOptions: inferenceOptions); - - Assert.True(JsonElement.DeepEquals(expected, actual)); - } - public class PocoWithTypesWithOpenAIUnsupportedKeywords { // Uses the unsupported "format" keyword diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index c72a2f3082..207a470575 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -191,7 +191,6 @@ public class AIFunctionFactoryTest Assert.NotNull(schemaOptions); Assert.True(schemaOptions.IncludeTypeInEnumSchemas); - Assert.True(schemaOptions.FilterDisallowedKeywords); Assert.True(schemaOptions.RequireAllProperties); Assert.True(schemaOptions.DisallowAdditionalProperties); }