Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingChatClient (#5616)

* Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingChatClient

Adds a ToStreamingChatCompletionUpdates method that's the counterpart to the recently added ToChatCompletion.

Then uses both from CachingChatClient instead of its now bespoke coalescing implementation. When coalescing is enabled (the default), CachingChatClient caches everything as a ChatCompletion, rather than distinguishing streaming and non-streaming.

* Address PR feedback
This commit is contained in:
Stephen Toub 2024-11-11 10:12:10 -05:00 коммит произвёл GitHub
Родитель c163960f9e
Коммит 148e221539
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
7 изменённых файлов: 283 добавлений и 133 удалений

Просмотреть файл

@ -87,4 +87,53 @@ public class ChatCompletion
/// <inheritdoc />
public override string ToString() =>
Choices is { Count: > 0 } choices ? string.Join(Environment.NewLine, choices) : string.Empty;
/// <summary>Creates an array of <see cref="StreamingChatCompletionUpdate" /> instances that represent this <see cref="ChatCompletion" />.</summary>
/// <returns>An array of <see cref="StreamingChatCompletionUpdate" /> instances that may be used to represent this <see cref="ChatCompletion" />.</returns>
public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates()
{
StreamingChatCompletionUpdate? extra = null;
if (AdditionalProperties is not null || Usage is not null)
{
extra = new StreamingChatCompletionUpdate
{
AdditionalProperties = AdditionalProperties
};
if (Usage is { } usage)
{
extra.Contents.Add(new UsageContent(usage));
}
}
int choicesCount = Choices.Count;
var updates = new StreamingChatCompletionUpdate[choicesCount + (extra is null ? 0 : 1)];
for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++)
{
ChatMessage choice = Choices[choiceIndex];
updates[choiceIndex] = new StreamingChatCompletionUpdate
{
ChoiceIndex = choiceIndex,
AdditionalProperties = choice.AdditionalProperties,
AuthorName = choice.AuthorName,
Contents = choice.Contents,
RawRepresentation = choice.RawRepresentation,
Role = choice.Role,
CompletionId = CompletionId,
CreatedAt = CreatedAt,
FinishReason = FinishReason,
ModelId = ModelId
};
}
if (extra is not null)
{
updates[choicesCount] = extra;
}
return updates;
}
}

Просмотреть файл

