diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index c55f5797f..00635886a 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -188,7 +188,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core.Tes
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA", "src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj", "{0AA6D5CB-195F-457A-8792-4221E76E6C44}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral", "src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj", "{2729CC66-7743-442B-B3A5-1F4F27F044A5}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Mistral.Tests", "test\Microsoft.ML.GenAI.Mistral.Tests\Microsoft.ML.GenAI.Mistral.Tests.csproj", "{49264202-C90A-43F6-8C30-BDAEF2F1465A}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -898,6 +902,22 @@ Global
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.Build.0 = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.ActiveCfg = Release|Any CPU
{D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.Build.0 = Release|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Debug|x64.Build.0 = Debug|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.ActiveCfg = Release|Any CPU
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5}.Release|x64.Build.0 = Release|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Debug|x64.Build.0 = Debug|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|Any CPU.Build.0 = Release|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.ActiveCfg = Release|Any CPU
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -991,6 +1011,8 @@ Global
{14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{0AA6D5CB-195F-457A-8792-4221E76E6C44} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D202353D-6FAF-4263-9A01-BDCFBC92391F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {2729CC66-7743-442B-B3A5-1F4F27F044A5} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {49264202-C90A-43F6-8C30-BDAEF2F1465A} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
index d9932106d..792391a59 100644
--- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
@@ -5,17 +5,20 @@
net8.0
enable
enable
+ true
+
+
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs
new file mode 100644
index 000000000..25580090f
--- /dev/null
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Mistral/Mistral_7B_Instruct.cs
@@ -0,0 +1,156 @@
+using System.Text.Json;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Mistral;
+using Microsoft.ML.GenAI.Mistral.Module;
+using Microsoft.ML.Tokenizers;
+using TorchSharp;
+using TorchSharp.PyBridge;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Samples.Mistral;
+
+public partial class Mistral_7B_Instruct
+{
+ private static Mistral_7B_Instruct instance = new Mistral_7B_Instruct();
+
+ ///
+ /// get weather from city
+ ///
+ ///
+ [Function]
+ public Task GetWeather(string city)
+ {
+ return Task.FromResult($"The weather in {city} is sunny.");
+ }
+
+ public static async Task RunAsync()
+ {
+ var device = "cuda";
+ if (device == "cuda")
+ {
+ torch.InitializeDeviceType(DeviceType.CUDA);
+ }
+
+ var defaultType = ScalarType.BFloat16;
+ torch.manual_seed(1);
+ torch.set_default_dtype(defaultType);
+ var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
+ var configName = "config.json";
+ var originalWeightFolder = Path.Combine(weightFolder);
+
+ Console.WriteLine("Loading Mistral from huggingface model weight folder");
+ var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
+ var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);
+
+ var pipeline = new CausalLMPipeline(tokenizer, model, device);
+
+ var agent = new MistralCausalLMAgent(pipeline, "assistant")
+ .RegisterPrintMessage();
+
+ var task = """
+ How are you.
+ """;
+
+ await agent.SendAsync(task);
+ }
+
+ public static void Embedding()
+ {
+ var device = "cuda";
+ if (device == "cuda")
+ {
+ torch.InitializeDeviceType(DeviceType.CUDA);
+ }
+
+ var defaultType = ScalarType.Float32;
+ torch.manual_seed(1);
+ torch.set_default_dtype(defaultType);
+ var weightFolder = @"C:\Users\xiaoyuz\source\repos\bge-en-icl";
+ var configName = "config.json";
+ var originalWeightFolder = Path.Combine(weightFolder);
+
+ Console.WriteLine("Loading Mistral from huggingface model weight folder");
+ var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder, modelName: "tokenizer.model");
+
+ var mistralConfig = JsonSerializer.Deserialize(File.ReadAllText(Path.Combine(weightFolder, configName))) ?? throw new ArgumentNullException(nameof(configName));
+ var model = new MistralModel(mistralConfig);
+ model.load_checkpoint(weightFolder, "model.safetensors.index.json", strict: true, useTqdm: false);
+ model.to(device);
+
+ var pipeline = new CausalLMPipeline(tokenizer, model, device);
+
+ var query = """
+ Given a web search query, retrieve relevant passages that answer the query.
+ what is a virtual interface
+ A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.
+
+ Given a web search query, retrieve relevant passages that answer the query.
+ causes of back pain in female for a week
+ Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management.
+
+ Given a web search query, retrieve relevant passages that answer the query.
+ how much protein should a female eat
+
+ """;
+
+ var document = """
+ As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.
+ """;
+ var queryEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(query);
+ var documentEmbedding = pipeline.GenerateEmbeddingFromLastTokenPool(document);
+
+ var score = 0f;
+ foreach (var (q, d) in queryEmbedding.Zip(documentEmbedding))
+ {
+ score += q * d * 100;
+ }
+
+ Console.WriteLine($"The similarity score between query and document is {score}");
+ }
+
+ public static async Task WeatherChatAsync()
+ {
+ var device = "cuda";
+ if (device == "cuda")
+ {
+ torch.InitializeDeviceType(DeviceType.CUDA);
+ }
+
+ var defaultType = ScalarType.BFloat16;
+ torch.manual_seed(1);
+ torch.set_default_dtype(defaultType);
+ var weightFolder = @"C:\Users\xiaoyuz\source\repos\Mistral-7B-Instruct-v0.3";
+ var configName = "config.json";
+ var originalWeightFolder = Path.Combine(weightFolder);
+
+ Console.WriteLine("Loading Mistral from huggingface model weight folder");
+ var tokenizer = MistralTokenizerHelper.FromPretrained(originalWeightFolder);
+ var model = MistralForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);
+
+ var pipeline = new CausalLMPipeline(tokenizer, model, device);
+
+ var weatherChatMiddleware = new FunctionCallMiddleware(
+ functions: [instance.GetWeatherFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { instance.GetWeatherFunctionContract.Name!, instance.GetWeatherWrapper }
+ });
+
+ var agent = new MistralCausalLMAgent(pipeline, "assistant")
+ .RegisterStreamingMiddleware(weatherChatMiddleware)
+ .RegisterPrintMessage();
+
+ var task = "what is the weather in Seattle";
+ var userMessage = new TextMessage(Role.User, task);
+
+ var reply = await agent.GenerateReplyAsync(messages: [userMessage],
+ new GenerateReplyOptions
+ {
+ Temperature = 0f,
+ });
+
+ // generate further reply using tool call result;
+ await agent.SendAsync(chatHistory: [userMessage, reply]);
+ }
+}
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
index 5e4355e59..cf166c755 100644
--- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
@@ -1,4 +1,5 @@
// See https://aka.ms/new-console-template for more information
+using Microsoft.ML.GenAI.Samples.Mistral;
using Microsoft.ML.GenAI.Samples.Phi3Mini;
-await AutoGenSample.RunAsync();
+await Mistral_7B_Instruct.WeatherChatAsync();
diff --git a/eng/Versions.props b/eng/Versions.props
index 2510aa58f..fde5bdbeb 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -68,7 +68,7 @@
2
2.3.1
1.4.1
- 0.0.15
+ 0.1.0
1.15.0
0.102.7
2.2.1.1
@@ -96,7 +96,7 @@
0.0.13-test
0.0.6-test
0.0.7-test
- 2.0.0-beta.24415.1
+ 2.0.0-beta.24455.2
4.8.6
1.0.118
1.6.24
diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
index 8745b81c6..64087de17 100644
--- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
+++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
@@ -20,9 +20,11 @@
+
-
+
+
diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
index 33e0bab19..c36837833 100644
--- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
+++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
@@ -266,8 +266,18 @@ public class CausalLMPipeline : ICausalLMPipeline
foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen))
{
var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray();
- var duplicateTokenString = this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids");
- var tokenString = this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids");
+ var duplicateTokenString = this.Tokenizer switch
+ {
+ SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ _ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ };
+
+ var tokenString = this.Tokenizer switch
+ {
+ SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ _ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ };
+
// replace the first occurrence of the token with the duplicate token
tokenString = duplicateTokenString.Substring(tokenString.Length);
@@ -294,7 +304,10 @@ public class CausalLMPipeline : ICausalLMPipeline
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
- var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
+ var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0)
+ {
+ OverrideCache = new DynamicKVCache(),
+ };
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];
diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
index a0720694c..4cf5a00ab 100644
--- a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
+++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
@@ -19,7 +19,7 @@ public interface ISemanticKernelChatTemplateBuilder
public interface IAutoGenChatTemplateBuilder
{
- string BuildPrompt(IEnumerable messages);
+ string BuildPrompt(IEnumerable messages, IEnumerable? tools = null);
}
public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
diff --git a/src/Microsoft.ML.GenAI.Core/Utils.cs b/src/Microsoft.ML.GenAI.Core/Utils.cs
index e4e1078d2..dccabad65 100644
--- a/src/Microsoft.ML.GenAI.Core/Utils.cs
+++ b/src/Microsoft.ML.GenAI.Core/Utils.cs
@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
@@ -161,4 +162,18 @@ public static class Utils
.reshape(batchSize, nKVHeads * nRep, seqLen, headDim);
}
+ internal static string GetEmbeddedResource(string resourceName)
+ {
+ // read file content from embedded resource
+ var assembly = Assembly.GetCallingAssembly();
+ var resourceStream = assembly.GetManifestResourceStream(resourceName);
+
+ if (resourceStream == null)
+ {
+ throw new ArgumentException("Resource not found", resourceName);
+ }
+
+ using var reader = new System.IO.StreamReader(resourceStream);
+ return reader.ReadToEnd();
+ }
}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
index b96dee6db..29e7fb1da 100644
--- a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
+++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
@@ -15,7 +15,7 @@ public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
{
private const char Newline = '\n';
- public string BuildPrompt(IEnumerable messages)
+ public string BuildPrompt(IEnumerable messages, IEnumerable? tools = null)
{
var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
if (messages.Any(m => m.GetContent() is null))
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
index 5deabd6df..d6593f445 100644
--- a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
@@ -60,7 +60,7 @@ public class LlamaCausalLMAgent : IStreamingAgent
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable messages,
GenerateReplyOptions? options = null,
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
index 1ba7820a9..ec6512833 100644
--- a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
+++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
@@ -2,13 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
-using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Utils.cs b/src/Microsoft.ML.GenAI.LLaMA/Utils.cs
deleted file mode 100644
index 622aba9ff..000000000
--- a/src/Microsoft.ML.GenAI.LLaMA/Utils.cs
+++ /dev/null
@@ -1,27 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Reflection;
-using TorchSharp;
-using static TorchSharp.torch;
-
-namespace Microsoft.ML.GenAI.LLaMA;
-
-internal static class Utils
-{
- public static string GetEmbeddedResource(string resourceName)
- {
- // read file content from embedded resource
- var assembly = Assembly.GetExecutingAssembly();
- var resourceStream = assembly.GetManifestResourceStream(resourceName);
-
- if (resourceStream == null)
- {
- throw new ArgumentException("Resource not found", resourceName);
- }
-
- using var reader = new System.IO.StreamReader(resourceStream);
- return reader.ReadToEnd();
- }
-}
diff --git a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj
new file mode 100644
index 000000000..896f47e5b
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj
@@ -0,0 +1,25 @@
+
+
+
+ net6.0;net8.0
+ enable
+ enable
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs
new file mode 100644
index 000000000..e20d3b860
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralCausalLMAgent.cs
@@ -0,0 +1,166 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.Tokenizers;
+
+namespace Microsoft.ML.GenAI.Mistral;
+
+public class MistralCausalLMAgent : IStreamingAgent
+{
+ private readonly ICausalLMPipeline _pipeline;
+ private readonly string? _systemMessage;
+ private readonly IAutoGenChatTemplateBuilder _templateBuilder;
+ private readonly string _stopSequence = "";
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// pipeline
+ /// agent name
+ /// system message.
+ /// the template builder to build chat prompt. If the value is null, would be used.
+ public MistralCausalLMAgent(
+ ICausalLMPipeline pipeline,
+ string name,
+ string? systemMessage = "you are a helpful assistant",
+ IAutoGenChatTemplateBuilder? templateBuilder = null)
+ {
+ this.Name = name;
+ this._pipeline = pipeline;
+ this._systemMessage = systemMessage;
+ this._templateBuilder = templateBuilder ?? Mistral_7B_0_3ChatTemplateBuilder.Instance;
+ }
+
+ public string Name { get; }
+
+ public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (_systemMessage != null)
+ {
+ var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
+ messages = messages.Prepend(systemMessage);
+ }
+ var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
+ var maxLen = options?.MaxToken ?? 1024;
+ var temperature = options?.Temperature ?? 0.7f;
+ var stopTokenSequence = options?.StopSequence ?? [];
+ stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
+
+ var output = _pipeline.Generate(
+ input,
+ maxLen: maxLen,
+ temperature: temperature,
+ stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply.");
+
+ // post-process the output for tool call
+ if (output.StartsWith("[TOOL_CALLS]"))
+ {
+ return Task.FromResult(ParseAsToolCallMessage(output));
+ }
+
+ return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name));
+ }
+
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
+#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ if (_systemMessage != null)
+ {
+ var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
+ messages = messages.Prepend(systemMessage);
+ }
+ var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
+ var maxLen = options?.MaxToken ?? 1024;
+ var temperature = options?.Temperature ?? 0.7f;
+ var stopTokenSequence = options?.StopSequence ?? [];
+ stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
+
+ // only streaming the output when the output is not a tool call
+ // otherwise, we collect all the chunks and convert them to a tool call message at the end of the streaming
+ var sb = new StringBuilder();
+ bool? isToolCall = null;
+ foreach (var output in _pipeline.GenerateStreaming(
+ input,
+ maxLen: maxLen,
+ temperature: temperature,
+ stopSequences: stopTokenSequence))
+ {
+ if (isToolCall is null)
+ {
+ sb.Append(output);
+ var str = sb.ToString();
+ if (!str.StartsWith("[TOOL_CALLS]".Substring(0, str.Length)))
+ {
+ yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
+ isToolCall = false;
+ }
+ else if (str.StartsWith("[TOOL_CALLS]"))
+ {
+ isToolCall = true;
+ }
+ }
+ else if (isToolCall == false)
+ {
+ yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
+ }
+ else
+ {
+ sb.Append(output);
+ }
+ }
+
+ if (isToolCall == true)
+ {
+ var toolCallMessage = ParseAsToolCallMessage(sb.ToString());
+ foreach (var toolCall in toolCallMessage.ToolCalls)
+ {
+ yield return new ToolCallMessageUpdate(toolCall.FunctionName, toolCall.FunctionArguments, from: this.Name);
+ }
+ }
+ }
+
+ private class MistralToolCall
+ {
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("arguments")]
+ public JsonObject? Arguments { get; set; }
+ }
+
+ private ToolCallMessage ParseAsToolCallMessage(string content)
+ {
+ var json = content.Substring("[TOOL_CALLS]".Length).Trim();
+
+ // the json string should be a list of tool call messages
+ // e.g. [{"name": "get_current_weather", "parameters": {"location": "Seattle"}}]
+ var mistralToolCalls = JsonSerializer.Deserialize>(json) ?? throw new InvalidOperationException("Failed to deserialize tool calls.");
+ var toolCalls = mistralToolCalls
+ .Select(tc => new ToolCall(tc.Name!, JsonSerializer.Serialize(tc.Arguments)) { ToolCallId = this.GenerateToolCallId() });
+
+ return new ToolCallMessage(toolCalls, from: this.Name);
+ }
+
+ ///
+ /// 9 random alphanumeric characters
+ ///
+ private string GenerateToolCallId(int length = 9)
+ {
+ const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
+ var random = new Random();
+ return new string(Enumerable.Repeat(chars, length)
+ .Select(s => s[random.Next(s.Length)]).ToArray());
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs
new file mode 100644
index 000000000..c2240f957
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralConfig.cs
@@ -0,0 +1,112 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using Microsoft.ML.GenAI.Core;
+using TorchSharp;
+
+namespace Microsoft.ML.GenAI.Mistral;
+
+public class MistralConfig
+{
+ public MistralConfig()
+ {
+ this.AttentionBias = false;
+ this.AttentionDropout = 0.0;
+ this.HiddenAct = "silu";
+ this.HiddenSize = 4096;
+ this.InitializerRange = 0.02;
+ this.IntermediateSize = 14336;
+ this.MaxPositionEmbeddings = 131072;
+ this.MlpBias = false;
+ this.NumAttentionHeads = 32;
+ this.NumHiddenLayers = 32;
+ this.NumKeyValueHeads = 8;
+ this.RmsNormEps = 1e-05f;
+ this.RopeScaling = new RopeScalingConfig();
+ this.RopeTheta = 500000.0;
+ this.TieWordEmbeddings = false;
+ this.VocabSize = 128256;
+ this.AttnImplementation = "eager";
+ this.DType = torch.ScalarType.BFloat16;
+ this.HeadDim = this.HiddenSize / this.NumAttentionHeads;
+ this.SlidingWindow ??= 4096;
+ }
+
+ static MistralConfig()
+ {
+#pragma warning disable MSML_ParameterLocalVarName // Parameter or local variable name not standard
+ var mistral7BInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Mistral.Resource.Config.mistral-7B-instruct-v0.3.json");
+#pragma warning restore MSML_ParameterLocalVarName // Parameter or local variable name not standard
+
+ Mistral_7B_Instruct_v0_3 = JsonSerializer.Deserialize(mistral7BInstructContent) ?? throw new ArgumentNullException(nameof(mistral7BInstructContent));
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ ///
+ /// The mistral-7b-instruct-v0.3 configuration created from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/tree/main.
+ ///
+ public static MistralConfig Mistral_7B_Instruct_v0_3 { get; }
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+
+ [JsonPropertyName("attention_bias")]
+ public bool AttentionBias { get; set; }
+
+ [JsonPropertyName("attention_dropout")]
+ public double AttentionDropout { get; set; }
+
+ [JsonPropertyName("hidden_act")]
+ public string HiddenAct { get; set; }
+
+ [JsonPropertyName("hidden_size")]
+ public int HiddenSize { get; set; }
+
+ [JsonPropertyName("initializer_range")]
+ public double InitializerRange { get; set; }
+
+ [JsonPropertyName("intermediate_size")]
+ public int IntermediateSize { get; set; }
+
+ [JsonPropertyName("max_position_embeddings")]
+ public int MaxPositionEmbeddings { get; set; }
+
+ [JsonPropertyName("mlp_bias")]
+ public bool MlpBias { get; set; }
+
+ [JsonPropertyName("num_attention_heads")]
+ public int NumAttentionHeads { get; set; }
+
+ [JsonPropertyName("num_hidden_layers")]
+ public int NumHiddenLayers { get; set; }
+
+ [JsonPropertyName("num_key_value_heads")]
+ public int NumKeyValueHeads { get; set; }
+
+ [JsonPropertyName("head_dim")]
+ public int HeadDim { get; set; }
+
+ [JsonPropertyName("rms_norm_eps")]
+ public float RmsNormEps { get; set; }
+
+ public RopeScalingConfig RopeScaling { get; set; }
+
+ [JsonPropertyName("rope_theta")]
+ public double RopeTheta { get; set; }
+
+ [JsonPropertyName("tie_word_embeddings")]
+ public bool TieWordEmbeddings { get; set; }
+
+ [JsonPropertyName("vocab_size")]
+ public int VocabSize { get; set; }
+
+ [JsonPropertyName("sliding_window")]
+ public int? SlidingWindow { get; set; }
+
+ public int? PadTokenId { get; set; }
+
+ public torch.ScalarType DType { get; set; }
+
+ public string AttnImplementation { get; set; }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs b/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs
new file mode 100644
index 000000000..7f17991b5
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralDecoderLayer.cs
@@ -0,0 +1,148 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.GenAI.Core;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Mistral.Module;
+
+internal class DecoderLayerInput
+{
+ public DecoderLayerInput(
+ Tensor hiddenStates,
+ Tensor attentionMask,
+ Tensor positionIds,
+ RotaryEmbeddingOutput positionEmbeddings, // cos, sin
+ IKVCache? pastKeyValue = null,
+ bool outputAttentions = false)
+ {
+ this.HiddenStates = hiddenStates;
+ this.AttentionMask = attentionMask;
+ this.PositionIds = positionIds;
+ this.PastKeyValue = pastKeyValue;
+ this.OutputAttentions = outputAttentions;
+ this.PositionalEmbeddings = positionEmbeddings;
+ }
+
+ public Tensor HiddenStates { get; set; }
+
+ public Tensor AttentionMask { get; set; }
+
+ public Tensor PositionIds { get; set; }
+
+ public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
+
+ public IKVCache? PastKeyValue { get; set; }
+
+ public bool OutputAttentions { get; set; }
+}
+
+internal class DecoderLayerOutput
+{
+ public DecoderLayerOutput(
+ Tensor hiddenStates,
+ Tensor? attentions = null,
+ IKVCache? pastKeyValue = null)
+ {
+ this.HiddenStates = hiddenStates;
+ this.Attentions = attentions;
+ this.PastKeyValue = pastKeyValue;
+ }
+
+ public Tensor HiddenStates { get; set; }
+
+ public Tensor? Attentions { get; set; }
+
+ public IKVCache? PastKeyValue { get; set; }
+}
+internal class MistralDecoderLayer : nn.Module, IDynamicLoadModule
+{
+ private readonly MistralConfig _llamaConfig;
+ private readonly int _layerIndex;
+ private readonly int _hiddenSize;
+
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly MistralMLP mlp;
+ private readonly Core.RMSNorm input_layernorm;
+ private readonly Core.RMSNorm post_attention_layernorm;
+ private readonly Attention self_attn;
+
+ public Action? LoadToDeviceFunc { get; set; }
+ public Action? UnloadFromDeviceFunc { get; set; }
+
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public MistralDecoderLayer(MistralConfig config, int layerIndex)
+ : base(nameof(MistralDecoderLayer))
+ {
+ _llamaConfig = config;
+ _layerIndex = layerIndex;
+ _hiddenSize = config.HiddenSize;
+
+ this.self_attn = CreateAttention(config, layerIndex);
+ this.mlp = new MistralMLP(config);
+ this.input_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
+ this.post_attention_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
+ }
+
+ private Attention CreateAttention(MistralConfig config, int layerIndex)
+ {
+ var headDim = config.HiddenSize / config.NumAttentionHeads;
+ return new Attention(
+ attentionDropout: config.AttentionDropout,
+ hiddenSize: config.HiddenSize,
+ numHeads: config.NumAttentionHeads,
+ headDim: headDim,
+ numKeyValueHeads: config.NumKeyValueHeads,
+ numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads,
+ maxPositionEmbeddings: config.MaxPositionEmbeddings,
+ originalMaxPositionEmbeddings: config.MaxPositionEmbeddings,
+ layerIdx: layerIndex,
+ useQkvProj: false,
+ dtype: config.DType,
+ attentionBias: config.AttentionBias);
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override DecoderLayerOutput forward(DecoderLayerInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ if (LoadToDeviceFunc != null)
+ {
+ LoadToDeviceFunc(this);
+ }
+
+ using var disposeScope = NewDisposeScope();
+ var residual = input.HiddenStates;
+ var hiddenStates = this.input_layernorm.forward(input.HiddenStates);
+
+ var selfAttnInput = new AttentionInput(
+ hiddenStates: hiddenStates,
+ attentionMask: input.AttentionMask,
+ positionIds: input.PositionIds,
+ cache: input.PastKeyValue,
+ positionalEmbeddings: input.PositionalEmbeddings,
+ outputAttentions: input.OutputAttentions);
+
+ var selfAttnOutput = this.self_attn.forward(selfAttnInput);
+
+ hiddenStates = residual + selfAttnOutput.HiddenStates;
+
+ // Fully connected
+ residual = hiddenStates;
+ hiddenStates = this.post_attention_layernorm.forward(hiddenStates);
+ hiddenStates = this.mlp.forward(hiddenStates);
+ hiddenStates = residual + hiddenStates;
+
+ if (UnloadFromDeviceFunc != null)
+ {
+ UnloadFromDeviceFunc(this);
+ }
+
+ return new DecoderLayerOutput(
+ hiddenStates: hiddenStates.MoveToOuterDisposeScope(),
+ attentions: input.OutputAttentions ? selfAttnOutput.Attentions?.MoveToOuterDisposeScope() : null,
+ pastKeyValue: selfAttnOutput.Cache);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs
new file mode 100644
index 000000000..18d43e531
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs
@@ -0,0 +1,130 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Diagnostics;
+using System.Text.Json;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using Microsoft.ML.GenAI.Mistral.Module;
+using TorchSharp;
+using TorchSharp.PyBridge;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Mistral;
+
+public class MistralForCausalLM : nn.Module
+{
+ private readonly MistralConfig _config;
+ private readonly int _vocabSize;
+
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly GenAILinear lm_head;
+ private readonly MistralModel model;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public MistralForCausalLM(MistralConfig config)
+ : base(nameof(MistralForCausalLM))
+ {
+ _config = config;
+ _vocabSize = config.VocabSize;
+
+ model = new MistralModel(config);
+ lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, hasBias: false);
+
+ this.RegisterComponents();
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ var outputs = this.model.forward(input);
+ var logits = this.lm_head.forward(outputs.LastHiddenState);
+ logits = logits.to_type(ScalarType.Float32);
+ outputs.Logits = logits;
+
+ return outputs;
+ }
+
+ public static MistralForCausalLM FromPretrained(
+ string modelFolder,
+ string configName = "config.json",
+ string checkPointName = "model.safetensors.index.json",
+ ScalarType torchDtype = ScalarType.BFloat16,
+ string device = "cpu")
+ {
+ var config = Path.Join(modelFolder, configName);
+ var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
+ modelConfig.DType = torchDtype;
+ var model = new MistralForCausalLM(modelConfig);
+
+ model.LoadSafeTensors(modelFolder, checkPointName);
+ model = model.to(device);
+
+ return model;
+ }
+
+ public static MistralForCausalLM FromPretrained(
+ string modelFolder,
+ string configName = "config.json",
+ string checkPointName = "model.safetensors.index.json",
+ bool quantizeToInt8 = false,
+ bool quantizeToInt4 = false,
+ int layersOnTargetDevice = -1,
+ ScalarType torchDtype = ScalarType.BFloat16,
+ string targetDevice = "cuda")
+ {
+ if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
+ {
+ return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
+ }
+
+ var originalDefaultDevice = torch.get_default_device();
+ torch.set_default_device("meta");
+ var config = Path.Join(modelFolder, configName);
+ var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
+ modelConfig.DType = torchDtype;
+ var model = new MistralForCausalLM(modelConfig);
+
+ if (quantizeToInt8)
+ {
+ model.ToInt8QuantizeModule();
+ }
+ else if (quantizeToInt4)
+ {
+ model.ToInt4QuantizeModule();
+ }
+
+ var deviceMap = model.InferDeviceMapForEachLayer(
+ [
+ KeyValuePair.Create(targetDevice, layersOnTargetDevice),
+ KeyValuePair.Create("cpu", -1)
+ ]);
+
+ torch.set_default_device("cpu");
+ model = new MistralForCausalLM(modelConfig);
+
+ model.LoadSafeTensors(modelFolder, checkPointName);
+
+ if (quantizeToInt8)
+ {
+ model.ToInt8QuantizeModule();
+ }
+ else if (quantizeToInt4)
+ {
+ model.ToInt4QuantizeModule();
+ }
+
+ model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
+
+ torch.set_default_device(originalDefaultDevice);
+
+ return model;
+ }
+
+ public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
+ {
+ this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs b/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs
new file mode 100644
index 000000000..347ee625e
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralMLP.cs
@@ -0,0 +1,45 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.GenAI.Core;
+using TorchSharp;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Mistral.Module;
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+internal class MistralMLP : torch.nn.Module
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+{
+ private readonly int _intermediateSize;
+ private readonly int _hiddenSize;
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly QuantizedLinear gate_proj;
+ private readonly QuantizedLinear up_proj;
+ private readonly QuantizedLinear down_proj;
+ private readonly torch.nn.Module act_fn;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public MistralMLP(MistralConfig config)
+ : base(nameof(MistralMLP))
+ {
+ this._hiddenSize = config.HiddenSize;
+ this._intermediateSize = config.IntermediateSize;
+ var hiddenAct = config.HiddenAct;
+ this.gate_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType);
+ this.up_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: false, dtype: config.DType);
+ this.down_proj = new QuantizedLinear(this._intermediateSize, this._hiddenSize, hasBias: false, dtype: config.DType);
+ this.RegisterComponents();
+ this.act_fn = Core.Utils.GetActivation(hiddenAct);
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override Tensor forward(Tensor input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ using var input1 = this.gate_proj.forward(input);
+ using var input2 = this.act_fn.forward(input1);
+ using var input3 = input2 * this.up_proj.forward(input);
+ return this.down_proj.forward(input3);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs
new file mode 100644
index 000000000..cab7e6cc5
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralModel.cs
@@ -0,0 +1,148 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.GenAI.Core;
+using TorchSharp;
+using TorchSharp.Modules;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Mistral.Module;
+
+public class MistralModel : nn.Module
+{
+ private readonly MistralConfig _config;
+ private readonly int? _paddingIdx;
+ private readonly int _vocabSize;
+ private IKVCache _cache;
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly Embedding embed_tokens;
+ private readonly ModuleList layers;
+ private readonly RMSNorm norm;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly nn.Module _rotaryEmb;
+
+
+ public MistralModel(MistralConfig config)
+ : base(nameof(MistralModel))
+ {
+ this._config = config;
+ this._paddingIdx = config.PadTokenId;
+ this._vocabSize = config.VocabSize;
+ var headDim = config.HeadDim;
+ this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType);
+ this.layers = new ModuleList();
+
+ for (int i = 0; i < config.NumHiddenLayers; i++)
+ {
+ this.layers.Add(new MistralDecoderLayer(config, i));
+ }
+ this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
+ this._cache = new DynamicKVCache();
+ this.RegisterComponents();
+ this._rotaryEmb = config.RopeScaling switch
+ {
+ null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim),
+ _ => new RotaryEmbedding(config.RopeTheta, headDim, config.RopeScaling),
+ };
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ if (input.OverrideCache is not null)
+ {
+ this._cache = input.OverrideCache;
+ }
+
+ var outputAttentions = input.OutputAttentions;
+ var outputHiddenStates = input.OutputHiddenStates;
+ var attentionMask = input.AttentionMask;
+ Device device;
+ var inputIds = input.InputIds;
+ var positionIds = input.PositionIds;
+ var inputsEmbeds = input.InputEmbeddings;
+ int batchSize;
+ int seqLength;
+ if (inputIds is not null && inputsEmbeds is not null)
+ {
+ throw new ArgumentException("Only one of input_ids or inputs_embeds may be set");
+ }
+ else if (inputIds is not null)
+ {
+ batchSize = inputIds.IntShape()[0];
+ seqLength = inputIds.IntShape()[1];
+ inputsEmbeds = this.embed_tokens.forward(inputIds);
+ device = inputIds.device;
+ }
+ else if (inputsEmbeds is not null)
+ {
+ batchSize = inputsEmbeds.IntShape()[0];
+ seqLength = inputsEmbeds.IntShape()[1];
+ device = inputsEmbeds.device;
+ }
+ else
+ {
+ throw new ArgumentException("Either input_ids or inputs_embeds must be set");
+ }
+
+ var pastKeyValuesLength = input.PastKeyValuesLength;
+
+ if (positionIds is null)
+ {
+ positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device);
+ positionIds = positionIds.unsqueeze(0).view(-1, seqLength);
+ }
+ else
+ {
+ positionIds = ((long)positionIds.view(-1, seqLength));
+ }
+
+ if (this._config.AttnImplementation == "flash_attention_2")
+ {
+ throw new NotImplementedException();
+ }
+ else
+ {
+ // the following behavior of creating 4d causal mask doesn't match python's, remember to look into it when there's time.
+ attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength, slidingWindow: _config.SlidingWindow);
+ }
+
+ var hiddenStates = inputsEmbeds;
+
+ var allHiddenStates = new List();
+ var allAttentions = new List();
+
+ var embOutput = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, pastKeyValuesLength));
+ foreach (var layer in this.layers)
+ {
+ if (outputHiddenStates)
+ {
+ allHiddenStates.Add(hiddenStates);
+ }
+
+ var decoderInput = new DecoderLayerInput(
+ hiddenStates: hiddenStates,
+ attentionMask: attentionMask!,
+ positionIds: positionIds,
+ pastKeyValue: this._cache,
+ positionEmbeddings: embOutput,
+ outputAttentions: outputAttentions);
+ var layerOutput = layer.forward(decoderInput);
+ hiddenStates = layerOutput.HiddenStates;
+ if (outputAttentions && layerOutput.Attentions is not null)
+ {
+ allAttentions.Add(layerOutput.Attentions);
+ }
+ }
+
+ hiddenStates = this.norm.forward(hiddenStates);
+ if (outputHiddenStates)
+ {
+ allHiddenStates.Add(hiddenStates);
+ }
+
+ return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs
new file mode 100644
index 000000000..3ed9a7978
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/MistralTokenizerHelper.cs
@@ -0,0 +1,107 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.Tokenizers;
+
+namespace Microsoft.ML.GenAI.Mistral;
+
+public class MistralTokenizerHelper
+{
+ private const string UnknownSymbol = "";
+ private const int UnknownSymbolId = 0;
+ private const string StartSymbol = "";
+ private const int StartSymbolId = 1;
+ private const string EndSymbol = "";
+ private const int EndSymbolId = 2;
+ private const string StartInstructionSymbol = "[INST]";
+ private const int StartInstructionSymbolId = 3;
+ private const string EndInstructionSymbol = "[/INST]";
+ private const int EndInstructionSymbolId = 4;
+ private const string ToolCallSymbol = "[TOOL_CALLS]";
+ private const int ToolCallSymbolId = 5;
+ private const string StartAvailableToolsSymbol = "[AVAILABLE_TOOLS]";
+ private const int StartAvailableToolsSymbolId = 6;
+ private const string EndAvailableToolsSymbol = "[/AVAILABLE_TOOLS]";
+ private const int EndAvailableToolsSymbolId = 7;
+ private const string StartToolResultSymbol = "[TOOL_RESULTS]";
+ private const int StartToolResultSymbolId = 8;
+ private const string EndToolResultSymbol = "[/TOOL_RESULTS]";
+ private const int EndToolResultSymbolId = 9;
+
+ public static LlamaTokenizer FromPretrained(
+ string modelWeightFolder,
+ string modelName = "tokenizer.model.v3",
+ string unknownSymbol = UnknownSymbol,
+ int unknownSymbolId = 0,
+ string startSymbol = StartSymbol,
+ int startSymbolId = 1,
+ string endSymbol = EndSymbol,
+ int endSymbolId = 2,
+ string startInstructionSymbol = StartInstructionSymbol,
+ int startInstructionSymbolId = 3,
+ string endInstructionSymbol = EndInstructionSymbol,
+ int endInstructionSymbolId = 4,
+ string toolCallSymbol = ToolCallSymbol,
+ int toolCallSymbolId = 5,
+ string startAvailableToolsSymbol = StartAvailableToolsSymbol,
+ int startAvailableToolsSymbolId = 6,
+ string endAvailableToolsSymbol = EndAvailableToolsSymbol,
+ int endAvailableToolsSymbolId = 7,
+ string startToolResultSymbol = StartToolResultSymbol,
+ int startToolResultSymbolId = 8,
+ string endToolResultSymbol = EndToolResultSymbol,
+ int endToolResultSymbolId = 9,
+ bool addPrecedingSpace = true,
+ Dictionary? additionalSpecialTokens = null)
+ {
+ var specialTokens = new Dictionary
+ {
+ { startSymbol, startSymbolId },
+ { endSymbol, endSymbolId },
+ { startInstructionSymbol, startInstructionSymbolId },
+ { endInstructionSymbol, endInstructionSymbolId },
+ { toolCallSymbol, toolCallSymbolId },
+ { startAvailableToolsSymbol, startAvailableToolsSymbolId },
+ { endAvailableToolsSymbol, endAvailableToolsSymbolId },
+ { startToolResultSymbol, startToolResultSymbolId },
+ { endToolResultSymbol, endToolResultSymbolId }
+ };
+
+ if (additionalSpecialTokens != null)
+ {
+ foreach (var (key, value) in additionalSpecialTokens)
+ {
+ specialTokens[key] = value;
+ }
+ }
+
+ return FromPretrained(
+ modelWeightFolder,
+ modelName,
+ specialTokens,
+ addPrecedingSpace);
+ }
+
+ public static LlamaTokenizer FromPretrained(
+ string modelWeightFolder,
+ string modelName,
+ Dictionary specialTokens,
+ bool addPrecedingSpace = true)
+ {
+ var modelPath = Path.Combine(modelWeightFolder, modelName);
+ var modelStream = File.OpenRead(modelPath);
+
+ var llamaTokenizer = LlamaTokenizer.Create(
+ modelStream,
+ addPrecedingSpace,
+ specialTokens: specialTokens);
+
+ return llamaTokenizer;
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs
new file mode 100644
index 000000000..8852f62da
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/Mistral_7B_0_3ChatTemplateBuilder.cs
@@ -0,0 +1,202 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using AutoGen.Core;
+using Json.Schema;
+using Json.Schema.Generation;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace Microsoft.ML.GenAI.Mistral;
+
+///
+/// the chat template builder for Mistral 7B v0.3
+///
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+public class Mistral_7B_0_3ChatTemplateBuilder : IChatTemplateBuilder
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+{
+ private const string Newline = "\r\n";
+
+ public static Mistral_7B_0_3ChatTemplateBuilder Instance { get; } = new Mistral_7B_0_3ChatTemplateBuilder();
+
+ public string BuildPrompt(IEnumerable messages, IEnumerable? tools = null)
+ {
+ // can only contain at most one system message
+ if (messages.Where(m => m.GetRole() == Role.System).Count() > 1)
+ {
+ throw new InvalidOperationException("Please provide at most one system message.");
+ }
+
+ var systemMessage = messages.FirstOrDefault(m => m.GetRole() == Role.System)?.GetContent();
+
+ // split the messages into two sequences by the last user message
+ // e.g [user, assistant, user, assistant, user] -> [[user, assistant, user, assistant], [user]]
+
+ var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
+ var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
+
+ var sb = new StringBuilder();
+ foreach (var message in firstSequence)
+ {
+ // skip system
+ if (message.GetRole() == Role.System)
+ {
+ continue;
+ }
+
+ var content = message.GetContent()!;
+ sb.Append(message switch
+ {
+ ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
+ ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
+ ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
+ TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
+ TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}",
+ _ => throw new InvalidOperationException("Invalid role.")
+ });
+ }
+
+ // insert [AVAILABLE TOOLS] section if tools are provided
+ if (tools?.Any() == true)
+ {
+ var schemas = tools.Select(t => new
+ {
+ type = "function",
+ function = new
+ {
+ name = t.Name,
+ description = t.Description,
+ parameters = BuildJsonSchemaFromFunctionContract(t)
+ }
+ });
+ var schemaPrompt = JsonSerializer.Serialize(schemas);
+
+ // add a space after the colon in json string so mistral can correctly generate the stop token after [TOOL_CALLS] symbol.
+ // This is probably because in the training data, all the tool call samples are separated by a space after the colon.
+ // e.g. [AVAILABLE_TOOLS][{"type": "function", "function": {....
+ // instead of [AVAILABLE_TOOLS][{"type":"function","function":{....
+ // Therefore when inferencing, we need to add a space after the colon in the json string to match with the training data.
+ schemaPrompt = schemaPrompt.Replace(":", ": ");
+ schemaPrompt = schemaPrompt.Replace(",", ", ");
+ sb.Append($"[AVAILABLE_TOOLS]{schemaPrompt}[/AVAILABLE_TOOLS]");
+ }
+
+ foreach (var message in secondSequence)
+ {
+ var content = message.GetContent()!;
+ sb.Append(message switch
+ {
+ ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
+ ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
+ ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
+ TextMessage when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST]{systemMessage}{Newline}{Newline}{content.Trim()}[/INST]",
+ TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
+ TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}",
+ _ => throw new InvalidOperationException("Invalid role.")
+ });
+ }
+
+ return sb.ToString();
+ }
+
+ public string BuildPrompt(ChatHistory chatHistory)
+ {
+ throw new NotImplementedException();
+ }
+
+ private string BuildFromToolCallMessage(ToolCallMessage message)
+ {
+ var toolCalls = message.ToolCalls;
+ if (toolCalls.Count() == 0)
+ {
+ return string.Empty;
+ }
+ else
+ {
+ var toolCallObjects = toolCalls.Select(tc =>
+ new
+ {
+ name = tc.FunctionName,
+ arguments = JsonObject.Parse(tc.FunctionArguments),
+ id = tc.ToolCallId,
+ }
+ );
+
+ var toolCallJson = JsonSerializer.Serialize(toolCallObjects);
+ return $"[TOOL_CALLS]{toolCallJson}";
+ }
+ }
+
+ private string BuildFromToolCallResultMessage(ToolCallResultMessage message)
+ {
+ var toolCallResults = message.ToolCalls;
+ if (toolCallResults.Count() == 0)
+ {
+ return string.Empty;
+ }
+ else
+ {
+ var toolCallResultObjects = toolCallResults.Select(tc =>
+ new
+ {
+ id = tc.ToolCallId,
+ content = tc.Result,
+ }
+ );
+
+ var toolCallResultJson = JsonSerializer.Serialize(toolCallResultObjects);
+ return $"[TOOL_RESULTS]{toolCallResultJson}[/TOOL_RESULTS]";
+ }
+ }
+
+ private string BuildFromAggregrateToolCallMessage(ToolCallAggregateMessage message)
+ {
+ var toolCallMessage = message.Message1;
+ var toolCallResultMessage = message.Message2;
+
+ var toolCall = BuildFromToolCallMessage(toolCallMessage);
+ var toolCallResult = BuildFromToolCallResultMessage(toolCallResultMessage);
+
+ return $"{toolCall}{toolCallResult}";
+ }
+
+ private JsonSchema BuildJsonSchemaFromFunctionContract(FunctionContract contract)
+ {
+ var requiredParameterNames = new List();
+ var propertiesSchemas = new Dictionary();
+ var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
+ foreach (var param in contract.Parameters ?? [])
+ {
+ if (param.Name is null)
+ {
+ throw new InvalidOperationException("Parameter name cannot be null");
+ }
+
+ var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
+ if (param.Description != null)
+ {
+ schemaBuilder = schemaBuilder.Description(param.Description);
+ }
+
+ if (param.IsRequired)
+ {
+ requiredParameterNames.Add(param.Name);
+ }
+
+ var schema = schemaBuilder.Build();
+ propertiesSchemas[param.Name] = schema;
+
+ }
+ propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
+ propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
+
+ var jsonSchema = propertySchemaBuilder.Build();
+
+ return jsonSchema;
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json b/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json
new file mode 100644
index 000000000..1da2dde41
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Mistral/Resource/Config/mistral-7B-instruct-v0.3.json
@@ -0,0 +1,21 @@
+{
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 32768,
+ "model_type": "mistral",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "rms_norm_eps": 1e-05,
+ "rope_theta": 1000000.0,
+ "sliding_window": null,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "use_cache": true,
+ "vocab_size": 32768
+}
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs
index fdba74ba7..580bde9b1 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Config.cs
@@ -41,7 +41,7 @@ public class Phi2Config
static Phi2Config()
{
- var phi2ConfigContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json");
+ var phi2ConfigContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-2-config.json");
var phi2Config = JsonSerializer.Deserialize(phi2ConfigContent) ?? throw new ArgumentNullException(nameof(phi2ConfigContent));
Phi2 = phi2Config;
}
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs
index def5ab344..0a020d672 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Config.cs
@@ -42,10 +42,10 @@ public class Phi3Config
static Phi3Config()
{
- var phi3Mini4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json");
- var phi3Mini128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json");
- var phi3Medium4kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json");
- var phi3Medium128kInstructContent = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json");
+ var phi3Mini4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-4k-instruct-config.json");
+ var phi3Mini128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-mini-128k-instruct-config.json");
+ var phi3Medium4kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-4k-instruct-config.json");
+ var phi3Medium128kInstructContent = Core.Utils.GetEmbeddedResource("Microsoft.ML.GenAI.Phi.Resource.Config.phi-3-medium-128k-instruct-config.json");
Phi3Mini4kInstruct = JsonSerializer.Deserialize(phi3Mini4kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini4kInstructContent));
Phi3Mini128kInstruct = JsonSerializer.Deserialize(phi3Mini128kInstructContent) ?? throw new ArgumentNullException(nameof(phi3Mini128kInstructContent));
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs
index abe1e9271..2b9e93a4a 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMAgent.cs
@@ -50,7 +50,7 @@ public class Phi3Agent : IStreamingAgent
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IEnumerable messages,
GenerateReplyOptions? options = null,
diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs
index aa5a71719..5be880a06 100644
--- a/src/Microsoft.ML.GenAI.Phi/Utils.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs
@@ -16,49 +16,6 @@ namespace Microsoft.ML.GenAI.Phi;
internal static class Utils
{
- public static string GetEmbeddedResource(string resourceName)
- {
- // read file content from embedded resource
- var assembly = Assembly.GetExecutingAssembly();
- var resourceStream = assembly.GetManifestResourceStream(resourceName);
-
- if (resourceStream == null)
- {
- throw new ArgumentException("Resource not found", nameof(resourceName));
- }
-
- using var reader = new System.IO.StreamReader(resourceStream);
- return reader.ReadToEnd();
- }
-
- public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex)
- {
- // Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
- // Two consecutive values will become a single complex number
- // (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
- var inputComplex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex();
- freqsComplex = freqsComplex.to(input.device);
-
- // Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
- // (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
- var freqsComplexReshaped = freqsComplex.unsqueeze(0).unsqueeze(2);
-
- // Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
- // Which results in the rotation of the complex number as shown in the Figure 1 of the paper
- // (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
- var rotatedComplex = inputComplex * freqsComplexReshaped;
- // Console.WriteLine(rotated_complex.mean().ToSingle());
-
- // Convert the complex number back to the real number
- // (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
- var rotated = rotatedComplex.view_as_real();
-
- // (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
- var rotatedReshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1);
-
- return rotatedReshaped.type_as(input);
- }
-
public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f)
{
// As written in the paragraph 3.2.2 of the paper
@@ -147,21 +104,4 @@ internal static class Utils
.expand(batchSize, seqLen, nKVHeads, nRep, headDim)
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
}
-
- public static Tensor Phi3RepeatKV(Tensor x, int nRep)
- {
- var batchSize = x.shape[0];
- var nKVHeads = x.shape[1];
- var seqLen = x.shape[2];
- var headDim = x.shape[3];
- if (nRep == 1)
- {
- return x;
- }
-
- return x.unsqueeze(3)
- .expand(batchSize, nKVHeads, nRep, seqLen, headDim)
- .view(batchSize, nKVHeads * nRep, seqLen, headDim);
- }
-
}
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt
new file mode 100644
index 000000000..493b07d9e
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt
@@ -0,0 +1,3 @@
+[INST]You are a helpful AI assistant.
+
+Hello?[/INST]World!
\ No newline at end of file
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt
new file mode 100644
index 000000000..4731561ae
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.ItBuildChatTemplateWithToolsFromAutoGenChatHistory.approved.txt
@@ -0,0 +1,3 @@
+[INST]What's the weather in Seattle?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in Seattle is 22.0 degrees celsius.[INST]What's the weather in New York?[/INST][TOOL_CALLS][{"name":"get_current_weather","arguments":{"location":"Seattle, WA"},"id":"9Ae3bDc2F"}][TOOL_RESULTS][{"id":"9Ae3bDc2F","content":"sunny"}][/TOOL_RESULTS]The current temperature in New York is 22.0 degrees celsius.[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]You are a helpful AI assistant.
+
+What's the weather in Paris?[/INST]
\ No newline at end of file
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt
new file mode 100644
index 000000000..4bad35f7d
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.Mistral_7B_Instruct_V0_3_ShapeTest.approved.txt
@@ -0,0 +1,291 @@
+0: lm_head.weight shape: [32768, 4096]
+1: model.embed_tokens.weight shape: [32768, 4096]
+2: model.layers.0.input_layernorm.weight shape: [4096]
+3: model.layers.0.mlp.down_proj.weight shape: [4096, 14336]
+4: model.layers.0.mlp.gate_proj.weight shape: [14336, 4096]
+5: model.layers.0.mlp.up_proj.weight shape: [14336, 4096]
+6: model.layers.0.post_attention_layernorm.weight shape: [4096]
+7: model.layers.0.self_attn.k_proj.weight shape: [1024, 4096]
+8: model.layers.0.self_attn.o_proj.weight shape: [4096, 4096]
+9: model.layers.0.self_attn.q_proj.weight shape: [4096, 4096]
+10: model.layers.0.self_attn.v_proj.weight shape: [1024, 4096]
+11: model.layers.1.input_layernorm.weight shape: [4096]
+12: model.layers.1.mlp.down_proj.weight shape: [4096, 14336]
+13: model.layers.1.mlp.gate_proj.weight shape: [14336, 4096]
+14: model.layers.1.mlp.up_proj.weight shape: [14336, 4096]
+15: model.layers.1.post_attention_layernorm.weight shape: [4096]
+16: model.layers.1.self_attn.k_proj.weight shape: [1024, 4096]
+17: model.layers.1.self_attn.o_proj.weight shape: [4096, 4096]
+18: model.layers.1.self_attn.q_proj.weight shape: [4096, 4096]
+19: model.layers.1.self_attn.v_proj.weight shape: [1024, 4096]
+20: model.layers.10.input_layernorm.weight shape: [4096]
+21: model.layers.10.mlp.down_proj.weight shape: [4096, 14336]
+22: model.layers.10.mlp.gate_proj.weight shape: [14336, 4096]
+23: model.layers.10.mlp.up_proj.weight shape: [14336, 4096]
+24: model.layers.10.post_attention_layernorm.weight shape: [4096]
+25: model.layers.10.self_attn.k_proj.weight shape: [1024, 4096]
+26: model.layers.10.self_attn.o_proj.weight shape: [4096, 4096]
+27: model.layers.10.self_attn.q_proj.weight shape: [4096, 4096]
+28: model.layers.10.self_attn.v_proj.weight shape: [1024, 4096]
+29: model.layers.11.input_layernorm.weight shape: [4096]
+30: model.layers.11.mlp.down_proj.weight shape: [4096, 14336]
+31: model.layers.11.mlp.gate_proj.weight shape: [14336, 4096]
+32: model.layers.11.mlp.up_proj.weight shape: [14336, 4096]
+33: model.layers.11.post_attention_layernorm.weight shape: [4096]
+34: model.layers.11.self_attn.k_proj.weight shape: [1024, 4096]
+35: model.layers.11.self_attn.o_proj.weight shape: [4096, 4096]
+36: model.layers.11.self_attn.q_proj.weight shape: [4096, 4096]
+37: model.layers.11.self_attn.v_proj.weight shape: [1024, 4096]
+38: model.layers.12.input_layernorm.weight shape: [4096]
+39: model.layers.12.mlp.down_proj.weight shape: [4096, 14336]
+40: model.layers.12.mlp.gate_proj.weight shape: [14336, 4096]
+41: model.layers.12.mlp.up_proj.weight shape: [14336, 4096]
+42: model.layers.12.post_attention_layernorm.weight shape: [4096]
+43: model.layers.12.self_attn.k_proj.weight shape: [1024, 4096]
+44: model.layers.12.self_attn.o_proj.weight shape: [4096, 4096]
+45: model.layers.12.self_attn.q_proj.weight shape: [4096, 4096]
+46: model.layers.12.self_attn.v_proj.weight shape: [1024, 4096]
+47: model.layers.13.input_layernorm.weight shape: [4096]
+48: model.layers.13.mlp.down_proj.weight shape: [4096, 14336]
+49: model.layers.13.mlp.gate_proj.weight shape: [14336, 4096]
+50: model.layers.13.mlp.up_proj.weight shape: [14336, 4096]
+51: model.layers.13.post_attention_layernorm.weight shape: [4096]
+52: model.layers.13.self_attn.k_proj.weight shape: [1024, 4096]
+53: model.layers.13.self_attn.o_proj.weight shape: [4096, 4096]
+54: model.layers.13.self_attn.q_proj.weight shape: [4096, 4096]
+55: model.layers.13.self_attn.v_proj.weight shape: [1024, 4096]
+56: model.layers.14.input_layernorm.weight shape: [4096]
+57: model.layers.14.mlp.down_proj.weight shape: [4096, 14336]
+58: model.layers.14.mlp.gate_proj.weight shape: [14336, 4096]
+59: model.layers.14.mlp.up_proj.weight shape: [14336, 4096]
+60: model.layers.14.post_attention_layernorm.weight shape: [4096]
+61: model.layers.14.self_attn.k_proj.weight shape: [1024, 4096]
+62: model.layers.14.self_attn.o_proj.weight shape: [4096, 4096]
+63: model.layers.14.self_attn.q_proj.weight shape: [4096, 4096]
+64: model.layers.14.self_attn.v_proj.weight shape: [1024, 4096]
+65: model.layers.15.input_layernorm.weight shape: [4096]
+66: model.layers.15.mlp.down_proj.weight shape: [4096, 14336]
+67: model.layers.15.mlp.gate_proj.weight shape: [14336, 4096]
+68: model.layers.15.mlp.up_proj.weight shape: [14336, 4096]
+69: model.layers.15.post_attention_layernorm.weight shape: [4096]
+70: model.layers.15.self_attn.k_proj.weight shape: [1024, 4096]
+71: model.layers.15.self_attn.o_proj.weight shape: [4096, 4096]
+72: model.layers.15.self_attn.q_proj.weight shape: [4096, 4096]
+73: model.layers.15.self_attn.v_proj.weight shape: [1024, 4096]
+74: model.layers.16.input_layernorm.weight shape: [4096]
+75: model.layers.16.mlp.down_proj.weight shape: [4096, 14336]
+76: model.layers.16.mlp.gate_proj.weight shape: [14336, 4096]
+77: model.layers.16.mlp.up_proj.weight shape: [14336, 4096]
+78: model.layers.16.post_attention_layernorm.weight shape: [4096]
+79: model.layers.16.self_attn.k_proj.weight shape: [1024, 4096]
+80: model.layers.16.self_attn.o_proj.weight shape: [4096, 4096]
+81: model.layers.16.self_attn.q_proj.weight shape: [4096, 4096]
+82: model.layers.16.self_attn.v_proj.weight shape: [1024, 4096]
+83: model.layers.17.input_layernorm.weight shape: [4096]
+84: model.layers.17.mlp.down_proj.weight shape: [4096, 14336]
+85: model.layers.17.mlp.gate_proj.weight shape: [14336, 4096]
+86: model.layers.17.mlp.up_proj.weight shape: [14336, 4096]
+87: model.layers.17.post_attention_layernorm.weight shape: [4096]
+88: model.layers.17.self_attn.k_proj.weight shape: [1024, 4096]
+89: model.layers.17.self_attn.o_proj.weight shape: [4096, 4096]
+90: model.layers.17.self_attn.q_proj.weight shape: [4096, 4096]
+91: model.layers.17.self_attn.v_proj.weight shape: [1024, 4096]
+92: model.layers.18.input_layernorm.weight shape: [4096]
+93: model.layers.18.mlp.down_proj.weight shape: [4096, 14336]
+94: model.layers.18.mlp.gate_proj.weight shape: [14336, 4096]
+95: model.layers.18.mlp.up_proj.weight shape: [14336, 4096]
+96: model.layers.18.post_attention_layernorm.weight shape: [4096]
+97: model.layers.18.self_attn.k_proj.weight shape: [1024, 4096]
+98: model.layers.18.self_attn.o_proj.weight shape: [4096, 4096]
+99: model.layers.18.self_attn.q_proj.weight shape: [4096, 4096]
+100: model.layers.18.self_attn.v_proj.weight shape: [1024, 4096]
+101: model.layers.19.input_layernorm.weight shape: [4096]
+102: model.layers.19.mlp.down_proj.weight shape: [4096, 14336]
+103: model.layers.19.mlp.gate_proj.weight shape: [14336, 4096]
+104: model.layers.19.mlp.up_proj.weight shape: [14336, 4096]
+105: model.layers.19.post_attention_layernorm.weight shape: [4096]
+106: model.layers.19.self_attn.k_proj.weight shape: [1024, 4096]
+107: model.layers.19.self_attn.o_proj.weight shape: [4096, 4096]
+108: model.layers.19.self_attn.q_proj.weight shape: [4096, 4096]
+109: model.layers.19.self_attn.v_proj.weight shape: [1024, 4096]
+110: model.layers.2.input_layernorm.weight shape: [4096]
+111: model.layers.2.mlp.down_proj.weight shape: [4096, 14336]
+112: model.layers.2.mlp.gate_proj.weight shape: [14336, 4096]
+113: model.layers.2.mlp.up_proj.weight shape: [14336, 4096]
+114: model.layers.2.post_attention_layernorm.weight shape: [4096]
+115: model.layers.2.self_attn.k_proj.weight shape: [1024, 4096]
+116: model.layers.2.self_attn.o_proj.weight shape: [4096, 4096]
+117: model.layers.2.self_attn.q_proj.weight shape: [4096, 4096]
+118: model.layers.2.self_attn.v_proj.weight shape: [1024, 4096]
+119: model.layers.20.input_layernorm.weight shape: [4096]
+120: model.layers.20.mlp.down_proj.weight shape: [4096, 14336]
+121: model.layers.20.mlp.gate_proj.weight shape: [14336, 4096]
+122: model.layers.20.mlp.up_proj.weight shape: [14336, 4096]
+123: model.layers.20.post_attention_layernorm.weight shape: [4096]
+124: model.layers.20.self_attn.k_proj.weight shape: [1024, 4096]
+125: model.layers.20.self_attn.o_proj.weight shape: [4096, 4096]
+126: model.layers.20.self_attn.q_proj.weight shape: [4096, 4096]
+127: model.layers.20.self_attn.v_proj.weight shape: [1024, 4096]
+128: model.layers.21.input_layernorm.weight shape: [4096]
+129: model.layers.21.mlp.down_proj.weight shape: [4096, 14336]
+130: model.layers.21.mlp.gate_proj.weight shape: [14336, 4096]
+131: model.layers.21.mlp.up_proj.weight shape: [14336, 4096]
+132: model.layers.21.post_attention_layernorm.weight shape: [4096]
+133: model.layers.21.self_attn.k_proj.weight shape: [1024, 4096]
+134: model.layers.21.self_attn.o_proj.weight shape: [4096, 4096]
+135: model.layers.21.self_attn.q_proj.weight shape: [4096, 4096]
+136: model.layers.21.self_attn.v_proj.weight shape: [1024, 4096]
+137: model.layers.22.input_layernorm.weight shape: [4096]
+138: model.layers.22.mlp.down_proj.weight shape: [4096, 14336]
+139: model.layers.22.mlp.gate_proj.weight shape: [14336, 4096]
+140: model.layers.22.mlp.up_proj.weight shape: [14336, 4096]
+141: model.layers.22.post_attention_layernorm.weight shape: [4096]
+142: model.layers.22.self_attn.k_proj.weight shape: [1024, 4096]
+143: model.layers.22.self_attn.o_proj.weight shape: [4096, 4096]
+144: model.layers.22.self_attn.q_proj.weight shape: [4096, 4096]
+145: model.layers.22.self_attn.v_proj.weight shape: [1024, 4096]
+146: model.layers.23.input_layernorm.weight shape: [4096]
+147: model.layers.23.mlp.down_proj.weight shape: [4096, 14336]
+148: model.layers.23.mlp.gate_proj.weight shape: [14336, 4096]
+149: model.layers.23.mlp.up_proj.weight shape: [14336, 4096]
+150: model.layers.23.post_attention_layernorm.weight shape: [4096]
+151: model.layers.23.self_attn.k_proj.weight shape: [1024, 4096]
+152: model.layers.23.self_attn.o_proj.weight shape: [4096, 4096]
+153: model.layers.23.self_attn.q_proj.weight shape: [4096, 4096]
+154: model.layers.23.self_attn.v_proj.weight shape: [1024, 4096]
+155: model.layers.24.input_layernorm.weight shape: [4096]
+156: model.layers.24.mlp.down_proj.weight shape: [4096, 14336]
+157: model.layers.24.mlp.gate_proj.weight shape: [14336, 4096]
+158: model.layers.24.mlp.up_proj.weight shape: [14336, 4096]
+159: model.layers.24.post_attention_layernorm.weight shape: [4096]
+160: model.layers.24.self_attn.k_proj.weight shape: [1024, 4096]
+161: model.layers.24.self_attn.o_proj.weight shape: [4096, 4096]
+162: model.layers.24.self_attn.q_proj.weight shape: [4096, 4096]
+163: model.layers.24.self_attn.v_proj.weight shape: [1024, 4096]
+164: model.layers.25.input_layernorm.weight shape: [4096]
+165: model.layers.25.mlp.down_proj.weight shape: [4096, 14336]
+166: model.layers.25.mlp.gate_proj.weight shape: [14336, 4096]
+167: model.layers.25.mlp.up_proj.weight shape: [14336, 4096]
+168: model.layers.25.post_attention_layernorm.weight shape: [4096]
+169: model.layers.25.self_attn.k_proj.weight shape: [1024, 4096]
+170: model.layers.25.self_attn.o_proj.weight shape: [4096, 4096]
+171: model.layers.25.self_attn.q_proj.weight shape: [4096, 4096]
+172: model.layers.25.self_attn.v_proj.weight shape: [1024, 4096]
+173: model.layers.26.input_layernorm.weight shape: [4096]
+174: model.layers.26.mlp.down_proj.weight shape: [4096, 14336]
+175: model.layers.26.mlp.gate_proj.weight shape: [14336, 4096]
+176: model.layers.26.mlp.up_proj.weight shape: [14336, 4096]
+177: model.layers.26.post_attention_layernorm.weight shape: [4096]
+178: model.layers.26.self_attn.k_proj.weight shape: [1024, 4096]
+179: model.layers.26.self_attn.o_proj.weight shape: [4096, 4096]
+180: model.layers.26.self_attn.q_proj.weight shape: [4096, 4096]
+181: model.layers.26.self_attn.v_proj.weight shape: [1024, 4096]
+182: model.layers.27.input_layernorm.weight shape: [4096]
+183: model.layers.27.mlp.down_proj.weight shape: [4096, 14336]
+184: model.layers.27.mlp.gate_proj.weight shape: [14336, 4096]
+185: model.layers.27.mlp.up_proj.weight shape: [14336, 4096]
+186: model.layers.27.post_attention_layernorm.weight shape: [4096]
+187: model.layers.27.self_attn.k_proj.weight shape: [1024, 4096]
+188: model.layers.27.self_attn.o_proj.weight shape: [4096, 4096]
+189: model.layers.27.self_attn.q_proj.weight shape: [4096, 4096]
+190: model.layers.27.self_attn.v_proj.weight shape: [1024, 4096]
+191: model.layers.28.input_layernorm.weight shape: [4096]
+192: model.layers.28.mlp.down_proj.weight shape: [4096, 14336]
+193: model.layers.28.mlp.gate_proj.weight shape: [14336, 4096]
+194: model.layers.28.mlp.up_proj.weight shape: [14336, 4096]
+195: model.layers.28.post_attention_layernorm.weight shape: [4096]
+196: model.layers.28.self_attn.k_proj.weight shape: [1024, 4096]
+197: model.layers.28.self_attn.o_proj.weight shape: [4096, 4096]
+198: model.layers.28.self_attn.q_proj.weight shape: [4096, 4096]
+199: model.layers.28.self_attn.v_proj.weight shape: [1024, 4096]
+200: model.layers.29.input_layernorm.weight shape: [4096]
+201: model.layers.29.mlp.down_proj.weight shape: [4096, 14336]
+202: model.layers.29.mlp.gate_proj.weight shape: [14336, 4096]
+203: model.layers.29.mlp.up_proj.weight shape: [14336, 4096]
+204: model.layers.29.post_attention_layernorm.weight shape: [4096]
+205: model.layers.29.self_attn.k_proj.weight shape: [1024, 4096]
+206: model.layers.29.self_attn.o_proj.weight shape: [4096, 4096]
+207: model.layers.29.self_attn.q_proj.weight shape: [4096, 4096]
+208: model.layers.29.self_attn.v_proj.weight shape: [1024, 4096]
+209: model.layers.3.input_layernorm.weight shape: [4096]
+210: model.layers.3.mlp.down_proj.weight shape: [4096, 14336]
+211: model.layers.3.mlp.gate_proj.weight shape: [14336, 4096]
+212: model.layers.3.mlp.up_proj.weight shape: [14336, 4096]
+213: model.layers.3.post_attention_layernorm.weight shape: [4096]
+214: model.layers.3.self_attn.k_proj.weight shape: [1024, 4096]
+215: model.layers.3.self_attn.o_proj.weight shape: [4096, 4096]
+216: model.layers.3.self_attn.q_proj.weight shape: [4096, 4096]
+217: model.layers.3.self_attn.v_proj.weight shape: [1024, 4096]
+218: model.layers.30.input_layernorm.weight shape: [4096]
+219: model.layers.30.mlp.down_proj.weight shape: [4096, 14336]
+220: model.layers.30.mlp.gate_proj.weight shape: [14336, 4096]
+221: model.layers.30.mlp.up_proj.weight shape: [14336, 4096]
+222: model.layers.30.post_attention_layernorm.weight shape: [4096]
+223: model.layers.30.self_attn.k_proj.weight shape: [1024, 4096]
+224: model.layers.30.self_attn.o_proj.weight shape: [4096, 4096]
+225: model.layers.30.self_attn.q_proj.weight shape: [4096, 4096]
+226: model.layers.30.self_attn.v_proj.weight shape: [1024, 4096]
+227: model.layers.31.input_layernorm.weight shape: [4096]
+228: model.layers.31.mlp.down_proj.weight shape: [4096, 14336]
+229: model.layers.31.mlp.gate_proj.weight shape: [14336, 4096]
+230: model.layers.31.mlp.up_proj.weight shape: [14336, 4096]
+231: model.layers.31.post_attention_layernorm.weight shape: [4096]
+232: model.layers.31.self_attn.k_proj.weight shape: [1024, 4096]
+233: model.layers.31.self_attn.o_proj.weight shape: [4096, 4096]
+234: model.layers.31.self_attn.q_proj.weight shape: [4096, 4096]
+235: model.layers.31.self_attn.v_proj.weight shape: [1024, 4096]
+236: model.layers.4.input_layernorm.weight shape: [4096]
+237: model.layers.4.mlp.down_proj.weight shape: [4096, 14336]
+238: model.layers.4.mlp.gate_proj.weight shape: [14336, 4096]
+239: model.layers.4.mlp.up_proj.weight shape: [14336, 4096]
+240: model.layers.4.post_attention_layernorm.weight shape: [4096]
+241: model.layers.4.self_attn.k_proj.weight shape: [1024, 4096]
+242: model.layers.4.self_attn.o_proj.weight shape: [4096, 4096]
+243: model.layers.4.self_attn.q_proj.weight shape: [4096, 4096]
+244: model.layers.4.self_attn.v_proj.weight shape: [1024, 4096]
+245: model.layers.5.input_layernorm.weight shape: [4096]
+246: model.layers.5.mlp.down_proj.weight shape: [4096, 14336]
+247: model.layers.5.mlp.gate_proj.weight shape: [14336, 4096]
+248: model.layers.5.mlp.up_proj.weight shape: [14336, 4096]
+249: model.layers.5.post_attention_layernorm.weight shape: [4096]
+250: model.layers.5.self_attn.k_proj.weight shape: [1024, 4096]
+251: model.layers.5.self_attn.o_proj.weight shape: [4096, 4096]
+252: model.layers.5.self_attn.q_proj.weight shape: [4096, 4096]
+253: model.layers.5.self_attn.v_proj.weight shape: [1024, 4096]
+254: model.layers.6.input_layernorm.weight shape: [4096]
+255: model.layers.6.mlp.down_proj.weight shape: [4096, 14336]
+256: model.layers.6.mlp.gate_proj.weight shape: [14336, 4096]
+257: model.layers.6.mlp.up_proj.weight shape: [14336, 4096]
+258: model.layers.6.post_attention_layernorm.weight shape: [4096]
+259: model.layers.6.self_attn.k_proj.weight shape: [1024, 4096]
+260: model.layers.6.self_attn.o_proj.weight shape: [4096, 4096]
+261: model.layers.6.self_attn.q_proj.weight shape: [4096, 4096]
+262: model.layers.6.self_attn.v_proj.weight shape: [1024, 4096]
+263: model.layers.7.input_layernorm.weight shape: [4096]
+264: model.layers.7.mlp.down_proj.weight shape: [4096, 14336]
+265: model.layers.7.mlp.gate_proj.weight shape: [14336, 4096]
+266: model.layers.7.mlp.up_proj.weight shape: [14336, 4096]
+267: model.layers.7.post_attention_layernorm.weight shape: [4096]
+268: model.layers.7.self_attn.k_proj.weight shape: [1024, 4096]
+269: model.layers.7.self_attn.o_proj.weight shape: [4096, 4096]
+270: model.layers.7.self_attn.q_proj.weight shape: [4096, 4096]
+271: model.layers.7.self_attn.v_proj.weight shape: [1024, 4096]
+272: model.layers.8.input_layernorm.weight shape: [4096]
+273: model.layers.8.mlp.down_proj.weight shape: [4096, 14336]
+274: model.layers.8.mlp.gate_proj.weight shape: [14336, 4096]
+275: model.layers.8.mlp.up_proj.weight shape: [14336, 4096]
+276: model.layers.8.post_attention_layernorm.weight shape: [4096]
+277: model.layers.8.self_attn.k_proj.weight shape: [1024, 4096]
+278: model.layers.8.self_attn.o_proj.weight shape: [4096, 4096]
+279: model.layers.8.self_attn.q_proj.weight shape: [4096, 4096]
+280: model.layers.8.self_attn.v_proj.weight shape: [1024, 4096]
+281: model.layers.9.input_layernorm.weight shape: [4096]
+282: model.layers.9.mlp.down_proj.weight shape: [4096, 14336]
+283: model.layers.9.mlp.gate_proj.weight shape: [14336, 4096]
+284: model.layers.9.mlp.up_proj.weight shape: [14336, 4096]
+285: model.layers.9.post_attention_layernorm.weight shape: [4096]
+286: model.layers.9.self_attn.k_proj.weight shape: [1024, 4096]
+287: model.layers.9.self_attn.o_proj.weight shape: [4096, 4096]
+288: model.layers.9.self_attn.q_proj.weight shape: [4096, 4096]
+289: model.layers.9.self_attn.v_proj.weight shape: [1024, 4096]
+290: model.norm.weight shape: [4096]
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt
new file mode 100644
index 000000000..0287bd2f2
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_7B_Instruct_V0_3Tests.TokenizerTest.approved.txt
@@ -0,0 +1,2 @@
+[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}][TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius.
+1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt
new file mode 100644
index 000000000..fc8562c9e
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Approvals/Mistral_V0_3Tests.TokenizerTest.approved.txt
@@ -0,0 +1,2 @@
+ [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}] What's the weather like in Paris? [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}] {"content": 22.0, "call_id": "9Ae3bDc2F"} The current temperature in Paris is 22.0 degrees celsius.
+1, 1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 14879, 29561, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 1065, 6233, 29572, 4, 5, 1501, 7567, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 29475, 17329, 1316, 1113, 17452, 2032, 10598, 3501, 2032, 1113, 4684, 1046, 29493, 5611, 1316, 1113, 6074, 2032, 1113, 29485, 1958, 3938, 8474, 1113, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 29507, 10925, 2, 8, 10598, 4557, 2032, 29473, 29518, 29518, 29491, 29502, 29493, 1113, 3613, 29498, 1081, 2032, 1113, 29542, 29509, 29474, 29538, 29494, 29525, 29485, 29518, 29533, 18163, 9, 1183, 2636, 8409, 1065, 6233, 1117, 29473, 29518, 29518, 29491, 29502, 11950, 1045, 1958, 3938, 29491, 2
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj
new file mode 100644
index 000000000..471594743
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj
@@ -0,0 +1,45 @@
+
+
+
+ net6.0
+ enable
+ $(NoWarn);MSML_ExtendBaseTestClass
+ enable
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs
new file mode 100644
index 000000000..0aa80e888
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Mistral_7B_Instruct_V0_3Tests.cs
@@ -0,0 +1,137 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+using ApprovalTests;
+using ApprovalTests.Namers;
+using ApprovalTests.Reporters;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using TorchSharp;
+using Xunit;
+
+namespace Microsoft.ML.GenAI.Mistral.Tests;
+
+[Collection("NoParallelization")]
+public class Mistral_7B_Instruct_V0_3Tests
+{
+ public Mistral_7B_Instruct_V0_3Tests()
+ {
+ if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
+ {
+ Approvals.UseAssemblyLocationForApprovedFiles();
+ }
+
+ torch.set_default_device("meta");
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void Mistral_7B_Instruct_V0_3_ShapeTest()
+ {
+ var model = new MistralForCausalLM(MistralConfig.Mistral_7B_Instruct_v0_3);
+ var stateDictStr = model.PeekShape();
+ Approvals.Verify(stateDictStr);
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void ItBuildChatTemplateFromAutoGenChatHistory()
+ {
+ var chatHistory = new List
+ {
+ new TextMessage(Role.System, "You are a helpful AI assistant."),
+ new TextMessage(Role.User, "Hello?"),
+ new TextMessage(Role.Assistant, "World!"),
+ };
+
+ var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
+
+ Approvals.Verify(prompt);
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void ItBuildChatTemplateWithToolsFromAutoGenChatHistory()
+ {
+ var getWeatherTool = new FunctionContract
+ {
+ Name = "get_current_weather",
+ Namespace = "weather",
+ Description = "Get the current weather",
+ Parameters = [
+ new FunctionParameterContract
+ {
+ Name = "location",
+ ParameterType = typeof(string),
+ Description = "The city and state, e.g. San Francisco, CA",
+ IsRequired = true
+ }
+ ]
+ };
+
+ var getWeatherToolCall = new ToolCall("get_current_weather", "{\"location\": \"Seattle, WA\"}") { ToolCallId = "9Ae3bDc2F" };
+ var getWeatherToolCallResult = new ToolCall("get_current_weather", "{\"temperature\": 22.0}", "sunny") { ToolCallId = "9Ae3bDc2F" };
+ var toolCallMessage = new ToolCallMessage([getWeatherToolCall]);
+ var toolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult]);
+ var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage);
+
+ var chatHistory = new List
+ {
+ new TextMessage(Role.System, "You are a helpful AI assistant."),
+ new TextMessage(Role.User, "What's the weather in Seattle?"),
+ toolCallMessage,
+ toolCallResultMessage,
+ new TextMessage(Role.Assistant, "The current temperature in Seattle is 22.0 degrees celsius."),
+
+ // test tool call aggregate message for immediate tool call execution
+ new TextMessage(Role.User, "What's the weather in New York?"),
+ aggregateToolCallMessage,
+ new TextMessage(Role.Assistant, "The current temperature in New York is 22.0 degrees celsius."),
+
+ new TextMessage(Role.User, "What's the weather in Paris?"),
+ };
+
+ var prompt = Mistral_7B_0_3ChatTemplateBuilder.Instance.BuildPrompt(chatHistory, [getWeatherTool]);
+
+ Approvals.Verify(prompt);
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void TokenizerTest()
+ {
+ var modelWeightFolder = "Mistral";
+ var tokenizer = MistralTokenizerHelper.FromPretrained(modelWeightFolder);
+
+ var messages = new string[]
+ {
+ // system : You are a helpful assistant that can answer questions about the weather.
+ // tool: [get-weather-tool-call]
+ // user : What's the weather like in Paris?
+ // assistant: // get-weather-tool-call
+ // tool: get-weather-tool-call-result
+ // assistant: The current temperature in Paris is 22.0 degrees celsius.
+ """
+ [AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}][/AVAILABLE_TOOLS][INST] What's the weather like in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "9Ae3bDc2F"}][TOOL_RESULTS] {"content": 22.0, "call_id": "9Ae3bDc2F"}[/TOOL_RESULTS] The current temperature in Paris is 22.0 degrees celsius.
+ """
+ };
+
+ var sb = new StringBuilder();
+ foreach (var message in messages)
+ {
+ var tokenizeIds = tokenizer.EncodeToIds(message, false, false);
+ var decodeToString = tokenizer.Decode(tokenizeIds, considerSpecialTokens: true);
+ sb.AppendLine(decodeToString);
+ var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString()));
+
+ sb.AppendLine(tokenizedStr);
+ }
+ Approvals.Verify(sb.ToString());
+ }
+}