diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs index 89182e2616..2cebeb71c2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs @@ -87,4 +87,53 @@ public class ChatCompletion /// public override string ToString() => Choices is { Count: > 0 } choices ? string.Join(Environment.NewLine, choices) : string.Empty; + + /// Creates an array of instances that represent this . + /// An array of instances that may be used to represent this . + public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates() + { + StreamingChatCompletionUpdate? extra = null; + if (AdditionalProperties is not null || Usage is not null) + { + extra = new StreamingChatCompletionUpdate + { + AdditionalProperties = AdditionalProperties + }; + + if (Usage is { } usage) + { + extra.Contents.Add(new UsageContent(usage)); + } + } + + int choicesCount = Choices.Count; + var updates = new StreamingChatCompletionUpdate[choicesCount + (extra is null ? 0 : 1)]; + + for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++) + { + ChatMessage choice = Choices[choiceIndex]; + updates[choiceIndex] = new StreamingChatCompletionUpdate + { + ChoiceIndex = choiceIndex, + + AdditionalProperties = choice.AdditionalProperties, + AuthorName = choice.AuthorName, + Contents = choice.Contents, + RawRepresentation = choice.RawRepresentation, + Role = choice.Role, + + CompletionId = CompletionId, + CreatedAt = CreatedAt, + FinishReason = FinishReason, + ModelId = ModelId + }; + } + + if (extra is not null) + { + updates[choicesCount] = extra; + } + + return updates; + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs index 278d875258..f63381c575 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs @@ -9,14 +9,35 @@ using System.Text.Json.Serialization; namespace Microsoft.Extensions.AI; -// Conceptually this combines the roles of ChatCompletion and ChatMessage in streaming output. -// For ease of consumption, it also flattens the nested structure you see on streaming chunks in -// the OpenAI/Gemini APIs, so instead of a dictionary of choices, each update represents a single -// choice (and hence has its own role, choice ID, etc.). - /// -/// Represents a single response chunk from an . +/// Represents a single streaming response chunk from an . /// +/// +/// +/// Conceptually, this combines the roles of and +/// in streaming output. For ease of consumption, it also flattens the nested structure you see on +/// streaming chunks in some AI service, so instead of a dictionary of choices, each update represents a +/// single choice (and hence has its own role, choice ID, etc.). +/// +/// +/// is so named because it represents streaming updates +/// to a single chat completion. As such, it is considered erroneous for multiple updates that are part +/// of the same completion to contain competing values. For example, some updates that are part of +/// the same completion may have a +/// value, and others may have a non- value, but all of those with a non- +/// value must have the same value (e.g. . It should never be the case, for example, +/// that one in a completion has a role of +/// while another has a role of "AI". +/// +/// +/// The relationship between and is +/// codified in the and +/// , which enable bidirectional conversions +/// between the two. Note, however, that the conversion may be slightly lossy, for example if multiple updates +/// all have different objects whereas there's +/// only one slot for such an object available in . +/// +/// public class StreamingChatCompletionUpdate { /// The completion update content items. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs index 05ac80dd68..928b9366a2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq; #if NET using System.Runtime.InteropServices; #endif @@ -133,7 +134,22 @@ public static class StreamingChatCompletionUpdateExtensions /// The corresponding option value provided to or . private static void AddMessagesToCompletion(Dictionary messages, ChatCompletion completion, bool coalesceContent) { - foreach (var entry in messages) + if (messages.Count <= 1) + { + foreach (var entry in messages) + { + AddMessage(completion, coalesceContent, entry); + } + } + else + { + foreach (var entry in messages.OrderBy(entry => entry.Key)) + { + AddMessage(completion, coalesceContent, entry); + } + } + + static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValuePair entry) { if (entry.Value.Role == default) { @@ -154,6 +170,8 @@ public static class StreamingChatCompletionUpdateExtensions if (content is UsageContent c) { completion.Usage = c.Details; + entry.Value.Contents = entry.Value.Contents.ToList(); + _ = entry.Value.Contents.Remove(c); break; } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index ad62034617..770ffa60cf 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -48,13 +47,12 @@ public abstract class CachingChatClient : DelegatingChatClient // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(false, chatMessages, options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing) + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) { - return existing; + result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); } - var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); - await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); return result; } @@ -64,127 +62,59 @@ public abstract class CachingChatClient : DelegatingChatClient { _ = Throw.IfNull(chatMessages); - var cacheKey = GetCacheKey(true, chatMessages, options); - if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) + if (CoalesceStreamingUpdates) { - // Yield all of the cached items. - foreach (var chunk in existingChunks) + // When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means + // we make a streaming request, yielding those results, but then convert those into a non-streaming + // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. + + var cacheKey = GetCacheKey(true, chatMessages, options); + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion) { - yield return chunk; + // Yield all of the cached items. + foreach (var chunk in chatCompletion.ToStreamingChatCompletionUpdates()) + { + yield return chunk; + } + } + else + { + // Yield and store all of the items. + List capturedItems = []; + await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + capturedItems.Add(chunk); + yield return chunk; + } + + // Write the captured items to the cache as a non-streaming result. + await WriteCacheAsync(cacheKey, capturedItems.ToChatCompletion(), cancellationToken).ConfigureAwait(false); } } else { - // Yield and store all of the items. - List capturedItems = []; - await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + var cacheKey = GetCacheKey(true, chatMessages, options); + if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) { - capturedItems.Add(chunk); - yield return chunk; - } - - // If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list. - if (CoalesceStreamingUpdates) - { - StringBuilder coalescedText = new(); - - // Iterate through all of the items in the list looking for contiguous items that can be coalesced. - for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++) + // Yield all of the cached items. + foreach (var chunk in existingChunks) { - // If an item isn't generally coalescable, skip it. - StreamingChatCompletionUpdate update = capturedItems[startInclusive]; - if (update.ChoiceIndex != 0 || - update.Contents.Count != 1 || - update.Contents[0] is not TextContent textContent) - { - continue; - } - - // We found a coalescable item. Look for more contiguous items that are also coalescable with it. - int endExclusive = startInclusive + 1; - for (; endExclusive < capturedItems.Count; endExclusive++) - { - StreamingChatCompletionUpdate next = capturedItems[endExclusive]; - if (next.ChoiceIndex != 0 || - next.Contents.Count != 1 || - next.Contents[0] is not TextContent || - - // changing role or author would be really strange, but check anyway - (update.Role is not null && next.Role is not null && update.Role != next.Role) || - (update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName)) - { - break; - } - } - - // If we couldn't find anything to coalesce, there's nothing to do. - if (endExclusive - startInclusive <= 1) - { - continue; - } - - // We found a coalescable run of items. Create a new node to represent the run. We create a new one - // rather than reappropriating one of the existing ones so as not to mutate an item already yielded. - _ = coalescedText.Clear().Append(capturedItems[startInclusive].Text); - - TextContent coalescedContent = new(null) // will patch the text after examining all items in the run - { - AdditionalProperties = textContent.AdditionalProperties?.Clone(), - }; - - StreamingChatCompletionUpdate coalesced = new() - { - AdditionalProperties = update.AdditionalProperties?.Clone(), - AuthorName = update.AuthorName, - CompletionId = update.CompletionId, - Contents = [coalescedContent], - CreatedAt = update.CreatedAt, - FinishReason = update.FinishReason, - ModelId = update.ModelId, - Role = update.Role, - - // Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used - // to represent multiple, and it won't be serialized anyway. - }; - - // Replace the starting node with the coalesced node. - capturedItems[startInclusive] = coalesced; - - // Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties, - // and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation. - // We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of - // the nulls in a single O(N) pass. - for (int i = startInclusive + 1; i < endExclusive; i++) - { - // Grab the next item. - StreamingChatCompletionUpdate next = capturedItems[i]; - capturedItems[i] = null!; - - var nextContent = (TextContent)next.Contents[0]; - _ = coalescedText.Append(nextContent.Text); - - coalesced.AuthorName ??= next.AuthorName; - coalesced.CompletionId ??= next.CompletionId; - coalesced.CreatedAt ??= next.CreatedAt; - coalesced.FinishReason ??= next.FinishReason; - coalesced.ModelId ??= next.ModelId; - coalesced.Role ??= next.Role; - } - - // Complete the coalescing by patching the text of the coalesced node. - coalesced.Text = coalescedText.ToString(); - - // Jump to the last update in the run, so that when we loop around and bump ahead, - // we're at the next update just after the run. - startInclusive = endExclusive - 1; + yield return chunk; + } + } + else + { + // Yield and store all of the items. + List capturedItems = []; + await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + capturedItems.Add(chunk); + yield return chunk; } - // Remove all of the null slots left over from the coalescing process. - _ = capturedItems.RemoveAll(u => u is null); + // Write the captured items to the cache. + await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); } - - // Write the captured items to the cache. - await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs index a695e686f6..35184f3ee5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs @@ -167,4 +167,97 @@ public class ChatCompletionTests Assert.IsType(value); Assert.Equal("value", ((JsonElement)value!).GetString()); } + + [Fact] + public void ToStreamingChatCompletionUpdates_SingleChoice() + { + ChatCompletion completion = new(new ChatMessage(new ChatRole("customRole"), "Text")) + { + CompletionId = "12345", + ModelId = "someModel", + FinishReason = ChatFinishReason.ContentFilter, + CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), + AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 }, + }; + + StreamingChatCompletionUpdate[] updates = completion.ToStreamingChatCompletionUpdates(); + Assert.NotNull(updates); + Assert.Equal(2, updates.Length); + + StreamingChatCompletionUpdate update0 = updates[0]; + Assert.Equal("12345", update0.CompletionId); + Assert.Equal("someModel", update0.ModelId); + Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason); + Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt); + Assert.Equal("customRole", update0.Role?.Value); + Assert.Equal("Text", update0.Text); + + StreamingChatCompletionUpdate update1 = updates[1]; + Assert.Equal("value1", update1.AdditionalProperties?["key1"]); + Assert.Equal(42, update1.AdditionalProperties?["key2"]); + } + + [Fact] + public void ToStreamingChatCompletionUpdates_MultiChoice() + { + ChatCompletion completion = new( + [ + new ChatMessage(ChatRole.Assistant, + [ + new TextContent("Hello, "), + new ImageContent("http://localhost/image.png"), + new TextContent("world!"), + ]) + { + AdditionalProperties = new() { ["choice1Key"] = "choice1Value" }, + }, + + new ChatMessage(ChatRole.System, + [ + new FunctionCallContent("call123", "name"), + new FunctionResultContent("call123", "name", 42), + ]) + { + AdditionalProperties = new() { ["choice2Key"] = "choice2Value" }, + }, + ]) + { + CompletionId = "12345", + ModelId = "someModel", + FinishReason = ChatFinishReason.ContentFilter, + CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), + AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 }, + Usage = new UsageDetails { TotalTokenCount = 123 }, + }; + + StreamingChatCompletionUpdate[] updates = completion.ToStreamingChatCompletionUpdates(); + Assert.NotNull(updates); + Assert.Equal(3, updates.Length); + + StreamingChatCompletionUpdate update0 = updates[0]; + Assert.Equal("12345", update0.CompletionId); + Assert.Equal("someModel", update0.ModelId); + Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason); + Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt); + Assert.Equal("assistant", update0.Role?.Value); + Assert.Equal("Hello, ", Assert.IsType(update0.Contents[0]).Text); + Assert.IsType(update0.Contents[1]); + Assert.Equal("world!", Assert.IsType(update0.Contents[2]).Text); + Assert.Equal("choice1Value", update0.AdditionalProperties?["choice1Key"]); + + StreamingChatCompletionUpdate update1 = updates[1]; + Assert.Equal("12345", update1.CompletionId); + Assert.Equal("someModel", update1.ModelId); + Assert.Equal(ChatFinishReason.ContentFilter, update1.FinishReason); + Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update1.CreatedAt); + Assert.Equal("system", update1.Role?.Value); + Assert.IsType(update1.Contents[0]); + Assert.IsType(update1.Contents[1]); + Assert.Equal("choice2Value", update1.AdditionalProperties?["choice2Key"]); + + StreamingChatCompletionUpdate update2 = updates[2]; + Assert.Equal("value1", update2.AdditionalProperties?["key1"]); + Assert.Equal(42, update2.AdditionalProperties?["key2"]); + Assert.Equal(123, Assert.IsType(Assert.Single(update2.Contents)).Details.TotalTokenCount); + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs index bb0f08325d..33eca7dcaa 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs @@ -189,6 +189,26 @@ public class StreamingChatCompletionUpdateExtensionsTests } } + [Fact] + public async Task ToChatCompletion_UsageContentExtractedFromContents() + { + StreamingChatCompletionUpdate[] updates = + { + new() { Text = "Hello, " }, + new() { Text = "world!" }, + new() { Contents = [new UsageContent(new() { TotalTokenCount = 42 })] }, + }; + + ChatCompletion completion = await YieldAsync(updates).ToChatCompletionAsync(); + + Assert.NotNull(completion); + + Assert.NotNull(completion.Usage); + Assert.Equal(42, completion.Usage.TotalTokenCount); + + Assert.Equal("Hello, world!", Assert.IsType(Assert.Single(completion.Message.Contents)).Text); + } + private static async IAsyncEnumerable YieldAsync(IEnumerable updates) { foreach (StreamingChatCompletionUpdate update in updates) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 7f6ca20915..67e23ec495 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -214,19 +214,18 @@ public class DistributedCachingChatClientTest // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization - List expectedCompletion = + List actualCompletion = [ new() { Role = new ChatRole("fakeRole1"), - ChoiceIndex = 3, + ChoiceIndex = 1, AdditionalProperties = new() { ["a"] = "b" }, Contents = [new TextContent("Chunk1")] }, new() { Role = new ChatRole("fakeRole2"), - Text = "Chunk2", Contents = [ new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" }), @@ -235,13 +234,33 @@ public class DistributedCachingChatClientTest } ]; + List expectedCachedCompletion = + [ + new() + { + Role = new ChatRole("fakeRole2"), + Contents = [new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" })], + }, + new() + { + Role = new ChatRole("fakeRole1"), + ChoiceIndex = 1, + AdditionalProperties = new() { ["a"] = "b" }, + Contents = [new TextContent("Chunk1")] + }, + new() + { + Contents = [new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 })], + }, + ]; + var innerCallCount = 0; using var testClient = new TestChatClient { CompleteStreamingAsyncCallback = delegate { innerCallCount++; - return ToAsyncEnumerableAsync(expectedCompletion); + return ToAsyncEnumerableAsync(actualCompletion); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -251,7 +270,7 @@ public class DistributedCachingChatClientTest // Make the initial request and do a quick sanity check var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); - await AssertCompletionsEqualAsync(expectedCompletion, result1); + await AssertCompletionsEqualAsync(actualCompletion, result1); Assert.Equal(1, innerCallCount); // Act @@ -259,7 +278,7 @@ public class DistributedCachingChatClientTest // Assert Assert.Equal(1, innerCallCount); - await AssertCompletionsEqualAsync(expectedCompletion, result2); + await AssertCompletionsEqualAsync(expectedCachedCompletion, result2); // Act/Assert 2: Cache misses do not return cached results await ToListAsync(outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some modified input")])); @@ -306,10 +325,11 @@ public class DistributedCachingChatClientTest // Assert if (coalesce is null or true) { - Assert.Collection(await ToListAsync(result2), - c => Assert.Equal("This becomes one chunk", c.Text), - c => Assert.IsType(Assert.Single(c.Contents)), - c => Assert.Equal("... and this becomes another one.", c.Text)); + StreamingChatCompletionUpdate update = Assert.Single(await ToListAsync(result2)); + Assert.Collection(update.Contents, + c => Assert.Equal("This becomes one chunk", Assert.IsType(c).Text), + c => Assert.IsType(c), + c => Assert.Equal("... and this becomes another one.", Assert.IsType(c).Text)); } else { @@ -396,7 +416,6 @@ public class DistributedCachingChatClientTest List expectedCompletion = [ new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, - new() { Role = ChatRole.System, Text = "Chunk 2" }, ]; using var testClient = new TestChatClient {