@ -9,14 +9,35 @@ using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI;
// Conceptually this combines the roles of ChatCompletion and ChatMessage in streaming output.
// For ease of consumption, it also flattens the nested structure you see on streaming chunks in
// the OpenAI/Gemini APIs, so instead of a dictionary of choices, each update represents a single
// choice (and hence has its own role, choice ID, etc.).
/// <summary>
/// Represents a single response chunk from an <see cref="IChatClient"/>.
/// Represents a single streaming response chunk from an <see cref="IChatClient"/>.
/// </summary>
/// <remarks>
/// <para>
/// Conceptually, this combines the roles of <see cref="ChatCompletion"/> and <see cref="ChatMessage"/>
/// in streaming output. For ease of consumption, it also flattens the nested structure you see on
/// streaming chunks in some AI service, so instead of a dictionary of choices, each update represents a
/// single choice (and hence has its own role, choice ID, etc.).
/// </para>
/// <para>
/// <see cref="StreamingChatCompletionUpdate"/> is so named because it represents streaming updates
/// to a single chat completion. As such, it is considered erroneous for multiple updates that are part
/// of the same completion to contain competing values. For example, some updates that are part of
/// the same completion may have a <see langword="null"/> <see cref="StreamingChatCompletionUpdate.Role"/>
/// value, and others may have a non-<see langword="null"/> value, but all of those with a non-<see langword="null"/>
/// value must have the same value (e.g. <see cref="ChatRole.Assistant"/>. It should never be the case, for example,
/// that one <see cref="StreamingChatCompletionUpdate"/> in a completion has a role of <see cref="ChatRole.Assistant"/>
/// while another has a role of "AI".
/// </para>
/// <para>
/// The relationship between <see cref="ChatCompletion"/> and <see cref="StreamingChatCompletionUpdate"/> is
/// codified in the <see cref="StreamingChatCompletionUpdateExtensions.ToChatCompletionAsync"/> and
/// <see cref="ChatCompletion.ToStreamingChatCompletionUpdates"/>, which enable bidirectional conversions
/// between the two. Note, however, that the conversion may be slightly lossy, for example if multiple updates
/// all have different <see cref="StreamingChatCompletionUpdate.RawRepresentation"/> objects whereas there's
/// only one slot for such an object available in <see cref="ChatCompletion.RawRepresentation"/>.
/// </para>
/// </remarks>
public class StreamingChatCompletionUpdate
{
/// <summary>The completion update content items.</summary>

Просмотреть файл

@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Linq;
#if NET
using System.Runtime.InteropServices;
#endif
@ -133,7 +134,22 @@ public static class StreamingChatCompletionUpdateExtensions
/// <param name="coalesceContent">The corresponding option value provided to <see cref="ToChatCompletion"/> or <see cref="ToChatCompletionAsync"/>.</param>
private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> messages, ChatCompletion completion, bool coalesceContent)
{
foreach (var entry in messages)
if (messages.Count <= 1)
{
foreach (var entry in messages)
{
AddMessage(completion, coalesceContent, entry);
}
}
else
{
foreach (var entry in messages.OrderBy(entry => entry.Key))
{
AddMessage(completion, coalesceContent, entry);
}
}
static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValuePair<int, ChatMessage> entry)
{
if (entry.Value.Role == default)
{
@ -154,6 +170,8 @@ public static class StreamingChatCompletionUpdateExtensions
if (content is UsageContent c)
{
completion.Usage = c.Details;
entry.Value.Contents = entry.Value.Contents.ToList();
_ = entry.Value.Contents.Remove(c);
break;
}
}

Просмотреть файл

@ -3,7 +3,6 @@
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
@ -48,13 +47,12 @@ public abstract class CachingChatClient : DelegatingChatClient
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(false, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing)
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
{
return existing;
result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
}
var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
return result;
}
@ -64,127 +62,59 @@ public abstract class CachingChatClient : DelegatingChatClient
{
_ = Throw.IfNull(chatMessages);
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
if (CoalesceStreamingUpdates)
{
// Yield all of the cached items.
foreach (var chunk in existingChunks)
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
// we make a streaming request, yielding those results, but then convert those into a non-streaming
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion)
{
yield return chunk;
// Yield all of the cached items.
foreach (var chunk in chatCompletion.ToStreamingChatCompletionUpdates())
{
yield return chunk;
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
capturedItems.Add(chunk);
yield return chunk;
}
// Write the captured items to the cache as a non-streaming result.
await WriteCacheAsync(cacheKey, capturedItems.ToChatCompletion(), cancellationToken).ConfigureAwait(false);
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
capturedItems.Add(chunk);
yield return chunk;
}
// If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list.
if (CoalesceStreamingUpdates)
{
StringBuilder coalescedText = new();
// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++)
// Yield all of the cached items.
foreach (var chunk in existingChunks)
{
// If an item isn't generally coalescable, skip it.
StreamingChatCompletionUpdate update = capturedItems[startInclusive];
if (update.ChoiceIndex != 0 ||
update.Contents.Count != 1 ||
update.Contents[0] is not TextContent textContent)
{
continue;
}
// We found a coalescable item. Look for more contiguous items that are also coalescable with it.
int endExclusive = startInclusive + 1;
for (; endExclusive < capturedItems.Count; endExclusive++)
{
StreamingChatCompletionUpdate next = capturedItems[endExclusive];
if (next.ChoiceIndex != 0 ||
next.Contents.Count != 1 ||
next.Contents[0] is not TextContent ||
// changing role or author would be really strange, but check anyway
(update.Role is not null && next.Role is not null && update.Role != next.Role) ||
(update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName))
{
break;
}
}
// If we couldn't find anything to coalesce, there's nothing to do.
if (endExclusive - startInclusive <= 1)
{
continue;
}
// We found a coalescable run of items. Create a new node to represent the run. We create a new one
// rather than reappropriating one of the existing ones so as not to mutate an item already yielded.
_ = coalescedText.Clear().Append(capturedItems[startInclusive].Text);
TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
};
StreamingChatCompletionUpdate coalesced = new()
{
AdditionalProperties = update.AdditionalProperties?.Clone(),
AuthorName = update.AuthorName,
CompletionId = update.CompletionId,
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
ModelId = update.ModelId,
Role = update.Role,
// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
// to represent multiple, and it won't be serialized anyway.
};
// Replace the starting node with the coalesced node.
capturedItems[startInclusive] = coalesced;
// Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties,
// and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation.
// We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of
// the nulls in a single O(N) pass.
for (int i = startInclusive + 1; i < endExclusive; i++)
{
// Grab the next item.
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;
var nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);
coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.ModelId ??= next.ModelId;
coalesced.Role ??= next.Role;
}
// Complete the coalescing by patching the text of the coalesced node.
coalesced.Text = coalescedText.ToString();
// Jump to the last update in the run, so that when we loop around and bump ahead,
// we're at the next update just after the run.
startInclusive = endExclusive - 1;
yield return chunk;
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
capturedItems.Add(chunk);
yield return chunk;
}
// Remove all of the null slots left over from the coalescing process.
_ = capturedItems.RemoveAll(u => u is null);
// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
}

