This commit is contained in:
XiaoYun Zhang 2024-10-28 14:08:42 -07:00
Родитель b5f5e0a993
Коммит 28643d6461
4 изменённых файлов: 14 добавлений и 42 удалений

Просмотреть файл

@ -9,7 +9,6 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using static TorchSharp.torch;
@ -20,7 +19,6 @@ public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : I
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
private readonly ChatClientMetadata _metadata;
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;
public CausalLMPipelineChatClient(
@ -28,15 +26,13 @@ public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : I
IMEAIChatTemplateBuilder chatTemplateBuilder,
ChatClientMetadata? metadata = null)
{
metadata ??= new ChatClientMetadata(modelId: nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>));
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
_chatTemplateBuilder = chatTemplateBuilder;
_pipeline = pipeline;
_metadata = metadata;
}
public ChatClientMetadata Metadata => _metadata;
public ChatClientMetadata Metadata { get; }
public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{

Просмотреть файл

@ -43,27 +43,15 @@ public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, Ll
return base.CompleteAsync(chatMessages, options, cancellationToken);
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);
if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = new List<string> { _eotToken };
}
await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken))
{
yield return update;
}
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}

Просмотреть файл

@ -48,27 +48,15 @@ public class Phi3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, Phi3
return base.CompleteAsync(chatMessages, options, cancellationToken);
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);
if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = new List<string> { _eotToken };
}
await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken))
{
yield return update;
}
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}

Просмотреть файл

@ -20,7 +20,7 @@ public class Phi3ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBu
{
private const char Newline = '\n';
public static Phi3ChatTemplateBuilder Instance => new Phi3ChatTemplateBuilder();
public static Phi3ChatTemplateBuilder Instance { get; } = new Phi3ChatTemplateBuilder();
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
{