[release/9.0] Merging changes from main into release branch for November release (#5602)

This commit is contained in:
Jose Perez Rodriguez 2024-11-06 09:38:41 -08:00 коммит произвёл GitHub
Родитель 7574980e14 f902047c64
Коммит d468173295
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
81 изменённых файлов: 6683 добавлений и 178 удалений

Просмотреть файл

@ -43,6 +43,10 @@
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\StringSyntaxAttribute\*.cs" LinkBase="LegacySupport\StringSyntaxAttribute" /> <Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\StringSyntaxAttribute\*.cs" LinkBase="LegacySupport\StringSyntaxAttribute" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(InjectJsonSchemaExporterOnLegacy)' == 'true' AND ('$(TargetFramework)' == 'net462' or '$(TargetFramework)' == 'netstandard2.0' or '$(TargetFramework)' == 'netcoreapp3.1' or '$(TargetFramework)' == 'net6.0' or '$(TargetFramework)' == 'net7.0' or '$(TargetFramework)' == 'net8.0')">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\Shared\JsonSchemaExporter\**\*.cs" LinkBase="Shared\EmptyCollections" />
</ItemGroup>
<ItemGroup Condition="'$(InjectGetOrAddOnLegacy)' == 'true' AND ('$(TargetFramework)' == 'net462' or '$(TargetFramework)' == 'netstandard2.0')"> <ItemGroup Condition="'$(InjectGetOrAddOnLegacy)' == 'true' AND ('$(TargetFramework)' == 'net462' or '$(TargetFramework)' == 'netstandard2.0')">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\GetOrAdd\*.cs" LinkBase="LegacySupport\GetOrAdd" /> <Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\GetOrAdd\*.cs" LinkBase="LegacySupport\GetOrAdd" />
</ItemGroup> </ItemGroup>

Просмотреть файл

@ -7,6 +7,7 @@
<PackageVersion Include="BenchmarkDotNet" Version="0.13.5" /> <PackageVersion Include="BenchmarkDotNet" Version="0.13.5" />
<PackageVersion Include="FluentAssertions" Version="6.11.0" /> <PackageVersion Include="FluentAssertions" Version="6.11.0" />
<PackageVersion Include="Grpc.AspNetCore" Version="2.65.0" /> <PackageVersion Include="Grpc.AspNetCore" Version="2.65.0" />
<PackageVersion Include="JsonSchema.Net" Version="7.2.3" />
<PackageVersion Include="Microsoft.Data.SqlClient" Version="5.2.2" /> <PackageVersion Include="Microsoft.Data.SqlClient" Version="5.2.2" />
<PackageVersion Include="Microsoft.Diagnostics.Tracing.TraceEvent" Version="3.1.3" /> <PackageVersion Include="Microsoft.Diagnostics.Tracing.TraceEvent" Version="3.1.3" />
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24378.1" /> <PackageVersion Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24378.1" />
@ -20,6 +21,7 @@
<PackageVersion Include="Verify.Xunit" Version="20.4.0" /> <PackageVersion Include="Verify.Xunit" Version="20.4.0" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.6.24" /> <PackageVersion Include="Xunit.Combinatorial" Version="1.6.24" />
<PackageVersion Include="xunit.extensibility.execution" Version="2.4.2" /> <PackageVersion Include="xunit.extensibility.execution" Version="2.4.2" />
<PackageVersion Include="Xunit.SkippableFact" Version="1.4.13" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net462'"> <ItemGroup Condition="'$(TargetFramework)' == 'net462'">

Двоичные данные
eng/spellchecking_exclusions.dic

Двоичный файл не отображается.

Просмотреть файл

@ -4,13 +4,21 @@
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Globalization; using System.Globalization;
using System.Linq; using System.Linq;
using Microsoft.Shared.Diagnostics;
#pragma warning disable S1144 // Unused private types or members should be removed
#pragma warning disable S2365 // Properties should not make collection or array copies
#pragma warning disable S3604 // Member initializer values should not be redundant
namespace Microsoft.Extensions.AI; namespace Microsoft.Extensions.AI;
/// <summary>Provides a dictionary used as the AdditionalProperties dictionary on Microsoft.Extensions.AI objects.</summary> /// <summary>Provides a dictionary used as the AdditionalProperties dictionary on Microsoft.Extensions.AI objects.</summary>
[DebuggerTypeProxy(typeof(DebugView))]
[DebuggerDisplay("Count = {Count}")]
public sealed class AdditionalPropertiesDictionary : IDictionary<string, object?>, IReadOnlyDictionary<string, object?> public sealed class AdditionalPropertiesDictionary : IDictionary<string, object?>, IReadOnlyDictionary<string, object?>
{ {
/// <summary>The underlying dictionary.</summary> /// <summary>The underlying dictionary.</summary>
@ -77,6 +85,25 @@ public sealed class AdditionalPropertiesDictionary : IDictionary<string, object?
/// <inheritdoc /> /// <inheritdoc />
public void Add(string key, object? value) => _dictionary.Add(key, value); public void Add(string key, object? value) => _dictionary.Add(key, value);
/// <summary>Attempts to add the specified key and value to the dictionary.</summary>
/// <param name="key">The key of the element to add.</param>
/// <param name="value">The value of the element to add.</param>
/// <returns><see langword="true"/> if the key/value pair was added to the dictionary successfully; otherwise, <see langword="false"/>.</returns>
public bool TryAdd(string key, object? value)
{
#if NET
return _dictionary.TryAdd(key, value);
#else
if (!_dictionary.ContainsKey(key))
{
_dictionary.Add(key, value);
return true;
}
return false;
#endif
}
/// <inheritdoc /> /// <inheritdoc />
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) => ((ICollection<KeyValuePair<string, object?>>)_dictionary).Add(item); void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) => ((ICollection<KeyValuePair<string, object?>>)_dictionary).Add(item);
@ -93,11 +120,17 @@ public sealed class AdditionalPropertiesDictionary : IDictionary<string, object?
void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) => void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) =>
((ICollection<KeyValuePair<string, object?>>)_dictionary).CopyTo(array, arrayIndex); ((ICollection<KeyValuePair<string, object?>>)_dictionary).CopyTo(array, arrayIndex);
/// <inheritdoc /> /// <summary>
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _dictionary.GetEnumerator(); /// Returns an enumerator that iterates through the <see cref="AdditionalPropertiesDictionary"/>.
/// </summary>
/// <returns>An <see cref="AdditionalPropertiesDictionary.Enumerator"/> that enumerates the contents of the <see cref="AdditionalPropertiesDictionary"/>.</returns>
public Enumerator GetEnumerator() => new(_dictionary.GetEnumerator());
/// <inheritdoc /> /// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => _dictionary.GetEnumerator(); IEnumerator<KeyValuePair<string, object?>> IEnumerable<KeyValuePair<string, object?>>.GetEnumerator() => GetEnumerator();
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
/// <inheritdoc /> /// <inheritdoc />
public bool Remove(string key) => _dictionary.Remove(key); public bool Remove(string key) => _dictionary.Remove(key);
@ -156,4 +189,59 @@ public sealed class AdditionalPropertiesDictionary : IDictionary<string, object?
value = default; value = default;
return false; return false;
} }
/// <summary>Enumerates the elements of an <see cref="AdditionalPropertiesDictionary"/>.</summary>
public struct Enumerator : IEnumerator<KeyValuePair<string, object?>>
{
/// <summary>The wrapped dictionary enumerator.</summary>
private Dictionary<string, object?>.Enumerator _dictionaryEnumerator;
/// <summary>Initializes a new instance of the <see cref="Enumerator"/> struct with the dictionary enumerator to wrap.</summary>
/// <param name="dictionaryEnumerator">The dictionary enumerator to wrap.</param>
internal Enumerator(Dictionary<string, object?>.Enumerator dictionaryEnumerator)
{
_dictionaryEnumerator = dictionaryEnumerator;
}
/// <inheritdoc />
public KeyValuePair<string, object?> Current => _dictionaryEnumerator.Current;
/// <inheritdoc />
object IEnumerator.Current => Current;
/// <inheritdoc />
public void Dispose() => _dictionaryEnumerator.Dispose();
/// <inheritdoc />
public bool MoveNext() => _dictionaryEnumerator.MoveNext();
/// <inheritdoc />
public void Reset() => Reset(ref _dictionaryEnumerator);
/// <summary>Calls <see cref="IEnumerator.Reset"/> on an enumerator.</summary>
private static void Reset<TEnumerator>(ref TEnumerator enumerator)
where TEnumerator : struct, IEnumerator
{
enumerator.Reset();
}
}
/// <summary>Provides a debugger view for the collection.</summary>
private sealed class DebugView(AdditionalPropertiesDictionary properties)
{
private readonly AdditionalPropertiesDictionary _properties = Throw.IfNull(properties);
[DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
public AdditionalProperty[] Items => (from p in _properties select new AdditionalProperty(p.Key, p.Value)).ToArray();
[DebuggerDisplay("{Value}", Name = "[{Key}]")]
public readonly struct AdditionalProperty(string key, object? value)
{
[DebuggerBrowsable(DebuggerBrowsableState.Collapsed)]
public string Key { get; } = key;
[DebuggerBrowsable(DebuggerBrowsableState.Collapsed)]
public object? Value { get; } = value;
}
}
} }

Просмотреть файл

@ -0,0 +1,19 @@
# Release History
## 9.0.0-preview.9.24525.1
- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
- Annotated `FunctionCallContent.Exception` and `FunctionResultContent.Exception` as `[JsonIgnore]`, such that they're ignored when serializing instances with `JsonSerializer`. The corresponding constructors accepting an `Exception` were removed.
- Annotated `ChatCompletion.Message` as `[JsonIgnore]`, such that it's ignored when serializing instances with `JsonSerializer`.
- Added the `FunctionCallContent.CreateFromParsedArguments` method.
- Added the `AdditionalPropertiesDictionary.TryGetValue<T>` method.
- Added the `StreamingChatCompletionUpdate.ModelId` property and removed the `AIContent.ModelId` property.
- Renamed the `GenerateAsync` extension method on `IEmbeddingGenerator<,>` to `GenerateEmbeddingsAsync` and updated it to return `Embedding<T>` rather than `GeneratedEmbeddings`.
- Added `GenerateAndZipAsync` and `GenerateEmbeddingVectorAsync` extension methods for `IEmbeddingGenerator<,>`.
- Added the `EmbeddingGeneratorOptions.Dimensions` property.
- Added the `ChatOptions.TopK` property.
- Normalized `null` inputs in `TextContent` to be empty strings.
## 9.0.0-preview.9.24507.7
Initial Preview

Просмотреть файл

@ -27,6 +27,9 @@ public class ChatOptions
/// <summary>Gets or sets the presence penalty for generating chat responses.</summary> /// <summary>Gets or sets the presence penalty for generating chat responses.</summary>
public float? PresencePenalty { get; set; } public float? PresencePenalty { get; set; }
/// <summary>Gets or sets a seed value used by a service to control the reproducability of results.</summary>
public long? Seed { get; set; }
/// <summary> /// <summary>
/// Gets or sets the response format for the chat request. /// Gets or sets the response format for the chat request.
/// </summary> /// </summary>
@ -74,6 +77,7 @@ public class ChatOptions
TopK = TopK, TopK = TopK,
FrequencyPenalty = FrequencyPenalty, FrequencyPenalty = FrequencyPenalty,
PresencePenalty = PresencePenalty, PresencePenalty = PresencePenalty,
Seed = Seed,
ResponseFormat = ResponseFormat, ResponseFormat = ResponseFormat,
ModelId = ModelId, ModelId = ModelId,
ToolMode = ToolMode, ToolMode = ToolMode,

Просмотреть файл

@ -16,9 +16,11 @@
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks> <TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA2227;CA1034;SA1316;S3253</NoWarn> <NoWarn>$(NoWarn);CA2227;CA1034;SA1316;S3253</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<InjectJsonSchemaExporterOnLegacy>true</InjectJsonSchemaExporterOnLegacy>
<InjectSharedEmptyCollections>true</InjectSharedEmptyCollections> <InjectSharedEmptyCollections>true</InjectSharedEmptyCollections>
<InjectStringHashOnLegacy>true</InjectStringHashOnLegacy> <InjectStringHashOnLegacy>true</InjectStringHashOnLegacy>
<InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy> <InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy>

Просмотреть файл

@ -23,11 +23,11 @@ public static partial class AIJsonUtilities
{ {
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model. // and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable Native AOT. // Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault) if (JsonSerializer.IsReflectionEnabledByDefault)
{ {
// Keep in sync with the JsonSourceGenerationOptions on JsonContext below. // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web) JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{ {
TypeInfoResolver = new DefaultJsonTypeInfoResolver(), TypeInfoResolver = new DefaultJsonTypeInfoResolver(),

Просмотреть файл

@ -5,17 +5,22 @@ using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.ComponentModel; using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
#if !NET9_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
#endif
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using System.Text.Json.Schema; using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics; using Microsoft.Shared.Diagnostics;
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions #pragma warning disable S1121 // Assignments should not be made from within sub-expressions
#pragma warning disable S107 // Methods should not have too many parameters #pragma warning disable S107 // Methods should not have too many parameters
#pragma warning disable S1075 // URIs should not be hardcoded #pragma warning disable S1075 // URIs should not be hardcoded
#pragma warning disable SA1118 // Parameter should not span multiple lines
using FunctionParameterKey = ( using FunctionParameterKey = (
System.Type? Type, System.Type? Type,
@ -138,8 +143,6 @@ public static partial class AIJsonUtilities
JsonSerializerOptions? serializerOptions = null, JsonSerializerOptions? serializerOptions = null,
AIJsonSchemaCreateOptions? inferenceOptions = null) AIJsonSchemaCreateOptions? inferenceOptions = null)
{ {
_ = Throw.IfNull(serializerOptions);
serializerOptions ??= DefaultOptions; serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default; inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
@ -176,6 +179,11 @@ public static partial class AIJsonUtilities
#endif #endif
} }
#if !NET9_0_OR_GREATER
[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access",
Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " +
"The exception message will guide users to turn off 'IlcTrimMetadata' which resolves all issues.")]
#endif
private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key) private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key)
{ {
_ = Throw.IfNull(options); _ = Throw.IfNull(options);
@ -238,16 +246,9 @@ public static partial class AIJsonUtilities
const string DefaultPropertyName = "default"; const string DefaultPropertyName = "default";
const string RefPropertyName = "$ref"; const string RefPropertyName = "$ref";
// Find the first DescriptionAttribute, starting first from the property, then the parameter, and finally the type itself. if (ctx.ResolveAttribute<DescriptionAttribute>() is { } attr)
Type descAttrType = typeof(DescriptionAttribute);
var descriptionAttribute =
GetAttrs(descAttrType, ctx.PropertyInfo?.AttributeProvider)?.FirstOrDefault() ??
GetAttrs(descAttrType, ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider)?.FirstOrDefault() ??
GetAttrs(descAttrType, ctx.TypeInfo.Type)?.FirstOrDefault();
if (descriptionAttribute is DescriptionAttribute attr)
{ {
ConvertSchemaToObject(ref schema).Insert(0, DescriptionPropertyName, (JsonNode)attr.Description); ConvertSchemaToObject(ref schema).InsertAtStart(DescriptionPropertyName, (JsonNode)attr.Description);
} }
if (schema is JsonObject objSchema) if (schema is JsonObject objSchema)
@ -270,7 +271,7 @@ public static partial class AIJsonUtilities
// Include the type keyword in enum types // Include the type keyword in enum types
if (key.IncludeTypeInEnumSchemas && ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName)) if (key.IncludeTypeInEnumSchemas && ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{ {
objSchema.Insert(0, TypePropertyName, "string"); objSchema.InsertAtStart(TypePropertyName, "string");
} }
// Disallow additional properties in object schemas // Disallow additional properties in object schemas
@ -278,25 +279,25 @@ public static partial class AIJsonUtilities
{ {
objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false); objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false);
} }
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama.
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
_ = obj.Remove(PatternPropertyName);
}
} }
if (ctx.Path.IsEmpty) if (ctx.Path.IsEmpty)
{ {
// We are at the root-level schema node, update/append parameter-specific metadata // We are at the root-level schema node, update/append parameter-specific metadata
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
if (TypeIsArrayContainingInteger(schema))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
_ = obj.Remove(PatternPropertyName);
}
if (!string.IsNullOrWhiteSpace(key.Description)) if (!string.IsNullOrWhiteSpace(key.Description))
{ {
JsonObject obj = ConvertSchemaToObject(ref schema); JsonObject obj = ConvertSchemaToObject(ref schema);
@ -305,7 +306,7 @@ public static partial class AIJsonUtilities
if (index < 0) if (index < 0)
{ {
// If there's no description property, insert it at the beginning of the doc. // If there's no description property, insert it at the beginning of the doc.
obj.Insert(0, DescriptionPropertyName, (JsonNode)key.Description!); obj.InsertAtStart(DescriptionPropertyName, (JsonNode)key.Description!);
} }
else else
{ {
@ -323,15 +324,12 @@ public static partial class AIJsonUtilities
if (key.IncludeSchemaUri) if (key.IncludeSchemaUri)
{ {
// The $schema property must be the first keyword in the object // The $schema property must be the first keyword in the object
ConvertSchemaToObject(ref schema).Insert(0, SchemaPropertyName, (JsonNode)SchemaKeywordUri); ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri);
} }
} }
return schema; return schema;
static object[]? GetAttrs(Type attrType, ICustomAttributeProvider? provider) =>
provider?.GetCustomAttributes(attrType, inherit: false);
static JsonObject ConvertSchemaToObject(ref JsonNode schema) static JsonObject ConvertSchemaToObject(ref JsonNode schema)
{ {
JsonObject obj; JsonObject obj;
@ -354,22 +352,82 @@ public static partial class AIJsonUtilities
} }
} }
private static bool TypeIsArrayContainingInteger(JsonNode schema) private static bool TypeIsIntegerWithStringNumberHandling(JsonSchemaExporterContext ctx, JsonObject schema)
{ {
if (schema["type"] is JsonArray typeArray) if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{ {
foreach (var entry in typeArray) int count = 0;
foreach (JsonNode? entry in typeArray)
{ {
if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue<string>() == "integer") if (entry?.GetValueKind() is JsonValueKind.String &&
entry.GetValue<string>() is "integer" or "string")
{ {
return true; count++;
} }
} }
return count == typeArray.Count;
} }
return false; return false;
} }
private static void InsertAtStart(this JsonObject jsonObject, string key, JsonNode value)
{
#if NET9_0_OR_GREATER
jsonObject.Insert(0, key, value);
#else
jsonObject.Remove(key);
var copiedEntries = jsonObject.ToArray();
jsonObject.Clear();
jsonObject.Add(key, value);
foreach (var entry in copiedEntries)
{
jsonObject[entry.Key] = entry.Value;
}
#endif
}
#if !NET9_0_OR_GREATER
private static int IndexOf(this JsonObject jsonObject, string key)
{
int i = 0;
foreach (var entry in jsonObject)
{
if (string.Equals(entry.Key, key, StringComparison.Ordinal))
{
return i;
}
i++;
}
return -1;
}
#endif
private static TAttribute? ResolveAttribute<TAttribute>(this JsonSchemaExporterContext ctx)
where TAttribute : Attribute
{
// Resolve attributes from locations in the following order:
// 1. Property-level attributes
// 2. Parameter-level attributes and
// 3. Type-level attributes.
return
#if NET9_0_OR_GREATER
GetAttrs(ctx.PropertyInfo?.AttributeProvider) ??
GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ??
#else
GetAttrs(ctx.PropertyAttributeProvider) ??
GetAttrs(ctx.ParameterInfo) ??
#endif
GetAttrs(ctx.TypeInfo.Type);
static TAttribute? GetAttrs(ICustomAttributeProvider? provider) =>
(TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault();
}
private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json) private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{ {
Utf8JsonReader reader = new(utf8Json); Utf8JsonReader reader = new(utf8Json);

Просмотреть файл

@ -285,6 +285,7 @@ public sealed class AzureAIInferenceChatClient : IChatClient
result.NucleusSamplingFactor = options.TopP; result.NucleusSamplingFactor = options.TopP;
result.PresencePenalty = options.PresencePenalty; result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature; result.Temperature = options.Temperature;
result.Seed = options.Seed;
if (options.StopSequences is { Count: > 0 } stopSequences) if (options.StopSequences is { Count: > 0 } stopSequences)
{ {
@ -306,11 +307,6 @@ public sealed class AzureAIInferenceChatClient : IChatClient
{ {
switch (prop.Key) switch (prop.Key)
{ {
// These properties are strongly-typed on the ChatCompletionsOptions class but not on the ChatOptions class.
case nameof(result.Seed) when prop.Value is long seed:
result.Seed = seed;
break;
// Propagate everything else to the ChatCompletionOptions' AdditionalProperties. // Propagate everything else to the ChatCompletionOptions' AdditionalProperties.
default: default:
if (prop.Value is not null) if (prop.Value is not null)

Просмотреть файл

@ -156,7 +156,7 @@ public sealed class AzureAIInferenceEmbeddingGenerator :
{ {
EmbeddingsOptions result = new(inputs) EmbeddingsOptions result = new(inputs)
{ {
Dimensions = _dimensions, Dimensions = options?.Dimensions ?? _dimensions,
Model = options?.ModelId ?? Metadata.ModelId, Model = options?.ModelId ?? Metadata.ModelId,
EncodingFormat = format, EncodingFormat = format,
}; };

Просмотреть файл

@ -0,0 +1,12 @@
# Release History
## 9.0.0-preview.9.24525.1
- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
- Updated to use Azure.AI.Inference 1.0.0-beta.2.
- Added `AzureAIInferenceEmbeddingGenerator` and corresponding `AsEmbeddingGenerator` extension method.
- Improved handling of assistant messages that include both text and function call content.
## 9.0.0-preview.9.24507.7
Initial Preview

Просмотреть файл

@ -48,11 +48,11 @@ internal sealed partial class JsonContext : JsonSerializerContext
{ {
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model. // and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable Native AOT. // Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault) if (JsonSerializer.IsReflectionEnabledByDefault)
{ {
// Keep in sync with the JsonSourceGenerationOptions on JsonContext below. // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web) JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{ {
TypeInfoResolver = new DefaultJsonTypeInfoResolver(), TypeInfoResolver = new DefaultJsonTypeInfoResolver(),

Просмотреть файл

@ -16,6 +16,7 @@
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks> <TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA1063;CA2227;SA1316;S1067;S1121;S3358</NoWarn> <NoWarn>$(NoWarn);CA1063;CA2227;SA1316;S1067;S1121;S3358</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
@ -28,6 +29,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Azure.AI.Inference" /> <PackageReference Include="Azure.AI.Inference" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" /> <PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="System.Memory.Data" />
<PackageReference Include="System.Text.Json" /> <PackageReference Include="System.Text.Json" />
</ItemGroup> </ItemGroup>

Просмотреть файл

@ -0,0 +1,10 @@
# Release History
## 9.0.0-preview.9.24525.1
- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
- Added additional constructors to `OllamaChatClient` and `OllamaEmbeddingGenerator` that accept `string` endpoints, in addition to the existing ones accepting `Uri` endpoints.
## 9.0.0-preview.9.24507.7
Initial Preview

Просмотреть файл

@ -16,6 +16,7 @@
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks> <TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA2227;SA1316;S1121;EA0002</NoWarn> <NoWarn>$(NoWarn);CA2227;SA1316;S1121;EA0002</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>

Просмотреть файл

@ -273,7 +273,6 @@ public sealed class OllamaChatClient : IChatClient
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value); TransferMetadataValue<bool>(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value);
TransferMetadataValue<int>(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value); TransferMetadataValue<int>(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value); TransferMetadataValue<float>(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value);
TransferMetadataValue<long>(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value); TransferMetadataValue<float>(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value); TransferMetadataValue<float>(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value);
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value); TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value);
@ -314,6 +313,11 @@ public sealed class OllamaChatClient : IChatClient
{ {
(request.Options ??= new()).top_k = topK; (request.Options ??= new()).top_k = topK;
} }
if (options.Seed is long seed)
{
(request.Options ??= new()).seed = seed;
}
} }
return request; return request;

Просмотреть файл

@ -0,0 +1,12 @@
# Release History
## 9.0.0-preview.9.24525.1
- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
- Improved handling of system messages that include multiple content items.
- Improved handling of assistant messages that include both text and function call content.
- Fixed handling of streaming updates containing empty payloads.
## 9.0.0-preview.9.24507.7
Initial Preview

Просмотреть файл

@ -16,6 +16,7 @@
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks> <TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002</NoWarn> <NoWarn>$(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
@ -25,6 +26,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="OpenAI" /> <PackageReference Include="OpenAI" />
<PackageReference Include="System.Memory.Data" />
<PackageReference Include="System.Text.Json" /> <PackageReference Include="System.Text.Json" />
</ItemGroup> </ItemGroup>

Просмотреть файл

@ -3,11 +3,13 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection; using System.Reflection;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics; using Microsoft.Shared.Diagnostics;
@ -265,8 +267,7 @@ public sealed partial class OpenAIChatClient : IChatClient
existing.CallId ??= toolCallUpdate.ToolCallId; existing.CallId ??= toolCallUpdate.ToolCallId;
existing.Name ??= toolCallUpdate.FunctionName; existing.Name ??= toolCallUpdate.FunctionName;
if (toolCallUpdate.FunctionArgumentsUpdate is { } update && if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty)
!update.ToMemory().IsEmpty) // workaround for https://github.com/dotnet/runtime/issues/68262 in 6.0.0 package
{ {
_ = (existing.Arguments ??= new()).Append(update.ToString()); _ = (existing.Arguments ??= new()).Append(update.ToString());
} }
@ -391,6 +392,9 @@ public sealed partial class OpenAIChatClient : IChatClient
result.TopP = options.TopP; result.TopP = options.TopP;
result.PresencePenalty = options.PresencePenalty; result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature; result.Temperature = options.Temperature;
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates.
result.Seed = options.Seed;
#pragma warning restore OPENAI001
if (options.StopSequences is { Count: > 0 } stopSequences) if (options.StopSequences is { Count: > 0 } stopSequences)
{ {
@ -425,13 +429,6 @@ public sealed partial class OpenAIChatClient : IChatClient
result.AllowParallelToolCalls = allowParallelToolCalls; result.AllowParallelToolCalls = allowParallelToolCalls;
} }
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
if (additionalProperties.TryGetValue(nameof(result.Seed), out long seed))
{
result.Seed = seed;
}
#pragma warning restore OPENAI001
if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
{ {
result.TopLogProbabilityCount = topLogProbabilityCountInt; result.TopLogProbabilityCount = topLogProbabilityCountInt;
@ -587,10 +584,9 @@ public sealed partial class OpenAIChatClient : IChatClient
string? result = resultContent.Result as string; string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null) if (result is null && resultContent.Result is not null)
{ {
JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
try try
{ {
result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object))); result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
} }
catch (NotSupportedException) catch (NotSupportedException)
{ {
@ -617,7 +613,9 @@ public sealed partial class OpenAIChatClient : IChatClient
ChatToolCall.CreateFunctionToolCall( ChatToolCall.CreateFunctionToolCall(
callRequest.CallId, callRequest.CallId,
callRequest.Name, callRequest.Name,
BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions))); new(JsonSerializer.SerializeToUtf8Bytes(
callRequest.Arguments,
JsonContext.GetTypeInfo(typeof(IDictionary<string, object?>), ToolCallJsonSerializerOptions)))));
} }
} }
@ -670,8 +668,53 @@ public sealed partial class OpenAIChatClient : IChatClient
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!); argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
/// <summary>Source-generated JSON type information.</summary> /// <summary>Source-generated JSON type information.</summary>
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatToolJson))] [JsonSerializable(typeof(OpenAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))] [JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(JsonElement))]
private sealed partial class JsonContext : JsonSerializerContext; private sealed partial class JsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();
/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
_defaultToolJsonOptions.GetTypeInfo(type);
/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};
options.MakeReadOnly();
return options;
}
return Default.Options;
}
}
} }