Просмотреть файл

@ -167,4 +167,97 @@ public class ChatCompletionTests
Assert.IsType<JsonElement>(value);
Assert.Equal("value", ((JsonElement)value!).GetString());
}
[Fact]
public void ToStreamingChatCompletionUpdates_SingleChoice()
{
ChatCompletion completion = new(new ChatMessage(new ChatRole("customRole"), "Text"))
{
CompletionId = "12345",
ModelId = "someModel",
FinishReason = ChatFinishReason.ContentFilter,
CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero),
AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 },
};
StreamingChatCompletionUpdate[] updates = completion.ToStreamingChatCompletionUpdates();
Assert.NotNull(updates);
Assert.Equal(2, updates.Length);
StreamingChatCompletionUpdate update0 = updates[0];
Assert.Equal("12345", update0.CompletionId);
Assert.Equal("someModel", update0.ModelId);
Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason);
Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt);
Assert.Equal("customRole", update0.Role?.Value);
Assert.Equal("Text", update0.Text);
StreamingChatCompletionUpdate update1 = updates[1];
Assert.Equal("value1", update1.AdditionalProperties?["key1"]);
Assert.Equal(42, update1.AdditionalProperties?["key2"]);
}
[Fact]
public void ToStreamingChatCompletionUpdates_MultiChoice()
{
ChatCompletion completion = new(
[
new ChatMessage(ChatRole.Assistant,
[
new TextContent("Hello, "),
new ImageContent("http://localhost/image.png"),
new TextContent("world!"),
])
{
AdditionalProperties = new() { ["choice1Key"] = "choice1Value" },
},
new ChatMessage(ChatRole.System,
[
new FunctionCallContent("call123", "name"),
new FunctionResultContent("call123", "name", 42),
])
{
AdditionalProperties = new() { ["choice2Key"] = "choice2Value" },
},
])
{
CompletionId = "12345",
ModelId = "someModel",
FinishReason = ChatFinishReason.ContentFilter,
CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero),
AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 },
Usage = new UsageDetails { TotalTokenCount = 123 },
};
StreamingChatCompletionUpdate[] updates = completion.ToStreamingChatCompletionUpdates();
Assert.NotNull(updates);
Assert.Equal(3, updates.Length);
StreamingChatCompletionUpdate update0 = updates[0];
Assert.Equal("12345", update0.CompletionId);
Assert.Equal("someModel", update0.ModelId);
Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason);
Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt);
Assert.Equal("assistant", update0.Role?.Value);
Assert.Equal("Hello, ", Assert.IsType<TextContent>(update0.Contents[0]).Text);
Assert.IsType<ImageContent>(update0.Contents[1]);
Assert.Equal("world!", Assert.IsType<TextContent>(update0.Contents[2]).Text);
Assert.Equal("choice1Value", update0.AdditionalProperties?["choice1Key"]);
StreamingChatCompletionUpdate update1 = updates[1];
Assert.Equal("12345", update1.CompletionId);
Assert.Equal("someModel", update1.ModelId);
Assert.Equal(ChatFinishReason.ContentFilter, update1.FinishReason);
Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update1.CreatedAt);
Assert.Equal("system", update1.Role?.Value);
Assert.IsType<FunctionCallContent>(update1.Contents[0]);
Assert.IsType<FunctionResultContent>(update1.Contents[1]);
Assert.Equal("choice2Value", update1.AdditionalProperties?["choice2Key"]);
StreamingChatCompletionUpdate update2 = updates[2];
Assert.Equal("value1", update2.AdditionalProperties?["key1"]);
Assert.Equal(42, update2.AdditionalProperties?["key2"]);
Assert.Equal(123, Assert.IsType<UsageContent>(Assert.Single(update2.Contents)).Details.TotalTokenCount);
}
}

