[GenAI] Add Mistral 7B Instruction V0.3 (#7231)

* add mistral and tests

* add test and sample

* add tool call support

* update autogen to v 0.1.0

* update autogen to 0.1.0

* remove tests on non-x64 machien

* add file header

* update

* update

* update ml tokenizer test version

* fix build error

* remove .receive.txt

* Update docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs

Co-authored-by: Weihan Li <7604648+WeihanLi@users.noreply.github.com>

* update

* set t to 0

* fix test

* Update Microsoft.ML.GenAI.Mistral.csproj

---------

Co-authored-by: Weihan Li <7604648+WeihanLi@users.noreply.github.com>
This commit is contained in:
Xiaoyun Zhang 2024-09-25 20:15:59 -07:00 коммит произвёл GitHub
Родитель e14f22fe0d
Коммит 817a77f89a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
34 изменённых файлов: 1816 добавлений и 110 удалений

Просмотреть файл

@ -188,7 +188,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core.Tes
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA", "src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj", "{0AA6D5CB-195F-457A-8792-4221E76E6C44}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral", "src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj", "{2729CC66-7743-442B-B3A5-1F4F27F044A5}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral.Tests", "test\Microsoft.ML.GenAI.Mistral.Tests\Microsoft.ML.GenAI.Mistral.Tests.csproj", "{49264202-C90A-43F6-8C30-BDAEF2F1465A}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@ -898,6 +902,22 @@ Global
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.Build.0 = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.ActiveCfg = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.Build.0 = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.ActiveCfg = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.Build.0 = Debug|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.Build.0 = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.ActiveCfg = Release|Any CPU
{2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.Build.0 = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.ActiveCfg = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.Build.0 = Debug|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.Build.0 = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.ActiveCfg = Release|Any CPU
{49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -991,6 +1011,8 @@ Global
{14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{0AA6D5CB-195F-457A-8792-4221E76E6C44} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D202353D-6FAF-4263-9A01-BDCFBC92391F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{2729CC66-7743-442B-B3A5-1F4F27F044A5} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{49264202-C90A-43F6-8C30-BDAEF2F1465A} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

Просмотреть файл

@ -5,17 +5,20 @@
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,156 @@
using System.Text.Json;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Mistral;
using Microsoft.ML.GenAI.Mistral.Module;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using TorchSharp.PyBridge;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Samples.Mistral;
public partial class Mistral_7B_Instruct
{
private static Mistral_7B_Instruct instance = new Mistral_7B_Instruct();
/// <summary>
/// get weather from city
/// </summary>
/// <param name="city"></param>
[Function]
public Task<string> GetWeather(string city)
{
return Task.FromResult($"The weather in {city} is sunny.");
}
public static async Task RunAsync()
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}
var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);
Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device);
var agent = new MistralCausalLMAgent(pipeline, "assistant")
.RegisterPrintMessage();
var task = """
How are you.
""";
await agent.SendAsync(task);
}
public static void Embedding()
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}
var defaultType = ScalarType.Float32;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\bge-en-icl";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);
Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder, modelName: "tokenizer.model");
var mistralConfig = JsonSerializer.Deserialize<MistralConfig>(File.ReadAllText(Path.Combine(weightFolder, configName))) ?? throw new ArgumentNullException(nameof(configName));
var model = new MistralModel(mistralConfig);
model.load_checkpoint(weightFolder, "model.safetensors.index.json", strict: true, useTqdm: false);
model.to(device);
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralModel>(tokenizer, model, device);
var query = """
<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>what is a virtual interface
<response>A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.
<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>causes of back pain in female for a week
<response>Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management.
<instruct>Given a web search query, retrieve relevant passages that answer the query.
<query>how much protein should a female eat
<response>
""";
var document = """
As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.
""";
var queryEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(query);
var documentEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(document);
var score = 0f;
foreach (var (q, d) in queryEmbedding.Zip(documentEmbedding))
{
score += q * d * 100;
}
Console.WriteLine($"The similarity score between query and document is {score}");
}
public static async Task WeatherChatAsync()
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}
var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder);
Console.WriteLine("Loading Mistral from huggingface model weight folder");
var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);
var pipeline = new CausalLMPipeline<LlamaTokenizer, MistralForCausalLM>(tokenizer, model, device);
var weatherChatMiddleware = new FunctionCallMiddleware(
functions: [instance.GetWeatherFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ instance.GetWeatherFunctionContract.Name!, instance.GetWeatherWrapper }
});
var agent = new MistralCausalLMAgent(pipeline, "assistant")
.RegisterStreamingMiddleware(weatherChatMiddleware)
.RegisterPrintMessage();
var task = "what is the weather in Seattle";
var userMessage = new TextMessage(Role.User, task);
var reply = await agent.GenerateReplyAsync(messages: [userMessage],
new GenerateReplyOptions
{
Temperature = 0f,
});
// generate further reply using tool call result;
await agent.SendAsync(chatHistory: [userMessage, reply]);
}
}

Просмотреть файл

@ -1,4 +1,5 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Mistral;
using Microsoft.ML.GenAI.Samples.Phi3Mini;
await AutoGenSample.RunAsync();
await Mistral_7B_Instruct.WeatherChatAsync();

Просмотреть файл

@ -68,7 +68,7 @@
<TensorFlowMajorVersion>2</TensorFlowMajorVersion>
<TensorFlowVersion>2.3.1</TensorFlowVersion>
<TorchSharpPyBridgeVersion>1.4.1</TorchSharpPyBridgeVersion>
<AutoGenVersion>0.0.15</AutoGenVersion>
<AutoGenVersion>0.1.0</AutoGenVersion>
<SemanticKernelVersion>1.15.0</SemanticKernelVersion>
<TorchSharpVersion>0.102.7</TorchSharpVersion>
<LibTorchVersion>2.2.1.1</LibTorchVersion>
@ -96,7 +96,7 @@
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24415.1</MicrosoftMLTestTokenizersVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24455.2</MicrosoftMLTestTokenizersVersion>
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.6.24</XunitCombinatorialVersion>

Просмотреть файл

@ -20,9 +20,11 @@
<ItemGroup>
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.LLaMA" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.LLaMA.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Phi.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Mistral" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Mistral.Tests" />
<InternalsVisibleTo Include="Microsoft.ML.GenAI.Core.Tests" />
</ItemGroup>

Просмотреть файл

@ -266,8 +266,18 @@ public class CausalLMPipeline : ICausalLMPipeline
foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen))
{
var tokenIds = token[0].to_type(ScalarType.Int32).data<int>().ToArray();
var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids");
var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids");
var duplicateTokenString = this.Tokenizer switch
{
SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
var tokenString = this.Tokenizer switch
{
SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
// replace the first occurrence of the token with the duplicate token
tokenString = duplicateTokenString.Substring(tokenString.Length);
@ -294,7 +304,10 @@ public class CausalLMPipeline : ICausalLMPipeline
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0)
{
OverrideCache = new DynamicKVCache(),
};
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];

Просмотреть файл

@ -19,7 +19,7 @@ public interface ISemanticKernelChatTemplateBuilder
public interface IAutoGenChatTemplateBuilder
{
string BuildPrompt(IEnumerable<IMessage> messages);
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
}
public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder

Просмотреть файл

@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
@ -161,4 +162,18 @@ public static class Utils
.reshape(batchSize, nKVHeads * nRep, seqLen, headDim);
}
internal static string GetEmbeddedResource(string resourceName)
{
// read file content from embedded resource
var assembly = Assembly.GetCallingAssembly();
var resourceStream = assembly.GetManifestResourceStream(resourceName);
if (resourceStream == null)
{
throw new ArgumentException("Resource not found", resourceName);
}
using var reader = new System.IO.StreamReader(resourceStream);
return reader.ReadToEnd();
}
}

Просмотреть файл