Просмотреть файл

@ -0,0 +1,17 @@
# Release History
## 9.0.0-preview.9.24525.1
- Added new `AIJsonUtilities` and `AIJsonSchemaCreateOptions` classes.
- Made `AIFunctionFactory.Create` safe for use with Native AOT.
- Simplified the set of `AIFunctionFactory.Create` overloads.
- Changed the default for `FunctionInvokingChatClient.ConcurrentInvocation` from `true` to `false`.
- Improved the readability of JSON generated as part of logging.
- Fixed handling of generated JSON schema names when using arrays or generic types.
- Improved `CachingChatClient`'s coalescing of streaming updates, including reduced memory allocation and enhanced metadata propagation.
- Updated `OpenTelemetryChatClient` and `OpenTelemetryEmbeddingGenerator` to conform to the latest 1.28.0 draft specification of the Semantic Conventions for Generative AI systems.
- Improved `CompleteAsync<T>`'s structured output support to handle primitive types, enums, and arrays.
## 9.0.0-preview.9.24507.7
Initial Preview

Просмотреть файл

@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI;
/// <para> /// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options /// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide /// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide
/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example /// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the /// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example: /// and mutating the clone, for example:
@ -31,6 +31,9 @@ namespace Microsoft.Extensions.AI;
/// </c> /// </c>
/// </para> /// </para>
/// <para> /// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next client in the pipeline.
/// </para>
/// <para>
/// The provided implementation of <see cref="IChatClient"/> is thread-safe for concurrent use so long as the employed configuration /// The provided implementation of <see cref="IChatClient"/> is thread-safe for concurrent use so long as the employed configuration
/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the /// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance. /// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
@ -39,7 +42,7 @@ namespace Microsoft.Extensions.AI;
public sealed class ConfigureOptionsChatClient : DelegatingChatClient public sealed class ConfigureOptionsChatClient : DelegatingChatClient
{ {
/// <summary>The callback delegate used to configure options.</summary> /// <summary>The callback delegate used to configure options.</summary>
private readonly Func<ChatOptions?, ChatOptions> _configureOptions; private readonly Func<ChatOptions?, ChatOptions?> _configureOptions;
/// <summary>Initializes a new instance of the <see cref="ConfigureOptionsChatClient"/> class with the specified <paramref name="configureOptions"/> callback.</summary> /// <summary>Initializes a new instance of the <see cref="ConfigureOptionsChatClient"/> class with the specified <paramref name="configureOptions"/> callback.</summary>
/// <param name="innerClient">The inner client.</param> /// <param name="innerClient">The inner client.</param>
@ -47,7 +50,7 @@ public sealed class ConfigureOptionsChatClient : DelegatingChatClient
/// The delegate to invoke to configure the <see cref="ChatOptions"/> instance. It is passed the caller-supplied <see cref="ChatOptions"/> /// The delegate to invoke to configure the <see cref="ChatOptions"/> instance. It is passed the caller-supplied <see cref="ChatOptions"/>
/// instance and should return the configured <see cref="ChatOptions"/> instance to use. /// instance and should return the configured <see cref="ChatOptions"/> instance to use.
/// </param> /// </param>
public ConfigureOptionsChatClient(IChatClient innerClient, Func<ChatOptions?, ChatOptions> configureOptions) public ConfigureOptionsChatClient(IChatClient innerClient, Func<ChatOptions?, ChatOptions?> configureOptions)
: base(innerClient) : base(innerClient)
{ {
_configureOptions = Throw.IfNull(configureOptions); _configureOptions = Throw.IfNull(configureOptions);

Просмотреть файл

@ -21,9 +21,10 @@ public static class ConfigureOptionsChatClientBuilderExtensions
/// </param> /// </param>
/// <returns>The <paramref name="builder"/>.</returns> /// <returns>The <paramref name="builder"/>.</returns>
/// <remarks> /// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options /// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide /// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide
/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example /// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the /// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example: /// and mutating the clone, for example:
@ -35,9 +36,13 @@ public static class ConfigureOptionsChatClientBuilderExtensions
/// return newOptions; /// return newOptions;
/// } /// }
/// </c> /// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next client in the pipeline.
/// </para>
/// </remarks> /// </remarks>
public static ChatClientBuilder UseChatOptions( public static ChatClientBuilder UseChatOptions(
this ChatClientBuilder builder, Func<ChatOptions?, ChatOptions> configureOptions) this ChatClientBuilder builder, Func<ChatOptions?, ChatOptions?> configureOptions)
{ {
_ = Throw.IfNull(builder); _ = Throw.IfNull(builder);
_ = Throw.IfNull(configureOptions); _ = Throw.IfNull(configureOptions);

Просмотреть файл

@ -322,7 +322,7 @@ public sealed partial class OpenTelemetryChatClient : DelegatingChatClient
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "response_format"), responseFormat); _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "response_format"), responseFormat);
} }
if (options.AdditionalProperties?.TryGetValue("seed", out long seed) is true) if (options.Seed is long seed)
{ {
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "seed"), seed); _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "seed"), seed);
} }

Просмотреть файл

@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
#pragma warning disable SA1629 // Documentation text should end with a period
namespace Microsoft.Extensions.AI;
/// <summary>A delegating embedding generator that updates or replaces the <see cref="EmbeddingGenerationOptions"/> used by the remainder of the pipeline.</summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
/// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="EmbeddingGenerationOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. To provide
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
/// <c>
/// options =>
/// {
/// var newOptions = options?.Clone() ?? new();
/// newOptions.Dimensions = 100;
/// return newOptions;
/// }
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next generator in the pipeline.
/// </para>
/// <para>
/// The provided implementation of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> is thread-safe for concurrent use so long as the employed configuration
/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
/// </para>
/// </remarks>
public sealed class ConfigureOptionsEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
where TEmbedding : Embedding
{
/// <summary>The callback delegate used to configure options.</summary>
private readonly Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> _configureOptions;
/// <summary>
/// Initializes a new instance of the <see cref="ConfigureOptionsEmbeddingGenerator{TInput, TEmbedding}"/> class with the
/// specified <paramref name="configureOptions"/> callback.
/// </summary>
/// <param name="innerGenerator">The inner generator.</param>
/// <param name="configureOptions">
/// The delegate to invoke to configure the <see cref="EmbeddingGenerationOptions"/> instance. It is passed the caller-supplied
/// <see cref="EmbeddingGenerationOptions"/> instance and should return the configured <see cref="EmbeddingGenerationOptions"/> instance to use.
/// </param>
public ConfigureOptionsEmbeddingGenerator(
IEmbeddingGenerator<TInput, TEmbedding> innerGenerator,
Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> configureOptions)
: base(innerGenerator)
{
_configureOptions = Throw.IfNull(configureOptions);
}
/// <inheritdoc/>
public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
IEnumerable<TInput> values,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
return await base.GenerateAsync(values, _configureOptions(options), cancellationToken).ConfigureAwait(false);
}
}

Просмотреть файл

@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using Microsoft.Shared.Diagnostics;
#pragma warning disable SA1629 // Documentation text should end with a period
namespace Microsoft.Extensions.AI;
/// <summary>Provides extensions for configuring <see cref="ConfigureOptionsEmbeddingGenerator{TInput, TEmbedding}"/> instances.</summary>
public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions
{
/// <summary>
/// Adds a callback that updates or replaces <see cref="EmbeddingGenerationOptions"/>. This can be used to set default options.
/// </summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
/// <param name="builder">The <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/>.</param>
/// <param name="configureOptions">
/// The delegate to invoke to configure the <see cref="EmbeddingGenerationOptions"/> instance. It is passed the caller-supplied
/// <see cref="EmbeddingGenerationOptions"/> instance and should return the configured <see cref="EmbeddingGenerationOptions"/> instance to use.
/// </param>
/// <returns>The <paramref name="builder"/>.</returns>
/// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="EmbeddingGenerationOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. To provide
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
/// <c>
/// options =>
/// {
/// var newOptions = options?.Clone() ?? new();
/// newOptions.Dimensions = 100;
/// return newOptions;
/// }
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next generator in the pipeline.
/// </para>
/// </remarks>
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> UseEmbeddingGenerationOptions<TInput, TEmbedding>(
this EmbeddingGeneratorBuilder<TInput, TEmbedding> builder,
Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> configureOptions)
where TEmbedding : Embedding
{
_ = Throw.IfNull(builder);
_ = Throw.IfNull(configureOptions);
return builder.Use(innerGenerator => new ConfigureOptionsEmbeddingGenerator<TInput, TEmbedding>(innerGenerator, configureOptions));
}
}

Просмотреть файл

@ -18,6 +18,7 @@
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks> <TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253</NoWarn> <NoWarn>$(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>

Просмотреть файл

@ -5,6 +5,7 @@ using System;
using System.Diagnostics; using System.Diagnostics;
using System.Threading; using System.Threading;
using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal; namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@ -22,7 +23,7 @@ internal partial class DefaultHybridCache
// zero. // zero.
// This counter also drives cache lifetime, with the cache itself incrementing the count by one. In the // This counter also drives cache lifetime, with the cache itself incrementing the count by one. In the
// case of mutable data, cache eviction may reduce this to zero (in cooperation with any concurrent readers, // case of mutable data, cache eviction may reduce this to zero (in cooperation with any concurrent readers,
// who incr/decr around their fetch), allowing safe buffer recycling. // who increment/decrement around their fetch), allowing safe buffer recycling.
internal int RefCount => Volatile.Read(ref _refCount); internal int RefCount => Volatile.Read(ref _refCount);
@ -89,13 +90,18 @@ internal partial class DefaultHybridCache
{ {
public abstract bool TryGetSize(out long size); public abstract bool TryGetSize(out long size);
// attempt to get a value that was *not* previously reserved // Attempt to get a value that was *not* previously reserved.
public abstract bool TryGetValue(out T value); // Note on ILogger usage: we don't want to propagate and store this everywhere.
// It is used for reporting deserialization problems - pass it as needed.
// (CacheItem gets into the IMemoryCache - let's minimize the onward reachable set
// of that cache, by only handing it leaf nodes of a "tree", not a "graph" with
// backwards access - we can also limit object size at the same time)
public abstract bool TryGetValue(ILogger log, out T value);
// get a value that *was* reserved, countermanding our reservation in the process // get a value that *was* reserved, countermanding our reservation in the process
public T GetReservedValue() public T GetReservedValue(ILogger log)
{ {
if (!TryGetValue(out var value)) if (!TryGetValue(log, out var value))
{ {
Throw(); Throw();
} }

Просмотреть файл

@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license. // The .NET Foundation licenses this file to you under the MIT license.
using System.Threading; using System.Threading;
using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal; namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@ -38,7 +39,7 @@ internal partial class DefaultHybridCache
Size = size; Size = size;
} }
public override bool TryGetValue(out T value) public override bool TryGetValue(ILogger log, out T value)
{ {
value = _value; value = _value;
return true; // always available return true; // always available

Просмотреть файл

@ -16,12 +16,16 @@ internal partial class DefaultHybridCache
{ {
[SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Manual sync check")] [SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Manual sync check")]
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Manual sync check")] [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Manual sync check")]
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Explicit async exception handling")]
[SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Deliberate recycle only on success")]
internal ValueTask<BufferChunk> GetFromL2Async(string key, CancellationToken token) internal ValueTask<BufferChunk> GetFromL2Async(string key, CancellationToken token)
{ {
switch (GetFeatures(CacheFeatures.BackendCache | CacheFeatures.BackendBuffers)) switch (GetFeatures(CacheFeatures.BackendCache | CacheFeatures.BackendBuffers))
{ {
case CacheFeatures.BackendCache: // legacy byte[]-based case CacheFeatures.BackendCache: // legacy byte[]-based
var pendingLegacy = _backendCache!.GetAsync(key, token); var pendingLegacy = _backendCache!.GetAsync(key, token);
#if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER #if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
if (!pendingLegacy.IsCompletedSuccessfully) if (!pendingLegacy.IsCompletedSuccessfully)
#else #else
@ -36,6 +40,7 @@ internal partial class DefaultHybridCache
case CacheFeatures.BackendCache | CacheFeatures.BackendBuffers: // IBufferWriter<byte>-based case CacheFeatures.BackendCache | CacheFeatures.BackendBuffers: // IBufferWriter<byte>-based
RecyclableArrayBufferWriter<byte> writer = RecyclableArrayBufferWriter<byte>.Create(MaximumPayloadBytes); RecyclableArrayBufferWriter<byte> writer = RecyclableArrayBufferWriter<byte>.Create(MaximumPayloadBytes);
var cache = Unsafe.As<IBufferDistributedCache>(_backendCache!); // type-checked already var cache = Unsafe.As<IBufferDistributedCache>(_backendCache!); // type-checked already
var pendingBuffers = cache.TryGetAsync(key, writer, token); var pendingBuffers = cache.TryGetAsync(key, writer, token);
if (!pendingBuffers.IsCompletedSuccessfully) if (!pendingBuffers.IsCompletedSuccessfully)
{ {
@ -49,7 +54,7 @@ internal partial class DefaultHybridCache
return new(result); return new(result);
} }
return default; return default; // treat as a "miss"
static async Task<BufferChunk> AwaitedLegacyAsync(Task<byte[]?> pending, DefaultHybridCache @this) static async Task<BufferChunk> AwaitedLegacyAsync(Task<byte[]?> pending, DefaultHybridCache @this)
{ {
@ -115,6 +120,11 @@ internal partial class DefaultHybridCache
// commit // commit
cacheEntry.Dispose(); cacheEntry.Dispose();
if (HybridCacheEventSource.Log.IsEnabled())
{
HybridCacheEventSource.Log.LocalCacheWrite();
}
} }
} }

Просмотреть файл

@ -1,14 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements. // Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license. // The .NET Foundation licenses this file to you under the MIT license.
using System;
using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal; namespace Microsoft.Extensions.Caching.Hybrid.Internal;
internal partial class DefaultHybridCache internal partial class DefaultHybridCache
{ {
private sealed partial class MutableCacheItem<T> : CacheItem<T> // used to hold types that require defensive copies private sealed partial class MutableCacheItem<T> : CacheItem<T> // used to hold types that require defensive copies
{ {
private IHybridCacheSerializer<T> _serializer = null!; // deferred until SetValue private IHybridCacheSerializer<T>? _serializer;
private BufferChunk _buffer; private BufferChunk _buffer;
private T? _fallbackValue; // only used in the case of serialization failures
public override bool NeedsEvictionCallback => _buffer.ReturnToPool; public override bool NeedsEvictionCallback => _buffer.ReturnToPool;
@ -21,16 +25,27 @@ internal partial class DefaultHybridCache
buffer = default; // we're taking over the lifetime; the caller no longer has it! buffer = default; // we're taking over the lifetime; the caller no longer has it!
} }
public override bool TryGetValue(out T value) public void SetFallbackValue(T fallbackValue)
{
_fallbackValue = fallbackValue;
}
public override bool TryGetValue(ILogger log, out T value)
{ {
// only if we haven't already burned // only if we haven't already burned
if (TryReserve()) if (TryReserve())
{ {
try try
{ {
value = _serializer.Deserialize(_buffer.AsSequence()); var serializer = _serializer;
value = serializer is null ? _fallbackValue! : serializer.Deserialize(_buffer.AsSequence());
return true; return true;
} }
catch (Exception ex)
{
log.DeserializationFailure(ex);
throw;
}
finally finally
{ {
_ = Release(); _ = Release();

Просмотреть файл

@ -3,7 +3,7 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Reflection; using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
@ -51,4 +51,54 @@ internal partial class DefaultHybridCache
return serializer; return serializer;
} }
} }
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Intentional for logged failure mode")]
private bool TrySerialize<T>(T value, out BufferChunk buffer, out IHybridCacheSerializer<T>? serializer)
{
// note: also returns the serializer we resolved, because most-any time we want to serialize, we'll also want
// to make sure we use that same instance later (without needing to re-resolve and/or store the entire HC machinery)
RecyclableArrayBufferWriter<byte>? writer = null;
buffer = default;
try
{
writer = RecyclableArrayBufferWriter<byte>.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async
serializer = GetSerializer<T>();
serializer.Serialize(value, writer);
buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer
writer.Dispose(); // we're done with the writer
return true;
}
catch (Exception ex)
{
bool knownCause = false;
// ^^^ if we know what happened, we can record directly via cause-specific events
// and treat as a handled failure (i.e. return false) - otherwise, we'll bubble
// the fault up a few layers *in addition to* logging in a failure event
if (writer is not null)
{
if (writer.QuotaExceeded)
{
_logger.MaximumPayloadBytesExceeded(ex, MaximumPayloadBytes);
knownCause = true;
}
writer.Dispose();
}
if (!knownCause)
{
_logger.SerializationFailure(ex);
throw;
}
buffer = default;
serializer = null;
return false;
}
}
} }

Просмотреть файл

@ -74,8 +74,6 @@ internal partial class DefaultHybridCache
public abstract void Execute(); public abstract void Execute();
protected int MaximumPayloadBytes => _cache.MaximumPayloadBytes;
public override string ToString() => Key.ToString(); public override string ToString() => Key.ToString();
public abstract void SetCanceled(); public abstract void SetCanceled();

Просмотреть файл

@ -6,6 +6,7 @@ using System.Diagnostics;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using static Microsoft.Extensions.Caching.Hybrid.Internal.DefaultHybridCache; using static Microsoft.Extensions.Caching.Hybrid.Internal.DefaultHybridCache;
namespace Microsoft.Extensions.Caching.Hybrid.Internal; namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@ -14,7 +15,8 @@ internal partial class DefaultHybridCache
{ {
internal sealed class StampedeState<TState, T> : StampedeState internal sealed class StampedeState<TState, T> : StampedeState
{ {
private const HybridCacheEntryFlags FlagsDisableL1AndL2 = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite; // note on terminology: L1 and L2 are, for brevity, used interchangeably with "local" and "distributed" cache, i.e. `IMemoryCache` and `IDistributedCache`
private const HybridCacheEntryFlags FlagsDisableL1AndL2Write = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite;
private readonly TaskCompletionSource<CacheItem<T>>? _result; private readonly TaskCompletionSource<CacheItem<T>>? _result;
private TState? _state; private TState? _state;
@ -76,13 +78,13 @@ internal partial class DefaultHybridCache
public override void SetCanceled() => _result?.TrySetCanceled(SharedToken); public override void SetCanceled() => _result?.TrySetCanceled(SharedToken);
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Custom task management")] [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Custom task management")]
public ValueTask<T> JoinAsync(CancellationToken token) public ValueTask<T> JoinAsync(ILogger log, CancellationToken token)
{ {
// If the underlying has already completed, and/or our local token can't cancel: we // If the underlying has already completed, and/or our local token can't cancel: we
// can simply wrap the shared task; otherwise, we need our own cancellation state. // can simply wrap the shared task; otherwise, we need our own cancellation state.
return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(this, token) : UnwrapReservedAsync(); return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(log, this, token) : UnwrapReservedAsync(log);
static async ValueTask<T> WithCancellationAsync(StampedeState<TState, T> stampede, CancellationToken token) static async ValueTask<T> WithCancellationAsync(ILogger log, StampedeState<TState, T> stampede, CancellationToken token)
{ {
var cancelStub = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously); var cancelStub = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
using var reg = token.Register(static obj => using var reg = token.Register(static obj =>
@ -112,7 +114,7 @@ internal partial class DefaultHybridCache
} }
// outside the catch, so we know we only decrement one way or the other // outside the catch, so we know we only decrement one way or the other
return result.GetReservedValue(); return result.GetReservedValue(log);
} }
} }
@ -133,7 +135,7 @@ internal partial class DefaultHybridCache
[SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Checked manual unwrap")] [SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Checked manual unwrap")]
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Checked manual unwrap")] [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Checked manual unwrap")]
[SuppressMessage("Major Code Smell", "S1121:Assignments should not be made from within sub-expressions", Justification = "Unusual, but legit here")] [SuppressMessage("Major Code Smell", "S1121:Assignments should not be made from within sub-expressions", Justification = "Unusual, but legit here")]
internal ValueTask<T> UnwrapReservedAsync() internal ValueTask<T> UnwrapReservedAsync(ILogger log)
{ {
var task = Task; var task = Task;
#if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER #if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
@ -142,16 +144,16 @@ internal partial class DefaultHybridCache
if (task.Status == TaskStatus.RanToCompletion) if (task.Status == TaskStatus.RanToCompletion)
#endif #endif
{ {
return new(task.Result.GetReservedValue()); return new(task.Result.GetReservedValue(log));
} }
// if the type is immutable, callers can share the final step too (this may leave dangling // if the type is immutable, callers can share the final step too (this may leave dangling
// reservation counters, but that's OK) // reservation counters, but that's OK)
var result = ImmutableTypeCache<T>.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(Task)) : AwaitedAsync(Task); var result = ImmutableTypeCache<T>.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(log, Task)) : AwaitedAsync(log, Task);
return new(result); return new(result);
static async Task<T> AwaitedAsync(Task<CacheItem<T>> task) static async Task<T> AwaitedAsync(ILogger log, Task<CacheItem<T>> task)
=> (await task.ConfigureAwait(false)).GetReservedValue(); => (await task.ConfigureAwait(false)).GetReservedValue(log);
} }
[DoesNotReturn] [DoesNotReturn]
@ -161,12 +163,43 @@ internal partial class DefaultHybridCache
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Exception is passed through to faulted task result")] [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Exception is passed through to faulted task result")]
private async Task BackgroundFetchAsync() private async Task BackgroundFetchAsync()
{ {
bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled();
try try
{ {
// read from L2 if appropriate // read from L2 if appropriate
if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheRead) == 0) if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheRead) == 0)
{ {
var result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false); BufferChunk result;
try
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.DistributedCacheGet();
}
result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false);
if (eventSourceEnabled)
{
if (result.Array is not null)
{
HybridCacheEventSource.Log.DistributedCacheHit();
}
else
{
HybridCacheEventSource.Log.DistributedCacheMiss();
}
}
}
catch (Exception ex)
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.DistributedCacheFailed();
}
Cache._logger.CacheUnderlyingDataQueryFailure(ex);
result = default; // treat as "miss"
}
if (result.Array is not null) if (result.Array is not null)
{ {
@ -179,7 +212,30 @@ internal partial class DefaultHybridCache
if ((Key.Flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0) if ((Key.Flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0)
{ {
// invoke the callback supplied by the caller // invoke the callback supplied by the caller
T newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false); T newValue;
try
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.UnderlyingDataQueryStart();
}
newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false);
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.UnderlyingDataQueryComplete();
}
}
catch
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.UnderlyingDataQueryFailed();
}
throw;
}
// If we're writing this value *anywhere*, we're going to need to serialize; this is obvious // If we're writing this value *anywhere*, we're going to need to serialize; this is obvious
// in the case of L2, but we also need it for L1, because MemoryCache might be enforcing // in the case of L2, but we also need it for L1, because MemoryCache might be enforcing
@ -187,11 +243,11 @@ internal partial class DefaultHybridCache
// Likewise, if we're writing to a MutableCacheItem, we'll be serializing *anyway* for the payload. // Likewise, if we're writing to a MutableCacheItem, we'll be serializing *anyway* for the payload.
// //
// Rephrasing that: the only scenario in which we *do not* need to serialize is if: // Rephrasing that: the only scenario in which we *do not* need to serialize is if:
// - it is an ImmutableCacheItem // - it is an ImmutableCacheItem (so we don't need bytes for the CacheItem, L1)
// - we're writing neither to L1 nor L2 // - we're not writing to L2
CacheItem cacheItem = CacheItem; CacheItem cacheItem = CacheItem;
bool skipSerialize = cacheItem is ImmutableCacheItem<T> && (Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2; bool skipSerialize = cacheItem is ImmutableCacheItem<T> && (Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write;
if (skipSerialize) if (skipSerialize)
{ {
@ -202,33 +258,55 @@ internal partial class DefaultHybridCache
// ^^^ The first thing we need to do is make sure we're not getting into a thread race over buffer disposal. // ^^^ The first thing we need to do is make sure we're not getting into a thread race over buffer disposal.
// In particular, if this cache item is somehow so short-lived that the buffers would be released *before* we're // In particular, if this cache item is somehow so short-lived that the buffers would be released *before* we're
// done writing them to L2, which happens *after* we've provided the value to consumers. // done writing them to L2, which happens *after* we've provided the value to consumers.
RecyclableArrayBufferWriter<byte> writer = RecyclableArrayBufferWriter<byte>.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async
IHybridCacheSerializer<T> serializer = Cache.GetSerializer<T>();
serializer.Serialize(newValue, writer);
BufferChunk buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer
writer.Dispose(); // we're done with the writer
// protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized BufferChunk bufferToRelease = default;
// *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and if (Cache.TrySerialize(newValue, out var buffer, out var serializer))
// the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event,
// (with TryReserve above guaranteeing that we aren't in a race condition).
BufferChunk bufferToRelease = buffer;
// and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit
// that we do not need or want "buffer" to do any recycling (they're the same memory)
buffer = buffer.DoNotReturnToPool();
// set the underlying result for this operation (includes L1 write if appropriate)
SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
// Note that at this point we've already released most or all of the waiting callers. Everything
// from this point onwards happens in the background, from the perspective of the calling code.
// Write to L2 if appropriate.
if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0)
{ {
// We already have the payload serialized, so this is trivial to do. // note we also capture the resolved serializer ^^^ - we'll need it again later
await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false);
// protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized
// *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and
// the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event,
// (with TryReserve above guaranteeing that we aren't in a race condition).
bufferToRelease = buffer;
// and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit
// that we do not need or want "buffer" to do any recycling (they're the same memory)
buffer = buffer.DoNotReturnToPool();
// set the underlying result for this operation (includes L1 write if appropriate)
SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
// Note that at this point we've already released most or all of the waiting callers. Everything
// from this point onwards happens in the background, from the perspective of the calling code.
// Write to L2 if appropriate.
if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0)
{
// We already have the payload serialized, so this is trivial to do.
try
{
await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false);
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.DistributedCacheWrite();
}
}
catch (Exception ex)
{
// log the L2 write failure, but that doesn't need to interrupt the app flow (so:
// don't rethrow); L1 will still reduce impact, and L1 without L2 is better than
// hard failure every time
Cache._logger.CacheBackendWriteFailure(ex);
}
}
}
else
{
// unable to serialize (or quota exceeded); try to at least store the onwards value; this is
// especially useful for immutable data types
SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
} }
// Release our hook on the CacheItem (only really important for "mutable"). // Release our hook on the CacheItem (only really important for "mutable").
@ -309,7 +387,7 @@ internal partial class DefaultHybridCache
private void SetImmutableResultWithoutSerialize(T value) private void SetImmutableResultWithoutSerialize(T value)
{ {
Debug.Assert((Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2, "Only expected if L1+L2 disabled"); Debug.Assert((Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write, "Only expected if L1+L2 disabled");
// set a result from a value we calculated directly // set a result from a value we calculated directly
CacheItem<T> cacheItem; CacheItem<T> cacheItem;
@ -328,7 +406,7 @@ internal partial class DefaultHybridCache
SetResult(cacheItem); SetResult(cacheItem);
} }
private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer<T> serializer) private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer<T>? serializer)
{ {
// set a result from a value we calculated directly that // set a result from a value we calculated directly that
// has ALREADY BEEN SERIALIZED (we can optionally consume this buffer) // has ALREADY BEEN SERIALIZED (we can optionally consume this buffer)
@ -343,8 +421,17 @@ internal partial class DefaultHybridCache
// (but leave the buffer alone) // (but leave the buffer alone)
break; break;
case MutableCacheItem<T> mutable: case MutableCacheItem<T> mutable:
mutable.SetValue(ref buffer, serializer); if (serializer is null)
mutable.DebugOnlyTrackBuffer(Cache); {
// serialization is failing; set fallback value
mutable.SetFallbackValue(value);
}
else
{
mutable.SetValue(ref buffer, serializer);
mutable.DebugOnlyTrackBuffer(Cache);
}
cacheItem = mutable; cacheItem = mutable;
break; break;
default: default:

Просмотреть файл

@ -22,6 +22,9 @@ namespace Microsoft.Extensions.Caching.Hybrid.Internal;
/// </summary> /// </summary>
internal sealed partial class DefaultHybridCache : HybridCache internal sealed partial class DefaultHybridCache : HybridCache
{ {
// reserve non-printable characters from keys, to prevent potential L2 abuse
private static readonly char[] _keyReservedCharacters = Enumerable.Range(0, 32).Select(i => (char)i).ToArray();
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
private readonly IDistributedCache? _backendCache; private readonly IDistributedCache? _backendCache;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
@ -37,6 +40,7 @@ internal sealed partial class DefaultHybridCache : HybridCache
private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags
private readonly TimeSpan _defaultExpiration; private readonly TimeSpan _defaultExpiration;
private readonly TimeSpan _defaultLocalCacheExpiration; private readonly TimeSpan _defaultLocalCacheExpiration;
private readonly int _maximumKeyLength;
private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration; private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration;
@ -90,6 +94,7 @@ internal sealed partial class DefaultHybridCache : HybridCache
_serializerFactories = factories; _serializerFactories = factories;
MaximumPayloadBytes = checked((int)_options.MaximumPayloadBytes); // for now hard-limit to 2GiB MaximumPayloadBytes = checked((int)_options.MaximumPayloadBytes); // for now hard-limit to 2GiB
_maximumKeyLength = _options.MaximumKeyLength;
var defaultEntryOptions = _options.DefaultEntryOptions; var defaultEntryOptions = _options.DefaultEntryOptions;
@ -119,11 +124,33 @@ internal sealed partial class DefaultHybridCache : HybridCache
} }
var flags = GetEffectiveFlags(options); var flags = GetEffectiveFlags(options);
if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0 && _localCache.TryGetValue(key, out var untyped) if (!ValidateKey(key))
&& untyped is CacheItem<T> typed && typed.TryGetValue(out var value))
{ {
// short-circuit // we can't use cache, but we can still provide the data
return new(value); return RunWithoutCacheAsync(flags, state, underlyingDataCallback, cancellationToken);
}
bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled();
if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0)
{
if (_localCache.TryGetValue(key, out var untyped)
&& untyped is CacheItem<T> typed && typed.TryGetValue(_logger, out var value))
{
// short-circuit
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.LocalCacheHit();
}
return new(value);
}
else
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.LocalCacheMiss();
}
}
} }
if (GetOrCreateStampedeState<TState, T>(key, flags, out var stampede, canBeCanceled)) if (GetOrCreateStampedeState<TState, T>(key, flags, out var stampede, canBeCanceled))
@ -139,11 +166,19 @@ internal sealed partial class DefaultHybridCache : HybridCache
{ {
// we're going to run to completion; no need to get complicated // we're going to run to completion; no need to get complicated
_ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc _ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc
return stampede.UnwrapReservedAsync(); return stampede.UnwrapReservedAsync(_logger);
}
}
else
{
// pre-existing query
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.StampedeJoin();
} }
} }
return stampede.JoinAsync(cancellationToken); return stampede.JoinAsync(_logger, cancellationToken);
} }
public override ValueTask RemoveAsync(string key, CancellationToken token = default) public override ValueTask RemoveAsync(string key, CancellationToken token = default)
@ -164,7 +199,39 @@ internal sealed partial class DefaultHybridCache : HybridCache
return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc
} }
private static ValueTask<T> RunWithoutCacheAsync<TState, T>(HybridCacheEntryFlags flags, TState state,
Func<TState, CancellationToken, ValueTask<T>> underlyingDataCallback,
CancellationToken cancellationToken)
{
return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0
? underlyingDataCallback(state, cancellationToken) : default;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options) private HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options)
=> (options?.Flags | _hardFlags) ?? _defaultFlags; => (options?.Flags | _hardFlags) ?? _defaultFlags;
private bool ValidateKey(string key)
{
if (string.IsNullOrWhiteSpace(key))
{
_logger.KeyEmptyOrWhitespace();
return false;
}
if (key.Length > _maximumKeyLength)
{
_logger.MaximumKeyLengthExceeded(_maximumKeyLength, key.Length);
return false;
}
if (key.IndexOfAny(_keyReservedCharacters) >= 0)
{
_logger.KeyInvalidContent();
return false;
}
// nothing to complain about
return true;
}
} }

