зеркало из https://github.com/dotnet/extensions.git
Rework cache key handling in caching client / generator (#5641)
* Rework cache key handling in caching client / generator - Expose the default cache key helper so that customization doesn't require re-implementing the whole thing. - Make it easy to incorporate additional state into the cache key. - Avoid serializing all of the values for the key into a new byte[], at least on .NET 8+. There, we can serialize directly into a stream that targets an IncrementalHash. - Include Chat/EmbeddingGenerationOptions in the cache key by default. * Update test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs Co-authored-by: Shyam N <shyamnamboodiripad@users.noreply.github.com> --------- Co-authored-by: Shyam N <shyamnamboodiripad@users.noreply.github.com>
This commit is contained in:
Родитель
a1863ea0b6
Коммит
0577c2110a
|
@ -2,9 +2,18 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
#if NET
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
#endif
|
||||
|
||||
#pragma warning disable S109 // Magic numbers should not be used
|
||||
#pragma warning disable SA1202 // Elements should be ordered by access
|
||||
#pragma warning disable SA1502 // Element should not be on a single line
|
||||
|
||||
namespace Microsoft.Extensions.AI;
|
||||
|
||||
|
@ -12,50 +21,110 @@ namespace Microsoft.Extensions.AI;
|
|||
internal static class CachingHelpers
|
||||
{
|
||||
/// <summary>Computes a default cache key for the specified parameters.</summary>
|
||||
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
|
||||
/// <param name="value">The data with which to compute the key.</param>
|
||||
/// <param name="values">The data with which to compute the key.</param>
|
||||
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
|
||||
/// <returns>A string that will be used as a cache key.</returns>
|
||||
public static string GetCacheKey<TValue>(TValue value, JsonSerializerOptions serializerOptions)
|
||||
=> GetCacheKey(value, false, serializerOptions);
|
||||
|
||||
/// <summary>Computes a default cache key for the specified parameters.</summary>
|
||||
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
|
||||
/// <param name="value">The data with which to compute the key.</param>
|
||||
/// <param name="flag">Another data item that causes the key to vary.</param>
|
||||
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
|
||||
/// <returns>A string that will be used as a cache key.</returns>
|
||||
public static string GetCacheKey<TValue>(TValue value, bool flag, JsonSerializerOptions serializerOptions)
|
||||
public static string GetCacheKey(ReadOnlySpan<object?> values, JsonSerializerOptions serializerOptions)
|
||||
{
|
||||
_ = Throw.IfNull(value);
|
||||
_ = Throw.IfNull(serializerOptions);
|
||||
serializerOptions.MakeReadOnly();
|
||||
|
||||
var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue)));
|
||||
|
||||
if (flag && jsonKeyBytes.Length > 0)
|
||||
{
|
||||
// Make an arbitrary change to the hash input based on the flag
|
||||
// The alternative would be including the flag in "value" in the
|
||||
// first place, but that's likely to require an extra allocation
|
||||
// or the inclusion of another type in the JsonSerializerContext.
|
||||
// This is a micro-optimization we can change at any time.
|
||||
jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]);
|
||||
}
|
||||
Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null");
|
||||
Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only.");
|
||||
|
||||
// The complete JSON representation is excessively long for a cache key, duplicating much of the content
|
||||
// from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes.
|
||||
// If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information
|
||||
// disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
|
||||
// invalidating any existing cache entries that may exist in whatever IDistributedCache was in use.
|
||||
#if NET8_0_OR_GREATER
|
||||
|
||||
#if NET
|
||||
IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new();
|
||||
IncrementalHashStream.ThreadStaticInstance = null;
|
||||
|
||||
foreach (object? value in values)
|
||||
{
|
||||
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
|
||||
}
|
||||
|
||||
Span<byte> hashData = stackalloc byte[SHA256.HashSizeInBytes];
|
||||
SHA256.HashData(jsonKeyBytes, hashData);
|
||||
stream.GetHashAndReset(hashData);
|
||||
IncrementalHashStream.ThreadStaticInstance = stream;
|
||||
|
||||
return Convert.ToHexString(hashData);
|
||||
#else
|
||||
MemoryStream stream = new();
|
||||
foreach (object? value in values)
|
||||
{
|
||||
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
|
||||
}
|
||||
|
||||
using var sha256 = SHA256.Create();
|
||||
var hashData = sha256.ComputeHash(jsonKeyBytes);
|
||||
return BitConverter.ToString(hashData).Replace("-", string.Empty);
|
||||
stream.Position = 0;
|
||||
var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);
|
||||
|
||||
var chars = new char[hashData.Length * 2];
|
||||
int destPos = 0;
|
||||
foreach (byte b in hashData)
|
||||
{
|
||||
int div = Math.DivRem(b, 16, out int rem);
|
||||
chars[destPos++] = ToHexChar(div);
|
||||
chars[destPos++] = ToHexChar(rem);
|
||||
|
||||
static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
|
||||
}
|
||||
|
||||
Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");
|
||||
|
||||
return new string(chars);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if NET
|
||||
/// <summary>Provides a stream that writes to an <see cref="IncrementalHash"/>.</summary>
|
||||
private sealed class IncrementalHashStream : Stream
|
||||
{
|
||||
/// <summary>A per-thread instance of <see cref="IncrementalHashStream"/>.</summary>
|
||||
/// <remarks>An instance stored must be in a reset state ready to be used by another consumer.</remarks>
|
||||
[ThreadStatic]
|
||||
public static IncrementalHashStream? ThreadStaticInstance;
|
||||
|
||||
/// <summary>Gets the current hash and resets.</summary>
|
||||
public void GetHashAndReset(Span<byte> bytes) => _hash.GetHashAndReset(bytes);
|
||||
|
||||
/// <summary>The <see cref="IncrementalHash"/> used by this instance.</summary>
|
||||
private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
_hash.Dispose();
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
public override void WriteByte(byte value) => Write(new ReadOnlySpan<byte>(in value));
|
||||
public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
|
||||
public override void Write(ReadOnlySpan<byte> buffer) => _hash.AppendData(buffer);
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
Write(buffer, offset, count);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
Write(buffer.Span);
|
||||
return ValueTask.CompletedTask;
|
||||
}
|
||||
|
||||
public override void Flush() { }
|
||||
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
|
||||
|
||||
public override bool CanWrite => true;
|
||||
public override bool CanRead => false;
|
||||
public override bool CanSeek => false;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
|
||||
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
|
||||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
|
@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI;
|
|||
/// </remarks>
|
||||
public class DistributedCachingChatClient : CachingChatClient
|
||||
{
|
||||
/// <summary>A boxed <see langword="true"/> value.</summary>
|
||||
private static readonly object _boxedTrue = true;
|
||||
|
||||
/// <summary>A boxed <see langword="false"/> value.</summary>
|
||||
private static readonly object _boxedFalse = false;
|
||||
|
||||
/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
|
||||
private readonly IDistributedCache _storage;
|
||||
private JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
|
||||
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
|
||||
|
||||
/// <summary>Initializes a new instance of the <see cref="DistributedCachingChatClient"/> class.</summary>
|
||||
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
|
||||
|
@ -29,7 +39,6 @@ public class DistributedCachingChatClient : CachingChatClient
|
|||
: base(innerClient)
|
||||
{
|
||||
_storage = Throw.IfNull(storage);
|
||||
_jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
|
||||
}
|
||||
|
||||
/// <summary>Gets or sets JSON serialization options to use when serializing cache data.</summary>
|
||||
|
@ -90,13 +99,16 @@ public class DistributedCachingChatClient : CachingChatClient
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
|
||||
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
|
||||
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);
|
||||
|
||||
/// <summary>Gets a cache key based on the supplied values.</summary>
|
||||
/// <param name="values">The values to inform the key.</param>
|
||||
/// <returns>The computed key.</returns>
|
||||
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
|
||||
protected string GetCacheKey(ReadOnlySpan<object?> values)
|
||||
{
|
||||
// While it might be desirable to include ChatOptions in the cache key, it's not always possible,
|
||||
// since ChatOptions can contain types that are not guaranteed to be serializable or have a stable
|
||||
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
|
||||
// the chat contents. Developers may subclass and override this to provide custom rules.
|
||||
_jsonSerializerOptions.MakeReadOnly();
|
||||
return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions);
|
||||
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization.Metadata;
|
||||
using System.Threading;
|
||||
|
@ -74,12 +75,16 @@ public class DistributedCachingEmbeddingGenerator<TInput, TEmbedding> : CachingE
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options)
|
||||
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
|
||||
GetCacheKey([value, options]);
|
||||
|
||||
/// <summary>Gets a cache key based on the supplied values.</summary>
|
||||
/// <param name="values">The values to inform the key.</param>
|
||||
/// <returns>The computed key.</returns>
|
||||
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
|
||||
protected string GetCacheKey(ReadOnlySpan<object?> values)
|
||||
{
|
||||
// While it might be desirable to include options in the cache key, it's not always possible,
|
||||
// since options can contain types that are not guaranteed to be serializable or have a stable
|
||||
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
|
||||
// the value. Developers may subclass and override this to provide custom rules.
|
||||
return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions);
|
||||
_jsonSerializerOptions.MakeReadOnly();
|
||||
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -527,7 +527,7 @@ public class DistributedCachingChatClientTest
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
|
||||
public async Task CacheKeyVariesByChatOptionsAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerCallCount = 0;
|
||||
|
@ -546,20 +546,35 @@ public class DistributedCachingChatClientTest
|
|||
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
|
||||
};
|
||||
|
||||
// Act: Call with two different ChatOptions
|
||||
// Act: Call with two different ChatOptions that have the same values
|
||||
var result1 = await outer.CompleteAsync([], new ChatOptions
|
||||
{
|
||||
AdditionalProperties = new() { { "someKey", "value 1" } }
|
||||
});
|
||||
var result2 = await outer.CompleteAsync([], new ChatOptions
|
||||
{
|
||||
AdditionalProperties = new() { { "someKey", "value 2" } }
|
||||
AdditionalProperties = new() { { "someKey", "value 1" } }
|
||||
});
|
||||
|
||||
// Assert: Same result
|
||||
Assert.Equal(1, innerCallCount);
|
||||
Assert.Equal("value 1", result1.Message.Text);
|
||||
Assert.Equal("value 1", result2.Message.Text);
|
||||
|
||||
// Act: Call with two different ChatOptions that have different values
|
||||
var result3 = await outer.CompleteAsync([], new ChatOptions
|
||||
{
|
||||
AdditionalProperties = new() { { "someKey", "value 1" } }
|
||||
});
|
||||
var result4 = await outer.CompleteAsync([], new ChatOptions
|
||||
{
|
||||
AdditionalProperties = new() { { "someKey", "value 2" } }
|
||||
});
|
||||
|
||||
// Assert: Different results
|
||||
Assert.Equal(2, innerCallCount);
|
||||
Assert.Equal("value 1", result3.Message.Text);
|
||||
Assert.Equal("value 2", result4.Message.Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
@ -221,7 +221,7 @@ public class DistributedCachingEmbeddingGeneratorTest
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
|
||||
public async Task CacheKeyVariesByEmbeddingOptionsAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerCallCount = 0;
|
||||
|
@ -232,7 +232,7 @@ public class DistributedCachingEmbeddingGeneratorTest
|
|||
{
|
||||
innerCallCount++;
|
||||
await Task.Yield();
|
||||
return [_expectedEmbedding];
|
||||
return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())];
|
||||
}
|
||||
};
|
||||
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
|
||||
|
@ -240,20 +240,35 @@ public class DistributedCachingEmbeddingGeneratorTest
|
|||
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
|
||||
};
|
||||
|
||||
// Act: Call with two different options
|
||||
// Act: Call with two different EmbeddingGenerationOptions that have the same values
|
||||
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
|
||||
{
|
||||
AdditionalProperties = new() { ["someKey"] = "value 1" }
|
||||
});
|
||||
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
|
||||
{
|
||||
AdditionalProperties = new() { ["someKey"] = "value 2" }
|
||||
AdditionalProperties = new() { ["someKey"] = "value 1" }
|
||||
});
|
||||
|
||||
// Assert: Same result
|
||||
Assert.Equal(1, innerCallCount);
|
||||
AssertEmbeddingsEqual(_expectedEmbedding, result1);
|
||||
AssertEmbeddingsEqual(_expectedEmbedding, result2);
|
||||
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1);
|
||||
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);
|
||||
|
||||
// Act: Call with two different EmbeddingGenerationOptions that have different values
|
||||
var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
|
||||
{
|
||||
AdditionalProperties = new() { ["someKey"] = "value 1" }
|
||||
});
|
||||
var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
|
||||
{
|
||||
AdditionalProperties = new() { ["someKey"] = "value 2" }
|
||||
});
|
||||
|
||||
// Assert: Different result
|
||||
Assert.Equal(2, innerCallCount);
|
||||
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3);
|
||||
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI;
|
|||
[JsonSerializable(typeof(Dictionary<string, string>))]
|
||||
[JsonSerializable(typeof(DayOfWeek[]))]
|
||||
[JsonSerializable(typeof(Guid))]
|
||||
[JsonSerializable(typeof(ChatOptions))]
|
||||
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
|
||||
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
|
||||
|
|
Загрузка…
Ссылка в новой задаче