@ -15,7 +15,7 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
{
private const char Newline = '\n';
public string BuildPrompt(IEnumerable<IMessage> messages)
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
{
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
if (messages.Any(m => m.GetContent() is null))

Просмотреть файл

@ -60,7 +60,7 @@ public class LlamaCausalLMAgent : IStreamingAgent
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,

Просмотреть файл

@ -2,13 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;

Просмотреть файл

@ -1,27 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Reflection;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.LLaMA;
internal static class Utils
{
public static string GetEmbeddedResource(string resourceName)
{
// read file content from embedded resource
var assembly = Assembly.GetExecutingAssembly();
var resourceStream = assembly.GetManifestResourceStream(resourceName);
if (resourceStream == null)
{
throw new ArgumentException("Resource not found", resourceName);
}
using var reader = new System.IO.StreamReader(resourceStream);
return reader.ReadToEnd();
}
}

Просмотреть файл

@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>true</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="TorchSharp.PyBridge" Version="$(TorchSharpPyBridgeVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj" PrivateAssets="all" />
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" PrivateAssets="all" />
<ProjectReference Include="..\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj" />
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="Resource\Config\*.json" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,166 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.Mistral;
public class MistralCausalLMAgent : IStreamingAgent
{
private readonly ICausalLMPipeline<Tokenizer, MistralForCausalLM> _pipeline;
private readonly string? _systemMessage;
private readonly IAutoGenChatTemplateBuilder _templateBuilder;
private readonly string _stopSequence = "</s>";
/// <summary>
/// Create a new instance of <see cref="MistralCausalLMAgent"/>.
/// </summary>
/// <param name="pipeline">pipeline</param>
/// <param name="name">agent name</param>
/// <param name="systemMessage">system message.</param>
/// <param name="templateBuilder">the template builder to build chat prompt. If the value is null, <see cref="Mistral_7B_0_3ChatTemplateBuilder.Instance"/> would be used.</param>
public MistralCausalLMAgent(
ICausalLMPipeline<Tokenizer, MistralForCausalLM> pipeline,
string name,
string? systemMessage = "you are a helpful assistant",
IAutoGenChatTemplateBuilder? templateBuilder = null)
{
this.Name = name;
this._pipeline = pipeline;
this._systemMessage = systemMessage;
this._templateBuilder = templateBuilder ?? Mistral_7B_0_3ChatTemplateBuilder.Instance;
}
public string Name { get; }
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (_systemMessage != null)
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
messages = messages.Prepend(systemMessage);
}
var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var stopTokenSequence = options?.StopSequence ?? [];
stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
var output = _pipeline.Generate(
input,
maxLen: maxLen,
temperature: temperature,
stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply.");
// post-process the output for tool call
if (output.StartsWith("[TOOL_CALLS]"))
{
return Task.FromResult<IMessage>(ParseAsToolCallMessage(output));
}
return Task.FromResult<IMessage>(new TextMessage(Role.Assistant, output, from: this.Name));
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (_systemMessage != null)
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
messages = messages.Prepend(systemMessage);
}
var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var stopTokenSequence = options?.StopSequence ?? [];
stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
// only streaming the output when the output is not a tool call
// otherwise, we collect all the chunks and convert them to a tool call message at the end of the streaming
var sb = new StringBuilder();
bool? isToolCall = null;
foreach (var output in _pipeline.GenerateStreaming(
input,
maxLen: maxLen,
temperature: temperature,
stopSequences: stopTokenSequence))
{
if (isToolCall is null)
{
sb.Append(output);
var str = sb.ToString();
if (!str.StartsWith("[TOOL_CALLS]".Substring(0, str.Length)))
{
yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
isToolCall = false;
}
else if (str.StartsWith("[TOOL_CALLS]"))
{
isToolCall = true;
}
}
else if (isToolCall == false)
{
yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
}
else
{
sb.Append(output);
}
}
if (isToolCall == true)
{
var toolCallMessage = ParseAsToolCallMessage(sb.ToString());
foreach (var toolCall in toolCallMessage.ToolCalls)
{
yield return new ToolCallMessageUpdate(toolCall.FunctionName, toolCall.FunctionArguments, from: this.Name);
}
}
}
private class MistralToolCall
{
[JsonPropertyName("name")]
public string? Name { get; set; }
[JsonPropertyName("arguments")]
public JsonObject? Arguments { get; set; }
}
private ToolCallMessage ParseAsToolCallMessage(string content)
{
var json = content.Substring("[TOOL_CALLS]".Length).Trim();
// the json string should be a list of tool call messages
// e.g. [{"name": "get_current_weather", "parameters": {"location": "Seattle"}}]
var mistralToolCalls = JsonSerializer.Deserialize<List<MistralToolCall>>(json) ?? throw new InvalidOperationException("Failed to deserialize tool calls.");
var toolCalls = mistralToolCalls
.Select(tc => new ToolCall(tc.Name!, JsonSerializer.Serialize(tc.Arguments)) { ToolCallId = this.GenerateToolCallId() });
return new ToolCallMessage(toolCalls, from: this.Name);
}
/// <summary>
/// 9 random alphanumeric characters
/// </summary>
private string GenerateToolCallId(int length = 9)
{
const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
var random = new Random();
return new string(Enumerable.Repeat(chars, length)
.Select(s => s[random.Next(s.Length)]).ToArray());
}
}

Просмотреть файл

