[GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (#7270)
* leverage MEAI abstraction * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * Update src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * fix comments * Update Microsoft.ML.GenAI.Core.csproj --------- Co-authored-by: Stephen Toub <stoub@microsoft.com>
This commit is contained in:
Родитель
5b4981a134
Коммит
3659a485a0
|
@ -0,0 +1,54 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.ML.GenAI.Core.Extension;
|
||||
using Microsoft.ML.GenAI.LLaMA;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
using TorchSharp;
|
||||
using static TorchSharp.torch;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Samples.MEAI;
|
||||
|
||||
internal class Llama3_1
|
||||
{
|
||||
public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json")
|
||||
{
|
||||
var device = "cuda";
|
||||
if (device == "cuda")
|
||||
{
|
||||
torch.InitializeDeviceType(DeviceType.CUDA);
|
||||
}
|
||||
|
||||
var defaultType = ScalarType.BFloat16;
|
||||
torch.manual_seed(1);
|
||||
torch.set_default_dtype(defaultType);
|
||||
var configName = "config.json";
|
||||
var originalWeightFolder = Path.Combine(weightFolder, "original");
|
||||
|
||||
Console.WriteLine("Loading Llama from huggingface model weight folder");
|
||||
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
|
||||
stopWatch.Start();
|
||||
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
|
||||
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);
|
||||
|
||||
var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);
|
||||
|
||||
var client = new Llama3CausalLMChatClient(pipeline);
|
||||
|
||||
var task = """
|
||||
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
|
||||
""";
|
||||
var chatMessage = new ChatMessage(ChatRole.User, task);
|
||||
|
||||
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
|
||||
{
|
||||
Console.Write(response.Text);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using static TorchSharp.torch;
|
||||
using TorchSharp;
|
||||
using Microsoft.ML.GenAI.Phi;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Samples.MEAI;
|
||||
|
||||
internal class Phi3
|
||||
{
|
||||
public static async Task RunAsync(string weightFolder)
|
||||
{
|
||||
var device = "cuda";
|
||||
if (device == "cuda")
|
||||
{
|
||||
torch.InitializeDeviceType(DeviceType.CUDA);
|
||||
}
|
||||
|
||||
var defaultType = ScalarType.Float16;
|
||||
torch.manual_seed(1);
|
||||
torch.set_default_dtype(defaultType);
|
||||
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
|
||||
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
|
||||
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
|
||||
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
|
||||
var client = new Phi3CausalLMChatClient(pipeline);
|
||||
|
||||
var task = """
|
||||
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
|
||||
""";
|
||||
var chatMessage = new ChatMessage(ChatRole.User, task);
|
||||
|
||||
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
|
||||
{
|
||||
Console.Write(response.Text);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,6 @@
|
|||
// See https://aka.ms/new-console-template for more information
|
||||
using Microsoft.ML.GenAI.Samples.Llama;
|
||||
using Microsoft.ML.GenAI.Samples.MEAI;
|
||||
|
||||
await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
|
||||
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
|
||||
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
|
||||
<SystemSecurityPrincipalWindows>5.0.0</SystemSecurityPrincipalWindows>
|
||||
<SystemTextEncodingsWebVersion>8.0.0</SystemTextEncodingsWebVersion>
|
||||
<SystemTextJsonVersion>8.0.4</SystemTextJsonVersion>
|
||||
<SystemTextJsonVersion>8.0.5</SystemTextJsonVersion>
|
||||
<SystemThreadingChannelsVersion>8.0.0</SystemThreadingChannelsVersion>
|
||||
<!-- Other product dependencies -->
|
||||
<ApacheArrowVersion>14.0.2</ApacheArrowVersion>
|
||||
|
@ -47,6 +47,7 @@
|
|||
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.24375.2</MicrosoftDotNetInteractiveVersion>
|
||||
<MicrosoftMLOnnxRuntimeVersion>1.18.1</MicrosoftMLOnnxRuntimeVersion>
|
||||
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
|
||||
<MicrosoftExtensionsAIVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAIVersion>
|
||||
<!--
|
||||
@("inteltbb.devel", "win", "2021.7.1.15305")
|
||||
-->
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
using static TorchSharp.torch;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Core;
|
||||
|
||||
public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
|
||||
where TTokenizer : Tokenizer
|
||||
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
|
||||
{
|
||||
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
|
||||
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;
|
||||
|
||||
public CausalLMPipelineChatClient(
|
||||
ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
|
||||
IMEAIChatTemplateBuilder chatTemplateBuilder,
|
||||
ChatClientMetadata? metadata = null)
|
||||
{
|
||||
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
|
||||
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
|
||||
_chatTemplateBuilder = chatTemplateBuilder;
|
||||
_pipeline = pipeline;
|
||||
}
|
||||
|
||||
public ChatClientMetadata Metadata { get; }
|
||||
|
||||
public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
|
||||
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
|
||||
|
||||
var output = _pipeline.Generate(
|
||||
prompt,
|
||||
maxLen: options?.MaxOutputTokens ?? 1024,
|
||||
temperature: options?.Temperature ?? 0.7f,
|
||||
stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");
|
||||
|
||||
var chatMessage = new ChatMessage(ChatRole.Assistant, output);
|
||||
return Task.FromResult(new ChatCompletion([chatMessage])
|
||||
{
|
||||
CreatedAt = DateTime.UtcNow,
|
||||
FinishReason = ChatFinishReason.Stop,
|
||||
});
|
||||
}
|
||||
|
||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
|
||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
IList<ChatMessage> chatMessages,
|
||||
ChatOptions? options = null,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
|
||||
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
|
||||
|
||||
foreach (var output in _pipeline.GenerateStreaming(
|
||||
prompt,
|
||||
maxLen: options?.MaxOutputTokens ?? 1024,
|
||||
temperature: options?.Temperature ?? 0.7f,
|
||||
stopSequences: stopSequences.ToArray()))
|
||||
{
|
||||
yield return new StreamingChatCompletionUpdate
|
||||
{
|
||||
Role = ChatRole.Assistant,
|
||||
Text = output,
|
||||
CreatedAt = DateTime.UtcNow,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public virtual void Dispose()
|
||||
{
|
||||
}
|
||||
|
||||
public virtual TService? GetService<TService>(object? key = null) where TService : class
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="AutoGen.Core" Version="$(AutoGenVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
|
||||
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
|
||||
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
|
||||
|
|
|
@ -8,6 +8,7 @@ using System.Linq;
|
|||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.SemanticKernel.ChatCompletion;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Core;
|
||||
|
@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder
|
|||
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
|
||||
}
|
||||
|
||||
public interface IMEAIChatTemplateBuilder
|
||||
{
|
||||
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
|
||||
}
|
||||
|
||||
public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
|
||||
{
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Runtime.CompilerServices;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
|
||||
namespace Microsoft.ML.GenAI.LLaMA;
|
||||
|
||||
public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, LlamaForCausalLM>
|
||||
{
|
||||
private readonly string _eotToken = "<|eot_id|>";
|
||||
|
||||
public Llama3CausalLMChatClient(
|
||||
ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
|
||||
IMEAIChatTemplateBuilder? chatTemplateBuilder = null,
|
||||
ChatClientMetadata? metadata = null)
|
||||
: base(
|
||||
pipeline,
|
||||
chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance,
|
||||
metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient)))
|
||||
{
|
||||
}
|
||||
|
||||
public override Task<ChatCompletion> CompleteAsync(
|
||||
IList<ChatMessage> chatMessages,
|
||||
ChatOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= new ChatOptions();
|
||||
|
||||
if (options.StopSequences != null)
|
||||
{
|
||||
options.StopSequences.Add(_eotToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
options.StopSequences = new List<string> { _eotToken };
|
||||
}
|
||||
|
||||
return base.CompleteAsync(chatMessages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
|
||||
IList<ChatMessage> chatMessages,
|
||||
ChatOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= new ChatOptions();
|
||||
options.StopSequences ??= [];
|
||||
options.StopSequences.Add(_eotToken);
|
||||
|
||||
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
|
||||
}
|
||||
}
|
|
@ -4,13 +4,15 @@
|
|||
|
||||
using System.Text;
|
||||
using AutoGen.Core;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.ChatCompletion;
|
||||
using TextContent = Microsoft.SemanticKernel.TextContent;
|
||||
|
||||
namespace Microsoft.ML.GenAI.LLaMA;
|
||||
#pragma warning disable MSML_GeneralName // This name should be PascalCased
|
||||
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
|
||||
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
|
||||
#pragma warning restore MSML_GeneralName // This name should be PascalCased
|
||||
{
|
||||
private const char Newline = '\n';
|
||||
|
@ -86,5 +88,39 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
|
|||
return sb.ToString();
|
||||
}
|
||||
|
||||
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
|
||||
{
|
||||
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
|
||||
if (messages.Any(m => m.Text is null))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with content.");
|
||||
}
|
||||
|
||||
if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
|
||||
}
|
||||
|
||||
var sb = new StringBuilder();
|
||||
sb.Append("<|begin_of_text|>");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var role = message.Role.Value;
|
||||
var content = message.Text!;
|
||||
sb.Append(message switch
|
||||
{
|
||||
_ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
|
||||
_ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
|
||||
_ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
|
||||
_ => throw new InvalidOperationException("Invalid role.")
|
||||
});
|
||||
}
|
||||
|
||||
sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
|
||||
var input = sb.ToString();
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
|
||||
}
|
||||
|
|
|
@ -19,22 +19,31 @@ public class Phi3Agent : IStreamingAgent
|
|||
private const char Newline = '\n';
|
||||
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
|
||||
private readonly string? _systemMessage;
|
||||
private readonly IAutoGenChatTemplateBuilder _templateBuilder;
|
||||
|
||||
public Phi3Agent(
|
||||
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
|
||||
string name,
|
||||
string? systemMessage = "you are a helpful assistant")
|
||||
string? systemMessage = "you are a helpful assistant",
|
||||
IAutoGenChatTemplateBuilder? templateBuilder = null)
|
||||
{
|
||||
this.Name = name;
|
||||
this._pipeline = pipeline;
|
||||
this._systemMessage = systemMessage;
|
||||
this._templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance;
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
|
||||
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var input = BuildPrompt(messages);
|
||||
if (_systemMessage != null)
|
||||
{
|
||||
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
|
||||
messages = messages.Prepend(systemMessage);
|
||||
}
|
||||
|
||||
var input = _templateBuilder.BuildPrompt(messages);
|
||||
var maxLen = options?.MaxToken ?? 1024;
|
||||
var temperature = options?.Temperature ?? 0.7f;
|
||||
var stopTokenSequence = options?.StopSequence ?? [];
|
||||
|
@ -56,7 +65,13 @@ public class Phi3Agent : IStreamingAgent
|
|||
GenerateReplyOptions? options = null,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var input = BuildPrompt(messages);
|
||||
if (_systemMessage != null)
|
||||
{
|
||||
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
|
||||
messages = messages.Prepend(systemMessage);
|
||||
}
|
||||
|
||||
var input = _templateBuilder.BuildPrompt(messages);
|
||||
var maxLen = options?.MaxToken ?? 1024;
|
||||
var temperature = options?.Temperature ?? 0.7f;
|
||||
var stopTokenSequence = options?.StopSequence ?? [];
|
||||
|
@ -71,44 +86,4 @@ public class Phi3Agent : IStreamingAgent
|
|||
yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
|
||||
}
|
||||
}
|
||||
|
||||
private string BuildPrompt(IEnumerable<IMessage> messages)
|
||||
{
|
||||
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
|
||||
if (messages.Any(m => m.GetContent() is null))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with content.");
|
||||
}
|
||||
|
||||
if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
|
||||
}
|
||||
|
||||
// construct template based on instruction from
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
|
||||
|
||||
var sb = new StringBuilder();
|
||||
if (_systemMessage is not null)
|
||||
{
|
||||
sb.Append($"<|system|>{Newline}{_systemMessage}<|end|>{Newline}");
|
||||
}
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var role = message.GetRole()!.Value;
|
||||
var content = message.GetContent()!;
|
||||
sb.Append(message switch
|
||||
{
|
||||
_ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
|
||||
_ => throw new InvalidOperationException("Invalid role.")
|
||||
});
|
||||
}
|
||||
|
||||
sb.Append("<|assistant|>");
|
||||
var input = sb.ToString();
|
||||
|
||||
return input;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Phi;
|
||||
|
||||
public class Phi3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, Phi3ForCasualLM>
|
||||
{
|
||||
private readonly string _eotToken = "<|end|>";
|
||||
|
||||
public Phi3CausalLMChatClient(
|
||||
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
|
||||
IMEAIChatTemplateBuilder? templateBuilder = null,
|
||||
ChatClientMetadata? metadata = null)
|
||||
: base(
|
||||
pipeline,
|
||||
templateBuilder ?? Phi3ChatTemplateBuilder.Instance,
|
||||
metadata ?? new ChatClientMetadata(modelId: nameof(Phi3CausalLMChatClient)))
|
||||
{
|
||||
}
|
||||
|
||||
public override Task<ChatCompletion> CompleteAsync(
|
||||
IList<ChatMessage> chatMessages,
|
||||
ChatOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= new ChatOptions();
|
||||
|
||||
if (options.StopSequences != null)
|
||||
{
|
||||
options.StopSequences.Add(_eotToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
options.StopSequences = [_eotToken];
|
||||
}
|
||||
|
||||
return base.CompleteAsync(chatMessages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
|
||||
IList<ChatMessage> chatMessages,
|
||||
ChatOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= new ChatOptions();
|
||||
options.StopSequences ??= [];
|
||||
options.StopSequences.Add(_eotToken);
|
||||
|
||||
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
|
||||
}
|
||||
}
|
|
@ -16,12 +16,15 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
|
|||
{
|
||||
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
|
||||
private readonly Phi3CausalLMTextGenerationService _textGenerationService;
|
||||
private const char NewLine = '\n'; // has to be \n, \r\n will cause wanky result.
|
||||
private readonly ISemanticKernelChatTemplateBuilder _templateBuilder;
|
||||
|
||||
public Phi3CausalLMChatCompletionService(ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
|
||||
public Phi3CausalLMChatCompletionService(
|
||||
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
|
||||
ISemanticKernelChatTemplateBuilder? templateBuilder = null)
|
||||
{
|
||||
_pipeline = pipeline;
|
||||
_textGenerationService = new Phi3CausalLMTextGenerationService(pipeline);
|
||||
_templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance;
|
||||
}
|
||||
|
||||
public IReadOnlyDictionary<string, object?> Attributes => _textGenerationService.Attributes;
|
||||
|
@ -32,7 +35,7 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
|
|||
Kernel? kernel = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = BuildPrompt(chatHistory);
|
||||
var prompt = _templateBuilder.BuildPrompt(chatHistory);
|
||||
var replies = await _textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken);
|
||||
return replies.Select(reply => new ChatMessageContent(AuthorRole.Assistant, reply.Text)).ToList();
|
||||
}
|
||||
|
@ -44,42 +47,11 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
|
|||
[EnumeratorCancellation]
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = BuildPrompt(chatHistory);
|
||||
var prompt = _templateBuilder.BuildPrompt(chatHistory);
|
||||
|
||||
await foreach (var reply in _textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken))
|
||||
{
|
||||
yield return new StreamingChatMessageContent(AuthorRole.Assistant, reply.Text);
|
||||
}
|
||||
}
|
||||
|
||||
private string BuildPrompt(ChatHistory chatHistory)
|
||||
{
|
||||
// build prompt from chat history
|
||||
var sb = new StringBuilder();
|
||||
|
||||
foreach (var message in chatHistory)
|
||||
{
|
||||
foreach (var item in message.Items)
|
||||
{
|
||||
if (item is not TextContent textContent)
|
||||
{
|
||||
throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}");
|
||||
}
|
||||
|
||||
var prompt = message.Role switch
|
||||
{
|
||||
_ when message.Role == AuthorRole.System => $"<|system|>{NewLine}{textContent}<|end|>{NewLine}",
|
||||
_ when message.Role == AuthorRole.User => $"<|user|>{NewLine}{textContent}<|end|>{NewLine}",
|
||||
_ when message.Role == AuthorRole.Assistant => $"<|assistant|>{NewLine}{textContent}<|end|>{NewLine}",
|
||||
_ => throw new NotSupportedException($"Unsupported role {message.Role}")
|
||||
};
|
||||
|
||||
sb.Append(prompt);
|
||||
}
|
||||
}
|
||||
|
||||
sb.Append("<|assistant|>");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,8 @@ public class Phi3CausalLMTextGenerationService : ITextGenerationService
|
|||
{
|
||||
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
|
||||
|
||||
public Phi3CausalLMTextGenerationService(ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
|
||||
public Phi3CausalLMTextGenerationService(
|
||||
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
|
||||
{
|
||||
_pipeline = pipeline;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core;
|
||||
using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.ChatCompletion;
|
||||
using TextContent = Microsoft.SemanticKernel.TextContent;
|
||||
|
||||
namespace Microsoft.ML.GenAI.Phi;
|
||||
|
||||
public class Phi3ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
|
||||
{
|
||||
private const char Newline = '\n';
|
||||
|
||||
public static Phi3ChatTemplateBuilder Instance { get; } = new Phi3ChatTemplateBuilder();
|
||||
|
||||
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
|
||||
{
|
||||
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
|
||||
if (messages.Any(m => m.GetContent() is null))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with content.");
|
||||
}
|
||||
|
||||
if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
|
||||
}
|
||||
|
||||
// construct template based on instruction from
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
|
||||
|
||||
var sb = new StringBuilder();
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var role = message.GetRole()!.Value;
|
||||
var content = message.GetContent()!;
|
||||
sb.Append(message switch
|
||||
{
|
||||
_ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
|
||||
_ => throw new InvalidOperationException("Invalid role.")
|
||||
});
|
||||
}
|
||||
|
||||
sb.Append("<|assistant|>");
|
||||
var input = sb.ToString();
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
public string BuildPrompt(ChatHistory chatHistory)
|
||||
{
|
||||
// build prompt from chat history
|
||||
var sb = new StringBuilder();
|
||||
|
||||
foreach (var message in chatHistory)
|
||||
{
|
||||
foreach (var item in message.Items)
|
||||
{
|
||||
if (item is not TextContent textContent)
|
||||
{
|
||||
throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}");
|
||||
}
|
||||
|
||||
var prompt = message.Role switch
|
||||
{
|
||||
_ when message.Role == AuthorRole.System => $"<|system|>{Newline}{textContent}<|end|>{Newline}",
|
||||
_ when message.Role == AuthorRole.User => $"<|user|>{Newline}{textContent}<|end|>{Newline}",
|
||||
_ when message.Role == AuthorRole.Assistant => $"<|assistant|>{Newline}{textContent}<|end|>{Newline}",
|
||||
_ => throw new NotSupportedException($"Unsupported role {message.Role}")
|
||||
};
|
||||
|
||||
sb.Append(prompt);
|
||||
}
|
||||
}
|
||||
|
||||
sb.Append("<|assistant|>");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
|
||||
{
|
||||
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
|
||||
if (messages.Any(m => m.Text is null))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with content.");
|
||||
}
|
||||
|
||||
if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
|
||||
{
|
||||
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
|
||||
}
|
||||
|
||||
// construct template based on instruction from
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
|
||||
|
||||
var sb = new StringBuilder();
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var role = message.Role.Value;
|
||||
var content = message.Text;
|
||||
sb.Append(message switch
|
||||
{
|
||||
_ when message.Role == ChatRole.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.Role == ChatRole.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
|
||||
_ when message.Role == ChatRole.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
|
||||
_ => throw new InvalidOperationException("Invalid role.")
|
||||
});
|
||||
}
|
||||
|
||||
sb.Append("<|assistant|>");
|
||||
var input = sb.ToString();
|
||||
|
||||
return input;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
You are a helpful AI assistant.<|eot_id|>
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
Hello?<|eot_id|>
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
World!<|eot_id|>
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
|
@ -7,6 +7,7 @@ using ApprovalTests;
|
|||
using ApprovalTests.Namers;
|
||||
using ApprovalTests.Reporters;
|
||||
using AutoGen.Core;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.ML.GenAI.Core.Extension;
|
||||
using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.ChatCompletion;
|
||||
|
@ -122,4 +123,21 @@ public class LLaMA3_1Tests
|
|||
|
||||
Approvals.Verify(prompt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("Approvals")]
|
||||
public void ItBuildChatTemplateFromMEAIChatHistory()
|
||||
{
|
||||
var chatHistory = new[]
|
||||
{
|
||||
new ChatMessage(ChatRole.System, "You are a helpful AI assistant."),
|
||||
new ChatMessage(ChatRole.User, "Hello?"),
|
||||
new ChatMessage(ChatRole.Assistant, "World!"),
|
||||
};
|
||||
|
||||
var prompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
|
||||
|
||||
Approvals.Verify(prompt);
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче