From 0577c2110a6dffc9ab98e232cef3e769395d3be6 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 14 Nov 2024 00:05:51 -0500 Subject: [PATCH] Rework cache key handling in caching client / generator (#5641) * Rework cache key handling in caching client / generator - Expose the default cache key helper so that customization doesn't require re-implementing the whole thing. - Make it easy to incorporate additional state into the cache key. - Avoid serializing all of the values for the key into a new byte[], at least on .NET 8+. There, we can serialize directly into a stream that targets an IncrementalHash. - Include Chat/EmbeddingGenerationOptions in the cache key by default. * Update test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs Co-authored-by: Shyam N --------- Co-authored-by: Shyam N --- .../Microsoft.Extensions.AI/CachingHelpers.cs | 133 +++++++++++++----- .../DistributedCachingChatClient.cs | 28 ++-- .../DistributedCachingEmbeddingGenerator.cs | 17 ++- .../DistributedCachingChatClientTest.cs | 21 ++- ...istributedCachingEmbeddingGeneratorTest.cs | 27 +++- .../TestJsonSerializerContext.cs | 2 + 6 files changed, 173 insertions(+), 55 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs index 13637dc522..102fc86b13 100644 --- a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs +++ b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs @@ -2,9 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Diagnostics; +using System.IO; using System.Security.Cryptography; using System.Text.Json; -using Microsoft.Shared.Diagnostics; +#if NET +using System.Threading; +using System.Threading.Tasks; +#endif + +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable SA1202 // Elements should be ordered by access +#pragma warning disable SA1502 // Element should not be on a single line namespace Microsoft.Extensions.AI; @@ -12,50 +21,110 @@ namespace Microsoft.Extensions.AI; internal static class CachingHelpers { /// Computes a default cache key for the specified parameters. - /// Specifies the type of the data being used to compute the key. - /// The data with which to compute the key. + /// The data with which to compute the key. /// The . /// A string that will be used as a cache key. - public static string GetCacheKey(TValue value, JsonSerializerOptions serializerOptions) - => GetCacheKey(value, false, serializerOptions); - - /// Computes a default cache key for the specified parameters. - /// Specifies the type of the data being used to compute the key. - /// The data with which to compute the key. - /// Another data item that causes the key to vary. - /// The . - /// A string that will be used as a cache key. - public static string GetCacheKey(TValue value, bool flag, JsonSerializerOptions serializerOptions) + public static string GetCacheKey(ReadOnlySpan values, JsonSerializerOptions serializerOptions) { - _ = Throw.IfNull(value); - _ = Throw.IfNull(serializerOptions); - serializerOptions.MakeReadOnly(); - - var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue))); - - if (flag && jsonKeyBytes.Length > 0) - { - // Make an arbitrary change to the hash input based on the flag - // The alternative would be including the flag in "value" in the - // first place, but that's likely to require an extra allocation - // or the inclusion of another type in the JsonSerializerContext. - // This is a micro-optimization we can change at any time. - jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]); - } + Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null"); + Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only."); // The complete JSON representation is excessively long for a cache key, duplicating much of the content // from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes. // If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information // disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit // invalidating any existing cache entries that may exist in whatever IDistributedCache was in use. -#if NET8_0_OR_GREATER + +#if NET + IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new(); + IncrementalHashStream.ThreadStaticInstance = null; + + foreach (object? value in values) + { + JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object))); + } + Span hashData = stackalloc byte[SHA256.HashSizeInBytes]; - SHA256.HashData(jsonKeyBytes, hashData); + stream.GetHashAndReset(hashData); + IncrementalHashStream.ThreadStaticInstance = stream; + return Convert.ToHexString(hashData); #else + MemoryStream stream = new(); + foreach (object? value in values) + { + JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object))); + } + using var sha256 = SHA256.Create(); - var hashData = sha256.ComputeHash(jsonKeyBytes); - return BitConverter.ToString(hashData).Replace("-", string.Empty); + stream.Position = 0; + var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length); + + var chars = new char[hashData.Length * 2]; + int destPos = 0; + foreach (byte b in hashData) + { + int div = Math.DivRem(b, 16, out int rem); + chars[destPos++] = ToHexChar(div); + chars[destPos++] = ToHexChar(rem); + + static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A'); + } + + Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array."); + + return new string(chars); #endif } + +#if NET + /// Provides a stream that writes to an . + private sealed class IncrementalHashStream : Stream + { + /// A per-thread instance of . + /// An instance stored must be in a reset state ready to be used by another consumer. + [ThreadStatic] + public static IncrementalHashStream? ThreadStaticInstance; + + /// Gets the current hash and resets. + public void GetHashAndReset(Span bytes) => _hash.GetHashAndReset(bytes); + + /// The used by this instance. + private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256); + + protected override void Dispose(bool disposing) + { + _hash.Dispose(); + base.Dispose(disposing); + } + + public override void WriteByte(byte value) => Write(new ReadOnlySpan(in value)); + public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count); + public override void Write(ReadOnlySpan buffer) => _hash.AppendData(buffer); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Write(buffer, offset, count); + return Task.CompletedTask; + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + Write(buffer.Span); + return ValueTask.CompletedTask; + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + } +#endif } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs index 6ea79f9f73..678e9bd652 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Text.Json; using System.Threading; @@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI; /// public class DistributedCachingChatClient : CachingChatClient { + /// A boxed value. + private static readonly object _boxedTrue = true; + + /// A boxed value. + private static readonly object _boxedFalse = false; + + /// The instance that will be used as the backing store for the cache. private readonly IDistributedCache _storage; - private JsonSerializerOptions _jsonSerializerOptions; + + /// The to use when serializing cache data. + private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions; /// Initializes a new instance of the class. /// The underlying . @@ -29,7 +39,6 @@ public class DistributedCachingChatClient : CachingChatClient : base(innerClient) { _storage = Throw.IfNull(storage); - _jsonSerializerOptions = AIJsonUtilities.DefaultOptions; } /// Gets or sets JSON serialization options to use when serializing cache data. @@ -90,13 +99,16 @@ public class DistributedCachingChatClient : CachingChatClient } /// - protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) + protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) => + GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]); + + /// Gets a cache key based on the supplied values. + /// The values to inform the key. + /// The computed key. + /// This provides the default implementation for . + protected string GetCacheKey(ReadOnlySpan values) { - // While it might be desirable to include ChatOptions in the cache key, it's not always possible, - // since ChatOptions can contain types that are not guaranteed to be serializable or have a stable - // hashcode across multiple calls. So the default cache key is simply the JSON representation of - // the chat contents. Developers may subclass and override this to provide custom rules. _jsonSerializerOptions.MakeReadOnly(); - return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions); + return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index ecec409a1b..6482ed8ed2 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using System.Threading; @@ -74,12 +75,16 @@ public class DistributedCachingEmbeddingGenerator : CachingE } /// - protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) + protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) => + GetCacheKey([value, options]); + + /// Gets a cache key based on the supplied values. + /// The values to inform the key. + /// The computed key. + /// This provides the default implementation for . + protected string GetCacheKey(ReadOnlySpan values) { - // While it might be desirable to include options in the cache key, it's not always possible, - // since options can contain types that are not guaranteed to be serializable or have a stable - // hashcode across multiple calls. So the default cache key is simply the JSON representation of - // the value. Developers may subclass and override this to provide custom rules. - return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions); + _jsonSerializerOptions.MakeReadOnly(); + return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 67e23ec495..772bb9cf7d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -527,7 +527,7 @@ public class DistributedCachingChatClientTest } [Fact] - public async Task CacheKeyDoesNotVaryByChatOptionsAsync() + public async Task CacheKeyVariesByChatOptionsAsync() { // Arrange var innerCallCount = 0; @@ -546,20 +546,35 @@ public class DistributedCachingChatClientTest JsonSerializerOptions = TestJsonSerializerContext.Default.Options }; - // Act: Call with two different ChatOptions + // Act: Call with two different ChatOptions that have the same values var result1 = await outer.CompleteAsync([], new ChatOptions { AdditionalProperties = new() { { "someKey", "value 1" } } }); var result2 = await outer.CompleteAsync([], new ChatOptions { - AdditionalProperties = new() { { "someKey", "value 2" } } + AdditionalProperties = new() { { "someKey", "value 1" } } }); // Assert: Same result Assert.Equal(1, innerCallCount); Assert.Equal("value 1", result1.Message.Text); Assert.Equal("value 1", result2.Message.Text); + + // Act: Call with two different ChatOptions that have different values + var result3 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 1" } } + }); + var result4 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 2" } } + }); + + // Assert: Different results + Assert.Equal(2, innerCallCount); + Assert.Equal("value 1", result3.Message.Text); + Assert.Equal("value 2", result4.Message.Text); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs index a2818c7c3e..f9356ef45c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -221,7 +221,7 @@ public class DistributedCachingEmbeddingGeneratorTest } [Fact] - public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync() + public async Task CacheKeyVariesByEmbeddingOptionsAsync() { // Arrange var innerCallCount = 0; @@ -232,7 +232,7 @@ public class DistributedCachingEmbeddingGeneratorTest { innerCallCount++; await Task.Yield(); - return [_expectedEmbedding]; + return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())]; } }; using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) @@ -240,20 +240,35 @@ public class DistributedCachingEmbeddingGeneratorTest JsonSerializerOptions = TestJsonSerializerContext.Default.Options, }; - // Act: Call with two different options + // Act: Call with two different EmbeddingGenerationOptions that have the same values var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions { AdditionalProperties = new() { ["someKey"] = "value 1" } }); var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions { - AdditionalProperties = new() { ["someKey"] = "value 2" } + AdditionalProperties = new() { ["someKey"] = "value 1" } }); // Assert: Same result Assert.Equal(1, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, result1); - AssertEmbeddingsEqual(_expectedEmbedding, result2); + AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1); + AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2); + + // Act: Call with two different EmbeddingGenerationOptions that have different values + var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 1" } + }); + var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 2" } + }); + + // Assert: Different result + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3); + AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs index e376da86da..b077542c17 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs @@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI; [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(DayOfWeek[]))] [JsonSerializable(typeof(Guid))] +[JsonSerializable(typeof(ChatOptions))] +[JsonSerializable(typeof(EmbeddingGenerationOptions))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;