@ -0,0 +1,112 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.ML.GenAI.Core;
using TorchSharp;
namespace Microsoft.ML.GenAI.Mistral;
public class MistralConfig
{
public MistralConfig()
{
this.AttentionBias = false;
this.AttentionDropout = 0.0;
this.HiddenAct = "silu";
this.HiddenSize = 4096;
this.InitializerRange = 0.02;
this.IntermediateSize = 14336;
this.MaxPositionEmbeddings = 131072;
this.MlpBias = false;
this.NumAttentionHeads = 32;
this.NumHiddenLayers = 32;
this.NumKeyValueHeads = 8;
this.RmsNormEps = 1e-05f;
this.RopeScaling = new RopeScalingConfig();
this.RopeTheta = 500000.0;
this.TieWordEmbeddings = false;
this.VocabSize = 128256;
this.AttnImplementation = "eager";
this.DType = torch.ScalarType.BFloat16;
this.HeadDim = this.HiddenSize / this.NumAttentionHeads;
this.SlidingWindow ??= 4096;
}
static MistralConfig()
{
#pragma warning disable MSML_ParameterLocalVarName // Parameter or local variable name not standard
var mistral7BInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Mistral.Resource.Config.mistral-7B-instruct-v0.3.json");
#pragma warning restore MSML_ParameterLocalVarName // Parameter or local variable name not standard
Mistral_7B_Instruct_v0_3 = JsonSerializer.Deserialize<MistralConfig>(mistral7BInstructContent) ?? throw new ArgumentNullException(nameof(mistral7BInstructContent));
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
/// <summary>
/// The mistral-7b-instruct-v0.3 configuration created from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/tree/main.
/// </summary>
public static MistralConfig Mistral_7B_Instruct_v0_3 { get; }
#pragma warning restore MSML_GeneralName // This name should be PascalCased
[JsonPropertyName("attention_bias")]
public bool AttentionBias { get; set; }
[JsonPropertyName("attention_dropout")]
public double AttentionDropout { get; set; }
[JsonPropertyName("hidden_act")]
public string HiddenAct { get; set; }
[JsonPropertyName("hidden_size")]
public int HiddenSize { get; set; }
[JsonPropertyName("initializer_range")]
public double InitializerRange { get; set; }
[JsonPropertyName("intermediate_size")]
public int IntermediateSize { get; set; }
[JsonPropertyName("max_position_embeddings")]
public int MaxPositionEmbeddings { get; set; }
[JsonPropertyName("mlp_bias")]
public bool MlpBias { get; set; }
[JsonPropertyName("num_attention_heads")]
public int NumAttentionHeads { get; set; }
[JsonPropertyName("num_hidden_layers")]
public int NumHiddenLayers { get; set; }
[JsonPropertyName("num_key_value_heads")]
public int NumKeyValueHeads { get; set; }
[JsonPropertyName("head_dim")]
public int HeadDim { get; set; }
[JsonPropertyName("rms_norm_eps")]
public float RmsNormEps { get; set; }
public RopeScalingConfig RopeScaling { get; set; }
[JsonPropertyName("rope_theta")]
public double RopeTheta { get; set; }
[JsonPropertyName("tie_word_embeddings")]
public bool TieWordEmbeddings { get; set; }
[JsonPropertyName("vocab_size")]
public int VocabSize { get; set; }
[JsonPropertyName("sliding_window")]
public int? SlidingWindow { get; set; }
public int? PadTokenId { get; set; }
public torch.ScalarType DType { get; set; }
public string AttnImplementation { get; set; }
}

Просмотреть файл

@ -0,0 +1,148 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Mistral.Module;
internal class DecoderLayerInput
{
public DecoderLayerInput(
Tensor hiddenStates,
Tensor attentionMask,
Tensor positionIds,
RotaryEmbeddingOutput positionEmbeddings, // cos, sin
IKVCache? pastKeyValue = null,
bool outputAttentions = false)
{
this.HiddenStates = hiddenStates;
this.AttentionMask = attentionMask;
this.PositionIds = positionIds;
this.PastKeyValue = pastKeyValue;
this.OutputAttentions = outputAttentions;
this.PositionalEmbeddings = positionEmbeddings;
}
public Tensor HiddenStates { get; set; }
public Tensor AttentionMask { get; set; }
public Tensor PositionIds { get; set; }
public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
public IKVCache? PastKeyValue { get; set; }
public bool OutputAttentions { get; set; }
}
internal class DecoderLayerOutput
{
public DecoderLayerOutput(
Tensor hiddenStates,
Tensor? attentions = null,
IKVCache? pastKeyValue = null)
{
this.HiddenStates = hiddenStates;
this.Attentions = attentions;
this.PastKeyValue = pastKeyValue;
}
public Tensor HiddenStates { get; set; }
public Tensor? Attentions { get; set; }
public IKVCache? PastKeyValue { get; set; }
}
internal class MistralDecoderLayer : nn.Module<DecoderLayerInput, DecoderLayerOutput>, IDynamicLoadModule
{
private readonly MistralConfig _llamaConfig;
private readonly int _layerIndex;
private readonly int _hiddenSize;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly MistralMLP mlp;
private readonly Core.RMSNorm input_layernorm;
private readonly Core.RMSNorm post_attention_layernorm;
private readonly Attention self_attn;
public Action<nn.Module>? LoadToDeviceFunc { get; set; }
public Action<nn.Module>? UnloadFromDeviceFunc { get; set; }
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
public MistralDecoderLayer(MistralConfig config, int layerIndex)
: base(nameof(MistralDecoderLayer))
{
_llamaConfig = config;
_layerIndex = layerIndex;
_hiddenSize = config.HiddenSize;
this.self_attn = CreateAttention(config, layerIndex);
this.mlp = new MistralMLP(config);
this.input_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
this.post_attention_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
}
private Attention CreateAttention(MistralConfig config, int layerIndex)
{
var headDim = config.HiddenSize / config.NumAttentionHeads;
return new Attention(
attentionDropout: config.AttentionDropout,
hiddenSize: config.HiddenSize,
numHeads: config.NumAttentionHeads,
headDim: headDim,
numKeyValueHeads: config.NumKeyValueHeads,
numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads,
maxPositionEmbeddings: config.MaxPositionEmbeddings,
originalMaxPositionEmbeddings: config.MaxPositionEmbeddings,
layerIdx: layerIndex,
useQkvProj: false,
dtype: config.DType,
attentionBias: config.AttentionBias);
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override DecoderLayerOutput forward(DecoderLayerInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
if (LoadToDeviceFunc != null)
{
LoadToDeviceFunc(this);
}
using var disposeScope = NewDisposeScope();
var residual = input.HiddenStates;
var hiddenStates = this.input_layernorm.forward(input.HiddenStates);
var selfAttnInput = new AttentionInput(
hiddenStates: hiddenStates,
attentionMask: input.AttentionMask,
positionIds: input.PositionIds,
cache: input.PastKeyValue,
positionalEmbeddings: input.PositionalEmbeddings,
outputAttentions: input.OutputAttentions);
var selfAttnOutput = this.self_attn.forward(selfAttnInput);
hiddenStates = residual + selfAttnOutput.HiddenStates;
// Fully connected
residual = hiddenStates;
hiddenStates = this.post_attention_layernorm.forward(hiddenStates);
hiddenStates = this.mlp.forward(hiddenStates);
hiddenStates = residual + hiddenStates;
if (UnloadFromDeviceFunc != null)
{
UnloadFromDeviceFunc(this);
}
return new DecoderLayerOutput(
hiddenStates: hiddenStates.MoveToOuterDisposeScope(),
attentions: input.OutputAttentions ? selfAttnOutput.Attentions?.MoveToOuterDisposeScope() : null,
pastKeyValue: selfAttnOutput.Cache);
}
}

Просмотреть файл

@ -0,0 +1,130 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
using System.Text.Json;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.Mistral.Module;
using TorchSharp;
using TorchSharp.PyBridge;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Mistral;
public class MistralForCausalLM : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly MistralConfig _config;
private readonly int _vocabSize;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly GenAILinear lm_head;
private readonly MistralModel model;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
public MistralForCausalLM(MistralConfig config)
: base(nameof(MistralForCausalLM))
{
_config = config;
_vocabSize = config.VocabSize;
model = new MistralModel(config);
lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, hasBias: false);
this.RegisterComponents();
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var outputs = this.model.forward(input);
var logits = this.lm_head.forward(outputs.LastHiddenState);
logits = logits.to_type(ScalarType.Float32);
outputs.Logits = logits;
return outputs;
}
public static MistralForCausalLM FromPretrained(
string modelFolder,
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
ScalarType torchDtype = ScalarType.BFloat16,
string device = "cpu")
{
var config = Path.Join(modelFolder, configName);
var modelConfig = JsonSerializer.Deserialize<MistralConfig>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
modelConfig.DType = torchDtype;
var model = new MistralForCausalLM(modelConfig);
model.LoadSafeTensors(modelFolder, checkPointName);
model = model.to(device);
return model;
}
public static MistralForCausalLM FromPretrained(
string modelFolder,
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
bool quantizeToInt8 = false,
bool quantizeToInt4 = false,
int layersOnTargetDevice = -1,
ScalarType torchDtype = ScalarType.BFloat16,
string targetDevice = "cuda")
{
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
{
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
}
var originalDefaultDevice = torch.get_default_device();
torch.set_default_device("meta");
var config = Path.Join(modelFolder, configName);
var modelConfig = JsonSerializer.Deserialize<MistralConfig>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
modelConfig.DType = torchDtype;
var model = new MistralForCausalLM(modelConfig);
if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}
var deviceMap = model.InferDeviceMapForEachLayer(
[
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
KeyValuePair.Create("cpu", -1)
]);
torch.set_default_device("cpu");
model = new MistralForCausalLM(modelConfig);
model.LoadSafeTensors(modelFolder, checkPointName);
if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}
model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
torch.set_default_device(originalDefaultDevice);
return model;
}
public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
{
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
}
}

Просмотреть файл

@ -0,0 +1,45 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Mistral.Module;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
internal class MistralMLP : torch.nn.Module<Tensor, Tensor>
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private readonly int _intermediateSize;
private readonly int _hiddenSize;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly QuantizedLinear gate_proj;
private readonly QuantizedLinear up_proj;
private readonly QuantizedLinear down_proj;
private readonly torch.nn.Module<Tensor, Tensor> act_fn;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
public MistralMLP(MistralConfig config)
: base(nameof(MistralMLP))
{
this._hiddenSize = config.HiddenSize;
this._intermediateSize = config.IntermediateSize;
var hiddenAct = config.HiddenAct;
this.gate_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType);
this.up_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType);
this.down_proj = new QuantizedLinear(this._intermediateSize, this._hiddenSize, hasBias: false, dtype: config.DType);
this.RegisterComponents();
this.act_fn = Core.Utils.GetActivation(hiddenAct);
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
using var input1 = this.gate_proj.forward(input);
using var input2 = this.act_fn.forward(input1);
using var input3 = input2 * this.up_proj.forward(input);
return this.down_proj.forward(input3);
}
}

Просмотреть файл