Просмотреть файл

@ -0,0 +1,203 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.Runtime.CompilerServices;
using System.Threading;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
[EventSource(Name = "Microsoft-Extensions-HybridCache")]
internal sealed class HybridCacheEventSource : EventSource
{
public static readonly HybridCacheEventSource Log = new();
internal const int EventIdLocalCacheHit = 1;
internal const int EventIdLocalCacheMiss = 2;
internal const int EventIdDistributedCacheGet = 3;
internal const int EventIdDistributedCacheHit = 4;
internal const int EventIdDistributedCacheMiss = 5;
internal const int EventIdDistributedCacheFailed = 6;
internal const int EventIdUnderlyingDataQueryStart = 7;
internal const int EventIdUnderlyingDataQueryComplete = 8;
internal const int EventIdUnderlyingDataQueryFailed = 9;
internal const int EventIdLocalCacheWrite = 10;
internal const int EventIdDistributedCacheWrite = 11;
internal const int EventIdStampedeJoin = 12;
// fast local counters
private long _totalLocalCacheHit;
private long _totalLocalCacheMiss;
private long _totalDistributedCacheHit;
private long _totalDistributedCacheMiss;
private long _totalUnderlyingDataQuery;
private long _currentUnderlyingDataQuery;
private long _currentDistributedFetch;
private long _totalLocalCacheWrite;
private long _totalDistributedCacheWrite;
private long _totalStampedeJoin;
#if !(NETSTANDARD2_0 || NET462)
// full Counter infrastructure
private DiagnosticCounter[]? _counters;
#endif
[NonEvent]
public void ResetCounters()
{
Debug.WriteLine($"{nameof(HybridCacheEventSource)} counters reset!");
Volatile.Write(ref _totalLocalCacheHit, 0);
Volatile.Write(ref _totalLocalCacheMiss, 0);
Volatile.Write(ref _totalDistributedCacheHit, 0);
Volatile.Write(ref _totalDistributedCacheMiss, 0);
Volatile.Write(ref _totalUnderlyingDataQuery, 0);
Volatile.Write(ref _currentUnderlyingDataQuery, 0);
Volatile.Write(ref _currentDistributedFetch, 0);
Volatile.Write(ref _totalLocalCacheWrite, 0);
Volatile.Write(ref _totalDistributedCacheWrite, 0);
Volatile.Write(ref _totalStampedeJoin, 0);
}
[Event(EventIdLocalCacheHit, Level = EventLevel.Verbose)]
public void LocalCacheHit()
{
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalLocalCacheHit);
WriteEvent(EventIdLocalCacheHit);
}
[Event(EventIdLocalCacheMiss, Level = EventLevel.Verbose)]
public void LocalCacheMiss()
{
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalLocalCacheMiss);
WriteEvent(EventIdLocalCacheMiss);
}
[Event(EventIdDistributedCacheGet, Level = EventLevel.Verbose)]
public void DistributedCacheGet()
{
// should be followed by DistributedCacheHit, DistributedCacheMiss or DistributedCacheFailed
DebugAssertEnabled();
_ = Interlocked.Increment(ref _currentDistributedFetch);
WriteEvent(EventIdDistributedCacheGet);
}
[Event(EventIdDistributedCacheHit, Level = EventLevel.Verbose)]
public void DistributedCacheHit()
{
DebugAssertEnabled();
// note: not concerned about off-by-one here, i.e. don't panic
// about these two being atomic ref each-other - just the overall shape
_ = Interlocked.Increment(ref _totalDistributedCacheHit);
_ = Interlocked.Decrement(ref _currentDistributedFetch);
WriteEvent(EventIdDistributedCacheHit);
}
[Event(EventIdDistributedCacheMiss, Level = EventLevel.Verbose)]
public void DistributedCacheMiss()
{
DebugAssertEnabled();
// note: not concerned about off-by-one here, i.e. don't panic
// about these two being atomic ref each-other - just the overall shape
_ = Interlocked.Increment(ref _totalDistributedCacheMiss);
_ = Interlocked.Decrement(ref _currentDistributedFetch);
WriteEvent(EventIdDistributedCacheMiss);
}
[Event(EventIdDistributedCacheFailed, Level = EventLevel.Error)]
public void DistributedCacheFailed()
{
DebugAssertEnabled();
_ = Interlocked.Decrement(ref _currentDistributedFetch);
WriteEvent(EventIdDistributedCacheFailed);
}
[Event(EventIdUnderlyingDataQueryStart, Level = EventLevel.Verbose)]
public void UnderlyingDataQueryStart()
{
// should be followed by UnderlyingDataQueryComplete or UnderlyingDataQueryFailed
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalUnderlyingDataQuery);
_ = Interlocked.Increment(ref _currentUnderlyingDataQuery);
WriteEvent(EventIdUnderlyingDataQueryStart);
}
[Event(EventIdUnderlyingDataQueryComplete, Level = EventLevel.Verbose)]
public void UnderlyingDataQueryComplete()
{
DebugAssertEnabled();
_ = Interlocked.Decrement(ref _currentUnderlyingDataQuery);
WriteEvent(EventIdUnderlyingDataQueryComplete);
}
[Event(EventIdUnderlyingDataQueryFailed, Level = EventLevel.Error)]
public void UnderlyingDataQueryFailed()
{
DebugAssertEnabled();
_ = Interlocked.Decrement(ref _currentUnderlyingDataQuery);
WriteEvent(EventIdUnderlyingDataQueryFailed);
}
[Event(EventIdLocalCacheWrite, Level = EventLevel.Verbose)]
public void LocalCacheWrite()
{
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalLocalCacheWrite);
WriteEvent(EventIdLocalCacheWrite);
}
[Event(EventIdDistributedCacheWrite, Level = EventLevel.Verbose)]
public void DistributedCacheWrite()
{
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalDistributedCacheWrite);
WriteEvent(EventIdDistributedCacheWrite);
}
[Event(EventIdStampedeJoin, Level = EventLevel.Verbose)]
internal void StampedeJoin()
{
DebugAssertEnabled();
_ = Interlocked.Increment(ref _totalStampedeJoin);
WriteEvent(EventIdStampedeJoin);
}
#if !(NETSTANDARD2_0 || NET462)
[System.Diagnostics.CodeAnalysis.SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Lifetime exceeds obvious scope; handed to event source")]
[NonEvent]
protected override void OnEventCommand(EventCommandEventArgs command)
{
if (command.Command == EventCommand.Enable)
{
// lazily create counters on first Enable
_counters ??= [
new PollingCounter("total-local-cache-hits", this, () => Volatile.Read(ref _totalLocalCacheHit)) { DisplayName = "Total Local Cache Hits" },
new PollingCounter("total-local-cache-misses", this, () => Volatile.Read(ref _totalLocalCacheMiss)) { DisplayName = "Total Local Cache Misses" },
new PollingCounter("total-distributed-cache-hits", this, () => Volatile.Read(ref _totalDistributedCacheHit)) { DisplayName = "Total Distributed Cache Hits" },
new PollingCounter("total-distributed-cache-misses", this, () => Volatile.Read(ref _totalDistributedCacheMiss)) { DisplayName = "Total Distributed Cache Misses" },
new PollingCounter("total-data-query", this, () => Volatile.Read(ref _totalUnderlyingDataQuery)) { DisplayName = "Total Data Queries" },
new PollingCounter("current-data-query", this, () => Volatile.Read(ref _currentUnderlyingDataQuery)) { DisplayName = "Current Data Queries" },
new PollingCounter("current-distributed-cache-fetches", this, () => Volatile.Read(ref _currentDistributedFetch)) { DisplayName = "Current Distributed Cache Fetches" },
new PollingCounter("total-local-cache-writes", this, () => Volatile.Read(ref _totalLocalCacheWrite)) { DisplayName = "Total Local Cache Writes" },
new PollingCounter("total-distributed-cache-writes", this, () => Volatile.Read(ref _totalDistributedCacheWrite)) { DisplayName = "Total Distributed Cache Writes" },
new PollingCounter("total-stampede-joins", this, () => Volatile.Read(ref _totalStampedeJoin)) { DisplayName = "Total Stampede Joins" },
];
}
base.OnEventCommand(command);
}
#endif
[NonEvent]
[Conditional("DEBUG")]
private void DebugAssertEnabled([CallerMemberName] string caller = "")
{
Debug.Assert(IsEnabled(), $"Missing check to {nameof(HybridCacheEventSource)}.{nameof(Log)}.{nameof(IsEnabled)} from {caller}");
Debug.WriteLine($"{nameof(HybridCacheEventSource)}: {caller}"); // also log all event calls, for visibility
}
}

Просмотреть файл

@ -17,6 +17,18 @@ internal sealed class InbuiltTypeSerializer : IHybridCacheSerializer<string>, IH
public static InbuiltTypeSerializer Instance { get; } = new(); public static InbuiltTypeSerializer Instance { get; } = new();
string IHybridCacheSerializer<string>.Deserialize(ReadOnlySequence<byte> source) string IHybridCacheSerializer<string>.Deserialize(ReadOnlySequence<byte> source)
=> DeserializeString(source);
void IHybridCacheSerializer<string>.Serialize(string value, IBufferWriter<byte> target)
=> SerializeString(value, target);
byte[] IHybridCacheSerializer<byte[]>.Deserialize(ReadOnlySequence<byte> source)
=> source.ToArray();
void IHybridCacheSerializer<byte[]>.Serialize(byte[] value, IBufferWriter<byte> target)
=> target.Write(value);
internal static string DeserializeString(ReadOnlySequence<byte> source)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
return Encoding.UTF8.GetString(source); return Encoding.UTF8.GetString(source);
@ -36,7 +48,7 @@ internal sealed class InbuiltTypeSerializer : IHybridCacheSerializer<string>, IH
#endif #endif
} }
void IHybridCacheSerializer<string>.Serialize(string value, IBufferWriter<byte> target) internal static void SerializeString(string value, IBufferWriter<byte> target)
{ {
#if NET5_0_OR_GREATER #if NET5_0_OR_GREATER
Encoding.UTF8.GetBytes(value, target); Encoding.UTF8.GetBytes(value, target);
@ -49,10 +61,4 @@ internal sealed class InbuiltTypeSerializer : IHybridCacheSerializer<string>, IH
ArrayPool<byte>.Shared.Return(oversized); ArrayPool<byte>.Shared.Return(oversized);
#endif #endif
} }
byte[] IHybridCacheSerializer<byte[]>.Deserialize(ReadOnlySequence<byte> source)
=> source.ToArray();
void IHybridCacheSerializer<byte[]>.Serialize(byte[] value, IBufferWriter<byte> target)
=> target.Write(value);
} }

Просмотреть файл

@ -0,0 +1,49 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
internal static partial class Log
{
internal const int IdMaximumPayloadBytesExceeded = 1;
internal const int IdSerializationFailure = 2;
internal const int IdDeserializationFailure = 3;
internal const int IdKeyEmptyOrWhitespace = 4;
internal const int IdMaximumKeyLengthExceeded = 5;
internal const int IdCacheBackendReadFailure = 6;
internal const int IdCacheBackendWriteFailure = 7;
internal const int IdKeyInvalidContent = 8;
[LoggerMessage(LogLevel.Error, "Cache MaximumPayloadBytes ({Bytes}) exceeded.", EventName = "MaximumPayloadBytesExceeded", EventId = IdMaximumPayloadBytesExceeded, SkipEnabledCheck = false)]
internal static partial void MaximumPayloadBytesExceeded(this ILogger logger, Exception e, int bytes);
// note that serialization is critical enough that we perform hard failures in addition to logging; serialization
// failures are unlikely to be transient (i.e. connectivity); we would rather this shows up in QA, rather than
// being invisible and people *thinking* they're using cache, when actually they are not
[LoggerMessage(LogLevel.Error, "Cache serialization failure.", EventName = "SerializationFailure", EventId = IdSerializationFailure, SkipEnabledCheck = false)]
internal static partial void SerializationFailure(this ILogger logger, Exception e);
// (see same notes per SerializationFailure)
[LoggerMessage(LogLevel.Error, "Cache deserialization failure.", EventName = "DeserializationFailure", EventId = IdDeserializationFailure, SkipEnabledCheck = false)]
internal static partial void DeserializationFailure(this ILogger logger, Exception e);
[LoggerMessage(LogLevel.Error, "Cache key empty or whitespace.", EventName = "KeyEmptyOrWhitespace", EventId = IdKeyEmptyOrWhitespace, SkipEnabledCheck = false)]
internal static partial void KeyEmptyOrWhitespace(this ILogger logger);
[LoggerMessage(LogLevel.Error, "Cache key maximum length exceeded (maximum: {MaxLength}, actual: {KeyLength}).", EventName = "MaximumKeyLengthExceeded",
EventId = IdMaximumKeyLengthExceeded, SkipEnabledCheck = false)]
internal static partial void MaximumKeyLengthExceeded(this ILogger logger, int maxLength, int keyLength);
[LoggerMessage(LogLevel.Error, "Cache backend read failure.", EventName = "CacheBackendReadFailure", EventId = IdCacheBackendReadFailure, SkipEnabledCheck = false)]
internal static partial void CacheUnderlyingDataQueryFailure(this ILogger logger, Exception ex);
[LoggerMessage(LogLevel.Error, "Cache backend write failure.", EventName = "CacheBackendWriteFailure", EventId = IdCacheBackendWriteFailure, SkipEnabledCheck = false)]
internal static partial void CacheBackendWriteFailure(this ILogger logger, Exception ex);
[LoggerMessage(LogLevel.Error, "Cache key contains invalid content.", EventName = "KeyInvalidContent", EventId = IdKeyInvalidContent, SkipEnabledCheck = false)]
internal static partial void KeyInvalidContent(this ILogger logger); // for PII etc reasons, we won't include the actual key
}

Просмотреть файл

@ -46,20 +46,20 @@ internal sealed class RecyclableArrayBufferWriter<T> : IBufferWriter<T>, IDispos
public int CommittedBytes => _index; public int CommittedBytes => _index;
public int FreeCapacity => _buffer.Length - _index; public int FreeCapacity => _buffer.Length - _index;
public bool QuotaExceeded { get; private set; }
private static RecyclableArrayBufferWriter<T>? _spare; private static RecyclableArrayBufferWriter<T>? _spare;
public static RecyclableArrayBufferWriter<T> Create(int maxLength) public static RecyclableArrayBufferWriter<T> Create(int maxLength)
{ {
var obj = Interlocked.Exchange(ref _spare, null) ?? new(); var obj = Interlocked.Exchange(ref _spare, null) ?? new();
Debug.Assert(obj._index == 0, "index should be zero initially"); obj.Initialize(maxLength);
obj._maxLength = maxLength;
return obj; return obj;
} }
private RecyclableArrayBufferWriter() private RecyclableArrayBufferWriter()
{ {
_buffer = []; _buffer = [];
_index = 0;
_maxLength = int.MaxValue;
} }
public void Dispose() public void Dispose()
@ -91,6 +91,7 @@ internal sealed class RecyclableArrayBufferWriter<T> : IBufferWriter<T>, IDispos
if (_index + count > _maxLength) if (_index + count > _maxLength)
{ {
QuotaExceeded = true;
ThrowQuota(); ThrowQuota();
} }
@ -199,4 +200,12 @@ internal sealed class RecyclableArrayBufferWriter<T> : IBufferWriter<T>, IDispos
static void ThrowOutOfMemoryException() => throw new InvalidOperationException("Unable to grow buffer as requested"); static void ThrowOutOfMemoryException() => throw new InvalidOperationException("Unable to grow buffer as requested");
} }
private void Initialize(int maxLength)
{
// think .ctor, but with pooled object re-use
_index = 0;
_maxLength = maxLength;
QuotaExceeded = false;
}
} }

Просмотреть файл

@ -4,7 +4,7 @@
<Description>Multi-level caching implementation building on and extending IDistributedCache</Description> <Description>Multi-level caching implementation building on and extending IDistributedCache</Description>
<TargetFrameworks>$(NetCoreTargetFrameworks)$(ConditionalNet462);netstandard2.0;netstandard2.1</TargetFrameworks> <TargetFrameworks>$(NetCoreTargetFrameworks)$(ConditionalNet462);netstandard2.0;netstandard2.1</TargetFrameworks>
<GenerateDocumentationFile>true</GenerateDocumentationFile> <GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>cache;distributedcache;hybrid</PackageTags> <PackageTags>cache;distributedcache;hybridcache</PackageTags>
<SuppressTfmSupportBuildWarnings>true</SuppressTfmSupportBuildWarnings> <SuppressTfmSupportBuildWarnings>true</SuppressTfmSupportBuildWarnings>
<InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy> <InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy>
<InjectCallerAttributesOnLegacy>true</InjectCallerAttributesOnLegacy> <InjectCallerAttributesOnLegacy>true</InjectCallerAttributesOnLegacy>
@ -20,6 +20,11 @@
<!-- This package needs to reference the dotnet9 versions of Microsoft.Extensions packages as it depends on <!-- This package needs to reference the dotnet9 versions of Microsoft.Extensions packages as it depends on
surface area that was added in .NET 9. --> surface area that was added in .NET 9. -->
<ForceLatestDotnetVersions>true</ForceLatestDotnetVersions> <ForceLatestDotnetVersions>true</ForceLatestDotnetVersions>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<UseLoggingGenerator>true</UseLoggingGenerator>
<!-- prefer the dotnet/runtime logging generator; we don't use the extra features, so: don't take the ref -->
<DisableMicrosoftExtensionsLoggingSourceGenerator>false</DisableMicrosoftExtensionsLoggingSourceGenerator>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>

Просмотреть файл

@ -1,6 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<RootNamespace>Microsoft.Extensions.Compliance</RootNamespace> <RootNamespace>Microsoft.Extensions.Compliance</RootNamespace>
<TargetFrameworks>$(NetCoreTargetFrameworks);netstandard2.0;</TargetFrameworks>
<Description>Abstractions to help ensure compliant data management.</Description> <Description>Abstractions to help ensure compliant data management.</Description>
<Workstream>Fundamentals</Workstream> <Workstream>Fundamentals</Workstream>
</PropertyGroup> </PropertyGroup>

Просмотреть файл

