Fixed under-count of token usage

This commit is contained in:
Ani 2024-11-04 20:50:38 +01:00
Родитель 12d50cfb58
Коммит 78f1b49e21
5 изменённых файлов: 51 добавлений и 19 удалений

Просмотреть файл

@ -45,8 +45,8 @@ public sealed class KernelServiceIntegrationTests : IDisposable
}
[TestMethod]
[DataRow("Translate to German", "What is that?", "Was ist das?", 600, new[] { PasteFormats.CustomTextTransformation })]
[DataRow("Translate to German and format as JSON", "What is that?", @"[\s*Was ist das\?\s*]", 600, new[] { PasteFormats.CustomTextTransformation, PasteFormats.Json })]
[DataRow("Translate to German", "What is that?", "Was ist das?", 1200, new[] { PasteFormats.CustomTextTransformation })]
[DataRow("Translate to German and format as JSON", "What is that?", @"[\s*Was ist das\?\s*]", 1500, new[] { PasteFormats.CustomTextTransformation, PasteFormats.Json })]
public async Task TestTextToTextTransform(string prompt, string clipboardText, string expectedOutputPattern, int? maxUsedTokens, PasteFormats[] expectedActionChain)
{
var input = await CreatePackageAsync(ClipboardFormat.Text, clipboardText);

Просмотреть файл

@ -53,19 +53,37 @@ internal static class DataPackageHelpers
{
var storageItems = await dataPackageView.GetStorageItemsAsync();
if (storageItems.Count == 1 && storageItems.Single() is StorageFile file && ImageFileTypes.Contains(file.FileType))
if (storageItems.Count == 1 && storageItems.Single() is StorageFile file)
{
availableFormats |= ClipboardFormat.Image;
}
if (availableFormats == ClipboardFormat.None)
{
// Advertise the "generic" File format only if there is no other specific format available; confusing for AI otherwise.
availableFormats |= ClipboardFormat.File;
if (ImageFileTypes.Contains(file.FileType))
{
availableFormats |= ClipboardFormat.Image;
}
}
}
return availableFormats;
return FixFormatsForAI(availableFormats);
}
private static ClipboardFormat FixFormatsForAI(ClipboardFormat formats)
{
var result = formats;
if (result.HasFlag(ClipboardFormat.File) && result != ClipboardFormat.File)
{
// Advertise the "generic" File format only if there is no other specific format available; confusing for AI otherwise.
result &= ~ClipboardFormat.File;
}
if (result == (ClipboardFormat.Image | ClipboardFormat.Html))
{
// The Windows Photo application advertises Image and Html when copying an image; this Html format is not easily usable and is confusing for AI.
result &= ~ClipboardFormat.Html;
}
return result;
}
internal static async Task<bool> HasUsableDataAsync(this DataPackageView dataPackageView)

Просмотреть файл

@ -7,4 +7,12 @@ namespace AdvancedPaste.Models;
public record class AIServiceUsage(int PromptTokens, int CompletionTokens)
{
public static AIServiceUsage None => new(PromptTokens: 0, CompletionTokens: 0);
public bool HasUsage => PromptTokens > 0 || CompletionTokens > 0;
public static AIServiceUsage Add(AIServiceUsage first, AIServiceUsage second) =>
new(first.PromptTokens + second.PromptTokens, first.CompletionTokens + second.CompletionTokens);
public override string ToString() =>
$"{nameof(PromptTokens)}: {PromptTokens}, {nameof(CompletionTokens)}: {CompletionTokens}";
}

Просмотреть файл

@ -33,7 +33,7 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
protected abstract void AddChatCompletionService(IKernelBuilder kernelBuilder);
protected abstract AIServiceUsage GetAIServiceUsage(ChatMessageContent chatResult);
protected abstract AIServiceUsage GetAIServiceUsage(ChatMessageContent chatMessage);
public async Task<DataPackage> TransformClipboardAsync(string prompt, DataPackageView clipboardData, bool isSavedQuery)
{
@ -104,10 +104,14 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
chatHistory.AddSystemMessage($"Available clipboard formats: {await kernel.GetDataFormatsAsync()}");
chatHistory.AddUserMessage(prompt);
var chatResult = await kernel.GetRequiredService<IChatCompletionService>().GetChatMessageContentAsync(chatHistory, PromptExecutionSettings, kernel);
chatHistory.AddAssistantMessage(chatResult.Content);
var chatResult = await kernel.GetRequiredService<IChatCompletionService>()
.GetChatMessageContentAsync(chatHistory, PromptExecutionSettings, kernel);
chatHistory.Add(chatResult);
return (chatHistory, GetAIServiceUsage(chatResult));
var totalUsage = chatHistory.Select(GetAIServiceUsage)
.Aggregate(AIServiceUsage.Add);
return (chatHistory, totalUsage);
}
private async Task<(ChatHistory ChatHistory, AIServiceUsage Usage)> ExecuteCachedActionChain(Kernel kernel, List<ActionChainItem> actionChain)
@ -202,10 +206,10 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
}
}
private static string FormatChatHistory(ChatHistory chatHistory) =>
private string FormatChatHistory(ChatHistory chatHistory) =>
chatHistory.Count == 0 ? "[No chat history]" : string.Join(Environment.NewLine, chatHistory.Select(FormatChatMessage));
private static string FormatChatMessage(ChatMessageContent chatMessage)
private string FormatChatMessage(ChatMessageContent chatMessage)
{
static string Redact(object data) =>
#if DEBUG
@ -230,6 +234,8 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
var role = chatMessage.Role;
var content = string.Join(" / ", chatMessage.Items.Select(FormatKernelContent));
var redactedContent = role == AuthorRole.System || role == AuthorRole.Tool ? content : Redact(content);
return $"-> {role}: {redactedContent}";
var usage = GetAIServiceUsage(chatMessage);
var usageString = usage.HasUsage ? $" [{usage}]" : string.Empty;
return $"-> {role}: {redactedContent}{usageString}";
}
}

Просмотреть файл

@ -27,8 +27,8 @@ public sealed class KernelService(IKernelQueryCacheService queryCacheService, IA
protected override void AddChatCompletionService(IKernelBuilder kernelBuilder) => kernelBuilder.AddOpenAIChatCompletion(ModelName, _aiCredentialsProvider.Key);
protected override AIServiceUsage GetAIServiceUsage(ChatMessageContent chatResult) =>
chatResult.Metadata?.GetValueOrDefault("Usage") is CompletionsUsage completionsUsage
protected override AIServiceUsage GetAIServiceUsage(ChatMessageContent chatMessage) =>
chatMessage.Metadata?.GetValueOrDefault("Usage") is CompletionsUsage completionsUsage
? new(PromptTokens: completionsUsage.PromptTokens, CompletionTokens: completionsUsage.CompletionTokens)
: AIServiceUsage.None;
}