@ -0,0 +1,148 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Mistral.Module;
public class MistralModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly MistralConfig _config;
private readonly int? _paddingIdx;
private readonly int _vocabSize;
private IKVCache _cache;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly Embedding embed_tokens;
private readonly ModuleList<MistralDecoderLayer> layers;
private readonly RMSNorm norm;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly nn.Module<RotaryEmbeddingInput, RotaryEmbeddingOutput> _rotaryEmb;
public MistralModel(MistralConfig config)
: base(nameof(MistralModel))
{
this._config = config;
this._paddingIdx = config.PadTokenId;
this._vocabSize = config.VocabSize;
var headDim = config.HeadDim;
this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType);
this.layers = new ModuleList<MistralDecoderLayer>();
for (int i = 0; i < config.NumHiddenLayers; i++)
{
this.layers.Add(new MistralDecoderLayer(config, i));
}
this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
this._cache = new DynamicKVCache();
this.RegisterComponents();
this._rotaryEmb = config.RopeScaling switch
{
null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim),
_ => new RotaryEmbedding(config.RopeTheta, headDim, config.RopeScaling),
};
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
if (input.OverrideCache is not null)
{
this._cache = input.OverrideCache;
}
var outputAttentions = input.OutputAttentions;
var outputHiddenStates = input.OutputHiddenStates;
var attentionMask = input.AttentionMask;
Device device;
var inputIds = input.InputIds;
var positionIds = input.PositionIds;
var inputsEmbeds = input.InputEmbeddings;
int batchSize;
int seqLength;
if (inputIds is not null && inputsEmbeds is not null)
{
throw new ArgumentException("Only one of input_ids or inputs_embeds may be set");
}
else if (inputIds is not null)
{
batchSize = inputIds.IntShape()[0];
seqLength = inputIds.IntShape()[1];
inputsEmbeds = this.embed_tokens.forward(inputIds);
device = inputIds.device;
}
else if (inputsEmbeds is not null)
{
batchSize = inputsEmbeds.IntShape()[0];
seqLength = inputsEmbeds.IntShape()[1];
device = inputsEmbeds.device;
}
else
{
throw new ArgumentException("Either input_ids or inputs_embeds must be set");
}
var pastKeyValuesLength = input.PastKeyValuesLength;
if (positionIds is null)
{
positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device);
positionIds = positionIds.unsqueeze(0).view(-1, seqLength);
}
else
{
positionIds = ((long)positionIds.view(-1, seqLength));
}
if (this._config.AttnImplementation == "flash_attention_2")
{
throw new NotImplementedException();
}
else
{
// the following behavior of creating 4d causal mask doesn't match python's, remember to look into it when there's time.
attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength, slidingWindow: _config.SlidingWindow);
}
var hiddenStates = inputsEmbeds;
var allHiddenStates = new List<Tensor>();
var allAttentions = new List<Tensor>();
var embOutput = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, pastKeyValuesLength));
foreach (var layer in this.layers)
{
if (outputHiddenStates)
{
allHiddenStates.Add(hiddenStates);
}
var decoderInput = new DecoderLayerInput(
hiddenStates: hiddenStates,
attentionMask: attentionMask!,
positionIds: positionIds,
pastKeyValue: this._cache,
positionEmbeddings: embOutput,
outputAttentions: outputAttentions);
var layerOutput = layer.forward(decoderInput);
hiddenStates = layerOutput.HiddenStates;
if (outputAttentions && layerOutput.Attentions is not null)
{
allAttentions.Add(layerOutput.Attentions);
}
}
hiddenStates = this.norm.forward(hiddenStates);
if (outputHiddenStates)
{
allHiddenStates.Add(hiddenStates);
}
return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
}
}

Просмотреть файл

@ -0,0 +1,107 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.Mistral;
public class MistralTokenizerHelper
{
private const string UnknownSymbol = "<unk>";
private const int UnknownSymbolId = 0;
private const string StartSymbol = "<s>";
private const int StartSymbolId = 1;
private const string EndSymbol = "</s>";
private const int EndSymbolId = 2;
private const string StartInstructionSymbol = "[INST]";
private const int StartInstructionSymbolId = 3;
private const string EndInstructionSymbol = "[/INST]";
private const int EndInstructionSymbolId = 4;
private const string ToolCallSymbol = "[TOOL_CALLS]";
private const int ToolCallSymbolId = 5;
private const string StartAvailableToolsSymbol = "[AVAILABLE_TOOLS]";
private const int StartAvailableToolsSymbolId = 6;
private const string EndAvailableToolsSymbol = "[/AVAILABLE_TOOLS]";
private const int EndAvailableToolsSymbolId = 7;
private const string StartToolResultSymbol = "[TOOL_RESULTS]";
private const int StartToolResultSymbolId = 8;
private const string EndToolResultSymbol = "[/TOOL_RESULTS]";
private const int EndToolResultSymbolId = 9;
public static LlamaTokenizer FromPretrained(
string modelWeightFolder,
string modelName = "tokenizer.model.v3",
string unknownSymbol = UnknownSymbol,
int unknownSymbolId = 0,
string startSymbol = StartSymbol,
int startSymbolId = 1,
string endSymbol = EndSymbol,
int endSymbolId = 2,
string startInstructionSymbol = StartInstructionSymbol,
int startInstructionSymbolId = 3,
string endInstructionSymbol = EndInstructionSymbol,
int endInstructionSymbolId = 4,
string toolCallSymbol = ToolCallSymbol,
int toolCallSymbolId = 5,
string startAvailableToolsSymbol = StartAvailableToolsSymbol,
int startAvailableToolsSymbolId = 6,
string endAvailableToolsSymbol = EndAvailableToolsSymbol,
int endAvailableToolsSymbolId = 7,
string startToolResultSymbol = StartToolResultSymbol,
int startToolResultSymbolId = 8,
string endToolResultSymbol = EndToolResultSymbol,
int endToolResultSymbolId = 9,
bool addPrecedingSpace = true,
Dictionary<string, int>? additionalSpecialTokens = null)
{
var specialTokens = new Dictionary<string, int>
{
{ startSymbol, startSymbolId },
{ endSymbol, endSymbolId },
{ startInstructionSymbol, startInstructionSymbolId },
{ endInstructionSymbol, endInstructionSymbolId },
{ toolCallSymbol, toolCallSymbolId },
{ startAvailableToolsSymbol, startAvailableToolsSymbolId },
{ endAvailableToolsSymbol, endAvailableToolsSymbolId },
{ startToolResultSymbol, startToolResultSymbolId },
{ endToolResultSymbol, endToolResultSymbolId }
};
if (additionalSpecialTokens != null)
{
foreach (var (key, value) in additionalSpecialTokens)
{
specialTokens[key] = value;
}
}
return FromPretrained(
modelWeightFolder,
modelName,
specialTokens,
addPrecedingSpace);
}
public static LlamaTokenizer FromPretrained(
string modelWeightFolder,
string modelName,
Dictionary<string, int> specialTokens,
bool addPrecedingSpace = true)
{
var modelPath = Path.Combine(modelWeightFolder, modelName);
var modelStream = File.OpenRead(modelPath);
var llamaTokenizer = LlamaTokenizer.Create(
modelStream,
addPrecedingSpace,
specialTokens: specialTokens);
return llamaTokenizer;
}
}

Просмотреть файл