@ -0,0 +1,545 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET9_0_OR_GREATER
using System.Collections.Generic;
using System.Diagnostics;
using System.Text.Json.Nodes;
namespace System.Text.Json.Schema;
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable S1144 // Unused private types or members should be removed
internal static partial class JsonSchemaExporter
{
// Simple JSON schema representation taken from System.Text.Json
// https://github.com/dotnet/runtime/blob/50d6cad649aad2bfa4069268eddd16fd51ec5cf3/src/libraries/System.Text.Json/src/System/Text/Json/Schema/JsonSchema.cs
private sealed class JsonSchema
{
public static JsonSchema False { get; } = new(false);
public static JsonSchema True { get; } = new(true);
public JsonSchema()
{
}
private JsonSchema(bool trueOrFalse)
{
_trueOrFalse = trueOrFalse;
}
public bool IsTrue => _trueOrFalse is true;
public bool IsFalse => _trueOrFalse is false;
private readonly bool? _trueOrFalse;
public string? Schema
{
get => _schema;
set
{
VerifyMutable();
_schema = value;
}
}
private string? _schema;
public string? Title
{
get => _title;
set
{
VerifyMutable();
_title = value;
}
}
private string? _title;
public string? Description
{
get => _description;
set
{
VerifyMutable();
_description = value;
}
}
private string? _description;
public string? Ref
{
get => _ref;
set
{
VerifyMutable();
_ref = value;
}
}
private string? _ref;
public string? Comment
{
get => _comment;
set
{
VerifyMutable();
_comment = value;
}
}
private string? _comment;
public JsonSchemaType Type
{
get => _type;
set
{
VerifyMutable();
_type = value;
}
}
private JsonSchemaType _type = JsonSchemaType.Any;
public string? Format
{
get => _format;
set
{
VerifyMutable();
_format = value;
}
}
private string? _format;
public string? Pattern
{
get => _pattern;
set
{
VerifyMutable();
_pattern = value;
}
}
private string? _pattern;
public JsonNode? Constant
{
get => _constant;
set
{
VerifyMutable();
_constant = value;
}
}
private JsonNode? _constant;
public List<KeyValuePair<string, JsonSchema>>? Properties
{
get => _properties;
set
{
VerifyMutable();
_properties = value;
}
}
private List<KeyValuePair<string, JsonSchema>>? _properties;
public List<string>? Required
{
get => _required;
set
{
VerifyMutable();
_required = value;
}
}
private List<string>? _required;
public JsonSchema? Items
{
get => _items;
set
{
VerifyMutable();
_items = value;
}
}
private JsonSchema? _items;
public JsonSchema? AdditionalProperties
{
get => _additionalProperties;
set
{
VerifyMutable();
_additionalProperties = value;
}
}
private JsonSchema? _additionalProperties;
public JsonArray? Enum
{
get => _enum;
set
{
VerifyMutable();
_enum = value;
}
}
private JsonArray? _enum;
public JsonSchema? Not
{
get => _not;
set
{
VerifyMutable();
_not = value;
}
}
private JsonSchema? _not;
public List<JsonSchema>? AnyOf
{
get => _anyOf;
set
{
VerifyMutable();
_anyOf = value;
}
}
private List<JsonSchema>? _anyOf;
public bool HasDefaultValue
{
get => _hasDefaultValue;
set
{
VerifyMutable();
_hasDefaultValue = value;
}
}
private bool _hasDefaultValue;
public JsonNode? DefaultValue
{
get => _defaultValue;
set
{
VerifyMutable();
_defaultValue = value;
}
}
private JsonNode? _defaultValue;
public int? MinLength
{
get => _minLength;
set
{
VerifyMutable();
_minLength = value;
}
}
private int? _minLength;
public int? MaxLength
{
get => _maxLength;
set
{
VerifyMutable();
_maxLength = value;
}
}
private int? _maxLength;
public JsonSchemaExporterContext? GenerationContext { get; set; }
public int KeywordCount
{
get
{
if (_trueOrFalse != null)
{
return 0;
}
int count = 0;
Count(Schema != null);
Count(Ref != null);
Count(Comment != null);
Count(Title != null);
Count(Description != null);
Count(Type != JsonSchemaType.Any);
Count(Format != null);
Count(Pattern != null);
Count(Constant != null);
Count(Properties != null);
Count(Required != null);
Count(Items != null);
Count(AdditionalProperties != null);
Count(Enum != null);
Count(Not != null);
Count(AnyOf != null);
Count(HasDefaultValue);
Count(MinLength != null);
Count(MaxLength != null);
return count;
void Count(bool isKeywordSpecified) => count += isKeywordSpecified ? 1 : 0;
}
}
public void MakeNullable()
{
if (_trueOrFalse != null)
{
return;
}
if (Type != JsonSchemaType.Any)
{
Type |= JsonSchemaType.Null;
}
}
public JsonNode ToJsonNode(JsonSchemaExporterOptions options)
{
if (_trueOrFalse is { } boolSchema)
{
return CompleteSchema((JsonNode)boolSchema);
}
var objSchema = new JsonObject();
if (Schema != null)
{
objSchema.Add(JsonSchemaConstants.SchemaPropertyName, Schema);
}
if (Title != null)
{
objSchema.Add(JsonSchemaConstants.TitlePropertyName, Title);
}
if (Description != null)
{
objSchema.Add(JsonSchemaConstants.DescriptionPropertyName, Description);
}
if (Ref != null)
{
objSchema.Add(JsonSchemaConstants.RefPropertyName, Ref);
}
if (Comment != null)
{
objSchema.Add(JsonSchemaConstants.CommentPropertyName, Comment);
}
if (MapSchemaType(Type) is JsonNode type)
{
objSchema.Add(JsonSchemaConstants.TypePropertyName, type);
}
if (Format != null)
{
objSchema.Add(JsonSchemaConstants.FormatPropertyName, Format);
}
if (Pattern != null)
{
objSchema.Add(JsonSchemaConstants.PatternPropertyName, Pattern);
}
if (Constant != null)
{
objSchema.Add(JsonSchemaConstants.ConstPropertyName, Constant);
}
if (Properties != null)
{
var properties = new JsonObject();
foreach (KeyValuePair<string, JsonSchema> property in Properties)
{
properties.Add(property.Key, property.Value.ToJsonNode(options));
}
objSchema.Add(JsonSchemaConstants.PropertiesPropertyName, properties);
}
if (Required != null)
{
var requiredArray = new JsonArray();
foreach (string requiredProperty in Required)
{
requiredArray.Add((JsonNode)requiredProperty);
}
objSchema.Add(JsonSchemaConstants.RequiredPropertyName, requiredArray);
}
if (Items != null)
{
objSchema.Add(JsonSchemaConstants.ItemsPropertyName, Items.ToJsonNode(options));
}
if (AdditionalProperties != null)
{
objSchema.Add(JsonSchemaConstants.AdditionalPropertiesPropertyName, AdditionalProperties.ToJsonNode(options));
}
if (Enum != null)
{
objSchema.Add(JsonSchemaConstants.EnumPropertyName, Enum);
}
if (Not != null)
{
objSchema.Add(JsonSchemaConstants.NotPropertyName, Not.ToJsonNode(options));
}
if (AnyOf != null)
{
JsonArray anyOfArray = new();
foreach (JsonSchema schema in AnyOf)
{
anyOfArray.Add(schema.ToJsonNode(options));
}
objSchema.Add(JsonSchemaConstants.AnyOfPropertyName, anyOfArray);
}
if (HasDefaultValue)
{
objSchema.Add(JsonSchemaConstants.DefaultPropertyName, DefaultValue);
}
if (MinLength is int minLength)
{
objSchema.Add(JsonSchemaConstants.MinLengthPropertyName, (JsonNode)minLength);
}
if (MaxLength is int maxLength)
{
objSchema.Add(JsonSchemaConstants.MaxLengthPropertyName, (JsonNode)maxLength);
}
return CompleteSchema(objSchema);
JsonNode CompleteSchema(JsonNode schema)
{
if (GenerationContext is { } context)
{
Debug.Assert(options.TransformSchemaNode != null, "context should only be populated if a callback is present.");
// Apply any user-defined transformations to the schema.
return options.TransformSchemaNode!(context, schema);
}
return schema;
}
}
public static void EnsureMutable(ref JsonSchema schema)
{
switch (schema._trueOrFalse)
{
case false:
schema = new JsonSchema { Not = JsonSchema.True };
break;
case true:
schema = new JsonSchema();
break;
}
}
private static readonly JsonSchemaType[] _schemaValues = new JsonSchemaType[]
{
// NB the order of these values influences order of types in the rendered schema
JsonSchemaType.String,
JsonSchemaType.Integer,
JsonSchemaType.Number,
JsonSchemaType.Boolean,
JsonSchemaType.Array,
JsonSchemaType.Object,
JsonSchemaType.Null,
};
private void VerifyMutable()
{
Debug.Assert(_trueOrFalse is null, "Schema is not mutable");
}
private static JsonNode? MapSchemaType(JsonSchemaType schemaType)
{
if (schemaType is JsonSchemaType.Any)
{
return null;
}
if (ToIdentifier(schemaType) is string identifier)
{
return identifier;
}
var array = new JsonArray();
foreach (JsonSchemaType type in _schemaValues)
{
if ((schemaType & type) != 0)
{
array.Add((JsonNode)ToIdentifier(type)!);
}
}
return array;
static string? ToIdentifier(JsonSchemaType schemaType) => schemaType switch
{
JsonSchemaType.Null => "null",
JsonSchemaType.Boolean => "boolean",
JsonSchemaType.Integer => "integer",
JsonSchemaType.Number => "number",
JsonSchemaType.String => "string",
JsonSchemaType.Array => "array",
JsonSchemaType.Object => "object",
_ => null,
};
}
}
[Flags]
private enum JsonSchemaType
{
Any = 0, // No type declared on the schema
Null = 1,
Boolean = 2,
Integer = 4,
Number = 8,
String = 16,
Array = 32,
Object = 64,
}
}
#endif

Просмотреть файл

@ -0,0 +1,427 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET9_0_OR_GREATER
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
#if !NET
using System.Linq;
#endif
using System.Reflection;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Shared.Diagnostics;
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
namespace System.Text.Json.Schema;
internal static partial class JsonSchemaExporter
{
private static class ReflectionHelpers
{
private const BindingFlags AllInstance = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
private static PropertyInfo? _jsonTypeInfo_ElementType;
private static PropertyInfo? _jsonPropertyInfo_MemberName;
private static FieldInfo? _nullableConverter_ElementConverter_Generic;
private static FieldInfo? _enumConverter_Options_Generic;
private static FieldInfo? _enumConverter_NamingPolicy_Generic;
public static bool IsBuiltInConverter(JsonConverter converter) =>
converter.GetType().Assembly == typeof(JsonConverter).Assembly;
public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;
public static Type GetElementType(JsonTypeInfo typeInfo)
{
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");
// Uses reflection to access the element type encapsulated by a JsonTypeInfo.
if (_jsonTypeInfo_ElementType is null)
{
PropertyInfo? elementTypeProperty = typeof(JsonTypeInfo).GetProperty("ElementType", AllInstance);
_jsonTypeInfo_ElementType = Throw.IfNull(elementTypeProperty);
}
return (Type)_jsonTypeInfo_ElementType.GetValue(typeInfo)!;
}
public static string? GetMemberName(JsonPropertyInfo propertyInfo)
{
// Uses reflection to the member name encapsulated by a JsonPropertyInfo.
if (_jsonPropertyInfo_MemberName is null)
{
PropertyInfo? memberName = typeof(JsonPropertyInfo).GetProperty("MemberName", AllInstance);
_jsonPropertyInfo_MemberName = Throw.IfNull(memberName);
}
return (string?)_jsonPropertyInfo_MemberName.GetValue(propertyInfo);
}
public static JsonConverter GetElementConverter(JsonConverter nullableConverter)
{
// Uses reflection to access the element converter encapsulated by a nullable converter.
if (_nullableConverter_ElementConverter_Generic is null)
{
FieldInfo? genericFieldInfo = Type
.GetType("System.Text.Json.Serialization.Converters.NullableConverter`1, System.Text.Json")!
.GetField("_elementConverter", AllInstance);
_nullableConverter_ElementConverter_Generic = Throw.IfNull(genericFieldInfo);
}
Type converterType = nullableConverter.GetType();
var thisFieldInfo = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_nullableConverter_ElementConverter_Generic);
return (JsonConverter)thisFieldInfo.GetValue(nullableConverter)!;
}
public static void GetEnumConverterConfig(JsonConverter enumConverter, out JsonNamingPolicy? namingPolicy, out bool allowString)
{
// Uses reflection to access configuration encapsulated by an enum converter.
if (_enumConverter_Options_Generic is null)
{
FieldInfo? genericFieldInfo = Type
.GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")!
.GetField("_converterOptions", AllInstance);
_enumConverter_Options_Generic = Throw.IfNull(genericFieldInfo);
}
if (_enumConverter_NamingPolicy_Generic is null)
{
FieldInfo? genericFieldInfo = Type
.GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")!
.GetField("_namingPolicy", AllInstance);
_enumConverter_NamingPolicy_Generic = Throw.IfNull(genericFieldInfo);
}
const int EnumConverterOptionsAllowStrings = 1;
Type converterType = enumConverter.GetType();
var converterOptionsField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_Options_Generic);
var namingPolicyField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_NamingPolicy_Generic);
namingPolicy = (JsonNamingPolicy?)namingPolicyField.GetValue(enumConverter);
int converterOptions = (int)converterOptionsField.GetValue(enumConverter)!;
allowString = (converterOptions & EnumConverterOptionsAllowStrings) != 0;
}
// The .NET 8 source generator doesn't populate attribute providers for properties
// cf. https://github.com/dotnet/runtime/issues/100095
// Work around the issue by running a query for the relevant MemberInfo using the internal MemberName property
// https://github.com/dotnet/runtime/blob/de774ff9ee1a2c06663ab35be34b755cd8d29731/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs#L206
public static ICustomAttributeProvider? ResolveAttributeProvider(
[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.NonPublicProperties |
DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.NonPublicFields)]
Type? declaringType,
JsonPropertyInfo? propertyInfo)
{
if (declaringType is null || propertyInfo is null)
{
return null;
}
if (propertyInfo.AttributeProvider is { } provider)
{
return provider;
}
string? memberName = ReflectionHelpers.GetMemberName(propertyInfo);
if (memberName is not null)
{
return (MemberInfo?)declaringType.GetProperty(memberName, AllInstance) ??
declaringType.GetField(memberName, AllInstance);
}
return null;
}
// Resolves the parameters of the deserialization constructor for a type, if they exist.
public static Func<JsonPropertyInfo, ParameterInfo?>? ResolveJsonConstructorParameterMapper(
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
Type type,
JsonTypeInfo typeInfo)
{
Debug.Assert(type == typeInfo.Type, "The declaring type must match the typeInfo type.");
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Object, "Should only be passed object JSON kinds.");
if (typeInfo.Properties.Count > 0 &&
typeInfo.CreateObject is null && // Ensure that a default constructor isn't being used
TryGetDeserializationConstructor(type, useDefaultCtorInAnnotatedStructs: true, out ConstructorInfo? ctor))
{
ParameterInfo[]? parameters = ctor?.GetParameters();
if (parameters?.Length > 0)
{
Dictionary<ParameterLookupKey, ParameterInfo> dict = new(parameters.Length);
foreach (ParameterInfo parameter in parameters)
{
if (parameter.Name is not null)
{
// We don't care about null parameter names or conflicts since they
// would have already been rejected by JsonTypeInfo exporterOptions.
dict[new(parameter.Name, parameter.ParameterType)] = parameter;
}
}
return prop => dict.TryGetValue(new(prop.Name, prop.PropertyType), out ParameterInfo? parameter) ? parameter : null;
}
}
return null;
}
// Resolves the nullable reference type annotations for a property or field,
// additionally addressing a few known bugs of the NullabilityInfo pre .NET 9.
public static NullabilityInfo GetMemberNullability(NullabilityInfoContext context, MemberInfo memberInfo)
{
Debug.Assert(memberInfo is PropertyInfo or FieldInfo, "Member must be property or field.");
return memberInfo is PropertyInfo prop
? context.Create(prop)
: context.Create((FieldInfo)memberInfo);
}
public static NullabilityState GetParameterNullability(NullabilityInfoContext context, ParameterInfo parameterInfo)
{
#if NET8_0
// Workaround for https://github.com/dotnet/runtime/issues/92487
// The fix has been incorporated into .NET 9 (and the polyfilled implementations in netfx).
// Should be removed once .NET 8 support is dropped.
if (GetGenericParameterDefinition(parameterInfo) is { ParameterType: { IsGenericParameter: true } typeParam })
{
// Step 1. Look for nullable annotations on the type parameter.
if (GetNullableFlags(typeParam) is byte[] flags)
{
return TranslateByte(flags[0]);
}
// Step 2. Look for nullable annotations on the generic method declaration.
if (typeParam.DeclaringMethod != null && GetNullableContextFlag(typeParam.DeclaringMethod) is byte flag)
{
return TranslateByte(flag);
}
// Step 3. Look for nullable annotations on the generic method declaration.
if (GetNullableContextFlag(typeParam.DeclaringType!) is byte flag2)
{
return TranslateByte(flag2);
}
// Default to nullable.
return NullabilityState.Nullable;
static byte[]? GetNullableFlags(MemberInfo member)
{
foreach (CustomAttributeData attr in member.GetCustomAttributesData())
{
Type attrType = attr.AttributeType;
if (attrType.Name == "NullableAttribute" && attrType.Namespace == "System.Runtime.CompilerServices")
{
foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments)
{
switch (ctorArg.Value)
{
case byte flag:
return [flag];
case byte[] flags:
return flags;
}
}
}
}
return null;
}
static byte? GetNullableContextFlag(MemberInfo member)
{
foreach (CustomAttributeData attr in member.GetCustomAttributesData())
{
Type attrType = attr.AttributeType;
if (attrType.Name == "NullableContextAttribute" && attrType.Namespace == "System.Runtime.CompilerServices")
{
foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments)
{
if (ctorArg.Value is byte flag)
{
return flag;
}
}
}
}
return null;
}
#pragma warning disable S109 // Magic numbers should not be used
static NullabilityState TranslateByte(byte b) => b switch
{
1 => NullabilityState.NotNull,
2 => NullabilityState.Nullable,
_ => NullabilityState.Unknown
};
#pragma warning restore S109 // Magic numbers should not be used
}
static ParameterInfo GetGenericParameterDefinition(ParameterInfo parameter)
{
if (parameter.Member is { DeclaringType.IsConstructedGenericType: true }
or MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false })
{
var genericMethod = (MethodBase)GetGenericMemberDefinition(parameter.Member);
return genericMethod.GetParameters()[parameter.Position];
}
return parameter;
}
static MemberInfo GetGenericMemberDefinition(MemberInfo member)
{
if (member is Type type)
{
return type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : type;
}
if (member.DeclaringType?.IsConstructedGenericType is true)
{
return member.DeclaringType.GetGenericTypeDefinition().GetMemberWithSameMetadataDefinitionAs(member);
}
if (member is MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false } method)
{
return method.GetGenericMethodDefinition();
}
return member;
}
#endif
return context.Create(parameterInfo).WriteState;
}
// Taken from https://github.com/dotnet/runtime/blob/903bc019427ca07080530751151ea636168ad334/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L288-L317
public static object? GetNormalizedDefaultValue(ParameterInfo parameterInfo)
{
Type parameterType = parameterInfo.ParameterType;
object? defaultValue = parameterInfo.DefaultValue;
if (defaultValue is null)
{
return null;
}
// DBNull.Value is sometimes used as the default value (returned by reflection) of nullable params in place of null.
if (defaultValue == DBNull.Value && parameterType != typeof(DBNull))
{
return null;
}
// Default values of enums or nullable enums are represented using the underlying type and need to be cast explicitly
// cf. https://github.com/dotnet/runtime/issues/68647
if (parameterType.IsEnum)
{
return Enum.ToObject(parameterType, defaultValue);
}
if (Nullable.GetUnderlyingType(parameterType) is Type underlyingType && underlyingType.IsEnum)
{
return Enum.ToObject(underlyingType, defaultValue);
}
return defaultValue;
}
// Resolves the deserialization constructor for a type using logic copied from
// https://github.com/dotnet/runtime/blob/e12e2fa6cbdd1f4b0c8ad1b1e2d960a480c21703/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L227-L286
private static bool TryGetDeserializationConstructor(
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
Type type,
bool useDefaultCtorInAnnotatedStructs,
out ConstructorInfo? deserializationCtor)
{
ConstructorInfo? ctorWithAttribute = null;
ConstructorInfo? publicParameterlessCtor = null;
ConstructorInfo? lonePublicCtor = null;
ConstructorInfo[] constructors = type.GetConstructors(BindingFlags.Public | BindingFlags.Instance);
if (constructors.Length == 1)
{
lonePublicCtor = constructors[0];
}
foreach (ConstructorInfo constructor in constructors)
{
if (HasJsonConstructorAttribute(constructor))
{
if (ctorWithAttribute != null)
{
deserializationCtor = null;
return false;
}
ctorWithAttribute = constructor;
}
else if (constructor.GetParameters().Length == 0)
{
publicParameterlessCtor = constructor;
}
}
// Search for non-public ctors with [JsonConstructor].
foreach (ConstructorInfo constructor in type.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance))
{
if (HasJsonConstructorAttribute(constructor))
{
if (ctorWithAttribute != null)
{
deserializationCtor = null;
return false;
}
ctorWithAttribute = constructor;
}
}
// Structs will use default constructor if attribute isn't used.
if (useDefaultCtorInAnnotatedStructs && type.IsValueType && ctorWithAttribute == null)
{
deserializationCtor = null;
return true;
}
deserializationCtor = ctorWithAttribute ?? publicParameterlessCtor ?? lonePublicCtor;
return true;
static bool HasJsonConstructorAttribute(ConstructorInfo constructorInfo) =>
constructorInfo.GetCustomAttribute<JsonConstructorAttribute>() != null;
}
// Parameter to property matching semantics as declared in
// https://github.com/dotnet/runtime/blob/12d96ccfaed98e23c345188ee08f8cfe211c03e7/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs#L1007-L1030
private readonly struct ParameterLookupKey : IEquatable<ParameterLookupKey>
{
public ParameterLookupKey(string name, Type type)
{
Name = name;
Type = type;
}
public string Name { get; }
public Type Type { get; }
public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Name);
public bool Equals(ParameterLookupKey other) => Type == other.Type && string.Equals(Name, other.Name, StringComparison.OrdinalIgnoreCase);
public override bool Equals(object? obj) => obj is ParameterLookupKey key && Equals(key);
}
}
#if !NET
private static MemberInfo GetMemberWithSameMetadataDefinitionAs(this Type specializedType, MemberInfo member)
{
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
return specializedType.GetMember(member.Name, member.MemberType, All).First(m => m.MetadataToken == member.MetadataToken);
}
#endif
}
#endif

Просмотреть файл