Просмотреть файл

@ -189,6 +189,26 @@ public class StreamingChatCompletionUpdateExtensionsTests
}
}
[Fact]
public async Task ToChatCompletion_UsageContentExtractedFromContents()
{
StreamingChatCompletionUpdate[] updates =
{
new() { Text = "Hello, " },
new() { Text = "world!" },
new() { Contents = [new UsageContent(new() { TotalTokenCount = 42 })] },
};
ChatCompletion completion = await YieldAsync(updates).ToChatCompletionAsync();
Assert.NotNull(completion);
Assert.NotNull(completion.Usage);
Assert.Equal(42, completion.Usage.TotalTokenCount);
Assert.Equal("Hello, world!", Assert.IsType<TextContent>(Assert.Single(completion.Message.Contents)).Text);
}
private static async IAsyncEnumerable<StreamingChatCompletionUpdate> YieldAsync(IEnumerable<StreamingChatCompletionUpdate> updates)
{
foreach (StreamingChatCompletionUpdate update in updates)

Просмотреть файл

@ -214,19 +214,18 @@ public class DistributedCachingChatClientTest
// Verify that all the expected properties will round-trip through the cache,
// even if this involves serialization
List<StreamingChatCompletionUpdate> expectedCompletion =
List<StreamingChatCompletionUpdate> actualCompletion =
[
new()
{
Role = new ChatRole("fakeRole1"),
ChoiceIndex = 3,
ChoiceIndex = 1,
AdditionalProperties = new() { ["a"] = "b" },
Contents = [new TextContent("Chunk1")]
},
new()
{
Role = new ChatRole("fakeRole2"),
Text = "Chunk2",
Contents =
[
new FunctionCallContent("someCallId", "someFn", new Dictionary<string, object?> { ["arg1"] = "value1" }),
@ -235,13 +234,33 @@ public class DistributedCachingChatClientTest
}
];
List<StreamingChatCompletionUpdate> expectedCachedCompletion =
[
new()
{
Role = new ChatRole("fakeRole2"),
Contents = [new FunctionCallContent("someCallId", "someFn", new Dictionary<string, object?> { ["arg1"] = "value1" })],
},
new()
{
Role = new ChatRole("fakeRole1"),
ChoiceIndex = 1,
AdditionalProperties = new() { ["a"] = "b" },
Contents = [new TextContent("Chunk1")]
},
new()
{
Contents = [new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 })],
},
];
var innerCallCount = 0;
using var testClient = new TestChatClient
{
CompleteStreamingAsyncCallback = delegate
{
innerCallCount++;
return ToAsyncEnumerableAsync(expectedCompletion);
return ToAsyncEnumerableAsync(actualCompletion);
}
};
using var outer = new DistributedCachingChatClient(testClient, _storage)
@ -251,7 +270,7 @@ public class DistributedCachingChatClientTest
// Make the initial request and do a quick sanity check
var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);
await AssertCompletionsEqualAsync(expectedCompletion, result1);
await AssertCompletionsEqualAsync(actualCompletion, result1);
Assert.Equal(1, innerCallCount);
// Act
@ -259,7 +278,7 @@ public class DistributedCachingChatClientTest
// Assert
Assert.Equal(1, innerCallCount);
await AssertCompletionsEqualAsync(expectedCompletion, result2);
await AssertCompletionsEqualAsync(expectedCachedCompletion, result2);
// Act/Assert 2: Cache misses do not return cached results
await ToListAsync(outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some modified input")]));
@ -306,10 +325,11 @@ public class DistributedCachingChatClientTest
// Assert
if (coalesce is null or true)
{
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this becomes another one.", c.Text));
StreamingChatCompletionUpdate update = Assert.Single(await ToListAsync(result2));
Assert.Collection(update.Contents,
c => Assert.Equal("This becomes one chunk", Assert.IsType<TextContent>(c).Text),
c => Assert.IsType<FunctionCallContent>(c),
c => Assert.Equal("... and this becomes another one.", Assert.IsType<TextContent>(c).Text));
}
else
{
@ -396,7 +416,6 @@ public class DistributedCachingChatClientTest
List<StreamingChatCompletionUpdate> expectedCompletion =
[
new() { Role = ChatRole.Assistant, Text = "Chunk 1" },
new() { Role = ChatRole.System, Text = "Chunk 2" },
];
using var testClient = new TestChatClient
{