@ -0,0 +1,202 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using AutoGen.Core;
using Json.Schema;
using Json.Schema.Generation;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel.ChatCompletion;
namespace Microsoft.ML.GenAI.Mistral;
/// <summary>
/// the chat template builder for Mistral 7B v0.3
/// </summary>
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public class Mistral_7B_0_3ChatTemplateBuilder : IChatTemplateBuilder
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private const string Newline = "\r\n";
public static Mistral_7B_0_3ChatTemplateBuilder Instance { get; } = new Mistral_7B_0_3ChatTemplateBuilder();
public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
{
// can only contain at most one system message
if (messages.Where(m => m.GetRole() == Role.System).Count() > 1)
{
throw new InvalidOperationException("Please provide at most one system message.");
}
var systemMessage = messages.FirstOrDefault(m => m.GetRole() == Role.System)?.GetContent();
// split the messages into two sequences by the last user message
// e.g [user, assistant, user, assistant, user] -> [[user, assistant, user, assistant], [user]]
var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
var sb = new StringBuilder();
foreach (var message in firstSequence)
{
// skip system
if (message.GetRole() == Role.System)
{
continue;
}
var content = message.GetContent()!;
sb.Append(message switch
{
ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}</s>",
_ => throw new InvalidOperationException("Invalid role.")
});
}
// insert [AVAILABLE TOOLS] section if tools are provided
if (tools?.Any() == true)
{
var schemas = tools.Select(t => new
{
type = "function",
function = new
{
name = t.Name,
description = t.Description,
parameters = BuildJsonSchemaFromFunctionContract(t)
}
});
var schemaPrompt = JsonSerializer.Serialize(schemas);
// add a space after the colon in json string so mistral can correctly generate the stop </s> token after [TOOL_CALLS] symbol.
// This is probably because in the training data, all the tool call samples are separated by a space after the colon.
// e.g. [AVAILABLE_TOOLS][{"type": "function", "function": {....
// instead of [AVAILABLE_TOOLS][{"type":"function","function":{....
// Therefore when inferencing, we need to add a space after the colon in the json string to match with the training data.
schemaPrompt = schemaPrompt.Replace(":", ": ");
schemaPrompt = schemaPrompt.Replace(",", ", ");
sb.Append($"[AVAILABLE_TOOLS]{schemaPrompt}[/AVAILABLE_TOOLS]");
}
foreach (var message in secondSequence)
{
var content = message.GetContent()!;
sb.Append(message switch
{
ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
TextMessage when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST]{systemMessage}{Newline}{Newline}{content.Trim()}[/INST]",
TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}</s>",
_ => throw new InvalidOperationException("Invalid role.")
});
}
return sb.ToString();
}
public string BuildPrompt(ChatHistory chatHistory)
{
throw new NotImplementedException();
}
private string BuildFromToolCallMessage(ToolCallMessage message)
{
var toolCalls = message.ToolCalls;
if (toolCalls.Count() == 0)
{
return string.Empty;
}
else
{
var toolCallObjects = toolCalls.Select(tc =>
new
{
name = tc.FunctionName,
arguments = JsonObject.Parse(tc.FunctionArguments),
id = tc.ToolCallId,
}
);
var toolCallJson = JsonSerializer.Serialize(toolCallObjects);
return $"[TOOL_CALLS]{toolCallJson}</s>";
}
}
private string BuildFromToolCallResultMessage(ToolCallResultMessage message)
{
var toolCallResults = message.ToolCalls;
if (toolCallResults.Count() == 0)
{
return string.Empty;
}
else
{
var toolCallResultObjects = toolCallResults.Select(tc =>
new
{
id = tc.ToolCallId,
content = tc.Result,
}
);
var toolCallResultJson = JsonSerializer.Serialize(toolCallResultObjects);
return $"[TOOL_RESULTS]{toolCallResultJson}[/TOOL_RESULTS]";
}
}
private string BuildFromAggregrateToolCallMessage(ToolCallAggregateMessage message)
{
var toolCallMessage = message.Message1;
var toolCallResultMessage = message.Message2;
var toolCall = BuildFromToolCallMessage(toolCallMessage);
var toolCallResult = BuildFromToolCallResultMessage(toolCallResultMessage);
return $"{toolCall}{toolCallResult}";
}
private JsonSchema BuildJsonSchemaFromFunctionContract(FunctionContract contract)
{
var requiredParameterNames = new List<string>();
var propertiesSchemas = new Dictionary<string, JsonSchema>();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
foreach (var param in contract.Parameters ?? [])
{
if (param.Name is null)
{
throw new InvalidOperationException("Parameter name cannot be null");
}
var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
if (param.Description != null)
{
schemaBuilder = schemaBuilder.Description(param.Description);
}
if (param.IsRequired)
{
requiredParameterNames.Add(param.Name);
}
var schema = schemaBuilder.Build();
propertiesSchemas[param.Name] = schema;
}
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
var jsonSchema = propertySchemaBuilder.Build();
return jsonSchema;
}
}

Просмотреть файл

@ -0,0 +1,21 @@
{
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 32768,
"model_type": "mistral",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"use_cache": true,
"vocab_size": 32768
}

Просмотреть файл

@ -41,7 +41,7 @@ public class Phi2Config
static Phi2Config()
{
var phi2ConfigContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json");
var phi2ConfigContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json");
var phi2Config = JsonSerializer.Deserialize<Phi2Config>(phi2ConfigContent) ?? throw new ArgumentNullException(nameof(phi2ConfigContent));
Phi2 = phi2Config;
}

Просмотреть файл

@ -42,10 +42,10 @@ public class Phi3Config
static Phi3Config()
{
var phi3Mini4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json");
var phi3Mini128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json");
var phi3Medium4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json");
var phi3Medium128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json");
var phi3Mini4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json");
var phi3Mini128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json");
var phi3Medium4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json");
var phi3Medium128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json");
Phi3Mini4kInstruct = JsonSerializer.Deserialize<Phi3Config>(phi3Mini4kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini4kInstructContent));
Phi3Mini128kInstruct = JsonSerializer.Deserialize<Phi3Config>(phi3Mini128kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini128kInstructContent));

Просмотреть файл

@ -50,7 +50,7 @@ public class Phi3Agent : IStreamingAgent
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,

Просмотреть файл

@ -16,49 +16,6 @@ namespace Microsoft.ML.GenAI.Phi;
internal static class Utils
{
public static string GetEmbeddedResource(string resourceName)
{
// read file content from embedded resource
var assembly = Assembly.GetExecutingAssembly();
var resourceStream = assembly.GetManifestResourceStream(resourceName);
if (resourceStream == null)
{
throw new ArgumentException("Resource not found", nameof(resourceName));
}
using var reader = new System.IO.StreamReader(resourceStream);
return reader.ReadToEnd();
}
public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex)
{
// Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
// Two consecutive values will become a single complex number
// (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
var inputComplex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex();
freqsComplex = freqsComplex.to(input.device);
// Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
// (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
var freqsComplexReshaped = freqsComplex.unsqueeze(0).unsqueeze(2);
// Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
// Which results in the rotation of the complex number as shown in the Figure 1 of the paper
// (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
var rotatedComplex = inputComplex * freqsComplexReshaped;
// Console.WriteLine(rotated_complex.mean().ToSingle());
// Convert the complex number back to the real number
// (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
var rotated = rotatedComplex.view_as_real();
// (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
var rotatedReshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1);
return rotatedReshaped.type_as(input);
}
public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f)
{
// As written in the paragraph 3.2.2 of the paper
@ -147,21 +104,4 @@ internal static class Utils
.expand(batchSize, seqLen, nKVHeads, nRep, headDim)
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
}
public static Tensor Phi3RepeatKV(Tensor x, int nRep)
{
var batchSize = x.shape[0];
var nKVHeads = x.shape[1];
var seqLen = x.shape[2];
var headDim = x.shape[3];
if (nRep == 1)
{
return x;
}
return x.unsqueeze(3)
.expand(batchSize, nKVHeads, nRep, seqLen, headDim)
.view(batchSize, nKVHeads * nRep, seqLen, headDim);
}
}

Просмотреть файл

@ -0,0 +1,3 @@
[INST]You are a helpful AI assistant.
Hello?[/INST]World!</s>

Просмотреть файл

@ -0,0 +1,3 @@
[INST]What's the weather in Seattle?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}]</s>[TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in Seattle is 22.0 degrees celsius.</s>[INST]What's the weather in New York?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}]</s>[TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in New York is 22.0 degrees celsius.</s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]You are a helpful AI assistant.
What's the weather in Paris?[/INST]

Просмотреть файл