@ -0,0 +1,801 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET9_0_OR_GREATER
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Reflection;
#if NET
using System.Runtime.InteropServices;
#endif
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Shared.Diagnostics;
#pragma warning disable LA0002 // Use 'Microsoft.Shared.Text.NumericExtensions.ToInvariantString' for improved performance
#pragma warning disable S107 // Methods should not have too many parameters
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
namespace System.Text.Json.Schema;
/// <summary>
/// Maps .NET types to JSON schema objects using contract metadata from <see cref="JsonTypeInfo"/> instances.
/// </summary>
#if !SHARED_PROJECT
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
#endif
internal static partial class JsonSchemaExporter
{
// Polyfill implementation of JsonSchemaExporter for System.Text.Json version 8.0.0.
// Uses private reflection to access metadata not available with the older APIs of STJ.
private const string RequiresUnreferencedCodeMessage =
"Uses private reflection on System.Text.Json components to access converter metadata. " +
"If running Native AOT ensure that the 'IlcTrimMetadata' property has been disabled.";
/// <summary>
/// Generates a JSON schema corresponding to the contract metadata of the specified type.
/// </summary>
/// <param name="options">The options instance from which to resolve the contract metadata.</param>
/// <param name="type">The root type for which to generate the JSON schema.</param>
/// <param name="exporterOptions">The exporterOptions object controlling the schema generation.</param>
/// <returns>A new <see cref="JsonNode"/> instance defining the JSON schema for <paramref name="type"/>.</returns>
/// <exception cref="ArgumentNullException">One of the specified parameters is <see langword="null" />.</exception>
/// <exception cref="NotSupportedException">The <paramref name="options"/> parameter contains unsupported exporterOptions.</exception>
[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
public static JsonNode GetJsonSchemaAsNode(this JsonSerializerOptions options, Type type, JsonSchemaExporterOptions? exporterOptions = null)
{
_ = Throw.IfNull(options);
_ = Throw.IfNull(type);
ValidateOptions(options);
exporterOptions ??= JsonSchemaExporterOptions.Default;
JsonTypeInfo typeInfo = options.GetTypeInfo(type);
return MapRootTypeJsonSchema(typeInfo, exporterOptions);
}
/// <summary>
/// Generates a JSON schema corresponding to the specified contract metadata.
/// </summary>
/// <param name="typeInfo">The contract metadata for which to generate the schema.</param>
/// <param name="exporterOptions">The exporterOptions object controlling the schema generation.</param>
/// <returns>A new <see cref="JsonNode"/> instance defining the JSON schema for <paramref name="typeInfo"/>.</returns>
/// <exception cref="ArgumentNullException">One of the specified parameters is <see langword="null" />.</exception>
/// <exception cref="NotSupportedException">The <paramref name="typeInfo"/> parameter contains unsupported exporterOptions.</exception>
[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
public static JsonNode GetJsonSchemaAsNode(this JsonTypeInfo typeInfo, JsonSchemaExporterOptions? exporterOptions = null)
{
_ = Throw.IfNull(typeInfo);
ValidateOptions(typeInfo.Options);
exporterOptions ??= JsonSchemaExporterOptions.Default;
return MapRootTypeJsonSchema(typeInfo, exporterOptions);
}
[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
private static JsonNode MapRootTypeJsonSchema(JsonTypeInfo typeInfo, JsonSchemaExporterOptions exporterOptions)
{
GenerationState state = new(exporterOptions, typeInfo.Options);
JsonSchema schema = MapJsonSchemaCore(ref state, typeInfo);
return schema.ToJsonNode(exporterOptions);
}
[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
private static JsonSchema MapJsonSchemaCore(
ref GenerationState state,
JsonTypeInfo typeInfo,
Type? parentType = null,
JsonPropertyInfo? propertyInfo = null,
ICustomAttributeProvider? propertyAttributeProvider = null,
ParameterInfo? parameterInfo = null,
bool isNonNullableType = false,
JsonConverter? customConverter = null,
JsonNumberHandling? customNumberHandling = null,
JsonTypeInfo? parentPolymorphicTypeInfo = null,
bool parentPolymorphicTypeContainsTypesWithoutDiscriminator = false,
bool parentPolymorphicTypeIsNonNullable = false,
KeyValuePair<string, JsonSchema>? typeDiscriminator = null,
bool cacheResult = true)
{
Debug.Assert(typeInfo.IsReadOnly, "The specified contract must have been made read-only.");
JsonSchemaExporterContext exporterContext = state.CreateContext(typeInfo, parentPolymorphicTypeInfo, parentType, propertyInfo, parameterInfo, propertyAttributeProvider);
if (cacheResult && typeInfo.Kind is not JsonTypeInfoKind.None &&
state.TryGetExistingJsonPointer(exporterContext, out string? existingJsonPointer))
{
// The schema context has already been generated in the schema document, return a reference to it.
return CompleteSchema(ref state, new JsonSchema { Ref = existingJsonPointer });
}
JsonSchema schema;
JsonConverter effectiveConverter = customConverter ?? typeInfo.Converter;
JsonNumberHandling effectiveNumberHandling = customNumberHandling ?? typeInfo.NumberHandling ?? typeInfo.Options.NumberHandling;
if (!ReflectionHelpers.IsBuiltInConverter(effectiveConverter))
{
// Return a `true` schema for types with user-defined converters.
return CompleteSchema(ref state, JsonSchema.True);
}
if (parentPolymorphicTypeInfo is null && typeInfo.PolymorphismOptions is { DerivedTypes.Count: > 0 } polyOptions)
{
// This is the base type of a polymorphic type hierarchy. The schema for this type
// will include an "anyOf" property with the schemas for all derived types.
string typeDiscriminatorKey = polyOptions.TypeDiscriminatorPropertyName;
List<JsonDerivedType> derivedTypes = polyOptions.DerivedTypes.ToList();
if (!typeInfo.Type.IsAbstract && !derivedTypes.Any(derived => derived.DerivedType == typeInfo.Type))
{
// For non-abstract base types that haven't been explicitly configured,
// add a trivial schema to the derived types since we should support it.
derivedTypes.Add(new JsonDerivedType(typeInfo.Type));
}
bool containsTypesWithoutDiscriminator = derivedTypes.Exists(static derivedTypes => derivedTypes.TypeDiscriminator is null);
JsonSchemaType schemaType = JsonSchemaType.Any;
List<JsonSchema>? anyOf = new(derivedTypes.Count);
state.PushSchemaNode(JsonSchemaConstants.AnyOfPropertyName);
foreach (JsonDerivedType derivedType in derivedTypes)
{
Debug.Assert(derivedType.TypeDiscriminator is null or int or string, "Type discriminator does not have the expected type.");
KeyValuePair<string, JsonSchema>? derivedTypeDiscriminator = null;
if (derivedType.TypeDiscriminator is { } discriminatorValue)
{
JsonNode discriminatorNode = discriminatorValue switch
{
string stringId => (JsonNode)stringId,
_ => (JsonNode)(int)discriminatorValue,
};
JsonSchema discriminatorSchema = new() { Constant = discriminatorNode };
derivedTypeDiscriminator = new(typeDiscriminatorKey, discriminatorSchema);
}
JsonTypeInfo derivedTypeInfo = typeInfo.Options.GetTypeInfo(derivedType.DerivedType);
state.PushSchemaNode(anyOf.Count.ToString(CultureInfo.InvariantCulture));
JsonSchema derivedSchema = MapJsonSchemaCore(
ref state,
derivedTypeInfo,
parentPolymorphicTypeInfo: typeInfo,
typeDiscriminator: derivedTypeDiscriminator,
parentPolymorphicTypeContainsTypesWithoutDiscriminator: containsTypesWithoutDiscriminator,
parentPolymorphicTypeIsNonNullable: isNonNullableType,
cacheResult: false);
state.PopSchemaNode();
// Determine if all derived schemas have the same type.
if (anyOf.Count == 0)
{
schemaType = derivedSchema.Type;
}
else if (schemaType != derivedSchema.Type)
{
schemaType = JsonSchemaType.Any;
}
anyOf.Add(derivedSchema);
}
state.PopSchemaNode();
if (schemaType is not JsonSchemaType.Any)
{
// If all derived types have the same schema type, we can simplify the schema
// by moving the type keyword to the base schema and removing it from the derived schemas.
foreach (JsonSchema derivedSchema in anyOf)
{
derivedSchema.Type = JsonSchemaType.Any;
if (derivedSchema.KeywordCount == 0)
{
// if removing the type results in an empty schema,
// remove the anyOf array entirely since it's always true.
anyOf = null;
break;
}
}
}
schema = new()
{
Type = schemaType,
AnyOf = anyOf,
// If all derived types have a discriminator, we can require it in the base schema.
Required = containsTypesWithoutDiscriminator ? null : new() { typeDiscriminatorKey },
};
return CompleteSchema(ref state, schema);
}
if (Nullable.GetUnderlyingType(typeInfo.Type) is Type nullableElementType)
{
JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(nullableElementType);
customConverter = ExtractCustomNullableConverter(customConverter);
schema = MapJsonSchemaCore(ref state, elementTypeInfo, customConverter: customConverter, cacheResult: false);
if (schema.Enum != null)
{
Debug.Assert(elementTypeInfo.Type.IsEnum, "The enum keyword should only be populated by schemas for enum types.");
schema.Enum.Add(null); // Append null to the enum array.
}
return CompleteSchema(ref state, schema);
}
switch (typeInfo.Kind)
{
case JsonTypeInfoKind.Object:
List<KeyValuePair<string, JsonSchema>>? properties = null;
List<string>? required = null;
JsonSchema? additionalProperties = null;
JsonUnmappedMemberHandling effectiveUnmappedMemberHandling = typeInfo.UnmappedMemberHandling ?? typeInfo.Options.UnmappedMemberHandling;
if (effectiveUnmappedMemberHandling is JsonUnmappedMemberHandling.Disallow)
{
// Disallow unspecified properties.
additionalProperties = JsonSchema.False;
}
if (typeDiscriminator is { } typeDiscriminatorPair)
{
(properties = new()).Add(typeDiscriminatorPair);
if (parentPolymorphicTypeContainsTypesWithoutDiscriminator)
{
// Require the discriminator here since it's not common to all derived types.
(required = new()).Add(typeDiscriminatorPair.Key);
}
}
Func<JsonPropertyInfo, ParameterInfo?>? parameterInfoMapper =
ReflectionHelpers.ResolveJsonConstructorParameterMapper(typeInfo.Type, typeInfo);
state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName);
foreach (JsonPropertyInfo property in typeInfo.Properties)
{
if (property is { Get: null, Set: null } or { IsExtensionData: true })
{
continue; // Skip JsonIgnored properties and extension data
}
JsonNumberHandling? propertyNumberHandling = property.NumberHandling ?? effectiveNumberHandling;
JsonTypeInfo propertyTypeInfo = typeInfo.Options.GetTypeInfo(property.PropertyType);
// Resolve the attribute provider for the property.
ICustomAttributeProvider? attributeProvider = ReflectionHelpers.ResolveAttributeProvider(typeInfo.Type, property);
// Declare the property as nullable if either getter or setter are nullable.
bool isNonNullableProperty = false;
if (attributeProvider is MemberInfo memberInfo)
{
NullabilityInfo nullabilityInfo = ReflectionHelpers.GetMemberNullability(state.NullabilityInfoContext, memberInfo);
isNonNullableProperty =
(property.Get is null || nullabilityInfo.ReadState is NullabilityState.NotNull) &&
(property.Set is null || nullabilityInfo.WriteState is NullabilityState.NotNull);
}
bool isRequired = property.IsRequired;
bool hasDefaultValue = false;
JsonNode? defaultValue = null;
ParameterInfo? associatedParameter = parameterInfoMapper?.Invoke(property);
if (associatedParameter != null)
{
ResolveParameterInfo(
associatedParameter,
propertyTypeInfo,
state.NullabilityInfoContext,
out hasDefaultValue,
out defaultValue,
out bool isNonNullableParameter,
ref isRequired);
isNonNullableProperty &= isNonNullableParameter;
}
state.PushSchemaNode(property.Name);
JsonSchema propertySchema = MapJsonSchemaCore(
ref state,
propertyTypeInfo,
parentType: typeInfo.Type,
propertyInfo: property,
parameterInfo: associatedParameter,
propertyAttributeProvider: attributeProvider,
isNonNullableType: isNonNullableProperty,
customConverter: property.CustomConverter,
customNumberHandling: propertyNumberHandling);
state.PopSchemaNode();
if (hasDefaultValue)
{
JsonSchema.EnsureMutable(ref propertySchema);
propertySchema.DefaultValue = defaultValue;
propertySchema.HasDefaultValue = true;
}
(properties ??= new()).Add(new(property.Name, propertySchema));
if (isRequired)
{
(required ??= new()).Add(property.Name);
}
}
state.PopSchemaNode();
return CompleteSchema(ref state, new()
{
Type = JsonSchemaType.Object,
Properties = properties,
Required = required,
AdditionalProperties = additionalProperties,
});
case JsonTypeInfoKind.Enumerable:
Type elementType = ReflectionHelpers.GetElementType(typeInfo);
JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(elementType);
if (typeDiscriminator is null)
{
state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName);
JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling);
state.PopSchemaNode();
return CompleteSchema(ref state, new()
{
Type = JsonSchemaType.Array,
Items = items.IsTrue ? null : items,
});
}
else
{
// Polymorphic enumerable types are represented using a wrapping object:
// { "$type" : "discriminator", "$values" : [element1, element2, ...] }
// Which corresponds to the schema
// { "properties" : { "$type" : { "const" : "discriminator" }, "$values" : { "type" : "array", "items" : { ... } } } }
const string ValuesKeyword = "$values";
state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName);
state.PushSchemaNode(ValuesKeyword);
state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName);
JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling);
state.PopSchemaNode();
state.PopSchemaNode();
state.PopSchemaNode();
return CompleteSchema(ref state, new()
{
Type = JsonSchemaType.Object,
Properties = new()
{
typeDiscriminator.Value,
new(ValuesKeyword,
new JsonSchema
{
Type = JsonSchemaType.Array,
Items = items.IsTrue ? null : items,
}),
},
Required = parentPolymorphicTypeContainsTypesWithoutDiscriminator ? new() { typeDiscriminator.Value.Key } : null,
});
}
case JsonTypeInfoKind.Dictionary:
Type valueType = ReflectionHelpers.GetElementType(typeInfo);
JsonTypeInfo valueTypeInfo = typeInfo.Options.GetTypeInfo(valueType);
List<KeyValuePair<string, JsonSchema>>? dictProps = null;
List<string>? dictRequired = null;
if (typeDiscriminator is { } dictDiscriminator)
{
dictProps = new() { dictDiscriminator };
if (parentPolymorphicTypeContainsTypesWithoutDiscriminator)
{
// Require the discriminator here since it's not common to all derived types.
dictRequired = new() { dictDiscriminator.Key };
}
}
state.PushSchemaNode(JsonSchemaConstants.AdditionalPropertiesPropertyName);
JsonSchema valueSchema = MapJsonSchemaCore(ref state, valueTypeInfo, customNumberHandling: effectiveNumberHandling);
state.PopSchemaNode();
return CompleteSchema(ref state, new()
{
Type = JsonSchemaType.Object,
Properties = dictProps,
Required = dictRequired,
AdditionalProperties = valueSchema.IsTrue ? null : valueSchema,
});
default:
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.None, "The default case should handle unrecognize type kinds.");
if (_simpleTypeSchemaFactories.TryGetValue(typeInfo.Type, out Func<JsonNumberHandling, JsonSchema>? simpleTypeSchemaFactory))
{
schema = simpleTypeSchemaFactory(effectiveNumberHandling);
}
else if (typeInfo.Type.IsEnum)
{
schema = GetEnumConverterSchema(typeInfo, effectiveConverter);
}
else
{
schema = JsonSchema.True;
}
return CompleteSchema(ref state, schema);
}
JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)
{
if (schema.Ref is null)
{
if (IsNullableSchema(ref state))
{
schema.MakeNullable();
}
bool IsNullableSchema(ref GenerationState state)
{
// A schema is marked as nullable if either
// 1. We have a schema for a property where either the getter or setter are marked as nullable.
// 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
if (propertyInfo != null || parameterInfo != null)
{
return !isNonNullableType;
}
else
{
return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
}
}
}
if (state.ExporterOptions.TransformSchemaNode != null)
{
// Prime the schema for invocation by the JsonNode transformer.
schema.GenerationContext = exporterContext;
}
return schema;
}
}
private readonly ref struct GenerationState
{
private const int DefaultMaxDepth = 64;
private readonly List<string> _currentPath = new();
private readonly Dictionary<(JsonTypeInfo, JsonPropertyInfo?), string[]> _generated = new();
private readonly int _maxDepth;
public GenerationState(JsonSchemaExporterOptions exporterOptions, JsonSerializerOptions options, NullabilityInfoContext? nullabilityInfoContext = null)
{
ExporterOptions = exporterOptions;
NullabilityInfoContext = nullabilityInfoContext ?? new();
_maxDepth = options.MaxDepth is 0 ? DefaultMaxDepth : options.MaxDepth;
}
public JsonSchemaExporterOptions ExporterOptions { get; }
public NullabilityInfoContext NullabilityInfoContext { get; }
public int CurrentDepth => _currentPath.Count;
public void PushSchemaNode(string nodeId)
{
if (CurrentDepth == _maxDepth)
{
ThrowHelpers.ThrowInvalidOperationException_MaxDepthReached();
}
_currentPath.Add(nodeId);
}
public void PopSchemaNode()
{
_currentPath.RemoveAt(_currentPath.Count - 1);
}
/// <summary>
/// Registers the current schema node generation context; if it has already been generated return a JSON pointer to its location.
/// </summary>
public bool TryGetExistingJsonPointer(in JsonSchemaExporterContext context, [NotNullWhen(true)] out string? existingJsonPointer)
{
(JsonTypeInfo, JsonPropertyInfo?) key = (context.TypeInfo, context.PropertyInfo);
#if NET
ref string[]? pathToSchema = ref CollectionsMarshal.GetValueRefOrAddDefault(_generated, key, out bool exists);
#else
bool exists = _generated.TryGetValue(key, out string[]? pathToSchema);
#endif
if (exists)
{
existingJsonPointer = FormatJsonPointer(pathToSchema);
return true;
}
#if NET
pathToSchema = context._path;
#else
_generated[key] = context._path;
#endif
existingJsonPointer = null;
return false;
}
public JsonSchemaExporterContext CreateContext(
JsonTypeInfo typeInfo,
JsonTypeInfo? baseTypeInfo,
Type? declaringType,
JsonPropertyInfo? propertyInfo,
ParameterInfo? parameterInfo,
ICustomAttributeProvider? propertyAttributeProvider)
{
return new JsonSchemaExporterContext(typeInfo, baseTypeInfo, declaringType, propertyInfo, parameterInfo, propertyAttributeProvider, _currentPath.ToArray());
}
private static string FormatJsonPointer(ReadOnlySpan<string> path)
{
if (path.IsEmpty)
{
return "#";
}
StringBuilder sb = new();
_ = sb.Append('#');
for (int i = 0; i < path.Length; i++)
{
string segment = path[i];
if (segment.AsSpan().IndexOfAny('~', '/') != -1)
{
#pragma warning disable CA1307 // Specify StringComparison for clarity
segment = segment.Replace("~", "~0").Replace("/", "~1");
#pragma warning restore CA1307
}
_ = sb.Append('/');
_ = sb.Append(segment);
}
return sb.ToString();
}
}
private static readonly Dictionary<Type, Func<JsonNumberHandling, JsonSchema>> _simpleTypeSchemaFactories = new()
{
[typeof(object)] = _ => JsonSchema.True,
[typeof(bool)] = _ => new JsonSchema { Type = JsonSchemaType.Boolean },
[typeof(byte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(ushort)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(uint)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(ulong)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(sbyte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(short)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(int)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(long)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(float)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
[typeof(double)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
[typeof(decimal)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling),
#if NET6_0_OR_GREATER
[typeof(Half)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
#endif
#if NET7_0_OR_GREATER
[typeof(UInt128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
[typeof(Int128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
#endif
[typeof(char)] = _ => new JsonSchema { Type = JsonSchemaType.String, MinLength = 1, MaxLength = 1 },
[typeof(string)] = _ => new JsonSchema { Type = JsonSchemaType.String },
[typeof(byte[])] = _ => new JsonSchema { Type = JsonSchemaType.String },
[typeof(Memory<byte>)] = _ => new JsonSchema { Type = JsonSchemaType.String },
[typeof(ReadOnlyMemory<byte>)] = _ => new JsonSchema { Type = JsonSchemaType.String },
[typeof(DateTime)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" },
[typeof(DateTimeOffset)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" },
[typeof(TimeSpan)] = _ => new JsonSchema
{
Comment = "Represents a System.TimeSpan value.",
Type = JsonSchemaType.String,
Pattern = @"^-?(\d+\.)?\d{2}:\d{2}:\d{2}(\.\d{1,7})?$",
},
#if NET6_0_OR_GREATER
[typeof(DateOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date" },
[typeof(TimeOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "time" },
#endif
[typeof(Guid)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uuid" },
[typeof(Uri)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uri" },
[typeof(Version)] = _ => new JsonSchema
{
Comment = "Represents a version string.",
Type = JsonSchemaType.String,
Pattern = @"^\d+(\.\d+){1,3}$",
},
[typeof(JsonDocument)] = _ => JsonSchema.True,
[typeof(JsonElement)] = _ => JsonSchema.True,
[typeof(JsonNode)] = _ => JsonSchema.True,
[typeof(JsonValue)] = _ => JsonSchema.True,
[typeof(JsonObject)] = _ => new JsonSchema { Type = JsonSchemaType.Object },
[typeof(JsonArray)] = _ => new JsonSchema { Type = JsonSchemaType.Array },
};
// Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/JsonPrimitiveConverter.cs#L36-L69
private static JsonSchema GetSchemaForNumericType(JsonSchemaType schemaType, JsonNumberHandling numberHandling, bool isIeeeFloatingPoint = false)
{
Debug.Assert(schemaType is JsonSchemaType.Integer or JsonSchemaType.Number, "schema type must be number or integer");
Debug.Assert(!isIeeeFloatingPoint || schemaType is JsonSchemaType.Number, "If specifying IEEE the schema type must be number");
string? pattern = null;
if ((numberHandling & (JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)) != 0)
{
if (schemaType is JsonSchemaType.Integer)
{
pattern = @"^-?(?:0|[1-9]\d*)$";
}
else if (isIeeeFloatingPoint)
{
pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$";
}
else
{
pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?$";
}
schemaType |= JsonSchemaType.String;
}
if (isIeeeFloatingPoint && (numberHandling & JsonNumberHandling.AllowNamedFloatingPointLiterals) != 0)
{
return new JsonSchema
{
AnyOf = new()
{
new JsonSchema { Type = schemaType, Pattern = pattern },
new JsonSchema { Enum = new() { (JsonNode)"NaN", (JsonNode)"Infinity", (JsonNode)"-Infinity" } },
},
};
}
return new JsonSchema { Type = schemaType, Pattern = pattern };
}
private static JsonConverter? ExtractCustomNullableConverter(JsonConverter? converter)
{
Debug.Assert(converter is null || ReflectionHelpers.IsBuiltInConverter(converter), "If specified the converter must be built-in.");
if (converter is null)
{
return null;
}
return ReflectionHelpers.GetElementConverter(converter);
}
private static void ValidateOptions(JsonSerializerOptions options)
{
if (options.ReferenceHandler == ReferenceHandler.Preserve)
{
ThrowHelpers.ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported();
}
options.MakeReadOnly();
}
private static void ResolveParameterInfo(
ParameterInfo parameter,
JsonTypeInfo parameterTypeInfo,
NullabilityInfoContext nullabilityInfoContext,
out bool hasDefaultValue,
out JsonNode? defaultValue,
out bool isNonNullable,
ref bool isRequired)
{
Debug.Assert(parameterTypeInfo.Type == parameter.ParameterType, "The typeInfo type must match the ParameterInfo type.");
// Incorporate the nullability information from the parameter.
isNonNullable = ReflectionHelpers.GetParameterNullability(nullabilityInfoContext, parameter) is NullabilityState.NotNull;
if (parameter.HasDefaultValue)
{
// Append the default value to the description.
object? defaultVal = ReflectionHelpers.GetNormalizedDefaultValue(parameter);
defaultValue = JsonSerializer.SerializeToNode(defaultVal, parameterTypeInfo);
hasDefaultValue = true;
}
else
{
// Parameter is not optional, mark as required.
isRequired = true;
defaultValue = null;
hasDefaultValue = false;
}
}
// Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/EnumConverter.cs#L498-L521
private static JsonSchema GetEnumConverterSchema(JsonTypeInfo typeInfo, JsonConverter converter)
{
Debug.Assert(typeInfo.Type.IsEnum && ReflectionHelpers.IsBuiltInConverter(converter), "must be using a built-in enum converter.");
if (converter is JsonConverterFactory factory)
{
converter = factory.CreateConverter(typeInfo.Type, typeInfo.Options)!;
}
ReflectionHelpers.GetEnumConverterConfig(converter, out JsonNamingPolicy? namingPolicy, out bool allowString);
if (allowString)
{
// This explicitly ignores the integer component in converters configured as AllowNumbers | AllowStrings
// which is the default for JsonStringEnumConverter. This sacrifices some precision in the schema for simplicity.
if (typeInfo.Type.GetCustomAttribute<FlagsAttribute>() is not null)
{
// Do not report enum values in case of flags.
return new() { Type = JsonSchemaType.String };
}
JsonArray enumValues = new();
foreach (string name in Enum.GetNames(typeInfo.Type))
{
// This does not account for custom names specified via the new
// JsonStringEnumMemberNameAttribute introduced in .NET 9.
string effectiveName = namingPolicy?.ConvertName(name) ?? name;
enumValues.Add((JsonNode)effectiveName);
}
return new() { Enum = enumValues };
}
return new() { Type = JsonSchemaType.Integer };
}
private static class JsonSchemaConstants
{
public const string SchemaPropertyName = "$schema";
public const string RefPropertyName = "$ref";
public const string CommentPropertyName = "$comment";
public const string TitlePropertyName = "title";
public const string DescriptionPropertyName = "description";
public const string TypePropertyName = "type";
public const string FormatPropertyName = "format";
public const string PatternPropertyName = "pattern";
public const string PropertiesPropertyName = "properties";
public const string RequiredPropertyName = "required";
public const string ItemsPropertyName = "items";
public const string AdditionalPropertiesPropertyName = "additionalProperties";
public const string EnumPropertyName = "enum";
public const string NotPropertyName = "not";
public const string AnyOfPropertyName = "anyOf";
public const string ConstPropertyName = "const";
public const string DefaultPropertyName = "default";
public const string MinLengthPropertyName = "minLength";
public const string MaxLengthPropertyName = "maxLength";
}
private static class ThrowHelpers
{
[DoesNotReturn]
public static void ThrowInvalidOperationException_MaxDepthReached() =>
throw new InvalidOperationException("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting.");
[DoesNotReturn]
public static void ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported() =>
throw new NotSupportedException("Schema generation not supported with ReferenceHandler.Preserve enabled.");
}
}
#endif

Просмотреть файл

@ -0,0 +1,77 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET9_0_OR_GREATER
using System;
using System.Reflection;
using System.Text.Json.Serialization.Metadata;
namespace System.Text.Json.Schema;
/// <summary>
/// Defines the context in which a JSON schema within a type graph is being generated.
/// </summary>
#if !SHARED_PROJECT
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
#endif
internal readonly struct JsonSchemaExporterContext
{
#pragma warning disable IDE1006 // Naming Styles
internal readonly string[] _path;
#pragma warning restore IDE1006 // Naming Styles
internal JsonSchemaExporterContext(
JsonTypeInfo typeInfo,
JsonTypeInfo? baseTypeInfo,
Type? declaringType,
JsonPropertyInfo? propertyInfo,
ParameterInfo? parameterInfo,
ICustomAttributeProvider? propertyAttributeProvider,
string[] path)
{
TypeInfo = typeInfo;
DeclaringType = declaringType;
BaseTypeInfo = baseTypeInfo;
PropertyInfo = propertyInfo;
ParameterInfo = parameterInfo;
PropertyAttributeProvider = propertyAttributeProvider;
_path = path;
}
/// <summary>
/// Gets the path to the schema document currently being generated.
/// </summary>
public ReadOnlySpan<string> Path => _path;
/// <summary>
/// Gets the <see cref="JsonTypeInfo"/> for the type being processed.
/// </summary>
public JsonTypeInfo TypeInfo { get; }
/// <summary>
/// Gets the declaring type of the property or parameter being processed.
/// </summary>
public Type? DeclaringType { get; }
/// <summary>
/// Gets the type info for the polymorphic base type if generated as a derived type.
/// </summary>
public JsonTypeInfo? BaseTypeInfo { get; }
/// <summary>
/// Gets the <see cref="JsonPropertyInfo"/> if the schema is being generated for a property.
/// </summary>
public JsonPropertyInfo? PropertyInfo { get; }
/// <summary>
/// Gets the <see cref="System.Reflection.ParameterInfo"/> if a constructor parameter
/// has been associated with the accompanying <see cref="PropertyInfo"/>.
/// </summary>
public ParameterInfo? ParameterInfo { get; }
/// <summary>
/// Gets the <see cref="ICustomAttributeProvider"/> corresponding to the property or field being processed.
/// </summary>
public ICustomAttributeProvider? PropertyAttributeProvider { get; }
}
#endif

Просмотреть файл

@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET9_0_OR_GREATER
using System;
using System.Text.Json.Nodes;
namespace System.Text.Json.Schema;
/// <summary>
/// Controls the behavior of the <see cref="JsonSchemaExporter"/> class.
/// </summary>
#if !SHARED_PROJECT
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
#endif
internal sealed class JsonSchemaExporterOptions
{
/// <summary>
/// Gets the default configuration object used by <see cref="JsonSchemaExporter"/>.
/// </summary>
public static JsonSchemaExporterOptions Default { get; } = new();
/// <summary>
/// Gets a value indicating whether non-nullable schemas should be generated for null oblivious reference types.
/// </summary>
/// <remarks>
/// Defaults to <see langword="false"/>. Due to restrictions in the run-time representation of nullable reference types
/// most occurrences are null oblivious and are treated as nullable by the serializer. A notable exception to that rule
/// are nullability annotations of field, property and constructor parameters which are represented in the contract metadata.
/// </remarks>
public bool TreatNullObliviousAsNonNullable { get; init; }
/// <summary>
/// Gets a callback that is invoked for every schema that is generated within the type graph.
/// </summary>
public Func<JsonSchemaExporterContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }
}
#endif

Просмотреть файл

@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET6_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
#pragma warning disable SA1623 // Property summary documentation should match accessors
namespace System.Reflection
{
/// <summary>
/// A class that represents nullability info.
/// </summary>
[ExcludeFromCodeCoverage]
internal sealed class NullabilityInfo
{
internal NullabilityInfo(Type type, NullabilityState readState, NullabilityState writeState,
NullabilityInfo? elementType, NullabilityInfo[] typeArguments)
{
Type = type;
ReadState = readState;
WriteState = writeState;
ElementType = elementType;
GenericTypeArguments = typeArguments;
}
/// <summary>
/// The <see cref="System.Type" /> of the member or generic parameter
/// to which this NullabilityInfo belongs.
/// </summary>
public Type Type { get; }
/// <summary>
/// The nullability read state of the member.
/// </summary>
public NullabilityState ReadState { get; internal set; }
/// <summary>
/// The nullability write state of the member.
/// </summary>
public NullabilityState WriteState { get; internal set; }
/// <summary>
/// If the member type is an array, gives the <see cref="NullabilityInfo" /> of the elements of the array, null otherwise.
/// </summary>
public NullabilityInfo? ElementType { get; }
/// <summary>
/// If the member type is a generic type, gives the array of <see cref="NullabilityInfo" /> for each type parameter.
/// </summary>
public NullabilityInfo[] GenericTypeArguments { get; }
}
/// <summary>
/// An enum that represents nullability state.
/// </summary>
internal enum NullabilityState
{
/// <summary>
/// Nullability context not enabled (oblivious).
/// </summary>
Unknown,
/// <summary>
/// Non nullable value or reference type.
/// </summary>
NotNull,
/// <summary>
/// Nullable value or reference type.
/// </summary>
Nullable,
}
}
#endif

Просмотреть файл

@ -0,0 +1,661 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET6_0_OR_GREATER
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S4136 // Method overloads should be grouped together
#pragma warning disable SA1202 // Elements should be ordered by access
#pragma warning disable IDE1006 // Naming Styles
namespace System.Reflection
{
/// <summary>
/// Provides APIs for populating nullability information/context from reflection members:
/// <see cref="ParameterInfo"/>, <see cref="FieldInfo"/>, <see cref="PropertyInfo"/> and <see cref="EventInfo"/>.
/// </summary>
[ExcludeFromCodeCoverage]
internal sealed class NullabilityInfoContext
{
private const string CompilerServicesNameSpace = "System.Runtime.CompilerServices";
private readonly Dictionary<Module, NotAnnotatedStatus> _publicOnlyModules = new();
private readonly Dictionary<MemberInfo, NullabilityState> _context = new();
[Flags]
private enum NotAnnotatedStatus
{
None = 0x0, // no restriction, all members annotated
Private = 0x1, // private members not annotated
Internal = 0x2, // internal members not annotated
}
private NullabilityState? GetNullableContext(MemberInfo? memberInfo)
{
while (memberInfo != null)
{
if (_context.TryGetValue(memberInfo, out NullabilityState state))
{
return state;
}
foreach (CustomAttributeData attribute in memberInfo.GetCustomAttributesData())
{
if (attribute.AttributeType.Name == "NullableContextAttribute" &&
attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
attribute.ConstructorArguments.Count == 1)
{
state = TranslateByte(attribute.ConstructorArguments[0].Value);
_context.Add(memberInfo, state);
return state;
}
}
memberInfo = memberInfo.DeclaringType;
}
return null;
}
/// <summary>
/// Populates <see cref="NullabilityInfo" /> for the given <see cref="ParameterInfo" />.
/// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
/// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
/// </summary>
/// <param name="parameterInfo">The parameter which nullability info gets populated.</param>
/// <exception cref="ArgumentNullException">If the parameterInfo parameter is null.</exception>
/// <returns><see cref="NullabilityInfo" />.</returns>
public NullabilityInfo Create(ParameterInfo parameterInfo)
{
IList<CustomAttributeData> attributes = parameterInfo.GetCustomAttributesData();
NullableAttributeStateParser parser = parameterInfo.Member is MethodBase method && IsPrivateOrInternalMethodAndAnnotationDisabled(method)
? NullableAttributeStateParser.Unknown
: CreateParser(attributes);
NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, parser);
if (nullability.ReadState != NullabilityState.Unknown)
{
CheckParameterMetadataType(parameterInfo, nullability);
}
CheckNullabilityAttributes(nullability, attributes);
return nullability;
}
private void CheckParameterMetadataType(ParameterInfo parameter, NullabilityInfo nullability)
{
ParameterInfo? metaParameter;
MemberInfo metaMember;
switch (parameter.Member)
{
case ConstructorInfo ctor:
var metaCtor = (ConstructorInfo)GetMemberMetadataDefinition(ctor);
metaMember = metaCtor;
metaParameter = GetMetaParameter(metaCtor, parameter);
break;
case MethodInfo method:
MethodInfo metaMethod = GetMethodMetadataDefinition(method);
metaMember = metaMethod;
metaParameter = string.IsNullOrEmpty(parameter.Name) ? metaMethod.ReturnParameter : GetMetaParameter(metaMethod, parameter);
break;
default:
return;
}
if (metaParameter != null)
{
CheckGenericParameters(nullability, metaMember, metaParameter.ParameterType, parameter.Member.ReflectedType);
}
}
private static ParameterInfo? GetMetaParameter(MethodBase metaMethod, ParameterInfo parameter)
{
var parameters = metaMethod.GetParameters();
for (int i = 0; i < parameters.Length; i++)
{
if (parameter.Position == i &&
parameter.Name == parameters[i].Name)
{
return parameters[i];
}
}
return null;
}
private static MethodInfo GetMethodMetadataDefinition(MethodInfo method)
{
if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
{
method = method.GetGenericMethodDefinition();
}
return (MethodInfo)GetMemberMetadataDefinition(method);
}
private static void CheckNullabilityAttributes(NullabilityInfo nullability, IList<CustomAttributeData> attributes)
{
var codeAnalysisReadState = NullabilityState.Unknown;
var codeAnalysisWriteState = NullabilityState.Unknown;
foreach (CustomAttributeData attribute in attributes)
{
if (attribute.AttributeType.Namespace == "System.Diagnostics.CodeAnalysis")
{
if (attribute.AttributeType.Name == "NotNullAttribute")
{
codeAnalysisReadState = NullabilityState.NotNull;
}
else if ((attribute.AttributeType.Name == "MaybeNullAttribute" ||
attribute.AttributeType.Name == "MaybeNullWhenAttribute") &&
codeAnalysisReadState == NullabilityState.Unknown &&
!IsValueTypeOrValueTypeByRef(nullability.Type))
{
codeAnalysisReadState = NullabilityState.Nullable;
}
else if (attribute.AttributeType.Name == "DisallowNullAttribute")
{
codeAnalysisWriteState = NullabilityState.NotNull;
}
else if (attribute.AttributeType.Name == "AllowNullAttribute" &&
codeAnalysisWriteState == NullabilityState.Unknown &&
!IsValueTypeOrValueTypeByRef(nullability.Type))
{
codeAnalysisWriteState = NullabilityState.Nullable;
}
}
}
if (codeAnalysisReadState != NullabilityState.Unknown)
{
nullability.ReadState = codeAnalysisReadState;
}
if (codeAnalysisWriteState != NullabilityState.Unknown)
{
nullability.WriteState = codeAnalysisWriteState;
}
}
/// <summary>
/// Populates <see cref="NullabilityInfo" /> for the given <see cref="PropertyInfo" />.
/// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
/// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
/// </summary>
/// <param name="propertyInfo">The parameter which nullability info gets populated.</param>
/// <exception cref="ArgumentNullException">If the propertyInfo parameter is null.</exception>
/// <returns><see cref="NullabilityInfo" />.</returns>
public NullabilityInfo Create(PropertyInfo propertyInfo)
{
MethodInfo? getter = propertyInfo.GetGetMethod(true);
MethodInfo? setter = propertyInfo.GetSetMethod(true);
bool annotationsDisabled = (getter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(getter))
&& (setter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(setter));
NullableAttributeStateParser parser = annotationsDisabled ? NullableAttributeStateParser.Unknown : CreateParser(propertyInfo.GetCustomAttributesData());
NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, parser);
if (getter != null)
{
CheckNullabilityAttributes(nullability, getter.ReturnParameter.GetCustomAttributesData());
}
else
{
nullability.ReadState = NullabilityState.Unknown;
}
if (setter != null)
{
CheckNullabilityAttributes(nullability, setter.GetParameters().Last().GetCustomAttributesData());
}
else
{
nullability.WriteState = NullabilityState.Unknown;
}
return nullability;
}
private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodBase method)
{
if ((method.IsPrivate || method.IsFamilyAndAssembly || method.IsAssembly) &&
IsPublicOnly(method.IsPrivate, method.IsFamilyAndAssembly, method.IsAssembly, method.Module))
{
return true;
}
return false;
}
/// <summary>
/// Populates <see cref="NullabilityInfo" /> for the given <see cref="EventInfo" />.
/// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
/// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
/// </summary>
/// <param name="eventInfo">The parameter which nullability info gets populated.</param>
/// <exception cref="ArgumentNullException">If the eventInfo parameter is null.</exception>
/// <returns><see cref="NullabilityInfo" />.</returns>
public NullabilityInfo Create(EventInfo eventInfo)
{
return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, CreateParser(eventInfo.GetCustomAttributesData()));
}
/// <summary>
/// Populates <see cref="NullabilityInfo" /> for the given <see cref="FieldInfo" />
/// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
/// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
/// </summary>
/// <param name="fieldInfo">The parameter which nullability info gets populated.</param>
/// <exception cref="ArgumentNullException">If the fieldInfo parameter is null.</exception>
/// <returns><see cref="NullabilityInfo" />.</returns>
public NullabilityInfo Create(FieldInfo fieldInfo)
{
IList<CustomAttributeData> attributes = fieldInfo.GetCustomAttributesData();
NullableAttributeStateParser parser = IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo) ? NullableAttributeStateParser.Unknown : CreateParser(attributes);
NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, parser);
CheckNullabilityAttributes(nullability, attributes);
return nullability;
}
private bool IsPrivateOrInternalFieldAndAnnotationDisabled(FieldInfo fieldInfo)
{
if ((fieldInfo.IsPrivate || fieldInfo.IsFamilyAndAssembly || fieldInfo.IsAssembly) &&
IsPublicOnly(fieldInfo.IsPrivate, fieldInfo.IsFamilyAndAssembly, fieldInfo.IsAssembly, fieldInfo.Module))
{
return true;
}
return false;
}
private bool IsPublicOnly(bool isPrivate, bool isFamilyAndAssembly, bool isAssembly, Module module)
{
if (!_publicOnlyModules.TryGetValue(module, out NotAnnotatedStatus value))
{
value = PopulateAnnotationInfo(module.GetCustomAttributesData());
_publicOnlyModules.Add(module, value);
}
if (value == NotAnnotatedStatus.None)
{
return false;
}
if (((isPrivate || isFamilyAndAssembly) && value.HasFlag(NotAnnotatedStatus.Private)) ||
(isAssembly && value.HasFlag(NotAnnotatedStatus.Internal)))
{
return true;
}
return false;
}
private static NotAnnotatedStatus PopulateAnnotationInfo(IList<CustomAttributeData> customAttributes)
{
foreach (CustomAttributeData attribute in customAttributes)
{
if (attribute.AttributeType.Name == "NullablePublicOnlyAttribute" &&
attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
attribute.ConstructorArguments.Count == 1)
{
if (attribute.ConstructorArguments[0].Value is bool boolValue && boolValue)
{
return NotAnnotatedStatus.Internal | NotAnnotatedStatus.Private;
}
else
{
return NotAnnotatedStatus.Private;
}
}
}
return NotAnnotatedStatus.None;
}
private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser)
{
int index = 0;
NullabilityInfo nullability = GetNullabilityInfo(memberInfo, type, parser, ref index);
if (nullability.ReadState != NullabilityState.Unknown)
{
TryLoadGenericMetaTypeNullability(memberInfo, nullability);
}
return nullability;
}
private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser, ref int index)
{
NullabilityState state = NullabilityState.Unknown;
NullabilityInfo? elementState = null;
NullabilityInfo[] genericArgumentsState = Array.Empty<NullabilityInfo>();
Type underlyingType = type;
if (underlyingType.IsByRef || underlyingType.IsPointer)
{
underlyingType = underlyingType.GetElementType()!;
}
if (underlyingType.IsValueType)
{
if (Nullable.GetUnderlyingType(underlyingType) is { } nullableUnderlyingType)
{
underlyingType = nullableUnderlyingType;
state = NullabilityState.Nullable;
}
else
{
state = NullabilityState.NotNull;
}
if (underlyingType.IsGenericType)
{
++index;
}
}
else
{
if (!parser.ParseNullableState(index++, ref state)
&& GetNullableContext(memberInfo) is { } contextState)
{
state = contextState;
}
if (underlyingType.IsArray)
{
elementState = GetNullabilityInfo(memberInfo, underlyingType.GetElementType()!, parser, ref index);
}
}
if (underlyingType.IsGenericType)
{
Type[] genericArguments = underlyingType.GetGenericArguments();
genericArgumentsState = new NullabilityInfo[genericArguments.Length];
for (int i = 0; i < genericArguments.Length; i++)
{
genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], parser, ref index);
}
}
return new NullabilityInfo(type, state, state, elementState, genericArgumentsState);
}
private static NullableAttributeStateParser CreateParser(IList<CustomAttributeData> customAttributes)
{
foreach (CustomAttributeData attribute in customAttributes)
{
if (attribute.AttributeType.Name == "NullableAttribute" &&
attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
attribute.ConstructorArguments.Count == 1)
{
return new NullableAttributeStateParser(attribute.ConstructorArguments[0].Value);
}
}
return new NullableAttributeStateParser(null);
}
private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, NullabilityInfo nullability)
{
MemberInfo? metaMember = GetMemberMetadataDefinition(memberInfo);
Type? metaType = null;
if (metaMember is FieldInfo field)
{
metaType = field.FieldType;
}
else if (metaMember is PropertyInfo property)
{
metaType = GetPropertyMetaType(property);
}
if (metaType != null)
{
CheckGenericParameters(nullability, metaMember!, metaType, memberInfo.ReflectedType);
}
}
private static MemberInfo GetMemberMetadataDefinition(MemberInfo member)
{
Type? type = member.DeclaringType;
if ((type != null) && type.IsGenericType && !type.IsGenericTypeDefinition)
{
return NullabilityInfoHelpers.GetMemberWithSameMetadataDefinitionAs(type.GetGenericTypeDefinition(), member);
}
return member;
}
private static Type GetPropertyMetaType(PropertyInfo property)
{
if (property.GetGetMethod(true) is MethodInfo method)
{
return method.ReturnType;
}
return property.GetSetMethod(true)!.GetParameters()[0].ParameterType;
}
private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType, Type? reflectedType)
{
if (metaType.IsGenericParameter)
{
if (nullability.ReadState == NullabilityState.NotNull)
{
_ = TryUpdateGenericParameterNullability(nullability, metaType, reflectedType);
}
}
else if (metaType.ContainsGenericParameters)
{
if (nullability.GenericTypeArguments.Length > 0)
{
Type[] genericArguments = metaType.GetGenericArguments();
for (int i = 0; i < genericArguments.Length; i++)
{
CheckGenericParameters(nullability.GenericTypeArguments[i], metaMember, genericArguments[i], reflectedType);
}
}
else if (nullability.ElementType is { } elementNullability && metaType.IsArray)
{
CheckGenericParameters(elementNullability, metaMember, metaType.GetElementType()!, reflectedType);
}
// We could also follow this branch for metaType.IsPointer, but since pointers must be unmanaged this
// will be a no-op regardless
else if (metaType.IsByRef)
{
CheckGenericParameters(nullability, metaMember, metaType.GetElementType()!, reflectedType);
}
}
}
private bool TryUpdateGenericParameterNullability(NullabilityInfo nullability, Type genericParameter, Type? reflectedType)
{
Debug.Assert(genericParameter.IsGenericParameter, "must be generic parameter");
if (reflectedType is not null
&& !genericParameter.IsGenericMethodParameter()
&& TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, reflectedType, reflectedType))
{
return true;
}
if (IsValueTypeOrValueTypeByRef(nullability.Type))
{
return true;
}
var state = NullabilityState.Unknown;
if (CreateParser(genericParameter.GetCustomAttributesData()).ParseNullableState(0, ref state))
{
nullability.ReadState = state;
nullability.WriteState = state;
return true;
}
if (GetNullableContext(genericParameter) is { } contextState)
{
nullability.ReadState = contextState;
nullability.WriteState = contextState;
return true;
}
return false;
}
private bool TryUpdateGenericTypeParameterNullabilityFromReflectedType(NullabilityInfo nullability, Type genericParameter, Type context, Type reflectedType)
{
Debug.Assert(genericParameter.IsGenericParameter && !genericParameter.IsGenericMethodParameter(), "must be generic parameter");
Type contextTypeDefinition = context.IsGenericType && !context.IsGenericTypeDefinition ? context.GetGenericTypeDefinition() : context;
if (genericParameter.DeclaringType == contextTypeDefinition)
{
return false;
}
Type? baseType = contextTypeDefinition.BaseType;
if (baseType is null)
{
return false;
}
if (!baseType.IsGenericType
|| (baseType.IsGenericTypeDefinition ? baseType : baseType.GetGenericTypeDefinition()) != genericParameter.DeclaringType)
{
return TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, baseType, reflectedType);
}
Type[] genericArguments = baseType.GetGenericArguments();
Type genericArgument = genericArguments[genericParameter.GenericParameterPosition];
if (genericArgument.IsGenericParameter)
{
return TryUpdateGenericParameterNullability(nullability, genericArgument, reflectedType);
}
NullableAttributeStateParser parser = CreateParser(contextTypeDefinition.GetCustomAttributesData());
int nullabilityStateIndex = 1; // start at 1 since index 0 is the type itself
for (int i = 0; i < genericParameter.GenericParameterPosition; i++)
{
nullabilityStateIndex += CountNullabilityStates(genericArguments[i]);
}
return TryPopulateNullabilityInfo(nullability, parser, ref nullabilityStateIndex);
static int CountNullabilityStates(Type type)
{
Type underlyingType = Nullable.GetUnderlyingType(type) ?? type;
if (underlyingType.IsGenericType)
{
int count = 1;
foreach (Type genericArgument in underlyingType.GetGenericArguments())
{
count += CountNullabilityStates(genericArgument);
}
return count;
}
if (underlyingType.HasElementType)
{
return (underlyingType.IsArray ? 1 : 0) + CountNullabilityStates(underlyingType.GetElementType()!);
}
return type.IsValueType ? 0 : 1;
}
}
#pragma warning disable SA1204 // Static elements should appear before instance elements
private static bool TryPopulateNullabilityInfo(NullabilityInfo nullability, NullableAttributeStateParser parser, ref int index)
#pragma warning restore SA1204 // Static elements should appear before instance elements
{
bool isValueType = IsValueTypeOrValueTypeByRef(nullability.Type);
if (!isValueType)
{
var state = NullabilityState.Unknown;
if (!parser.ParseNullableState(index, ref state))
{
return false;
}
nullability.ReadState = state;
nullability.WriteState = state;
}
if (!isValueType || (Nullable.GetUnderlyingType(nullability.Type) ?? nullability.Type).IsGenericType)
{
index++;
}
if (nullability.GenericTypeArguments.Length > 0)
{
foreach (NullabilityInfo genericTypeArgumentNullability in nullability.GenericTypeArguments)
{
_ = TryPopulateNullabilityInfo(genericTypeArgumentNullability, parser, ref index);
}
}
else if (nullability.ElementType is { } elementTypeNullability)
{
_ = TryPopulateNullabilityInfo(elementTypeNullability, parser, ref index);
}
return true;
}
private static NullabilityState TranslateByte(object? value)
{
return value is byte b ? TranslateByte(b) : NullabilityState.Unknown;
}
private static NullabilityState TranslateByte(byte b) =>
b switch
{
1 => NullabilityState.NotNull,
2 => NullabilityState.Nullable,
_ => NullabilityState.Unknown
};
private static bool IsValueTypeOrValueTypeByRef(Type type) =>
type.IsValueType || ((type.IsByRef || type.IsPointer) && type.GetElementType()!.IsValueType);
private readonly struct NullableAttributeStateParser
{
private static readonly object UnknownByte = (byte)0;
private readonly object? _nullableAttributeArgument;
public NullableAttributeStateParser(object? nullableAttributeArgument)
{
_nullableAttributeArgument = nullableAttributeArgument;
}
public static NullableAttributeStateParser Unknown => new(UnknownByte);
public bool ParseNullableState(int index, ref NullabilityState state)
{
switch (_nullableAttributeArgument)
{
case byte b:
state = TranslateByte(b);
return true;
case ReadOnlyCollection<CustomAttributeTypedArgument> args
when index < args.Count && args[index].Value is byte elementB:
state = TranslateByte(elementB);
return true;
default:
return false;
}
}
}
}
}
#endif

