[GenAI] Add generateEmbedding API to CausalLMPipeline (#7227)
* add embedding * add frompretrain api to phi3 model * fix bug * Update CausalLMPipeline.cs
This commit is contained in:
Родитель
1d1cc997d5
Коммит
7c937bf81a
|
@ -9,6 +9,7 @@ using static TorchSharp.torch;
|
||||||
using TorchSharp;
|
using TorchSharp;
|
||||||
using Microsoft.ML.GenAI.Core;
|
using Microsoft.ML.GenAI.Core;
|
||||||
using Microsoft.ML.GenAI.Core.Extension;
|
using Microsoft.ML.GenAI.Core.Extension;
|
||||||
|
using Microsoft.ML.Tokenizers;
|
||||||
|
|
||||||
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
|
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
|
||||||
|
|
||||||
|
@ -26,12 +27,15 @@ public class AutoGenSample
|
||||||
torch.manual_seed(1);
|
torch.manual_seed(1);
|
||||||
torch.set_default_dtype(defaultType);
|
torch.set_default_dtype(defaultType);
|
||||||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
||||||
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
|
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
|
||||||
|
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
|
||||||
|
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
|
||||||
|
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
|
||||||
|
var question = @"write a C# program to calculate the factorial of a number";
|
||||||
|
|
||||||
// agent
|
// agent
|
||||||
var agent = new Phi3Agent(pipeline, "assistant")
|
var agent = new Phi3Agent(pipeline, "assistant")
|
||||||
.RegisterPrintMessage();
|
.RegisterPrintMessage();
|
||||||
var question = @"write a C# program to calculate the factorial of a number";
|
|
||||||
|
|
||||||
// chat with the assistant
|
// chat with the assistant
|
||||||
await agent.SendAsync(question);
|
await agent.SendAsync(question);
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
using Microsoft.ML.GenAI.Phi.Extension;
|
using Microsoft.ML.GenAI.Core;
|
||||||
|
using Microsoft.ML.GenAI.Phi;
|
||||||
|
using Microsoft.ML.GenAI.Phi.Extension;
|
||||||
|
using Microsoft.ML.Tokenizers;
|
||||||
using Microsoft.SemanticKernel;
|
using Microsoft.SemanticKernel;
|
||||||
using Microsoft.SemanticKernel.ChatCompletion;
|
using Microsoft.SemanticKernel.ChatCompletion;
|
||||||
using TorchSharp;
|
using TorchSharp;
|
||||||
|
@ -20,8 +23,10 @@ public class SemanticKernelSample
|
||||||
torch.manual_seed(1);
|
torch.manual_seed(1);
|
||||||
torch.set_default_dtype(defaultType);
|
torch.set_default_dtype(defaultType);
|
||||||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
||||||
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);
|
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
|
||||||
|
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
|
||||||
|
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
|
||||||
|
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
|
||||||
|
|
||||||
var kernel = Kernel.CreateBuilder()
|
var kernel = Kernel.CreateBuilder()
|
||||||
.AddGenAIChatCompletion(pipeline)
|
.AddGenAIChatCompletion(pipeline)
|
||||||
|
@ -49,8 +54,10 @@ public class SemanticKernelSample
|
||||||
torch.manual_seed(1);
|
torch.manual_seed(1);
|
||||||
torch.set_default_dtype(defaultType);
|
torch.set_default_dtype(defaultType);
|
||||||
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
|
||||||
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device);
|
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
|
||||||
|
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
|
||||||
|
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
|
||||||
|
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
|
||||||
|
|
||||||
var kernel = Kernel.CreateBuilder()
|
var kernel = Kernel.CreateBuilder()
|
||||||
.AddGenAITextGeneration(pipeline)
|
.AddGenAITextGeneration(pipeline)
|
||||||
|
|
|
@ -1,103 +0,0 @@
|
||||||
using System;
|
|
||||||
using System.Collections.Generic;
|
|
||||||
using System.Linq;
|
|
||||||
using System.Text;
|
|
||||||
using System.Threading.Tasks;
|
|
||||||
using Microsoft.ML.GenAI.Core;
|
|
||||||
using Microsoft.ML.GenAI.Phi;
|
|
||||||
using Tensorboard;
|
|
||||||
using static TorchSharp.torch;
|
|
||||||
using TorchSharp;
|
|
||||||
using Microsoft.ML.GenAI.Core.Extension;
|
|
||||||
using System.Text.Json;
|
|
||||||
using Microsoft.ML.Tokenizers;
|
|
||||||
|
|
||||||
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
|
|
||||||
|
|
||||||
internal static class Utils
|
|
||||||
{
|
|
||||||
public static ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> LoadPhi3Mini4KFromFolder(
|
|
||||||
string weightFolder,
|
|
||||||
string configName = "config.json",
|
|
||||||
string device = "cuda",
|
|
||||||
int modelSizeOnCudaInGB = 55,
|
|
||||||
int modelSizeOnMemoryInGB = 64,
|
|
||||||
int modelSizeOnDiskInGB = 200,
|
|
||||||
bool quantizeToInt8 = false,
|
|
||||||
bool quantizeToInt4 = false)
|
|
||||||
{
|
|
||||||
Console.WriteLine("Loading Phi3 from huggingface model weight folder");
|
|
||||||
torch.set_default_device("meta");
|
|
||||||
var configPath = System.IO.Path.Combine(weightFolder, configName);
|
|
||||||
var config = JsonSerializer.Deserialize<Phi3Config>(System.IO.File.ReadAllText(configPath)) ?? throw new ArgumentNullException(nameof(configPath));
|
|
||||||
var timer = System.Diagnostics.Stopwatch.StartNew();
|
|
||||||
var model = new Phi3ForCasualLM(config);
|
|
||||||
var tokenzierPath = System.IO.Path.Combine(weightFolder, "tokenizer.model");
|
|
||||||
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenzierPath);
|
|
||||||
|
|
||||||
if (quantizeToInt8)
|
|
||||||
{
|
|
||||||
model.ToInt8QuantizeModule();
|
|
||||||
}
|
|
||||||
else if (quantizeToInt4)
|
|
||||||
{
|
|
||||||
model.ToInt4QuantizeModule();
|
|
||||||
}
|
|
||||||
|
|
||||||
var deviceSizeMap = new Dictionary<string, long>
|
|
||||||
{
|
|
||||||
["cuda"] = modelSizeOnCudaInGB * 1L * 1024 * 1024 * 1024,
|
|
||||||
["cpu"] = modelSizeOnMemoryInGB * 1L * 1024 * 1024 * 1024,
|
|
||||||
["disk"] = modelSizeOnDiskInGB * 1L * 1024 * 1024 * 1024,
|
|
||||||
};
|
|
||||||
|
|
||||||
var deviceMap = model.InferDeviceMapForEachLayer(
|
|
||||||
devices: ["cuda", "cpu", "disk"],
|
|
||||||
deviceSizeMapInByte: deviceSizeMap);
|
|
||||||
|
|
||||||
var deviceMapJson = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true });
|
|
||||||
Console.WriteLine($"Device map:");
|
|
||||||
Console.WriteLine(deviceMapJson);
|
|
||||||
|
|
||||||
// load weight
|
|
||||||
torch.set_default_device("cpu");
|
|
||||||
|
|
||||||
Console.WriteLine("Start loading");
|
|
||||||
timer = System.Diagnostics.Stopwatch.StartNew();
|
|
||||||
model = new Phi3ForCasualLM(config);
|
|
||||||
timer.Stop();
|
|
||||||
Console.WriteLine($"Phi3 model created in {timer.ElapsedMilliseconds / 1000} s");
|
|
||||||
|
|
||||||
timer = System.Diagnostics.Stopwatch.StartNew();
|
|
||||||
model.LoadSafeTensors(weightFolder);
|
|
||||||
timer.Stop();
|
|
||||||
Console.WriteLine($"Phi3 weight loaded in {timer.ElapsedMilliseconds / 1000} s");
|
|
||||||
|
|
||||||
if (quantizeToInt8 || quantizeToInt4)
|
|
||||||
{
|
|
||||||
timer = System.Diagnostics.Stopwatch.StartNew();
|
|
||||||
Console.WriteLine("Start quantizing if needed");
|
|
||||||
if (quantizeToInt8)
|
|
||||||
{
|
|
||||||
model.ToInt8QuantizeModule();
|
|
||||||
}
|
|
||||||
else if (quantizeToInt4)
|
|
||||||
{
|
|
||||||
model.ToInt4QuantizeModule();
|
|
||||||
}
|
|
||||||
Console.WriteLine("Quantizing done");
|
|
||||||
timer.Stop();
|
|
||||||
Console.WriteLine($"Quantizing done in {timer.ElapsedMilliseconds / 1000} s");
|
|
||||||
}
|
|
||||||
|
|
||||||
timer = System.Diagnostics.Stopwatch.StartNew();
|
|
||||||
Console.WriteLine($"Start loading to device: {device}");
|
|
||||||
model = model.ToDynamicLoadingModel(deviceMap, "cuda");
|
|
||||||
timer.Stop();
|
|
||||||
Console.WriteLine($"Phi3 loaded to device: {device} in {timer.ElapsedMilliseconds / 1000} s");
|
|
||||||
var pipeline = new CausalLMPipeline<Tokenizer, Phi3ForCasualLM>(tokenizer, model, device);
|
|
||||||
torch.set_default_device(device);
|
|
||||||
|
|
||||||
return pipeline;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
// See https://aka.ms/new-console-template for more information
|
// See https://aka.ms/new-console-template for more information
|
||||||
using Microsoft.ML.GenAI.Samples.Phi3Mini;
|
using Microsoft.ML.GenAI.Samples.Phi3Mini;
|
||||||
|
|
||||||
await SemanticKernelSample.RunChatCompletionSample();
|
await AutoGenSample.RunAsync();
|
||||||
|
|
|
@ -32,6 +32,11 @@ public interface ICausalLMPipeline
|
||||||
float topP = CausalLMPipeline.Defaults.TopP,
|
float topP = CausalLMPipeline.Defaults.TopP,
|
||||||
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
|
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
|
||||||
|
/// </summary>
|
||||||
|
float[] GenerateEmbeddingFromLastTokenPool(string prompt);
|
||||||
|
|
||||||
IEnumerable<string> GenerateStreaming(
|
IEnumerable<string> GenerateStreaming(
|
||||||
string prompt,
|
string prompt,
|
||||||
int maxLen = CausalLMPipeline.Defaults.MaxLen,
|
int maxLen = CausalLMPipeline.Defaults.MaxLen,
|
||||||
|
@ -281,4 +286,23 @@ public class CausalLMPipeline : ICausalLMPipeline
|
||||||
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
|
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
|
||||||
return nextToken;
|
return nextToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
|
||||||
|
{
|
||||||
|
using var scope = NewDisposeScope();
|
||||||
|
using var noGrad = torch.no_grad();
|
||||||
|
var inputIds = this.Tokenizer.EncodeToIds(prompt);
|
||||||
|
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
|
||||||
|
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
|
||||||
|
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
|
||||||
|
var output = this.Model.forward(input);
|
||||||
|
var lastTokenHiddenState = output.LastHiddenState[0, ^1];
|
||||||
|
|
||||||
|
// shape of lastTokenHiddenState: [hidden_size]
|
||||||
|
// L2 norm
|
||||||
|
var norm = lastTokenHiddenState.norm();
|
||||||
|
var normalized = lastTokenHiddenState / norm;
|
||||||
|
|
||||||
|
return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ using System.Text;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.ML.GenAI.Core;
|
using Microsoft.ML.GenAI.Core;
|
||||||
|
using Microsoft.ML.GenAI.Core.Extension;
|
||||||
using Microsoft.ML.GenAI.Phi.Module;
|
using Microsoft.ML.GenAI.Phi.Module;
|
||||||
using TorchSharp;
|
using TorchSharp;
|
||||||
using TorchSharp.Modules;
|
using TorchSharp.Modules;
|
||||||
|
@ -66,6 +67,55 @@ public class Phi3ForCasualLM : nn.Module<CausalLMModelInput, CausalLMModelOutput
|
||||||
return phi;
|
return phi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Phi3ForCasualLM FromPretrained(
|
||||||
|
string modelFolder,
|
||||||
|
string configName = "config.json",
|
||||||
|
string checkPointName = "model.safetensors.index.json",
|
||||||
|
bool quantizeToInt8 = false,
|
||||||
|
bool quantizeToInt4 = false,
|
||||||
|
int layersOnTargetDevice = -1,
|
||||||
|
ScalarType torchDtype = ScalarType.BFloat16,
|
||||||
|
string targetDevice = "cuda")
|
||||||
|
{
|
||||||
|
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
|
||||||
|
{
|
||||||
|
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
|
||||||
|
}
|
||||||
|
|
||||||
|
var originalDefaultDevice = torch.get_default_device();
|
||||||
|
torch.set_default_device("meta");
|
||||||
|
var config = Path.Join(modelFolder, configName);
|
||||||
|
var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
|
||||||
|
modelConfig.DType = torchDtype;
|
||||||
|
var model = new Phi3ForCasualLM(modelConfig);
|
||||||
|
|
||||||
|
if (quantizeToInt8)
|
||||||
|
{
|
||||||
|
model.ToInt8QuantizeModule();
|
||||||
|
}
|
||||||
|
else if (quantizeToInt4)
|
||||||
|
{
|
||||||
|
model.ToInt4QuantizeModule();
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceMap = model.InferDeviceMapForEachLayer(
|
||||||
|
[
|
||||||
|
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
|
||||||
|
KeyValuePair.Create("cpu", -1)
|
||||||
|
]);
|
||||||
|
|
||||||
|
torch.set_default_device("cpu");
|
||||||
|
model = new Phi3ForCasualLM(modelConfig);
|
||||||
|
|
||||||
|
model.LoadSafeTensors(modelFolder, checkPointName);
|
||||||
|
|
||||||
|
model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
|
||||||
|
|
||||||
|
torch.set_default_device(originalDefaultDevice);
|
||||||
|
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
|
public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
|
||||||
{
|
{
|
||||||
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
|
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
|
||||||
|
|
Загрузка…
Ссылка в новой задаче