[GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (#7270)

* leverage MEAI abstraction

* Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Update src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* fix comments

* Update Microsoft.ML.GenAI.Core.csproj

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
This commit is contained in:
Xiaoyun Zhang 2024-11-05 09:27:05 -08:00 коммит произвёл GitHub
Родитель 5b4981a134
Коммит 3659a485a0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
16 изменённых файлов: 534 добавлений и 82 удалений

Просмотреть файл

@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Samples.MEAI;
internal class Llama3_1
{
public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}
var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder, "original");
Console.WriteLine("Loading Llama from huggingface model weight folder");
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);
var client = new Llama3CausalLMChatClient(pipeline);
var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}

Просмотреть файл

@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using Microsoft.Extensions.AI;
namespace Microsoft.ML.GenAI.Samples.MEAI;
internal class Phi3
{
public static async Task RunAsync(string weightFolder)
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}
var defaultType = ScalarType.Float16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var client = new Phi3CausalLMChatClient(pipeline);
var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}

Просмотреть файл

@ -1,4 +1,6 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;
await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");

Просмотреть файл

@ -34,7 +34,7 @@
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
<SystemSecurityPrincipalWindows>5.0.0</SystemSecurityPrincipalWindows>
<SystemTextEncodingsWebVersion>8.0.0</SystemTextEncodingsWebVersion>
<SystemTextJsonVersion>8.0.4</SystemTextJsonVersion>
<SystemTextJsonVersion>8.0.5</SystemTextJsonVersion>
<SystemThreadingChannelsVersion>8.0.0</SystemThreadingChannelsVersion>
<!-- Other product dependencies -->
<ApacheArrowVersion>14.0.2</ApacheArrowVersion>
@ -47,6 +47,7 @@
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.24375.2</MicrosoftDotNetInteractiveVersion>
<MicrosoftMLOnnxRuntimeVersion>1.18.1</MicrosoftMLOnnxRuntimeVersion>
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
<MicrosoftExtensionsAIVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAIVersion>
<!--
@("inteltbb.devel", "win", "2021.7.1.15305")
-->

Просмотреть файл

@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
where TTokenizer : Tokenizer
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;
public CausalLMPipelineChatClient(
ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
IMEAIChatTemplateBuilder chatTemplateBuilder,
ChatClientMetadata? metadata = null)
{
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
_chatTemplateBuilder = chatTemplateBuilder;
_pipeline = pipeline;
}
public ChatClientMetadata Metadata { get; }
public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
var output = _pipeline.Generate(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");
var chatMessage = new ChatMessage(ChatRole.Assistant, output);
return Task.FromResult(new ChatCompletion([chatMessage])
{
CreatedAt = DateTime.UtcNow,
FinishReason = ChatFinishReason.Stop,
});
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
foreach (var output in _pipeline.GenerateStreaming(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
stopSequences: stopSequences.ToArray()))
{
yield return new StreamingChatCompletionUpdate
{
Role = ChatRole.Assistant,
Text = output,
CreatedAt = DateTime.UtcNow,
};
}
}
public virtual void Dispose()
{
}
public virtual TService? GetService<TService>(object? key = null) where TService : class
{
return null;
}
}

Просмотреть файл

@ -13,6 +13,7 @@
<ItemGroup>
<PackageReference Include="AutoGen.Core" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />

Просмотреть файл

@ -8,6 +8,7 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;
namespace Microsoft.ML.GenAI.Core;
@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
}
public interface IMEAIChatTemplateBuilder
{
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
}
public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
{
}

Просмотреть файл

@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.LLaMA;
public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, LlamaForCausalLM>
{
private readonly string _eotToken = "<|eot_id|>";
public Llama3CausalLMChatClient(
ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
IMEAIChatTemplateBuilder? chatTemplateBuilder = null,
ChatClientMetadata? metadata = null)
: base(
pipeline,
chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance,
metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient)))
{
}
public override Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = new List<string> { _eotToken };
}
return base.CompleteAsync(chatMessages, options, cancellationToken);
}
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}

Просмотреть файл