Просмотреть файл

@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#if !NET6_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
#pragma warning disable IDE1006 // Naming Styles
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
namespace System.Reflection
{
/// <summary>
/// Polyfills for System.Private.CoreLib internals.
/// </summary>
[ExcludeFromCodeCoverage]
internal static class NullabilityInfoHelpers
{
public static MemberInfo GetMemberWithSameMetadataDefinitionAs(Type type, MemberInfo member)
{
const BindingFlags all = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
foreach (var info in type.GetMembers(all))
{
if (info.HasSameMetadataDefinitionAs(member))
{
return info;
}
}
throw new MissingMemberException(type.FullName, member.Name);
}
// https://github.com/dotnet/runtime/blob/main/src/coreclr/System.Private.CoreLib/src/System/Reflection/MemberInfo.Internal.cs
public static bool HasSameMetadataDefinitionAs(this MemberInfo target, MemberInfo other)
{
return target.MetadataToken == other.MetadataToken &&
target.Module.Equals(other.Module);
}
// https://github.com/dotnet/runtime/issues/23493
public static bool IsGenericMethodParameter(this Type target)
{
return target.IsGenericParameter &&
target.DeclaringMethod != null;
}
}
}
#endif

Просмотреть файл

@ -0,0 +1,11 @@
# JsonSchemaExporter
Provides a polyfill for the [.NET 9 `JsonSchemaExporter` component](https://learn.microsoft.com/dotnet/standard/serialization/system-text-json/extract-schema) that is compatible with all supported targets using System.Text.Json version 8.
To use this in your project, add the following to your `.csproj` file:
```xml
<PropertyGroup>
<InjectJsonSchemaExporterOnLegacy>true</InjectJsonSchemaExporterOnLegacy>
</PropertyGroup>
```

Просмотреть файл

@ -12,11 +12,12 @@
<InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy> <InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy>
<InjectCallerAttributesOnLegacy>true</InjectCallerAttributesOnLegacy> <InjectCallerAttributesOnLegacy>true</InjectCallerAttributesOnLegacy>
<InjectBitOperationsOnLegacy>true</InjectBitOperationsOnLegacy> <InjectBitOperationsOnLegacy>true</InjectBitOperationsOnLegacy>
<InjectIsExtenalInitOnLegacy>true</InjectIsExtenalInitOnLegacy> <InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy>
<InjectSkipLocalsInitAttributeOnLegacy>true</InjectSkipLocalsInitAttributeOnLegacy> <InjectSkipLocalsInitAttributeOnLegacy>true</InjectSkipLocalsInitAttributeOnLegacy>
<InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy> <InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy>
<InjectTrimAttributesOnLegacy>true</InjectTrimAttributesOnLegacy> <InjectTrimAttributesOnLegacy>true</InjectTrimAttributesOnLegacy>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<IsAotCompatible Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net7.0'))">true</IsAotCompatible>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
@ -33,6 +34,10 @@
<PackageReference Include="System.Memory" Condition="!$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0'))" /> <PackageReference Include="System.Memory" Condition="!$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0'))" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
<PackageReference Include="System.Text.Json" />
</ItemGroup>
<ItemGroup> <ItemGroup>
<InternalsVisibleToTest Include="$(AssemblyName).Tests" /> <InternalsVisibleToTest Include="$(AssemblyName).Tests" />
<InternalsVisibleToDynamicProxyGenAssembly2 Include="*" /> <InternalsVisibleToDynamicProxyGenAssembly2 Include="*" />

Просмотреть файл

@ -90,4 +90,45 @@ public class AdditionalPropertiesDictionaryTests
Assert.Equal(default(T2), value); Assert.Equal(default(T2), value);
} }
} }
[Fact]
public void TryAdd_AddsOnlyIfNonExistent()
{
AdditionalPropertiesDictionary d = [];
Assert.False(d.ContainsKey("key"));
Assert.True(d.TryAdd("key", "value"));
Assert.True(d.ContainsKey("key"));
Assert.Equal("value", d["key"]);
Assert.False(d.TryAdd("key", "value2"));
Assert.True(d.ContainsKey("key"));
Assert.Equal("value", d["key"]);
}
[Fact]
public void Enumerator_EnumeratesAllItems()
{
AdditionalPropertiesDictionary d = [];
const int NumProperties = 10;
for (int i = 0; i < NumProperties; i++)
{
d.Add($"key{i}", $"value{i}");
}
Assert.Equal(NumProperties, d.Count);
// This depends on an implementation detail of the ordering in which the dictionary
// enumerates items. If that ever changes, this test will need to be updated.
int count = 0;
foreach (KeyValuePair<string, object?> item in d)
{
Assert.Equal($"key{count}", item.Key);
Assert.Equal($"value{count}", item.Value);
count++;
}
Assert.Equal(NumProperties, count);
}
} }

Просмотреть файл

@ -19,6 +19,7 @@ public class ChatOptionsTests
Assert.Null(options.TopK); Assert.Null(options.TopK);
Assert.Null(options.FrequencyPenalty); Assert.Null(options.FrequencyPenalty);
Assert.Null(options.PresencePenalty); Assert.Null(options.PresencePenalty);
Assert.Null(options.Seed);
Assert.Null(options.ResponseFormat); Assert.Null(options.ResponseFormat);
Assert.Null(options.ModelId); Assert.Null(options.ModelId);
Assert.Null(options.StopSequences); Assert.Null(options.StopSequences);
@ -33,6 +34,7 @@ public class ChatOptionsTests
Assert.Null(clone.TopK); Assert.Null(clone.TopK);
Assert.Null(clone.FrequencyPenalty); Assert.Null(clone.FrequencyPenalty);
Assert.Null(clone.PresencePenalty); Assert.Null(clone.PresencePenalty);
Assert.Null(options.Seed);
Assert.Null(clone.ResponseFormat); Assert.Null(clone.ResponseFormat);
Assert.Null(clone.ModelId); Assert.Null(clone.ModelId);
Assert.Null(clone.StopSequences); Assert.Null(clone.StopSequences);
@ -69,6 +71,7 @@ public class ChatOptionsTests
options.TopK = 42; options.TopK = 42;
options.FrequencyPenalty = 0.4f; options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f; options.PresencePenalty = 0.5f;
options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json; options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId"; options.ModelId = "modelId";
options.StopSequences = stopSequences; options.StopSequences = stopSequences;
@ -82,6 +85,7 @@ public class ChatOptionsTests
Assert.Equal(42, options.TopK); Assert.Equal(42, options.TopK);
Assert.Equal(0.4f, options.FrequencyPenalty); Assert.Equal(0.4f, options.FrequencyPenalty);
Assert.Equal(0.5f, options.PresencePenalty); Assert.Equal(0.5f, options.PresencePenalty);
Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, options.ResponseFormat); Assert.Same(ChatResponseFormat.Json, options.ResponseFormat);
Assert.Equal("modelId", options.ModelId); Assert.Equal("modelId", options.ModelId);
Assert.Same(stopSequences, options.StopSequences); Assert.Same(stopSequences, options.StopSequences);
@ -96,6 +100,7 @@ public class ChatOptionsTests
Assert.Equal(42, clone.TopK); Assert.Equal(42, clone.TopK);
Assert.Equal(0.4f, clone.FrequencyPenalty); Assert.Equal(0.4f, clone.FrequencyPenalty);
Assert.Equal(0.5f, clone.PresencePenalty); Assert.Equal(0.5f, clone.PresencePenalty);
Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat); Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat);
Assert.Equal("modelId", clone.ModelId); Assert.Equal("modelId", clone.ModelId);
Assert.Equal(stopSequences, clone.StopSequences); Assert.Equal(stopSequences, clone.StopSequences);
@ -126,6 +131,7 @@ public class ChatOptionsTests
options.TopK = 42; options.TopK = 42;
options.FrequencyPenalty = 0.4f; options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f; options.PresencePenalty = 0.5f;
options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json; options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId"; options.ModelId = "modelId";
options.StopSequences = stopSequences; options.StopSequences = stopSequences;
@ -148,6 +154,7 @@ public class ChatOptionsTests
Assert.Equal(42, deserialized.TopK); Assert.Equal(42, deserialized.TopK);
Assert.Equal(0.4f, deserialized.FrequencyPenalty); Assert.Equal(0.4f, deserialized.FrequencyPenalty);
Assert.Equal(0.5f, deserialized.PresencePenalty); Assert.Equal(0.5f, deserialized.PresencePenalty);
Assert.Equal(12345, deserialized.Seed);
Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat); Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat); Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.Equal("modelId", deserialized.ModelId); Assert.Equal("modelId", deserialized.ModelId);

Просмотреть файл

