From 3659a485a0e08dc7fa77152945e31fdf6fccd519 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Tue, 5 Nov 2024 09:27:05 -0800 Subject: [PATCH] [GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (#7270) * leverage MEAI abstraction * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub * Update src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs Co-authored-by: Stephen Toub * fix comments * Update Microsoft.ML.GenAI.Core.csproj --------- Co-authored-by: Stephen Toub --- .../MEAI/Llama3_1.cs | 54 ++++++++ .../Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs | 44 ++++++ .../Microsoft.ML.GenAI.Samples/Program.cs | 4 +- eng/Versions.props | 3 +- .../CausalLMPipelineChatClient.cs | 89 ++++++++++++ .../Microsoft.ML.GenAI.Core.csproj | 1 + .../Utility/IChatTemplateBuilder.cs | 6 + .../Llama3CausalLMChatClient.cs | 57 ++++++++ .../Llama3_1ChatTemplateBuilder.cs | 38 +++++- .../Phi3/Phi3CausalLMAgent.cs | 61 +++------ .../Phi3/Phi3CausalLMChatClient.cs | 62 +++++++++ .../Phi3/Phi3CausalLMChatCompletionService.cs | 42 +----- .../Phi3/Phi3CausalLMTextGenerationService.cs | 3 +- .../Phi3/Phi3ChatTemplateBuilder.cs | 127 ++++++++++++++++++ ...atTemplateFromMEAIChatHistory.approved.txt | 7 + .../LLaMA3_1Tests.cs | 18 +++ 16 files changed, 534 insertions(+), 82 deletions(-) create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs create mode 100644 src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs create mode 100644 src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs create mode 100644 src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs create mode 100644 test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromMEAIChatHistory.approved.txt diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs b/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs new file mode 100644 index 000000000..4416fe757 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs @@ -0,0 +1,54 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using AutoGen.Core; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.Core.Extension; +using Microsoft.ML.GenAI.LLaMA; +using Microsoft.ML.Tokenizers; +using TorchSharp; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Samples.MEAI; + +internal class Llama3_1 +{ + public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json") + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.BFloat16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder, "original"); + + Console.WriteLine("Loading Llama from huggingface model weight folder"); + var stopWatch = System.Diagnostics.Stopwatch.StartNew(); + stopWatch.Start(); + var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + var client = new Llama3CausalLMChatClient(pipeline); + + var task = """ + Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```. + """; + var chatMessage = new ChatMessage(ChatRole.User, task); + + await foreach (var response in client.CompleteStreamingAsync([chatMessage])) + { + Console.Write(response.Text); + } + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs b/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs new file mode 100644 index 000000000..7bba28bc1 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.ML.GenAI.Phi; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; +using Microsoft.Extensions.AI; + +namespace Microsoft.ML.GenAI.Samples.MEAI; + +internal class Phi3 +{ + public static async Task RunAsync(string weightFolder) + { + var device = "cuda"; + if (device == "cuda") + { + torch.InitializeDeviceType(DeviceType.CUDA); + } + + var defaultType = ScalarType.Float16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model"); + var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath); + var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true); + var pipeline = new CausalLMPipeline(tokenizer, model, device); + var client = new Phi3CausalLMChatClient(pipeline); + + var task = """ + Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```. + """; + var chatMessage = new ChatMessage(ChatRole.User, task); + + await foreach (var response in client.CompleteStreamingAsync([chatMessage])) + { + Console.Write(response.Text); + } + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 769e9f0fb..de091afe4 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -1,4 +1,6 @@ // See https://aka.ms/new-console-template for more information using Microsoft.ML.GenAI.Samples.Llama; +using Microsoft.ML.GenAI.Samples.MEAI; -await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct"); +//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); +await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); diff --git a/eng/Versions.props b/eng/Versions.props index 48c8bb2e1..e2e60447f 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -34,7 +34,7 @@ 6.0.0 5.0.0 8.0.0 - 8.0.4 + 8.0.5 8.0.0 14.0.2 @@ -47,6 +47,7 @@ 1.0.0-beta.24375.2 1.18.1 0.0.0.12 + 9.0.0-preview.9.24507.7 diff --git a/src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs b/src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs new file mode 100644 index 000000000..c0aa398ed --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.ML.Tokenizers; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Core; + +public abstract class CausalLMPipelineChatClient : IChatClient + where TTokenizer : Tokenizer + where TCausalLMModel : nn.Module +{ + private readonly ICausalLMPipeline _pipeline; + private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder; + + public CausalLMPipelineChatClient( + ICausalLMPipeline pipeline, + IMEAIChatTemplateBuilder chatTemplateBuilder, + ChatClientMetadata? metadata = null) + { + var classNameWithType = $"{nameof(CausalLMPipelineChatClient)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>"; + Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name); + _chatTemplateBuilder = chatTemplateBuilder; + _pipeline = pipeline; + } + + public ChatClientMetadata Metadata { get; } + + public virtual Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options); + var stopSequences = options?.StopSequences ?? Array.Empty(); + + var output = _pipeline.Generate( + prompt, + maxLen: options?.MaxOutputTokens ?? 1024, + temperature: options?.Temperature ?? 0.7f, + stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply."); + + var chatMessage = new ChatMessage(ChatRole.Assistant, output); + return Task.FromResult(new ChatCompletion([chatMessage]) + { + CreatedAt = DateTime.UtcNow, + FinishReason = ChatFinishReason.Stop, + }); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public virtual async IAsyncEnumerable CompleteStreamingAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + IList chatMessages, + ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options); + var stopSequences = options?.StopSequences ?? Array.Empty(); + + foreach (var output in _pipeline.GenerateStreaming( + prompt, + maxLen: options?.MaxOutputTokens ?? 1024, + temperature: options?.Temperature ?? 0.7f, + stopSequences: stopSequences.ToArray())) + { + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = output, + CreatedAt = DateTime.UtcNow, + }; + } + } + + public virtual void Dispose() + { + } + + public virtual TService? GetService(object? key = null) where TService : class + { + return null; + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 59cc59edc..0efeeb093 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -13,6 +13,7 @@ + diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs index 4cf5a00ab..7d9292562 100644 --- a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; using AutoGen.Core; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.ML.GenAI.Core; @@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder string BuildPrompt(IEnumerable messages, IEnumerable? tools = null); } +public interface IMEAIChatTemplateBuilder +{ + string BuildPrompt(IList messages, ChatOptions? options = null); +} + public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder { } diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs new file mode 100644 index 000000000..ad0b58c3b --- /dev/null +++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; + +namespace Microsoft.ML.GenAI.LLaMA; + +public class Llama3CausalLMChatClient : CausalLMPipelineChatClient +{ + private readonly string _eotToken = "<|eot_id|>"; + + public Llama3CausalLMChatClient( + ICausalLMPipeline pipeline, + IMEAIChatTemplateBuilder? chatTemplateBuilder = null, + ChatClientMetadata? metadata = null) + : base( + pipeline, + chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance, + metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient))) + { + } + + public override Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + options ??= new ChatOptions(); + + if (options.StopSequences != null) + { + options.StopSequences.Add(_eotToken); + } + else + { + options.StopSequences = new List { _eotToken }; + } + + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + + public override IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + options ??= new ChatOptions(); + options.StopSequences ??= []; + options.StopSequences.Add(_eotToken); + + return base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } +} diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs index 29e7fb1da..f54e24b9f 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs @@ -4,13 +4,15 @@ using System.Text; using AutoGen.Core; +using Microsoft.Extensions.AI; using Microsoft.ML.GenAI.Core; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; +using TextContent = Microsoft.SemanticKernel.TextContent; namespace Microsoft.ML.GenAI.LLaMA; #pragma warning disable MSML_GeneralName // This name should be PascalCased -public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder +public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder #pragma warning restore MSML_GeneralName // This name should be PascalCased { private const char Newline = '\n'; @@ -86,5 +88,39 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder return sb.ToString(); } + public string BuildPrompt(IList messages, ChatOptions? options = null) + { + var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; + if (messages.Any(m => m.Text is null)) + { + throw new InvalidOperationException("Please provide a message with content."); + } + + if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false)) + { + throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant."); + } + + var sb = new StringBuilder(); + sb.Append("<|begin_of_text|>"); + foreach (var message in messages) + { + var role = message.Role.Value; + var content = message.Text!; + sb.Append(message switch + { + _ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}", + _ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}", + _ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}"); + var input = sb.ToString(); + + return input; + } + public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder(); } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs index 2b9e93a4a..6971ac599 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs @@ -19,22 +19,31 @@ public class Phi3Agent : IStreamingAgent private const char Newline = '\n'; private readonly ICausalLMPipeline _pipeline; private readonly string? _systemMessage; + private readonly IAutoGenChatTemplateBuilder _templateBuilder; public Phi3Agent( ICausalLMPipeline pipeline, string name, - string? systemMessage = "you are a helpful assistant") + string? systemMessage = "you are a helpful assistant", + IAutoGenChatTemplateBuilder? templateBuilder = null) { this.Name = name; this._pipeline = pipeline; this._systemMessage = systemMessage; + this._templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance; } public string Name { get; } public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var input = BuildPrompt(messages); + if (_systemMessage != null) + { + var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); + messages = messages.Prepend(systemMessage); + } + + var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; @@ -56,7 +65,13 @@ public class Phi3Agent : IStreamingAgent GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var input = BuildPrompt(messages); + if (_systemMessage != null) + { + var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name); + messages = messages.Prepend(systemMessage); + } + + var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; @@ -71,44 +86,4 @@ public class Phi3Agent : IStreamingAgent yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name); } } - - private string BuildPrompt(IEnumerable messages) - { - var availableRoles = new[] { Role.System, Role.User, Role.Assistant }; - if (messages.Any(m => m.GetContent() is null)) - { - throw new InvalidOperationException("Please provide a message with content."); - } - - if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false)) - { - throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant."); - } - - // construct template based on instruction from - // https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format - - var sb = new StringBuilder(); - if (_systemMessage is not null) - { - sb.Append($"<|system|>{Newline}{_systemMessage}<|end|>{Newline}"); - } - foreach (var message in messages) - { - var role = message.GetRole()!.Value; - var content = message.GetContent()!; - sb.Append(message switch - { - _ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}", - _ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}", - _ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}", - _ => throw new InvalidOperationException("Invalid role.") - }); - } - - sb.Append("<|assistant|>"); - var input = sb.ToString(); - - return input; - } } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs new file mode 100644 index 000000000..9477f423b --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.Tokenizers; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3CausalLMChatClient : CausalLMPipelineChatClient +{ + private readonly string _eotToken = "<|end|>"; + + public Phi3CausalLMChatClient( + ICausalLMPipeline pipeline, + IMEAIChatTemplateBuilder? templateBuilder = null, + ChatClientMetadata? metadata = null) + : base( + pipeline, + templateBuilder ?? Phi3ChatTemplateBuilder.Instance, + metadata ?? new ChatClientMetadata(modelId: nameof(Phi3CausalLMChatClient))) + { + } + + public override Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + options ??= new ChatOptions(); + + if (options.StopSequences != null) + { + options.StopSequences.Add(_eotToken); + } + else + { + options.StopSequences = [_eotToken]; + } + + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + + public override IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + options ??= new ChatOptions(); + options.StopSequences ??= []; + options.StopSequences.Add(_eotToken); + + return base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } +} diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs index 480e0d7e0..1d9588265 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs @@ -16,12 +16,15 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService { private readonly ICausalLMPipeline _pipeline; private readonly Phi3CausalLMTextGenerationService _textGenerationService; - private const char NewLine = '\n'; // has to be \n, \r\n will cause wanky result. + private readonly ISemanticKernelChatTemplateBuilder _templateBuilder; - public Phi3CausalLMChatCompletionService(ICausalLMPipeline pipeline) + public Phi3CausalLMChatCompletionService( + ICausalLMPipeline pipeline, + ISemanticKernelChatTemplateBuilder? templateBuilder = null) { _pipeline = pipeline; _textGenerationService = new Phi3CausalLMTextGenerationService(pipeline); + _templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance; } public IReadOnlyDictionary Attributes => _textGenerationService.Attributes; @@ -32,7 +35,7 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService Kernel? kernel = null, CancellationToken cancellationToken = default) { - var prompt = BuildPrompt(chatHistory); + var prompt = _templateBuilder.BuildPrompt(chatHistory); var replies = await _textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken); return replies.Select(reply => new ChatMessageContent(AuthorRole.Assistant, reply.Text)).ToList(); } @@ -44,42 +47,11 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var prompt = BuildPrompt(chatHistory); + var prompt = _templateBuilder.BuildPrompt(chatHistory); await foreach (var reply in _textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken)) { yield return new StreamingChatMessageContent(AuthorRole.Assistant, reply.Text); } } - - private string BuildPrompt(ChatHistory chatHistory) - { - // build prompt from chat history - var sb = new StringBuilder(); - - foreach (var message in chatHistory) - { - foreach (var item in message.Items) - { - if (item is not TextContent textContent) - { - throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}"); - } - - var prompt = message.Role switch - { - _ when message.Role == AuthorRole.System => $"<|system|>{NewLine}{textContent}<|end|>{NewLine}", - _ when message.Role == AuthorRole.User => $"<|user|>{NewLine}{textContent}<|end|>{NewLine}", - _ when message.Role == AuthorRole.Assistant => $"<|assistant|>{NewLine}{textContent}<|end|>{NewLine}", - _ => throw new NotSupportedException($"Unsupported role {message.Role}") - }; - - sb.Append(prompt); - } - } - - sb.Append("<|assistant|>"); - - return sb.ToString(); - } } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs index ac22b4f35..d4c8c34e8 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMTextGenerationService.cs @@ -14,7 +14,8 @@ public class Phi3CausalLMTextGenerationService : ITextGenerationService { private readonly ICausalLMPipeline _pipeline; - public Phi3CausalLMTextGenerationService(ICausalLMPipeline pipeline) + public Phi3CausalLMTextGenerationService( + ICausalLMPipeline pipeline) { _pipeline = pipeline; } diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs new file mode 100644 index 000000000..213b1f740 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs @@ -0,0 +1,127 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AutoGen.Core; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using TextContent = Microsoft.SemanticKernel.TextContent; + +namespace Microsoft.ML.GenAI.Phi; + +public class Phi3ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder +{ + private const char Newline = '\n'; + + public static Phi3ChatTemplateBuilder Instance { get; } = new Phi3ChatTemplateBuilder(); + + public string BuildPrompt(IEnumerable messages, IEnumerable? tools = null) + { + var availableRoles = new[] { Role.System, Role.User, Role.Assistant }; + if (messages.Any(m => m.GetContent() is null)) + { + throw new InvalidOperationException("Please provide a message with content."); + } + + if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false)) + { + throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant."); + } + + // construct template based on instruction from + // https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format + + var sb = new StringBuilder(); + foreach (var message in messages) + { + var role = message.GetRole()!.Value; + var content = message.GetContent()!; + sb.Append(message switch + { + _ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}", + _ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}", + _ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + sb.Append("<|assistant|>"); + var input = sb.ToString(); + + return input; + } + + public string BuildPrompt(ChatHistory chatHistory) + { + // build prompt from chat history + var sb = new StringBuilder(); + + foreach (var message in chatHistory) + { + foreach (var item in message.Items) + { + if (item is not TextContent textContent) + { + throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}"); + } + + var prompt = message.Role switch + { + _ when message.Role == AuthorRole.System => $"<|system|>{Newline}{textContent}<|end|>{Newline}", + _ when message.Role == AuthorRole.User => $"<|user|>{Newline}{textContent}<|end|>{Newline}", + _ when message.Role == AuthorRole.Assistant => $"<|assistant|>{Newline}{textContent}<|end|>{Newline}", + _ => throw new NotSupportedException($"Unsupported role {message.Role}") + }; + + sb.Append(prompt); + } + } + + sb.Append("<|assistant|>"); + + return sb.ToString(); + } + + public string BuildPrompt(IList messages, ChatOptions? options = null) + { + var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; + if (messages.Any(m => m.Text is null)) + { + throw new InvalidOperationException("Please provide a message with content."); + } + + if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false)) + { + throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant."); + } + + // construct template based on instruction from + // https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format + + var sb = new StringBuilder(); + foreach (var message in messages) + { + var role = message.Role.Value; + var content = message.Text; + sb.Append(message switch + { + _ when message.Role == ChatRole.System => $"<|system|>{Newline}{content}<|end|>{Newline}", + _ when message.Role == ChatRole.User => $"<|user|>{Newline}{content}<|end|>{Newline}", + _ when message.Role == ChatRole.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}", + _ => throw new InvalidOperationException("Invalid role.") + }); + } + + sb.Append("<|assistant|>"); + var input = sb.ToString(); + + return input; + } +} diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromMEAIChatHistory.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromMEAIChatHistory.approved.txt new file mode 100644 index 000000000..e4a2466fe --- /dev/null +++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromMEAIChatHistory.approved.txt @@ -0,0 +1,7 @@ +<|begin_of_text|><|start_header_id|>system<|end_header_id|> +You are a helpful AI assistant.<|eot_id|> +<|start_header_id|>user<|end_header_id|> +Hello?<|eot_id|> +<|start_header_id|>assistant<|end_header_id|> +World!<|eot_id|> +<|start_header_id|>assistant<|end_header_id|> diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs index 7d97150f7..453bbcc28 100644 --- a/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs +++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs @@ -7,6 +7,7 @@ using ApprovalTests; using ApprovalTests.Namers; using ApprovalTests.Reporters; using AutoGen.Core; +using Microsoft.Extensions.AI; using Microsoft.ML.GenAI.Core.Extension; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -122,4 +123,21 @@ public class LLaMA3_1Tests Approvals.Verify(prompt); } + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("Approvals")] + public void ItBuildChatTemplateFromMEAIChatHistory() + { + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are a helpful AI assistant."), + new ChatMessage(ChatRole.User, "Hello?"), + new ChatMessage(ChatRole.Assistant, "World!"), + }; + + var prompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatHistory); + + Approvals.Verify(prompt); + } }