@ -4,13 +4,15 @@
using System.Text;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TextContent = Microsoft.SemanticKernel.TextContent;
namespace Microsoft.ML.GenAI.LLaMA;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private const char Newline = '\n';
@ -86,5 +88,39 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
return sb.ToString();
}
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}
if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}
var sb = new StringBuilder();
sb.Append("<|begin_of_text|>");
foreach (var message in messages)
{
var role = message.Role.Value;
var content = message.Text!;
sb.Append(message switch
{
_ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}
sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
var input = sb.ToString();
return input;
}
public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
}

Просмотреть файл

@ -19,22 +19,31 @@ public class Phi3Agent : IStreamingAgent
private const char Newline = '\n';
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
private readonly string? _systemMessage;
private readonly IAutoGenChatTemplateBuilder _templateBuilder;
public Phi3Agent(
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
string name,
string? systemMessage = "you are a helpful assistant")
string? systemMessage = "you are a helpful assistant",
IAutoGenChatTemplateBuilder? templateBuilder = null)
{
this.Name = name;
this._pipeline = pipeline;
this._systemMessage = systemMessage;
this._templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance;
}
public string Name { get; }
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var input = BuildPrompt(messages);
if (_systemMessage != null)
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
messages = messages.Prepend(systemMessage);
}
var input = _templateBuilder.BuildPrompt(messages);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var stopTokenSequence = options?.StopSequence ?? [];
@ -56,7 +65,13 @@ public class Phi3Agent : IStreamingAgent
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var input = BuildPrompt(messages);
if (_systemMessage != null)
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
messages = messages.Prepend(systemMessage);
}
var input = _templateBuilder.BuildPrompt(messages);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var stopTokenSequence = options?.StopSequence ?? [];
@ -71,44 +86,4 @@ public class Phi3Agent : IStreamingAgent
yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
}
}
private string BuildPrompt(IEnumerable<IMessage> messages)
{
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
if (messages.Any(m => m.GetContent() is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}
if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}
// construct template based on instruction from
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
var sb = new StringBuilder();
if (_systemMessage is not null)
{
sb.Append($"<|system|>{Newline}{_systemMessage}<|end|>{Newline}");
}
foreach (var message in messages)
{
var role = message.GetRole()!.Value;
var content = message.GetContent()!;
sb.Append(message switch
{
_ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
_ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
_ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}
sb.Append("<|assistant|>");
var input = sb.ToString();
return input;
}
}

Просмотреть файл

@ -0,0 +1,62 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.Phi;
public class Phi3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, Phi3ForCasualLM>
{
private readonly string _eotToken = "<|end|>";
public Phi3CausalLMChatClient(
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
IMEAIChatTemplateBuilder? templateBuilder = null,
ChatClientMetadata? metadata = null)
: base(
pipeline,
templateBuilder ?? Phi3ChatTemplateBuilder.Instance,
metadata ?? new ChatClientMetadata(modelId: nameof(Phi3CausalLMChatClient)))
{
}
public override Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = [_eotToken];
}
return base.CompleteAsync(chatMessages, options, cancellationToken);
}
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}

Просмотреть файл

@ -16,12 +16,15 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
{
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
private readonly Phi3CausalLMTextGenerationService _textGenerationService;
private const char NewLine = '\n'; // has to be \n, \r\n will cause wanky result.
private readonly ISemanticKernelChatTemplateBuilder _templateBuilder;
public Phi3CausalLMChatCompletionService(ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
public Phi3CausalLMChatCompletionService(
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline,
ISemanticKernelChatTemplateBuilder? templateBuilder = null)
{
_pipeline = pipeline;
_textGenerationService = new Phi3CausalLMTextGenerationService(pipeline);
_templateBuilder = templateBuilder ?? Phi3ChatTemplateBuilder.Instance;
}
public IReadOnlyDictionary<string, object?> Attributes => _textGenerationService.Attributes;
@ -32,7 +35,7 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
var prompt = BuildPrompt(chatHistory);
var prompt = _templateBuilder.BuildPrompt(chatHistory);
var replies = await _textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken);
return replies.Select(reply => new ChatMessageContent(AuthorRole.Assistant, reply.Text)).ToList();
}
@ -44,42 +47,11 @@ public class Phi3CausalLMChatCompletionService : IChatCompletionService
[EnumeratorCancellation]
CancellationToken cancellationToken = default)
{
var prompt = BuildPrompt(chatHistory);
var prompt = _templateBuilder.BuildPrompt(chatHistory);
await foreach (var reply in _textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken))
{
yield return new StreamingChatMessageContent(AuthorRole.Assistant, reply.Text);
}
}
private string BuildPrompt(ChatHistory chatHistory)
{
// build prompt from chat history
var sb = new StringBuilder();
foreach (var message in chatHistory)
{
foreach (var item in message.Items)
{
if (item is not TextContent textContent)
{
throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}");
}
var prompt = message.Role switch
{
_ when message.Role == AuthorRole.System => $"<|system|>{NewLine}{textContent}<|end|>{NewLine}",
_ when message.Role == AuthorRole.User => $"<|user|>{NewLine}{textContent}<|end|>{NewLine}",
_ when message.Role == AuthorRole.Assistant => $"<|assistant|>{NewLine}{textContent}<|end|>{NewLine}",
_ => throw new NotSupportedException($"Unsupported role {message.Role}")
};
sb.Append(prompt);
}
}
sb.Append("<|assistant|>");
return sb.ToString();
}
}