@ -0,0 +1,291 @@
0: lm_head.weight shape: [32768, 4096]
1: model.embed_tokens.weight shape: [32768, 4096]
2: model.layers.0.input_layernorm.weight shape: [4096]
3: model.layers.0.mlp.down_proj.weight shape: [4096, 14336]
4: model.layers.0.mlp.gate_proj.weight shape: [14336, 4096]
5: model.layers.0.mlp.up_proj.weight shape: [14336, 4096]
6: model.layers.0.post_attention_layernorm.weight shape: [4096]
7: model.layers.0.self_attn.k_proj.weight shape: [1024, 4096]
8: model.layers.0.self_attn.o_proj.weight shape: [4096, 4096]
9: model.layers.0.self_attn.q_proj.weight shape: [4096, 4096]
10: model.layers.0.self_attn.v_proj.weight shape: [1024, 4096]
11: model.layers.1.input_layernorm.weight shape: [4096]
12: model.layers.1.mlp.down_proj.weight shape: [4096, 14336]
13: model.layers.1.mlp.gate_proj.weight shape: [14336, 4096]
14: model.layers.1.mlp.up_proj.weight shape: [14336, 4096]
15: model.layers.1.post_attention_layernorm.weight shape: [4096]
16: model.layers.1.self_attn.k_proj.weight shape: [1024, 4096]
17: model.layers.1.self_attn.o_proj.weight shape: [4096, 4096]
18: model.layers.1.self_attn.q_proj.weight shape: [4096, 4096]
19: model.layers.1.self_attn.v_proj.weight shape: [1024, 4096]
20: model.layers.10.input_layernorm.weight shape: [4096]
21: model.layers.10.mlp.down_proj.weight shape: [4096, 14336]
22: model.layers.10.mlp.gate_proj.weight shape: [14336, 4096]
23: model.layers.10.mlp.up_proj.weight shape: [14336, 4096]
24: model.layers.10.post_attention_layernorm.weight shape: [4096]
25: model.layers.10.self_attn.k_proj.weight shape: [1024, 4096]
26: model.layers.10.self_attn.o_proj.weight shape: [4096, 4096]
27: model.layers.10.self_attn.q_proj.weight shape: [4096, 4096]
28: model.layers.10.self_attn.v_proj.weight shape: [1024, 4096]
29: model.layers.11.input_layernorm.weight shape: [4096]
30: model.layers.11.mlp.down_proj.weight shape: [4096, 14336]
31: model.layers.11.mlp.gate_proj.weight shape: [14336, 4096]
32: model.layers.11.mlp.up_proj.weight shape: [14336, 4096]
33: model.layers.11.post_attention_layernorm.weight shape: [4096]
34: model.layers.11.self_attn.k_proj.weight shape: [1024, 4096]
35: model.layers.11.self_attn.o_proj.weight shape: [4096, 4096]
36: model.layers.11.self_attn.q_proj.weight shape: [4096, 4096]
37: model.layers.11.self_attn.v_proj.weight shape: [1024, 4096]
38: model.layers.12.input_layernorm.weight shape: [4096]
39: model.layers.12.mlp.down_proj.weight shape: [4096, 14336]
40: model.layers.12.mlp.gate_proj.weight shape: [14336, 4096]
41: model.layers.12.mlp.up_proj.weight shape: [14336, 4096]
42: model.layers.12.post_attention_layernorm.weight shape: [4096]
43: model.layers.12.self_attn.k_proj.weight shape: [1024, 4096]
44: model.layers.12.self_attn.o_proj.weight shape: [4096, 4096]
45: model.layers.12.self_attn.q_proj.weight shape: [4096, 4096]
46: model.layers.12.self_attn.v_proj.weight shape: [1024, 4096]
47: model.layers.13.input_layernorm.weight shape: [4096]
48: model.layers.13.mlp.down_proj.weight shape: [4096, 14336]
49: model.layers.13.mlp.gate_proj.weight shape: [14336, 4096]
50: model.layers.13.mlp.up_proj.weight shape: [14336, 4096]
51: model.layers.13.post_attention_layernorm.weight shape: [4096]
52: model.layers.13.self_attn.k_proj.weight shape: [1024, 4096]
53: model.layers.13.self_attn.o_proj.weight shape: [4096, 4096]
54: model.layers.13.self_attn.q_proj.weight shape: [4096, 4096]
55: model.layers.13.self_attn.v_proj.weight shape: [1024, 4096]
56: model.layers.14.input_layernorm.weight shape: [4096]
57: model.layers.14.mlp.down_proj.weight shape: [4096, 14336]
58: model.layers.14.mlp.gate_proj.weight shape: [14336, 4096]
59: model.layers.14.mlp.up_proj.weight shape: [14336, 4096]
60: model.layers.14.post_attention_layernorm.weight shape: [4096]
61: model.layers.14.self_attn.k_proj.weight shape: [1024, 4096]
62: model.layers.14.self_attn.o_proj.weight shape: [4096, 4096]
63: model.layers.14.self_attn.q_proj.weight shape: [4096, 4096]
64: model.layers.14.self_attn.v_proj.weight shape: [1024, 4096]
65: model.layers.15.input_layernorm.weight shape: [4096]
66: model.layers.15.mlp.down_proj.weight shape: [4096, 14336]
67: model.layers.15.mlp.gate_proj.weight shape: [14336, 4096]
68: model.layers.15.mlp.up_proj.weight shape: [14336, 4096]
69: model.layers.15.post_attention_layernorm.weight shape: [4096]
70: model.layers.15.self_attn.k_proj.weight shape: [1024, 4096]
71: model.layers.15.self_attn.o_proj.weight shape: [4096, 4096]
72: model.layers.15.self_attn.q_proj.weight shape: [4096, 4096]
73: model.layers.15.self_attn.v_proj.weight shape: [1024, 4096]
74: model.layers.16.input_layernorm.weight shape: [4096]
75: model.layers.16.mlp.down_proj.weight shape: [4096, 14336]
76: model.layers.16.mlp.gate_proj.weight shape: [14336, 4096]
77: model.layers.16.mlp.up_proj.weight shape: [14336, 4096]
78: model.layers.16.post_attention_layernorm.weight shape: [4096]
79: model.layers.16.self_attn.k_proj.weight shape: [1024, 4096]
80: model.layers.16.self_attn.o_proj.weight shape: [4096, 4096]
81: model.layers.16.self_attn.q_proj.weight shape: [4096, 4096]
82: model.layers.16.self_attn.v_proj.weight shape: [1024, 4096]
83: model.layers.17.input_layernorm.weight shape: [4096]
84: model.layers.17.mlp.down_proj.weight shape: [4096, 14336]
85: model.layers.17.mlp.gate_proj.weight shape: [14336, 4096]
86: model.layers.17.mlp.up_proj.weight shape: [14336, 4096]
87: model.layers.17.post_attention_layernorm.weight shape: [4096]
88: model.layers.17.self_attn.k_proj.weight shape: [1024, 4096]
89: model.layers.17.self_attn.o_proj.weight shape: [4096, 4096]
90: model.layers.17.self_attn.q_proj.weight shape: [4096, 4096]
91: model.layers.17.self_attn.v_proj.weight shape: [1024, 4096]
92: model.layers.18.input_layernorm.weight shape: [4096]
93: model.layers.18.mlp.down_proj.weight shape: [4096, 14336]
94: model.layers.18.mlp.gate_proj.weight shape: [14336, 4096]
95: model.layers.18.mlp.up_proj.weight shape: [14336, 4096]
96: model.layers.18.post_attention_layernorm.weight shape: [4096]
97: model.layers.18.self_attn.k_proj.weight shape: [1024, 4096]
98: model.layers.18.self_attn.o_proj.weight shape: [4096, 4096]
99: model.layers.18.self_attn.q_proj.weight shape: [4096, 4096]
100: model.layers.18.self_attn.v_proj.weight shape: [1024, 4096]
101: model.layers.19.input_layernorm.weight shape: [4096]
102: model.layers.19.mlp.down_proj.weight shape: [4096, 14336]
103: model.layers.19.mlp.gate_proj.weight shape: [14336, 4096]
104: model.layers.19.mlp.up_proj.weight shape: [14336, 4096]
105: model.layers.19.post_attention_layernorm.weight shape: [4096]
106: model.layers.19.self_attn.k_proj.weight shape: [1024, 4096]
107: model.layers.19.self_attn.o_proj.weight shape: [4096, 4096]
108: model.layers.19.self_attn.q_proj.weight shape: [4096, 4096]
109: model.layers.19.self_attn.v_proj.weight shape: [1024, 4096]
110: model.layers.2.input_layernorm.weight shape: [4096]
111: model.layers.2.mlp.down_proj.weight shape: [4096, 14336]
112: model.layers.2.mlp.gate_proj.weight shape: [14336, 4096]
113: model.layers.2.mlp.up_proj.weight shape: [14336, 4096]
114: model.layers.2.post_attention_layernorm.weight shape: [4096]
115: model.layers.2.self_attn.k_proj.weight shape: [1024, 4096]
116: model.layers.2.self_attn.o_proj.weight shape: [4096, 4096]
117: model.layers.2.self_attn.q_proj.weight shape: [4096, 4096]
118: model.layers.2.self_attn.v_proj.weight shape: [1024, 4096]
119: model.layers.20.input_layernorm.weight shape: [4096]
120: model.layers.20.mlp.down_proj.weight shape: [4096, 14336]
121: model.layers.20.mlp.gate_proj.weight shape: [14336, 4096]
122: model.layers.20.mlp.up_proj.weight shape: [14336, 4096]
123: model.layers.20.post_attention_layernorm.weight shape: [4096]
124: model.layers.20.self_attn.k_proj.weight shape: [1024, 4096]
125: model.layers.20.self_attn.o_proj.weight shape: [4096, 4096]
126: model.layers.20.self_attn.q_proj.weight shape: [4096, 4096]
127: model.layers.20.self_attn.v_proj.weight shape: [1024, 4096]
128: model.layers.21.input_layernorm.weight shape: [4096]
129: model.layers.21.mlp.down_proj.weight shape: [4096, 14336]
130: model.layers.21.mlp.gate_proj.weight shape: [14336, 4096]
131: model.layers.21.mlp.up_proj.weight shape: [14336, 4096]
132: model.layers.21.post_attention_layernorm.weight shape: [4096]
133: model.layers.21.self_attn.k_proj.weight shape: [1024, 4096]
134: model.layers.21.self_attn.o_proj.weight shape: [4096, 4096]
135: model.layers.21.self_attn.q_proj.weight shape: [4096, 4096]
136: model.layers.21.self_attn.v_proj.weight shape: [1024, 4096]
137: model.layers.22.input_layernorm.weight shape: [4096]
138: model.layers.22.mlp.down_proj.weight shape: [4096, 14336]
139: model.layers.22.mlp.gate_proj.weight shape: [14336, 4096]
140: model.layers.22.mlp.up_proj.weight shape: [14336, 4096]
141: model.layers.22.post_attention_layernorm.weight shape: [4096]
142: model.layers.22.self_attn.k_proj.weight shape: [1024, 4096]
143: model.layers.22.self_attn.o_proj.weight shape: [4096, 4096]
144: model.layers.22.self_attn.q_proj.weight shape: [4096, 4096]
145: model.layers.22.self_attn.v_proj.weight shape: [1024, 4096]
146: model.layers.23.input_layernorm.weight shape: [4096]
147: model.layers.23.mlp.down_proj.weight shape: [4096, 14336]
148: model.layers.23.mlp.gate_proj.weight shape: [14336, 4096]
149: model.layers.23.mlp.up_proj.weight shape: [14336, 4096]
150: model.layers.23.post_attention_layernorm.weight shape: [4096]
151: model.layers.23.self_attn.k_proj.weight shape: [1024, 4096]
152: model.layers.23.self_attn.o_proj.weight shape: [4096, 4096]
153: model.layers.23.self_attn.q_proj.weight shape: [4096, 4096]
154: model.layers.23.self_attn.v_proj.weight shape: [1024, 4096]
155: model.layers.24.input_layernorm.weight shape: [4096]
156: model.layers.24.mlp.down_proj.weight shape: [4096, 14336]
157: model.layers.24.mlp.gate_proj.weight shape: [14336, 4096]
158: model.layers.24.mlp.up_proj.weight shape: [14336, 4096]
159: model.layers.24.post_attention_layernorm.weight shape: [4096]
160: model.layers.24.self_attn.k_proj.weight shape: [1024, 4096]
161: model.layers.24.self_attn.o_proj.weight shape: [4096, 4096]
162: model.layers.24.self_attn.q_proj.weight shape: [4096, 4096]
163: model.layers.24.self_attn.v_proj.weight shape: [1024, 4096]
164: model.layers.25.input_layernorm.weight shape: [4096]
165: model.layers.25.mlp.down_proj.weight shape: [4096, 14336]
166: model.layers.25.mlp.gate_proj.weight shape: [14336, 4096]
167: model.layers.25.mlp.up_proj.weight shape: [14336, 4096]
168: model.layers.25.post_attention_layernorm.weight shape: [4096]
169: model.layers.25.self_attn.k_proj.weight shape: [1024, 4096]
170: model.layers.25.self_attn.o_proj.weight shape: [4096, 4096]
171: model.layers.25.self_attn.q_proj.weight shape: [4096, 4096]
172: model.layers.25.self_attn.v_proj.weight shape: [1024, 4096]
173: model.layers.26.input_layernorm.weight shape: [4096]
174: model.layers.26.mlp.down_proj.weight shape: [4096, 14336]
175: model.layers.26.mlp.gate_proj.weight shape: [14336, 4096]
176: model.layers.26.mlp.up_proj.weight shape: [14336, 4096]
177: model.layers.26.post_attention_layernorm.weight shape: [4096]
178: model.layers.26.self_attn.k_proj.weight shape: [1024, 4096]
179: model.layers.26.self_attn.o_proj.weight shape: [4096, 4096]
180: model.layers.26.self_attn.q_proj.weight shape: [4096, 4096]
181: model.layers.26.self_attn.v_proj.weight shape: [1024, 4096]
182: model.layers.27.input_layernorm.weight shape: [4096]
183: model.layers.27.mlp.down_proj.weight shape: [4096, 14336]
184: model.layers.27.mlp.gate_proj.weight shape: [14336, 4096]
185: model.layers.27.mlp.up_proj.weight shape: [14336, 4096]
186: model.layers.27.post_attention_layernorm.weight shape: [4096]
187: model.layers.27.self_attn.k_proj.weight shape: [1024, 4096]
188: model.layers.27.self_attn.o_proj.weight shape: [4096, 4096]
189: model.layers.27.self_attn.q_proj.weight shape: [4096, 4096]
190: model.layers.27.self_attn.v_proj.weight shape: [1024, 4096]
191: model.layers.28.input_layernorm.weight shape: [4096]
192: model.layers.28.mlp.down_proj.weight shape: [4096, 14336]
193: model.layers.28.mlp.gate_proj.weight shape: [14336, 4096]
194: model.layers.28.mlp.up_proj.weight shape: [14336, 4096]
195: model.layers.28.post_attention_layernorm.weight shape: [4096]
196: model.layers.28.self_attn.k_proj.weight shape: [1024, 4096]
197: model.layers.28.self_attn.o_proj.weight shape: [4096, 4096]
198: model.layers.28.self_attn.q_proj.weight shape: [4096, 4096]
199: model.layers.28.self_attn.v_proj.weight shape: [1024, 4096]
200: model.layers.29.input_layernorm.weight shape: [4096]
201: model.layers.29.mlp.down_proj.weight shape: [4096, 14336]
202: model.layers.29.mlp.gate_proj.weight shape: [14336, 4096]
203: model.layers.29.mlp.up_proj.weight shape: [14336, 4096]
204: model.layers.29.post_attention_layernorm.weight shape: [4096]
205: model.layers.29.self_attn.k_proj.weight shape: [1024, 4096]
206: model.layers.29.self_attn.o_proj.weight shape: [4096, 4096]
207: model.layers.29.self_attn.q_proj.weight shape: [4096, 4096]
208: model.layers.29.self_attn.v_proj.weight shape: [1024, 4096]
209: model.layers.3.input_layernorm.weight shape: [4096]
210: model.layers.3.mlp.down_proj.weight shape: [4096, 14336]
211: model.layers.3.mlp.gate_proj.weight shape: [14336, 4096]
212: model.layers.3.mlp.up_proj.weight shape: [14336, 4096]
213: model.layers.3.post_attention_layernorm.weight shape: [4096]
214: model.layers.3.self_attn.k_proj.weight shape: [1024, 4096]
215: model.layers.3.self_attn.o_proj.weight shape: [4096, 4096]
216: model.layers.3.self_attn.q_proj.weight shape: [4096, 4096]
217: model.layers.3.self_attn.v_proj.weight shape: [1024, 4096]
218: model.layers.30.input_layernorm.weight shape: [4096]
219: model.layers.30.mlp.down_proj.weight shape: [4096, 14336]
220: model.layers.30.mlp.gate_proj.weight shape: [14336, 4096]
221: model.layers.30.mlp.up_proj.weight shape: [14336, 4096]
222: model.layers.30.post_attention_layernorm.weight shape: [4096]
223: model.layers.30.self_attn.k_proj.weight shape: [1024, 4096]
224: model.layers.30.self_attn.o_proj.weight shape: [4096, 4096]
225: model.layers.30.self_attn.q_proj.weight shape: [4096, 4096]
226: model.layers.30.self_attn.v_proj.weight shape: [1024, 4096]
227: model.layers.31.input_layernorm.weight shape: [4096]
228: model.layers.31.mlp.down_proj.weight shape: [4096, 14336]
229: model.layers.31.mlp.gate_proj.weight shape: [14336, 4096]
230: model.layers.31.mlp.up_proj.weight shape: [14336, 4096]
231: model.layers.31.post_attention_layernorm.weight shape: [4096]
232: model.layers.31.self_attn.k_proj.weight shape: [1024, 4096]
233: model.layers.31.self_attn.o_proj.weight shape: [4096, 4096]
234: model.layers.31.self_attn.q_proj.weight shape: [4096, 4096]
235: model.layers.31.self_attn.v_proj.weight shape: [1024, 4096]
236: model.layers.4.input_layernorm.weight shape: [4096]
237: model.layers.4.mlp.down_proj.weight shape: [4096, 14336]
238: model.layers.4.mlp.gate_proj.weight shape: [14336, 4096]
239: model.layers.4.mlp.up_proj.weight shape: [14336, 4096]
240: model.layers.4.post_attention_layernorm.weight shape: [4096]
241: model.layers.4.self_attn.k_proj.weight shape: [1024, 4096]
242: model.layers.4.self_attn.o_proj.weight shape: [4096, 4096]
243: model.layers.4.self_attn.q_proj.weight shape: [4096, 4096]
244: model.layers.4.self_attn.v_proj.weight shape: [1024, 4096]
245: model.layers.5.input_layernorm.weight shape: [4096]
246: model.layers.5.mlp.down_proj.weight shape: [4096, 14336]
247: model.layers.5.mlp.gate_proj.weight shape: [14336, 4096]
248: model.layers.5.mlp.up_proj.weight shape: [14336, 4096]
249: model.layers.5.post_attention_layernorm.weight shape: [4096]
250: model.layers.5.self_attn.k_proj.weight shape: [1024, 4096]
251: model.layers.5.self_attn.o_proj.weight shape: [4096, 4096]
252: model.layers.5.self_attn.q_proj.weight shape: [4096, 4096]
253: model.layers.5.self_attn.v_proj.weight shape: [1024, 4096]
254: model.layers.6.input_layernorm.weight shape: [4096]
255: model.layers.6.mlp.down_proj.weight shape: [4096, 14336]
256: model.layers.6.mlp.gate_proj.weight shape: [14336, 4096]
257: model.layers.6.mlp.up_proj.weight shape: [14336, 4096]
258: model.layers.6.post_attention_layernorm.weight shape: [4096]
259: model.layers.6.self_attn.k_proj.weight shape: [1024, 4096]
260: model.layers.6.self_attn.o_proj.weight shape: [4096, 4096]
261: model.layers.6.self_attn.q_proj.weight shape: [4096, 4096]
262: model.layers.6.self_attn.v_proj.weight shape: [1024, 4096]
263: model.layers.7.input_layernorm.weight shape: [4096]
264: model.layers.7.mlp.down_proj.weight shape: [4096, 14336]
265: model.layers.7.mlp.gate_proj.weight shape: [14336, 4096]
266: model.layers.7.mlp.up_proj.weight shape: [14336, 4096]
267: model.layers.7.post_attention_layernorm.weight shape: [4096]
268: model.layers.7.self_attn.k_proj.weight shape: [1024, 4096]
269: model.layers.7.self_attn.o_proj.weight shape: [4096, 4096]
270: model.layers.7.self_attn.q_proj.weight shape: [4096, 4096]
271: model.layers.7.self_attn.v_proj.weight shape: [1024, 4096]
272: model.layers.8.input_layernorm.weight shape: [4096]
273: model.layers.8.mlp.down_proj.weight shape: [4096, 14336]
274: model.layers.8.mlp.gate_proj.weight shape: [14336, 4096]
275: model.layers.8.mlp.up_proj.weight shape: [14336, 4096]
276: model.layers.8.post_attention_layernorm.weight shape: [4096]
277: model.layers.8.self_attn.k_proj.weight shape: [1024, 4096]
278: model.layers.8.self_attn.o_proj.weight shape: [4096, 4096]
279: model.layers.8.self_attn.q_proj.weight shape: [4096, 4096]
280: model.layers.8.self_attn.v_proj.weight shape: [1024, 4096]
281: model.layers.9.input_layernorm.weight shape: [4096]
282: model.layers.9.mlp.down_proj.weight shape: [4096, 14336]
283: model.layers.9.mlp.gate_proj.weight shape: [14336, 4096]
284: model.layers.9.mlp.up_proj.weight shape: [14336, 4096]
285: model.layers.9.post_attention_layernorm.weight shape: [4096]
286: model.layers.9.self_attn.k_proj.weight shape: [1024, 4096]
287: model.layers.9.self_attn.o_proj.weight shape: [4096, 4096]
288: model.layers.9.self_attn.q_proj.weight shape: [4096, 4096]
289: model.layers.9.self_attn.v_proj.weight shape: [1024, 4096]
290: model.norm.weight shape: [4096]