@ -5,16 +5,27 @@
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<NoWarn>$(NoWarn);CA1063;CA1861;CA2201;VSTHRD003</NoWarn> <NoWarn>$(NoWarn);CA1063;CA1861;CA2201;VSTHRD003;S104</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy>
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
<InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy> <InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy>
<InjectStringSyntaxAttributeOnLegacy>true</InjectStringSyntaxAttributeOnLegacy>
</PropertyGroup> </PropertyGroup>
<ItemGroup>
<Compile Include="..\..\Shared\JsonSchemaExporter\SchemaTestHelpers.cs" Link="Utilities\SchemaTestHelpers.cs" />
<Compile Include="..\..\Shared\JsonSchemaExporter\TestData.cs" Link="Utilities\TestData.cs" />
<Compile Include="..\..\Shared\JsonSchemaExporter\TestTypes.cs" Link="Utilities\TestTypes.cs" />
</ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="System.Memory.Data" /> <PackageReference Include="System.Memory.Data" />
<PackageReference Include="JsonSchema.Net" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

Просмотреть файл

@ -3,7 +3,9 @@
using System.ComponentModel; using System.ComponentModel;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Microsoft.Extensions.AI.JsonSchemaExporter;
using Xunit; using Xunit;
namespace Microsoft.Extensions.AI; namespace Microsoft.Extensions.AI;
@ -130,7 +132,7 @@ public static class AIJsonUtilitiesTests
} }
[Fact] [Fact]
public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString() public static void CreateParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{ {
JsonElement expected = JsonDocument.Parse(""" JsonElement expected = JsonDocument.Parse("""
{ {
@ -158,4 +160,38 @@ public static class AIJsonUtilitiesTests
A = 1, A = 1,
B = 2 B = 2
} }
[Fact]
public static void CreateJsonSchema_CanBeBoolean()
{
JsonElement schema = AIJsonUtilities.CreateJsonSchema(typeof(object));
Assert.Equal(JsonValueKind.True, schema.ValueKind);
}
[Theory]
[MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))]
public static void CreateJsonSchema_ValidateWithTestData(ITestData testData)
{
// Stress tests the schema generation method using types from the JsonSchemaExporter test battery.
JsonSerializerOptions options = testData.Options is { } opts
? new(opts) { TypeInfoResolver = TestTypes.TestTypesContext.Default }
: TestTypes.TestTypesContext.Default.Options;
JsonElement schema = AIJsonUtilities.CreateJsonSchema(testData.Type, serializerOptions: options);
JsonNode? schemaAsNode = JsonSerializer.SerializeToNode(schema, options);
Assert.NotNull(schemaAsNode);
Assert.Equal(testData.ExpectedJsonSchema.GetValueKind(), schemaAsNode.GetValueKind());
if (testData.Value is null || testData.WritesNumbersAsStrings)
{
// By design, our generated schema does not accept null root values
// or numbers formatted as strings, so we skip schema validation.
return;
}
JsonNode? serializedValue = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options);
SchemaTestHelpers.AssertDocumentMatchesSchema(schemaAsNode, serializedValue);
}
} }

Просмотреть файл

@ -0,0 +1,26 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFrameworks>$(LatestTargetFramework)</TargetFrameworks>
<PublishAot>true</PublishAot>
<TrimmerSingleWarn>false</TrimmerSingleWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>
<ItemGroup>
<TrimmerRootAssembly Include="Microsoft.Extensions.AI" />
<TrimmerRootAssembly Include="Microsoft.Extensions.AI.Abstractions" />
<TrimmerRootAssembly Include="Microsoft.Extensions.AI.Ollama" />
<!-- Azure.AI.Inference produces many warnings
<TrimmerRootAssembly Include="Microsoft.Extensions.AI.AzureAIInference" />
-->
<!-- OpenAI produces a few warnings
<TrimmerRootAssembly Include="Microsoft.Extensions.AI.OpenAI" />
-->
<TrimmerRootAssembly Update="@(TrimmerRootAssembly)" Path="$(RepoRoot)\src\Libraries\%(Identity)\%(Identity).csproj" />
<ProjectReference Include="@(TrimmerRootAssembly->'%(Path)')" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,22 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#pragma warning disable S125 // Remove this commented out code
using Microsoft.Extensions.AI;
// Use types from each library.
// Microsoft.Extensions.AI.Ollama
using var b = new OllamaChatClient("http://localhost:11434", "llama3.2");
// Microsoft.Extensions.AI.AzureAIInference
// using var a = new Azure.AI.Inference.ChatCompletionClient(new Uri("http://localhost"), new("apikey")); // uncomment once warnings in Azure.AI.Inference are addressed
// Microsoft.Extensions.AI.OpenAI
// using var c = new OpenAI.OpenAIClient("apikey").AsChatClient("gpt-4o-mini"); // uncomment once warnings in OpenAI are addressed
// Microsoft.Extensions.AI
AIFunctionFactory.Create(() => { });
System.Console.WriteLine("Success!");

Просмотреть файл

@ -247,8 +247,8 @@ public class AzureAIInferenceChatClientTests
], ],
"presence_penalty": 0.5, "presence_penalty": 0.5,
"frequency_penalty": 0.75, "frequency_penalty": 0.75,
"model": "gpt-4o-mini", "seed": 42,
"seed": 42 "model": "gpt-4o-mini"
} }
"""; """;
@ -303,7 +303,7 @@ public class AzureAIInferenceChatClientTests
FrequencyPenalty = 0.75f, FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f, PresencePenalty = 0.5f,
StopSequences = ["great"], StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42L }, Seed = 42,
}); });
Assert.NotNull(response); Assert.NotNull(response);

Просмотреть файл

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
@ -132,6 +133,27 @@ public abstract class ChatClientIntegrationTests : IDisposable
Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount); Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount);
} }
protected virtual string? GetModel_MultiModal_DescribeImage() => null;
[ConditionalFact]
public virtual async Task MultiModal_DescribeImage()
{
SkipIfNotEnabled();
var response = await _chatClient.CompleteAsync(
[
new(ChatRole.User,
[
new TextContent("What does this logo say?"),
new ImageContent(GetImageDataUri()),
])
],
new() { ModelId = GetModel_MultiModal_DescribeImage() });
Assert.Single(response.Choices);
Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text);
}
[ConditionalFact] [ConditionalFact]
public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless() public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless()
{ {
@ -714,6 +736,15 @@ public abstract class ChatClientIntegrationTests : IDisposable
Unknown, Unknown,
} }
private static Uri GetImageDataUri()
{
using Stream? s = typeof(ChatClientIntegrationTests).Assembly.GetManifestResourceStream("Microsoft.Extensions.AI.dotnet.png");
Assert.NotNull(s);
MemoryStream ms = new();
s.CopyTo(ms);
return new Uri($"data:image/png;base64,{Convert.ToBase64String(ms.ToArray())}");
}
[MemberNotNull(nameof(_chatClient))] [MemberNotNull(nameof(_chatClient))]
protected void SkipIfNotEnabled() protected void SkipIfNotEnabled()
{ {

Просмотреть файл

@ -15,6 +15,10 @@
<InjectSharedThrow>true</InjectSharedThrow> <InjectSharedThrow>true</InjectSharedThrow>
</PropertyGroup> </PropertyGroup>
<ItemGroup>
<EmbeddedResource Include="dotnet.png" />
</ItemGroup>
<ItemGroup> <ItemGroup>
<Compile Include="..\Microsoft.Extensions.AI.Abstractions.Tests\CapturingLogger.cs" /> <Compile Include="..\Microsoft.Extensions.AI.Abstractions.Tests\CapturingLogger.cs" />
<Compile Include="..\Microsoft.Extensions.AI.Abstractions.Tests\TestChatClient.cs" /> <Compile Include="..\Microsoft.Extensions.AI.Abstractions.Tests\TestChatClient.cs" />

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 2.1 KiB

Просмотреть файл

@ -30,6 +30,8 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
public override Task FunctionInvocation_RequireSpecific() => public override Task FunctionInvocation_RequireSpecific() =>
throw new SkipTestException("Ollama does not currently support requiring function invocation."); throw new SkipTestException("Ollama does not currently support requiring function invocation.");
protected override string? GetModel_MultiModal_DescribeImage() => "llava";
[ConditionalFact] [ConditionalFact]
public async Task PromptBasedFunctionCalling_NoArgs() public async Task PromptBasedFunctionCalling_NoArgs()
{ {
@ -47,7 +49,7 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
ModelId = "llama3:8b", ModelId = "llama3:8b",
Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")], Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")],
Temperature = 0, Temperature = 0,
AdditionalProperties = new() { ["seed"] = 0L }, Seed = 0,
}); });
Assert.Single(response.Choices); Assert.Single(response.Choices);
@ -81,7 +83,7 @@ public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests
{ {
Tools = [stockPriceTool, irrelevantTool], Tools = [stockPriceTool, irrelevantTool],
Temperature = 0, Temperature = 0,
AdditionalProperties = new() { ["seed"] = 0L }, Seed = 0,
}); });
Assert.Single(response.Choices); Assert.Single(response.Choices);

Просмотреть файл

@ -254,7 +254,7 @@ public class OllamaChatClientTests
FrequencyPenalty = 0.75f, FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f, PresencePenalty = 0.5f,
StopSequences = ["great"], StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42 }, Seed = 42,
}); });
Assert.NotNull(response); Assert.NotNull(response);

Просмотреть файл

@ -348,7 +348,7 @@ public class OpenAIChatClientTests
FrequencyPenalty = 0.75f, FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f, PresencePenalty = 0.5f,
StopSequences = ["great"], StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42 }, Seed = 42,
}); });
Assert.NotNull(response); Assert.NotNull(response);

Просмотреть файл