Просмотреть файл

@ -14,7 +14,8 @@ public class Phi3CausalLMTextGenerationService : ITextGenerationService
{
private readonly ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> _pipeline;
public Phi3CausalLMTextGenerationService(ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
public Phi3CausalLMTextGenerationService(
ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> pipeline)
{
_pipeline = pipeline;
}

Просмотреть файл

@ -0,0 +1,127 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TextContent = Microsoft.SemanticKernel.TextContent;
namespace Microsoft.ML.GenAI.Phi;
public class Phi3ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
{
private const char Newline = '\n';
public static Phi3ChatTemplateBuilder Instance { get; } = new Phi3ChatTemplateBuilder();
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
{
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
if (messages.Any(m => m.GetContent() is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}
if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}
// construct template based on instruction from
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
var sb = new StringBuilder();
foreach (var message in messages)
{
var role = message.GetRole()!.Value;
var content = message.GetContent()!;
sb.Append(message switch
{
_ when message.GetRole() == Role.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
_ when message.GetRole() == Role.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
_ when message.GetRole() == Role.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}
sb.Append("<|assistant|>");
var input = sb.ToString();
return input;
}
public string BuildPrompt(ChatHistory chatHistory)
{
// build prompt from chat history
var sb = new StringBuilder();
foreach (var message in chatHistory)
{
foreach (var item in message.Items)
{
if (item is not TextContent textContent)
{
throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}");
}
var prompt = message.Role switch
{
_ when message.Role == AuthorRole.System => $"<|system|>{Newline}{textContent}<|end|>{Newline}",
_ when message.Role == AuthorRole.User => $"<|user|>{Newline}{textContent}<|end|>{Newline}",
_ when message.Role == AuthorRole.Assistant => $"<|assistant|>{Newline}{textContent}<|end|>{Newline}",
_ => throw new NotSupportedException($"Unsupported role {message.Role}")
};
sb.Append(prompt);
}
}
sb.Append("<|assistant|>");
return sb.ToString();
}
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}
if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}
// construct template based on instruction from
// https://huggingface.co/microsoft/Phi-3-mini-128k-instruct#chat-format
var sb = new StringBuilder();
foreach (var message in messages)
{
var role = message.Role.Value;
var content = message.Text;
sb.Append(message switch
{
_ when message.Role == ChatRole.System => $"<|system|>{Newline}{content}<|end|>{Newline}",
_ when message.Role == ChatRole.User => $"<|user|>{Newline}{content}<|end|>{Newline}",
_ when message.Role == ChatRole.Assistant => $"<|assistant|>{Newline}{content}<|end|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}
sb.Append("<|assistant|>");
var input = sb.ToString();
return input;
}
}

Просмотреть файл

@ -0,0 +1,7 @@
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Hello?<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
World!<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Просмотреть файл

@ -7,6 +7,7 @@ using ApprovalTests;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
@ -122,4 +123,21 @@ public class LLaMA3_1Tests
Approvals.Verify(prompt);
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("Approvals")]
public void ItBuildChatTemplateFromMEAIChatHistory()
{
var chatHistory = new[]
{
new ChatMessage(ChatRole.System, "You are a helpful AI assistant."),
new ChatMessage(ChatRole.User, "Hello?"),
new ChatMessage(ChatRole.Assistant, "World!"),
};
var prompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
Approvals.Verify(prompt);
}
}