Просмотреть файл

@ -0,0 +1,2 @@
<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}]</s>[TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius.</s>
1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2

Просмотреть файл

@ -0,0 +1,2 @@
 [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}] What's the weather like in Paris? [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}] {"content": 22.0, "call_id": "9Ae3bDc2F"} The current temperature in Paris is 22.0 degrees celsius.
1, 1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2

Просмотреть файл

@ -0,0 +1,45 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net6.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<NoWarn>$(NoWarn);MSML_ExtendBaseTestClass</NoWarn>
<Nullable>enable</Nullable>
<PreserveCompilationContext>true</PreserveCompilationContext>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestsVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="$(SystemDataSqlClientVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="Moq" Version="$(MoqVersion)" />
<PackageReference Include="Microsoft.ML.TestTokenizers" Version="$(MicrosoftMLTestTokenizersVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
</ItemGroup>
<ItemGroup Condition="'$(TargetArchitecture)' != 'x64'">
<Compile Remove="Mistral_7B_Instruct_V0_3Tests.cs" />
</ItemGroup>
<ItemGroup>
<None Update="Approvals\**\*">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
<ItemGroup Condition="'$(TargetArchitecture)' == 'x64'">
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows')) AND '$(TargetArchitecture)' == 'x64'" />
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux')) AND '$(TargetArchitecture)' == 'x64'" />
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX')) AND '$(TargetArchitecture)' == 'x64'" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,137 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Text;
using ApprovalTests;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using Xunit;
namespace Microsoft.ML.GenAI.Mistral.Tests;
[Collection("NoParallelization")]
public class Mistral_7B_Instruct_V0_3Tests
{
public Mistral_7B_Instruct_V0_3Tests()
{
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
{
Approvals.UseAssemblyLocationForApprovedFiles();
}
torch.set_default_device("meta");
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("Approvals")]
public void Mistral_7B_Instruct_V0_3_ShapeTest()
{
var model = new MistralForCausalLM(MistralConfig.Mistral_7B_Instruct_v0_3);
var stateDictStr = model.PeekShape();
Approvals.Verify(stateDictStr);
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("Approvals")]
public void ItBuildChatTemplateFromAutoGenChatHistory()
{
var chatHistory = new List<IMessage>
{
new TextMessage(Role.System, "You are a helpful AI assistant."),
new TextMessage(Role.User, "Hello?"),
new TextMessage(Role.Assistant, "World!"),
};
var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
Approvals.Verify(prompt);
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("Approvals")]
public void ItBuildChatTemplateWithToolsFromAutoGenChatHistory()
{
var getWeatherTool = new FunctionContract
{
Name = "get_current_weather",
Namespace = "weather",
Description = "Get the current weather",
Parameters = [
new FunctionParameterContract
{
Name = "location",
ParameterType = typeof(string),
Description = "The city and state, e.g. San Francisco, CA",
IsRequired = true
}
]
};
var getWeatherToolCall = new ToolCall("get_current_weather", "{\"location\": \"Seattle, WA\"}") { ToolCallId = "9Ae3bDc2F" };
var getWeatherToolCallResult = new ToolCall("get_current_weather", "{\"temperature\": 22.0}", "sunny") { ToolCallId = "9Ae3bDc2F" };
var toolCallMessage = new ToolCallMessage([getWeatherToolCall]);
var toolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult]);
var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage);
var chatHistory = new List<IMessage>
{
new TextMessage(Role.System, "You are a helpful AI assistant."),
new TextMessage(Role.User, "What's the weather in Seattle?"),
toolCallMessage,
toolCallResultMessage,
new TextMessage(Role.Assistant, "The current temperature in Seattle is 22.0 degrees celsius."),
// test tool call aggregate message for immediate tool call execution
new TextMessage(Role.User, "What's the weather in New York?"),
aggregateToolCallMessage,
new TextMessage(Role.Assistant, "The current temperature in New York is 22.0 degrees celsius."),
new TextMessage(Role.User, "What's the weather in Paris?"),
};
var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory, [getWeatherTool]);
Approvals.Verify(prompt);
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("Approvals")]
public void TokenizerTest()
{
var modelWeightFolder = "Mistral";
var tokenizer = MistralTokenizerHelper.FromPretrained(modelWeightFolder);
var messages = new string[]
{
// system : You are a helpful assistant that can answer questions about the weather.
// tool: [get-weather-tool-call]
// user : What's the weather like in Paris?
// assistant: // get-weather-tool-call
// tool: get-weather-tool-call-result
// assistant: The current temperature in Paris is 22.0 degrees celsius.
"""
<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}]</s>[TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius.</s>
"""
};
var sb = new StringBuilder();
foreach (var message in messages)
{
var tokenizeIds = tokenizer.EncodeToIds(message, false, false);
var decodeToString = tokenizer.Decode(tokenizeIds, considerSpecialTokens: true);
sb.AppendLine(decodeToString);
var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString()));
sb.AppendLine(tokenizedStr);
}
Approvals.Verify(sb.ToString());
}
}