@ -26,11 +26,13 @@ public class ConfigureOptionsChatClientTests
Assert.Throws<ArgumentNullException>("configureOptions", () => builder.UseChatOptions(null!)); Assert.Throws<ArgumentNullException>("configureOptions", () => builder.UseChatOptions(null!));
} }
[Fact] [Theory]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient() [InlineData(false)]
[InlineData(true)]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
{ {
ChatOptions providedOptions = new(); ChatOptions providedOptions = new();
ChatOptions returnedOptions = new(); ChatOptions? returnedOptions = nullReturned ? null : new();
ChatCompletion expectedCompletion = new(Array.Empty<ChatMessage>()); ChatCompletion expectedCompletion = new(Array.Empty<ChatMessage>());
var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray(); var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray();
using CancellationTokenSource cts = new(); using CancellationTokenSource cts = new();

Просмотреть файл

@ -0,0 +1,58 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.AI;
public class ConfigureOptionsEmbeddingGeneratorTests
{
[Fact]
public void ConfigureOptionsEmbeddingGenerator_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("innerGenerator", () => new ConfigureOptionsEmbeddingGenerator<string, Embedding<float>>(null!, _ => new EmbeddingGenerationOptions()));
Assert.Throws<ArgumentNullException>("configureOptions", () => new ConfigureOptionsEmbeddingGenerator<string, Embedding<float>>(new TestEmbeddingGenerator(), null!));
}
[Fact]
public void UseEmbeddingGenerationOptions_InvalidArgs_Throws()
{
var builder = new EmbeddingGeneratorBuilder<string, Embedding<float>>();
Assert.Throws<ArgumentNullException>("configureOptions", () => builder.UseEmbeddingGenerationOptions(null!));
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
{
EmbeddingGenerationOptions providedOptions = new();
EmbeddingGenerationOptions? returnedOptions = nullReturned ? null : new();
GeneratedEmbeddings<Embedding<float>> expectedEmbeddings = [];
using CancellationTokenSource cts = new();
using IEmbeddingGenerator<string, Embedding<float>> innerGenerator = new TestEmbeddingGenerator
{
GenerateAsyncCallback = (inputs, options, cancellationToken) =>
{
Assert.Same(returnedOptions, options);
Assert.Equal(cts.Token, cancellationToken);
return Task.FromResult(expectedEmbeddings);
}
};
using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
.UseEmbeddingGenerationOptions(options =>
{
Assert.Same(providedOptions, options);
return returnedOptions;
})
.Use(innerGenerator);
var embeddings = await generator.GenerateAsync([], providedOptions, cts.Token);
Assert.Same(expectedEmbeddings, embeddings);
}
}

Просмотреть файл

@ -0,0 +1,205 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.Tracing;
using Microsoft.Extensions.Caching.Hybrid.Internal;
using Xunit.Abstractions;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
public class HybridCacheEventSourceTests(ITestOutputHelper log, TestEventListener listener) : IClassFixture<TestEventListener>
{
// see notes in TestEventListener for context on fixture usage
[SkippableFact]
public void MatchesNameAndGuid()
{
// Assert
Assert.Equal("Microsoft-Extensions-HybridCache", listener.Source.Name);
Assert.Equal(Guid.Parse("b3aca39e-5dc9-5e21-f669-b72225b66cfc"), listener.Source.Guid); // from name
}
[SkippableFact]
public async Task LocalCacheHit()
{
AssertEnabled();
listener.Reset().Source.LocalCacheHit();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheHit, "LocalCacheHit", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-local-cache-hits", "Total Local Cache Hits", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task LocalCacheMiss()
{
AssertEnabled();
listener.Reset().Source.LocalCacheMiss();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheMiss, "LocalCacheMiss", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-local-cache-misses", "Total Local Cache Misses", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task DistributedCacheGet()
{
AssertEnabled();
listener.Reset().Source.DistributedCacheGet();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheGet, "DistributedCacheGet", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("current-distributed-cache-fetches", "Current Distributed Cache Fetches", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task DistributedCacheHit()
{
AssertEnabled();
listener.Reset().Source.DistributedCacheGet();
listener.Reset(resetCounters: false).Source.DistributedCacheHit();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheHit, "DistributedCacheHit", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-distributed-cache-hits", "Total Distributed Cache Hits", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task DistributedCacheMiss()
{
AssertEnabled();
listener.Reset().Source.DistributedCacheGet();
listener.Reset(resetCounters: false).Source.DistributedCacheMiss();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheMiss, "DistributedCacheMiss", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-distributed-cache-misses", "Total Distributed Cache Misses", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task DistributedCacheFailed()
{
AssertEnabled();
listener.Reset().Source.DistributedCacheGet();
listener.Reset(resetCounters: false).Source.DistributedCacheFailed();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheFailed, "DistributedCacheFailed", EventLevel.Error);
await AssertCountersAsync();
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task UnderlyingDataQueryStart()
{
AssertEnabled();
listener.Reset().Source.UnderlyingDataQueryStart();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryStart, "UnderlyingDataQueryStart", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("current-data-query", "Current Data Queries", 1);
listener.AssertCounter("total-data-query", "Total Data Queries", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task UnderlyingDataQueryComplete()
{
AssertEnabled();
listener.Reset().Source.UnderlyingDataQueryStart();
listener.Reset(resetCounters: false).Source.UnderlyingDataQueryComplete();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryComplete, "UnderlyingDataQueryComplete", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-data-query", "Total Data Queries", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task UnderlyingDataQueryFailed()
{
AssertEnabled();
listener.Reset().Source.UnderlyingDataQueryStart();
listener.Reset(resetCounters: false).Source.UnderlyingDataQueryFailed();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryFailed, "UnderlyingDataQueryFailed", EventLevel.Error);
await AssertCountersAsync();
listener.AssertCounter("total-data-query", "Total Data Queries", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task LocalCacheWrite()
{
AssertEnabled();
listener.Reset().Source.LocalCacheWrite();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheWrite, "LocalCacheWrite", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-local-cache-writes", "Total Local Cache Writes", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task DistributedCacheWrite()
{
AssertEnabled();
listener.Reset().Source.DistributedCacheWrite();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheWrite, "DistributedCacheWrite", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-distributed-cache-writes", "Total Distributed Cache Writes", 1);
listener.AssertRemainingCountersZero();
}
[SkippableFact]
public async Task StampedeJoin()
{
AssertEnabled();
listener.Reset().Source.StampedeJoin();
listener.AssertSingleEvent(HybridCacheEventSource.EventIdStampedeJoin, "StampedeJoin", EventLevel.Verbose);
await AssertCountersAsync();
listener.AssertCounter("total-stampede-joins", "Total Stampede Joins", 1);
listener.AssertRemainingCountersZero();
}
private void AssertEnabled()
{
// including this data for visibility when tests fail - ETW subsystem can be ... weird
log.WriteLine($".NET {Environment.Version} on {Environment.OSVersion}, {IntPtr.Size * 8}-bit");
Skip.IfNot(listener.Source.IsEnabled(), "Event source not enabled");
}
private async Task AssertCountersAsync()
{
var count = await listener.TryAwaitCountersAsync();
// ETW counters timing can be painfully unpredictable; generally
// it'll work fine locally, especially on modern .NET, but:
// CI servers and netfx in particular - not so much. The tests
// can still observe and validate the simple events, though, which
// should be enough to be credible that the eventing system is
// fundamentally working. We're not meant to be testing that
// the counters system *itself* works!
Skip.If(count == 0, "No counters received");
}
}

Просмотреть файл

@ -0,0 +1,84 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.Extensions.Logging;
using Xunit.Abstractions;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
// dummy implementation for collecting test output
internal class LogCollector : ILoggerProvider
{
private readonly List<(string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)> _items = [];
public (string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)[] ToArray()
{
lock (_items)
{
return _items.ToArray();
}
}
public void WriteTo(ITestOutputHelper log)
{
lock (_items)
{
foreach (var logItem in _items)
{
var errSuffix = logItem.exception is null ? "" : $" - {logItem.exception.Message}";
log.WriteLine($"{logItem.categoryName} {logItem.eventId}: {logItem.message}{errSuffix}");
}
}
}
public void AssertErrors(int[] errorIds)
{
lock (_items)
{
bool same;
if (errorIds.Length == _items.Count)
{
int index = 0;
same = true;
foreach (var item in _items)
{
if (item.eventId.Id != errorIds[index++])
{
same = false;
break;
}
}
}
else
{
same = false;
}
if (!same)
{
// we expect this to fail, then
Assert.Equal(string.Join(",", errorIds), string.Join(",", _items.Select(static x => x.eventId.Id)));
}
}
}
ILogger ILoggerProvider.CreateLogger(string categoryName) => new TypedLogCollector(this, categoryName);
void IDisposable.Dispose()
{
// nothing to do
}
private sealed class TypedLogCollector(LogCollector parent, string categoryName) : ILogger
{
IDisposable? ILogger.BeginScope<TState>(TState state) => null;
bool ILogger.IsEnabled(LogLevel logLevel) => true;
void ILogger.Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
lock (parent._items)
{
parent._items.Add((categoryName, logLevel, eventId, exception, formatter(state, exception)));
}
}
}
}

Просмотреть файл

@ -12,13 +12,15 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" />
<PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" /> <PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" />
<PackageReference Include="Microsoft.Extensions.Caching.SqlServer" /> <PackageReference Include="Microsoft.Extensions.Caching.SqlServer" />
<PackageReference Include="Microsoft.Extensions.Configuration" /> <PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" /> <PackageReference Include="Microsoft.Extensions.Configuration.Binder" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" /> <PackageReference Include="Microsoft.Extensions.Configuration.Json" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" /> <PackageReference Include="Microsoft.Extensions.DependencyInjection" />
<PackageReference Include="Microsoft.Data.SqlClient" /> <PackageReference Include="Microsoft.Extensions.Logging" />
<PackageReference Include="Xunit.SkippableFact" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

Просмотреть файл

@ -0,0 +1,31 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.Extensions.Caching.Distributed;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
// dummy L2 that doesn't actually store anything
internal class NullDistributedCache : IDistributedCache
{
byte[]? IDistributedCache.Get(string key) => null;
Task<byte[]?> IDistributedCache.GetAsync(string key, CancellationToken token) => Task.FromResult<byte[]?>(null);
void IDistributedCache.Refresh(string key)
{
// nothing to do
}
Task IDistributedCache.RefreshAsync(string key, CancellationToken token) => Task.CompletedTask;
void IDistributedCache.Remove(string key)
{
// nothing to do
}
Task IDistributedCache.RemoveAsync(string key, CancellationToken token) => Task.CompletedTask;
void IDistributedCache.Set(string key, byte[] value, DistributedCacheEntryOptions options)
{
// nothing to do
}
Task IDistributedCache.SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token) => Task.CompletedTask;
}

Просмотреть файл

@ -1,31 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements. // Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license. // The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.ComponentModel;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Hybrid.Internal; using Microsoft.Extensions.Caching.Hybrid.Internal;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit.Abstractions;
namespace Microsoft.Extensions.Caching.Hybrid.Tests; namespace Microsoft.Extensions.Caching.Hybrid.Tests;
public class SizeTests public class SizeTests(ITestOutputHelper log)
{ {
[Theory] [Theory]
[InlineData(null, true)] // does not enforce size limits [InlineData("abc", null, true, null, null)] // does not enforce size limits
[InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
[InlineData(1024L, true)] // reasonable size limit [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1) [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
[InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
[InlineData("abc", 1024L, true, null, null)] // reasonable size limit
[InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota
[InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded
[InlineData("a\u0000c", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key
[InlineData("a\u001Fc", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key
[InlineData("a\u0020c", null, true, null, null)] // fine (this is just space)
public async Task ValidateSizeLimit_Immutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength,
params int[] errorIds)
{ {
using var collector = new LogCollector();
var services = new ServiceCollection(); var services = new ServiceCollection();
services.AddMemoryCache(options => options.SizeLimit = sizeLimit); services.AddMemoryCache(options => options.SizeLimit = sizeLimit);
services.AddHybridCache(); services.AddHybridCache(options =>
{
if (maximumKeyLength.HasValue)
{
options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault();
}
if (maximumPayloadBytes.HasValue)
{
options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault();
}
});
services.AddLogging(options =>
{
options.ClearProviders();
options.AddProvider(collector);
});
using ServiceProvider provider = services.BuildServiceProvider(); using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>()); var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>());
const string Key = "abc";
// this looks weird; it is intentionally not a const - we want to check // this looks weird; it is intentionally not a const - we want to check
// same instance without worrying about interning from raw literals // same instance without worrying about interning from raw literals
string expected = new("simple value".ToArray()); string expected = new("simple value".ToArray());
var actual = await cache.GetOrCreateAsync<string>(Key, ct => new(expected)); var actual = await cache.GetOrCreateAsync<string>(key!, ct => new(expected));
// expect same contents // expect same contents
Assert.Equal(expected, actual); Assert.Equal(expected, actual);
@ -35,7 +64,7 @@ public class SizeTests
Assert.Same(expected, actual); Assert.Same(expected, actual);
// rinse and repeat, to check we get the value from L1 // rinse and repeat, to check we get the value from L1
actual = await cache.GetOrCreateAsync<string>(Key, ct => new(Guid.NewGuid().ToString())); actual = await cache.GetOrCreateAsync<string>(key!, ct => new(Guid.NewGuid().ToString()));
if (expectFromL1) if (expectFromL1)
{ {
@ -51,30 +80,54 @@ public class SizeTests
// L1 cache not used // L1 cache not used
Assert.NotEqual(expected, actual); Assert.NotEqual(expected, actual);
} }
collector.WriteTo(log);
collector.AssertErrors(errorIds);
} }
[Theory] [Theory]
[InlineData(null, true)] // does not enforce size limits [InlineData("abc", null, true, null, null)] // does not enforce size limits
[InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
[InlineData(1024L, true)] // reasonable size limit [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
public async Task ValidateSizeLimit_Mutable(long? sizeLimit, bool expectFromL1) [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
[InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
[InlineData("abc", 1024L, true, null, null)] // reasonable size limit
[InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota
[InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded
public async Task ValidateSizeLimit_Mutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength,
params int[] errorIds)
{ {
using var collector = new LogCollector();
var services = new ServiceCollection(); var services = new ServiceCollection();
services.AddMemoryCache(options => options.SizeLimit = sizeLimit); services.AddMemoryCache(options => options.SizeLimit = sizeLimit);
services.AddHybridCache(); services.AddHybridCache(options =>
{
if (maximumKeyLength.HasValue)
{
options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault();
}
if (maximumPayloadBytes.HasValue)
{
options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault();
}
});
services.AddLogging(options =>
{
options.ClearProviders();
options.AddProvider(collector);
});
using ServiceProvider provider = services.BuildServiceProvider(); using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>()); var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>());
const string Key = "abc";
string expected = "simple value"; string expected = "simple value";
var actual = await cache.GetOrCreateAsync<MutablePoco>(Key, ct => new(new MutablePoco { Value = expected })); var actual = await cache.GetOrCreateAsync<MutablePoco>(key!, ct => new(new MutablePoco { Value = expected }));
// expect same contents // expect same contents
Assert.Equal(expected, actual.Value); Assert.Equal(expected, actual.Value);
// rinse and repeat, to check we get the value from L1 // rinse and repeat, to check we get the value from L1
actual = await cache.GetOrCreateAsync<MutablePoco>(Key, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() })); actual = await cache.GetOrCreateAsync<MutablePoco>(key!, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() }));
if (expectFromL1) if (expectFromL1)
{ {
@ -86,10 +139,217 @@ public class SizeTests
// L1 cache not used // L1 cache not used
Assert.NotEqual(expected, actual.Value); Assert.NotEqual(expected, actual.Value);
} }
collector.WriteTo(log);
collector.AssertErrors(errorIds);
}
[Theory]
[InlineData("some value", false, 1, 1, 2, false)]
[InlineData("read fail", false, 1, 1, 1, true, Log.IdDeserializationFailure)]
[InlineData("write fail", true, 1, 1, 0, true, Log.IdSerializationFailure)]
public async Task BrokenSerializer_Mutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, params int[] errorIds)
{
using var collector = new LogCollector();
var services = new ServiceCollection();
services.AddMemoryCache();
services.AddSingleton<IDistributedCache, NullDistributedCache>();
var serializer = new MutablePoco.Serializer();
services.AddHybridCache().AddSerializer(serializer);
services.AddLogging(options =>
{
options.ClearProviders();
options.AddProvider(collector);
});
using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>());
int actualRunCount = 0;
Func<CancellationToken, ValueTask<MutablePoco>> func = _ =>
{
Interlocked.Increment(ref actualRunCount);
return new(new MutablePoco { Value = value });
};
if (expectKnownFailure)
{
await Assert.ThrowsAsync<KnownFailureException>(async () => await cache.GetOrCreateAsync("key", func));
}
else
{
var first = await cache.GetOrCreateAsync("key", func);
var second = await cache.GetOrCreateAsync("key", func);
Assert.Equal(value, first.Value);
Assert.Equal(value, second.Value);
if (same)
{
Assert.Same(first, second);
}
else
{
Assert.NotSame(first, second);
}
}
Assert.Equal(runCount, Volatile.Read(ref actualRunCount));
Assert.Equal(serializeCount, serializer.WriteCount);
Assert.Equal(deserializeCount, serializer.ReadCount);
collector.WriteTo(log);
collector.AssertErrors(errorIds);
}
[Theory]
[InlineData("some value", true, 1, 1, 0, false, true)]
[InlineData("read fail", true, 1, 1, 0, false, true)]
[InlineData("write fail", true, 1, 1, 0, true, true, Log.IdSerializationFailure)]
// without L2, we only need the serializer for sizing purposes (L1), not used for deserialize
[InlineData("some value", true, 1, 1, 0, false, false)]
[InlineData("read fail", true, 1, 1, 0, false, false)]
[InlineData("write fail", true, 1, 1, 0, true, false, Log.IdSerializationFailure)]
[System.Diagnostics.CodeAnalysis.SuppressMessage("Major Code Smell", "S107:Methods should not have too many parameters", Justification = "Test scenario range; reducing duplication")]
public async Task BrokenSerializer_Immutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, bool withL2,
params int[] errorIds)
{
using var collector = new LogCollector();
var services = new ServiceCollection();
services.AddMemoryCache();
if (withL2)
{
services.AddSingleton<IDistributedCache, NullDistributedCache>();
}
var serializer = new ImmutablePoco.Serializer();
services.AddHybridCache().AddSerializer(serializer);
services.AddLogging(options =>
{
options.ClearProviders();
options.AddProvider(collector);
});
using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType<DefaultHybridCache>(provider.GetRequiredService<HybridCache>());
int actualRunCount = 0;
Func<CancellationToken, ValueTask<ImmutablePoco>> func = _ =>
{
Interlocked.Increment(ref actualRunCount);
return new(new ImmutablePoco(value));
};
if (expectKnownFailure)
{
await Assert.ThrowsAsync<KnownFailureException>(async () => await cache.GetOrCreateAsync("key", func));
}
else
{
var first = await cache.GetOrCreateAsync("key", func);
var second = await cache.GetOrCreateAsync("key", func);
Assert.Equal(value, first.Value);
Assert.Equal(value, second.Value);
if (same)
{
Assert.Same(first, second);
}
else
{
Assert.NotSame(first, second);
}
}
Assert.Equal(runCount, Volatile.Read(ref actualRunCount));
Assert.Equal(serializeCount, serializer.WriteCount);
Assert.Equal(deserializeCount, serializer.ReadCount);
collector.WriteTo(log);
collector.AssertErrors(errorIds);
}
public class KnownFailureException : Exception
{
public KnownFailureException(string message)
: base(message)
{
}
} }
public class MutablePoco public class MutablePoco
{ {
public string Value { get; set; } = ""; public string Value { get; set; } = "";
public sealed class Serializer : IHybridCacheSerializer<MutablePoco>
{
private int _readCount;
private int _writeCount;
public int ReadCount => Volatile.Read(ref _readCount);
public int WriteCount => Volatile.Read(ref _writeCount);
public MutablePoco Deserialize(ReadOnlySequence<byte> source)
{
Interlocked.Increment(ref _readCount);
var value = InbuiltTypeSerializer.DeserializeString(source);
if (value == "read fail")
{
throw new KnownFailureException("read failure");
}
return new MutablePoco { Value = value };
}
public void Serialize(MutablePoco value, IBufferWriter<byte> target)
{
Interlocked.Increment(ref _writeCount);
if (value.Value == "write fail")
{
throw new KnownFailureException("write failure");
}
InbuiltTypeSerializer.SerializeString(value.Value, target);
}
}
}
[ImmutableObject(true)]
public sealed class ImmutablePoco
{
public ImmutablePoco(string value)
{
Value = value;
}
public string Value { get; }
public sealed class Serializer : IHybridCacheSerializer<ImmutablePoco>
{
private int _readCount;
private int _writeCount;
public int ReadCount => Volatile.Read(ref _readCount);
public int WriteCount => Volatile.Read(ref _writeCount);
public ImmutablePoco Deserialize(ReadOnlySequence<byte> source)
{
Interlocked.Increment(ref _readCount);
var value = InbuiltTypeSerializer.DeserializeString(source);
if (value == "read fail")
{
throw new KnownFailureException("read failure");
}
return new ImmutablePoco(value);
}
public void Serialize(ImmutablePoco value, IBufferWriter<byte> target)
{
Interlocked.Increment(ref _writeCount);
if (value.Value == "write fail")
{
throw new KnownFailureException("write failure");
}
InbuiltTypeSerializer.SerializeString(value.Value, target);
}
}
} }
} }

Просмотреть файл

@ -0,0 +1,189 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.Globalization;
using Microsoft.Extensions.Caching.Hybrid.Internal;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
public sealed class TestEventListener : EventListener
{
// captures both event and counter data
// this is used as a class fixture from HybridCacheEventSourceTests, because there
// seems to be some unpredictable behaviours if multiple event sources/listeners are
// casually created etc
private const double EventCounterIntervalSec = 0.25;
private readonly List<(int id, string name, EventLevel level)> _events = [];
private readonly Dictionary<string, (string? displayName, double value)> _counters = [];
private object SyncLock => _events;
internal HybridCacheEventSource Source { get; } = new();
public TestEventListener Reset(bool resetCounters = true)
{
lock (SyncLock)
{
_events.Clear();
_counters.Clear();
if (resetCounters)
{
Source.ResetCounters();
}
}
Assert.True(Source.IsEnabled(), "should report as enabled");
return this;
}
protected override void OnEventSourceCreated(EventSource eventSource)
{
if (ReferenceEquals(eventSource, Source))
{
var args = new Dictionary<string, string?>
{
["EventCounterIntervalSec"] = EventCounterIntervalSec.ToString("G", CultureInfo.InvariantCulture),
};
EnableEvents(Source, EventLevel.Verbose, EventKeywords.All, args);
}
base.OnEventSourceCreated(eventSource);
}
protected override void OnEventWritten(EventWrittenEventArgs eventData)
{
if (ReferenceEquals(eventData.EventSource, Source))
{
// capture counters/events
lock (SyncLock)
{
if (eventData.EventName == "EventCounters"
&& eventData.Payload is { Count: > 0 })
{
foreach (var payload in eventData.Payload)
{
if (payload is IDictionary<string, object> map)
{
string? name = null;
string? displayName = null;
double? value = null;
bool isIncrement = false;
foreach (var pair in map)
{
switch (pair.Key)
{
case "Name" when pair.Value is string:
name = (string)pair.Value;
break;
case "DisplayName" when pair.Value is string s:
displayName = s;
break;
case "Mean":
isIncrement = false;
value = Convert.ToDouble(pair.Value);
break;
case "Increment":
isIncrement = true;
value = Convert.ToDouble(pair.Value);
break;
}
}
if (name is not null && value is not null)
{
if (isIncrement && _counters.TryGetValue(name, out var oldPair))
{
value += oldPair.value; // treat as delta from old
}
Debug.WriteLine($"{name}={value}");
_counters[name] = (displayName, value.Value);
}
}
}
}
else
{
_events.Add((eventData.EventId, eventData.EventName ?? "", eventData.Level));
}
}
}
base.OnEventWritten(eventData);
}
public (int id, string name, EventLevel level) SingleEvent()
{
(int id, string name, EventLevel level) evt;
lock (SyncLock)
{
evt = Assert.Single(_events);
}
return evt;
}
public void AssertSingleEvent(int id, string name, EventLevel level)
{
var evt = SingleEvent();
Assert.Equal(name, evt.name);
Assert.Equal(id, evt.id);
Assert.Equal(level, evt.level);
}
public double AssertCounter(string name, string displayName)
{
lock (SyncLock)
{
Assert.True(_counters.TryGetValue(name, out var pair), $"counter not found: {name}");
Assert.Equal(displayName, pair.displayName);
_counters.Remove(name); // count as validated
return pair.value;
}
}
public void AssertCounter(string name, string displayName, double expected)
{
var actual = AssertCounter(name, displayName);
if (!Equals(expected, actual))
{
Assert.Fail($"{name}: expected {expected}, actual {actual}");
}
}
[System.Diagnostics.CodeAnalysis.SuppressMessage("Major Bug", "S1244:Floating point numbers should not be tested for equality", Justification = "Test expects exact zero")]
public void AssertRemainingCountersZero()
{
lock (SyncLock)
{
foreach (var pair in _counters)
{
if (pair.Value.value != 0)
{
Assert.Fail($"{pair.Key}: expected 0, actual {pair.Value.value}");
}
}
}
}
[System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "Clarity and usability")]
public async Task<int> TryAwaitCountersAsync()
{
// allow 2 cycles because if we only allow 1, we run the risk of a
// snapshot being captured mid-cycle when we were setting up the test
// (ok, that's an unlikely race condition, but!)
await Task.Delay(TimeSpan.FromSeconds(EventCounterIntervalSec * 2));
lock (SyncLock)
{
return _counters.Count;
}
}
}

Просмотреть файл

@ -0,0 +1,251 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Hybrid.Internal;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit.Abstractions;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
// validate HC stability when the L2 is unreliable
public class UnreliableL2Tests(ITestOutputHelper testLog)
{
[Theory]
[InlineData(BreakType.None)]
[InlineData(BreakType.Synchronous, Log.IdCacheBackendWriteFailure)]
[InlineData(BreakType.Asynchronous, Log.IdCacheBackendWriteFailure)]
[InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendWriteFailure)]
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
public async Task WriteFailureInvisible(BreakType writeBreak, params int[] errorIds)
{
using (GetServices(out var hc, out var l1, out var l2, out var log))
using (log)
{
// normal behaviour when working fine
var x = await hc.GetOrCreateAsync("x", NewGuid);
Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
Assert.NotNull(l2.Tail.Get("x")); // exists
l2.WriteBreak = writeBreak;
var y = await hc.GetOrCreateAsync("y", NewGuid);
Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
if (writeBreak == BreakType.None)
{
Assert.NotNull(l2.Tail.Get("y")); // exists
}
else
{
Assert.Null(l2.Tail.Get("y")); // does not exist
}
await l2.LastWrite; // allows out-of-band write to complete
await Task.Delay(150); // even then: thread jitter can cause problems
log.WriteTo(testLog);
log.AssertErrors(errorIds);
}
}
[Theory]
[InlineData(BreakType.None)]
[InlineData(BreakType.Synchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
[InlineData(BreakType.Asynchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
[InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
public async Task ReadFailureInvisible(BreakType readBreak, params int[] errorIds)
{
using (GetServices(out var hc, out var l1, out var l2, out var log))
using (log)
{
// create two new values via HC; this should go down to l2
var x = await hc.GetOrCreateAsync("x", NewGuid);
var y = await hc.GetOrCreateAsync("y", NewGuid);
// this should be reliable and repeatable
Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
// even if we clean L1, causing new L2 fetches
l1.Clear();
Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
// now we break L2 in some predictable way, *without* clearing L1 - the
// values should still be available via L1
l2.ReadBreak = readBreak;
Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
// but if we clear L1 to force L2 hits, we anticipate problems
l1.Clear();
if (readBreak == BreakType.None)
{
Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
}
else
{
// because L2 is unavailable and L1 is empty, we expect the callback
// to be used again, generating new values
var a = await hc.GetOrCreateAsync("x", NewGuid, NoL2Write);
var b = await hc.GetOrCreateAsync("y", NewGuid, NoL2Write);
Assert.NotEqual(x, a);
Assert.NotEqual(y, b);
// but those *new* values are at least reliable inside L1
Assert.Equal(a, await hc.GetOrCreateAsync("x", NewGuid));
Assert.Equal(b, await hc.GetOrCreateAsync("y", NewGuid));
}
log.WriteTo(testLog);
log.AssertErrors(errorIds);
}
}
private static HybridCacheEntryOptions NoL2Write { get; } = new HybridCacheEntryOptions { Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite };
public enum BreakType
{
None, // async API works correctly
Synchronous, // async API faults directly rather than return a faulted task
Asynchronous, // async API returns a completed asynchronous fault
AsynchronousYield, // async API returns an incomplete asynchronous fault
}
private static ValueTask<Guid> NewGuid(CancellationToken cancellationToken) => new(Guid.NewGuid());
private static IDisposable GetServices(out HybridCache hc, out MemoryCache l1,
out UnreliableDistributedCache l2, out LogCollector log)
{
// we need an entirely separate MC for the dummy backend, not connected to our
// "real" services
var services = new ServiceCollection();
services.AddDistributedMemoryCache();
var backend = services.BuildServiceProvider().GetRequiredService<IDistributedCache>();
// now create the "real" services
l2 = new UnreliableDistributedCache(backend);
var collector = new LogCollector();
log = collector;
services = new ServiceCollection();
services.AddSingleton<IDistributedCache>(l2);
services.AddHybridCache();
services.AddLogging(options =>
{
options.ClearProviders();
options.AddProvider(collector);
});
var lifetime = services.BuildServiceProvider();
hc = lifetime.GetRequiredService<HybridCache>();
l1 = Assert.IsType<MemoryCache>(lifetime.GetRequiredService<IMemoryCache>());
return lifetime;
}
private sealed class UnreliableDistributedCache : IDistributedCache
{
public UnreliableDistributedCache(IDistributedCache tail)
{
Tail = tail;
}
public IDistributedCache Tail { get; }
public BreakType ReadBreak { get; set; }
public BreakType WriteBreak { get; set; }
public Task LastWrite { get; private set; } = Task.CompletedTask;
public byte[]? Get(string key) => throw new NotSupportedException(); // only async API in use
public Task<byte[]?> GetAsync(string key, CancellationToken token = default)
=> TrackLast(ThrowIfBrokenAsync<byte[]?>(ReadBreak) ?? Tail.GetAsync(key, token));
public void Refresh(string key) => throw new NotSupportedException(); // only async API in use
public Task RefreshAsync(string key, CancellationToken token = default)
=> TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RefreshAsync(key, token));
public void Remove(string key) => throw new NotSupportedException(); // only async API in use
public Task RemoveAsync(string key, CancellationToken token = default)
=> TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RemoveAsync(key, token));
public void Set(string key, byte[] value, DistributedCacheEntryOptions options) => throw new NotSupportedException(); // only async API in use
public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default)
=> TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.SetAsync(key, value, options, token));
[DoesNotReturn]
private static void Throw() => throw new IOException("L2 offline");
private static async Task<T> ThrowAsync<T>(bool yield)
{
if (yield)
{
await Task.Yield();
}
Throw();
return default; // never reached
}
private static Task? ThrowIfBrokenAsync(BreakType breakType) => ThrowIfBrokenAsync<int>(breakType);
[SuppressMessage("Critical Bug", "S4586:Non-async \"Task/Task<T>\" methods should not return null", Justification = "Intentional for propagation")]
private static Task<T>? ThrowIfBrokenAsync<T>(BreakType breakType)
{
switch (breakType)
{
case BreakType.Asynchronous:
return ThrowAsync<T>(false);
case BreakType.AsynchronousYield:
return ThrowAsync<T>(true);
case BreakType.None:
return null;
default:
// includes BreakType.Synchronous and anything unknown
Throw();
break;
}
return null;
}
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We don't need the failure type - just the timing")]
private static Task IgnoreFailure(Task task)
{
return task.Status == TaskStatus.RanToCompletion
? Task.CompletedTask : IgnoreAsync(task);
static async Task IgnoreAsync(Task task)
{
try
{
await task;
}
catch
{
// we only care about the "when"; failure is fine
}
}
}
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
private Task TrackLast(Task lastWrite)
{
LastWrite = IgnoreFailure(lastWrite);
return lastWrite;
}
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
private Task<T> TrackLast<T>(Task<T> lastWrite)
{
LastWrite = IgnoreFailure(lastWrite);
return lastWrite;
}
}
}

Просмотреть файл

@ -12,4 +12,8 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection" /> <PackageReference Include="Microsoft.Extensions.DependencyInjection" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition="'$(TargetFramework)'=='net462'">
<PackageReference Include="System.ComponentModel.Annotations" />
</ItemGroup>
</Project> </Project>

Просмотреть файл

@ -0,0 +1,35 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Text.Json.Schema;
using Xunit;
namespace Microsoft.Extensions.AI.JsonSchemaExporter;
public static class JsonSchemaExporterConfigurationTests
{
[Theory]
[InlineData(false)]
[InlineData(true)]
public static void JsonSchemaExporterOptions_DefaultValues(bool useSingleton)
{
JsonSchemaExporterOptions configuration = useSingleton ? JsonSchemaExporterOptions.Default : new();
Assert.False(configuration.TreatNullObliviousAsNonNullable);
Assert.Null(configuration.TransformSchemaNode);
}
[Fact]
public static void JsonSchemaExporterOptions_Singleton_ReturnsSameInstance()
{
Assert.Same(JsonSchemaExporterOptions.Default, JsonSchemaExporterOptions.Default);
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public static void JsonSchemaExporterOptions_TreatNullObliviousAsNonNullable(bool treatNullObliviousAsNonNullable)
{
JsonSchemaExporterOptions configuration = new() { TreatNullObliviousAsNonNullable = treatNullObliviousAsNonNullable };
Assert.Equal(treatNullObliviousAsNonNullable, configuration.TreatNullObliviousAsNonNullable);
}
}

Просмотреть файл

@ -0,0 +1,147 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
#if !NET9_0_OR_GREATER
using System.Xml.Linq;
#endif
using Xunit;
#pragma warning disable SA1402 // File may only contain a single type
namespace Microsoft.Extensions.AI.JsonSchemaExporter;
public abstract class JsonSchemaExporterTests
{
protected abstract JsonSerializerOptions Options { get; }
[Theory]
[MemberData(nameof(TestTypes.GetTestData), MemberType = typeof(TestTypes))]
public void TestTypes_GeneratesExpectedJsonSchema(ITestData testData)
{
JsonSerializerOptions options = testData.Options is { } opts
? new(opts) { TypeInfoResolver = Options.TypeInfoResolver }
: Options;
JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions);
SchemaTestHelpers.AssertEqualJsonSchema(testData.ExpectedJsonSchema, schema);
}
[Theory]
[MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))]
public void TestTypes_SerializedValueMatchesGeneratedSchema(ITestData testData)
{
JsonSerializerOptions options = testData.Options is { } opts
? new(opts) { TypeInfoResolver = Options.TypeInfoResolver }
: Options;
JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions);
JsonNode? instance = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options);
SchemaTestHelpers.AssertDocumentMatchesSchema(schema, instance);
}
[Theory]
[InlineData(typeof(string), "string")]
[InlineData(typeof(int[]), "array")]
[InlineData(typeof(Dictionary<string, int>), "object")]
[InlineData(typeof(TestTypes.SimplePoco), "object")]
public void TreatNullObliviousAsNonNullable_True_MarksAllReferenceTypesAsNonNullable(Type referenceType, string expectedType)
{
Assert.True(!referenceType.IsValueType);
var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config);
JsonValue type = Assert.IsAssignableFrom<JsonValue>(schema["type"]);
Assert.Equal(expectedType, (string)type!);
}
[Theory]
[InlineData(typeof(int), "integer")]
[InlineData(typeof(double), "number")]
[InlineData(typeof(bool), "boolean")]
[InlineData(typeof(ImmutableArray<int>), "array")]
[InlineData(typeof(TestTypes.StructDictionary<string, int>), "object")]
[InlineData(typeof(TestTypes.SimpleRecordStruct), "object")]
public void TreatNullObliviousAsNonNullable_True_DoesNotImpactNonReferenceTypes(Type referenceType, string expectedType)
{
Assert.True(referenceType.IsValueType);
var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config);
JsonValue value = Assert.IsAssignableFrom<JsonValue>(schema["type"]);
Assert.Equal(expectedType, (string)value!);
}
#if !NET9_0 // Disable until https://github.com/dotnet/runtime/pull/108764 gets backported
[Fact]
public void CanGenerateXElementSchema()
{
JsonNode schema = Options.GetJsonSchemaAsNode(typeof(XElement));
Assert.True(schema.ToJsonString().Length < 100_000);
}
#endif
[Fact]
public void TreatNullObliviousAsNonNullable_True_DoesNotImpactObjectType()
{
var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
JsonNode schema = Options.GetJsonSchemaAsNode(typeof(object), config);
Assert.False(schema is JsonObject jObj && jObj.ContainsKey("type"));
}
[Fact]
public void TypeWithDisallowUnmappedMembers_AdditionalPropertiesFailValidation()
{
JsonNode schema = Options.GetJsonSchemaAsNode(typeof(TestTypes.PocoDisallowingUnmappedMembers));
JsonNode? jsonWithUnmappedProperties = JsonNode.Parse("""{ "UnmappedProperty" : {} }""");
SchemaTestHelpers.AssertDoesNotMatchSchema(schema, jsonWithUnmappedProperties);
}
[Fact]
public void GetJsonSchema_NullInputs_ThrowsArgumentNullException()
{
Assert.Throws<ArgumentNullException>(() => ((JsonSerializerOptions)null!).GetJsonSchemaAsNode(typeof(int)));
Assert.Throws<ArgumentNullException>(() => Options.GetJsonSchemaAsNode(type: null!));
Assert.Throws<ArgumentNullException>(() => ((JsonTypeInfo)null!).GetJsonSchemaAsNode());
}
[Fact]
public void GetJsonSchema_NoResolver_ThrowInvalidOperationException()
{
var options = new JsonSerializerOptions();
Assert.Throws<InvalidOperationException>(() => options.GetJsonSchemaAsNode(typeof(int)));
}
[Fact]
public void MaxDepth_SetToZero_NonTrivialSchema_ThrowsInvalidOperationException()
{
JsonSerializerOptions options = new(Options) { MaxDepth = 1 };
var ex = Assert.Throws<InvalidOperationException>(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco)));
Assert.Contains("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting.", ex.Message);
}
[Fact]
public void ReferenceHandlePreserve_Enabled_ThrowsNotSupportedException()
{
var options = new JsonSerializerOptions(Options) { ReferenceHandler = ReferenceHandler.Preserve };
options.MakeReadOnly();
var ex = Assert.Throws<NotSupportedException>(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco)));
Assert.Contains("ReferenceHandler.Preserve", ex.Message);
}
}
public sealed class ReflectionJsonSchemaExporterTests : JsonSchemaExporterTests
{
protected override JsonSerializerOptions Options => JsonSerializerOptions.Default;
}
public sealed class SourceGenJsonSchemaExporterTests : JsonSchemaExporterTests
{
protected override JsonSerializerOptions Options => TestTypes.TestTypesContext.Default.Options;
}

Просмотреть файл

@ -0,0 +1,82 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Json.Schema;
using Xunit.Sdk;
namespace Microsoft.Extensions.AI.JsonSchemaExporter;
internal static partial class SchemaTestHelpers
{
public static void AssertEqualJsonSchema(JsonNode expectedJsonSchema, JsonNode actualJsonSchema)
{
if (!JsonNode.DeepEquals(expectedJsonSchema, actualJsonSchema))
{
throw new XunitException($"""
Generated schema does not match the expected specification.
Expected:
{FormatJson(expectedJsonSchema)}
Actual:
{FormatJson(actualJsonSchema)}
""");
}
}
public static void AssertDocumentMatchesSchema(JsonNode schema, JsonNode? instance)
{
EvaluationResults results = EvaluateSchemaCore(schema, instance);
if (!results.IsValid)
{
IEnumerable<string> errors = results.Details
.Where(d => d.HasErrors)
.SelectMany(d => d.Errors!.Select(error => $"Path:${d.InstanceLocation} {error.Key}:{error.Value}"));
throw new XunitException($"""
Instance JSON document does not match the specified schema.
Schema:
{FormatJson(schema)}
Instance:
{FormatJson(instance)}
Errors:
{string.Join(Environment.NewLine, errors)}
""");
}
}
public static void AssertDoesNotMatchSchema(JsonNode schema, JsonNode? instance)
{
EvaluationResults results = EvaluateSchemaCore(schema, instance);
if (results.IsValid)
{
throw new XunitException($"""
Instance JSON document matches the specified schema.
Schema:
{FormatJson(schema)}
Instance:
{FormatJson(instance)}
""");
}
}
private static EvaluationResults EvaluateSchemaCore(JsonNode schema, JsonNode? instance)
{
JsonSchema jsonSchema = JsonSerializer.Deserialize(schema, Context.Default.JsonSchema)!;
EvaluationOptions options = new() { OutputFormat = OutputFormat.List };
return jsonSchema.Evaluate(instance, options);
}
private static string FormatJson(JsonNode? node) =>
JsonSerializer.Serialize(node, Context.Default.JsonNode!);
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(JsonNode))]
[JsonSerializable(typeof(JsonSchema))]
[JsonSourceGenerationOptions(WriteIndented = true)]
private partial class Context : JsonSerializerContext;
}

Просмотреть файл

@ -0,0 +1,67 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
namespace Microsoft.Extensions.AI.JsonSchemaExporter;
internal sealed record TestData<T>(
T? Value,
[StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
IEnumerable<T?>? AdditionalValues = null,
JsonSchemaExporterOptions? ExporterOptions = null,
JsonSerializerOptions? Options = null,
bool WritesNumbersAsStrings = false)
: ITestData
{
private static readonly JsonDocumentOptions _schemaParseOptions = new() { CommentHandling = JsonCommentHandling.Skip };
public Type Type => typeof(T);
object? ITestData.Value => Value;
object? ITestData.ExporterOptions => ExporterOptions;
JsonNode ITestData.ExpectedJsonSchema { get; } =
JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
?? throw new ArgumentNullException("schema must not be null");
IEnumerable<ITestData> ITestData.GetTestDataForAllValues()
{
yield return this;
if (default(T) is null &&
ExporterOptions is { TreatNullObliviousAsNonNullable: false } &&
Value is not null)
{
yield return this with { Value = default };
}
if (AdditionalValues != null)
{
foreach (T? value in AdditionalValues)
{
yield return this with { Value = value, AdditionalValues = null };
}
}
}
}
public interface ITestData
{
Type Type { get; }
object? Value { get; }
JsonNode ExpectedJsonSchema { get; }
object? ExporterOptions { get; }
JsonSerializerOptions? Options { get; }
bool WritesNumbersAsStrings { get; }
IEnumerable<ITestData> GetTestDataForAllValues();
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -2,19 +2,26 @@
<PropertyGroup> <PropertyGroup>
<RootNamespace>Microsoft.Shared.Test</RootNamespace> <RootNamespace>Microsoft.Shared.Test</RootNamespace>
<Description>Unit tests for Microsoft.Shared</Description> <Description>Unit tests for Microsoft.Shared</Description>
<DefineConstants>$(DefineConstants);TESTS_JSON_SCHEMA_EXPORTER_POLYFILL</DefineConstants>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<NoWarn>$(NoWarn);CA1716</NoWarn> <NoWarn>$(NoWarn);CA1716;S104</NoWarn>
<TargetFrameworks>$(TestNetCoreTargetFrameworks)</TargetFrameworks> <TargetFrameworks>$(TestNetCoreTargetFrameworks)</TargetFrameworks>
<TargetFrameworks Condition=" '$(IsWindowsBuild)' == 'true' ">$(TestNetCoreTargetFrameworks)$(ConditionalNet462)</TargetFrameworks> <TargetFrameworks Condition=" '$(IsWindowsBuild)' == 'true' ">$(TestNetCoreTargetFrameworks)$(ConditionalNet462)</TargetFrameworks>
</PropertyGroup> </PropertyGroup>
<PropertyGroup>
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
</PropertyGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\..\src\Shared\Shared.csproj" ProjectUnderTest="true" /> <ProjectReference Include="..\..\src\Shared\Shared.csproj" ProjectUnderTest="true" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection" /> <PackageReference Include="Microsoft.Extensions.DependencyInjection" />
<PackageReference Include="JsonSchema.Net" />
</ItemGroup> </ItemGroup>
</Project> </Project>