[GenAI] Add generateEmbedding API to CausalLMPipeline (#7227)

* add embedding

* add frompretrain api to phi3 model

* fix bug

* Update CausalLMPipeline.cs
This commit is contained in:
Xiaoyun Zhang 2024-08-30 10:30:46 -07:00 коммит произвёл GitHub
Родитель 1d1cc997d5
Коммит 7c937bf81a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
6 изменённых файлов: 93 добавлений и 111 удалений

Просмотреть файл

@ -9,6 +9,7 @@ using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
@ -26,12 +27,15 @@ public class AutoGenSample
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var question = @"write a C# program to calculate the factorial of a number";
// agent
var agent = new Phi3Agent(pipeline, "assistant")
.RegisterPrintMessage();
var question = @"write a C# program to calculate the factorial of a number";
// chat with the assistant
await agent.SendAsync(question);

Просмотреть файл

@ -1,4 +1,7 @@
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TorchSharp;
@ -20,8 +23,10 @@ public class SemanticKernelSample
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var kernel = Kernel.CreateBuilder()
.AddGenAIChatCompletion(pipeline)
@ -49,8 +54,10 @@ public class SemanticKernelSample
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var kernel = Kernel.CreateBuilder()
.AddGenAITextGeneration(pipeline)

Просмотреть файл

@ -1,103 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Phi;
using Tensorboard;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.GenAI.Core.Extension;
using System.Text.Json;
using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
internal static class Utils
{
public static ICausalLMPipeline<Tokenizer, Phi3ForCasualLM> LoadPhi3Mini4KFromFolder(
string weightFolder,
string configName = "config.json",
string device = "cuda",
int modelSizeOnCudaInGB = 55,
int modelSizeOnMemoryInGB = 64,
int modelSizeOnDiskInGB = 200,
bool quantizeToInt8 = false,
bool quantizeToInt4 = false)
{
Console.WriteLine("Loading Phi3 from huggingface model weight folder");
torch.set_default_device("meta");
var configPath = System.IO.Path.Combine(weightFolder, configName);
var config = JsonSerializer.Deserialize<Phi3Config>(System.IO.File.ReadAllText(configPath)) ?? throw new ArgumentNullException(nameof(configPath));
var timer = System.Diagnostics.Stopwatch.StartNew();
var model = new Phi3ForCasualLM(config);
var tokenzierPath = System.IO.Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenzierPath);
if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}
var deviceSizeMap = new Dictionary<string, long>
{
["cuda"] = modelSizeOnCudaInGB * 1L * 1024 * 1024 * 1024,
["cpu"] = modelSizeOnMemoryInGB * 1L * 1024 * 1024 * 1024,
["disk"] = modelSizeOnDiskInGB * 1L * 1024 * 1024 * 1024,
};
var deviceMap = model.InferDeviceMapForEachLayer(
devices: ["cuda", "cpu", "disk"],
deviceSizeMapInByte: deviceSizeMap);
var deviceMapJson = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true });
Console.WriteLine($"Device map:");
Console.WriteLine(deviceMapJson);
// load weight
torch.set_default_device("cpu");
Console.WriteLine("Start loading");
timer = System.Diagnostics.Stopwatch.StartNew();
model = new Phi3ForCasualLM(config);
timer.Stop();
Console.WriteLine($"Phi3 model created in {timer.ElapsedMilliseconds / 1000} s");
timer = System.Diagnostics.Stopwatch.StartNew();
model.LoadSafeTensors(weightFolder);
timer.Stop();
Console.WriteLine($"Phi3 weight loaded in {timer.ElapsedMilliseconds / 1000} s");
if (quantizeToInt8 || quantizeToInt4)
{
timer = System.Diagnostics.Stopwatch.StartNew();
Console.WriteLine("Start quantizing if needed");
if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}
Console.WriteLine("Quantizing done");
timer.Stop();
Console.WriteLine($"Quantizing done in {timer.ElapsedMilliseconds / 1000} s");
}
timer = System.Diagnostics.Stopwatch.StartNew();
Console.WriteLine($"Start loading to device: {device}");
model = model.ToDynamicLoadingModel(deviceMap, "cuda");
timer.Stop();
Console.WriteLine($"Phi3 loaded to device: {device} in {timer.ElapsedMilliseconds / 1000} s");
var pipeline = new CausalLMPipeline<Tokenizer, Phi3ForCasualLM>(tokenizer, model, device);
torch.set_default_device(device);
return pipeline;
}
}

Просмотреть файл

@ -1,4 +1,4 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Phi3Mini;
await SemanticKernelSample.RunChatCompletionSample();
await AutoGenSample.RunAsync();

Просмотреть файл

@ -32,6 +32,11 @@ public interface ICausalLMPipeline
float topP = CausalLMPipeline.Defaults.TopP,
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
/// <summary>
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
/// </summary>
float[] GenerateEmbeddingFromLastTokenPool(string prompt);
IEnumerable<string> GenerateStreaming(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
@ -281,4 +286,23 @@ public class CausalLMPipeline : ICausalLMPipeline
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
return nextToken;
}
public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
{
using var scope = NewDisposeScope();
using var noGrad = torch.no_grad();
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];
// shape of lastTokenHiddenState: [hidden_size]
// L2 norm
var norm = lastTokenHiddenState.norm();
var normalized = lastTokenHiddenState / norm;
return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
}
}

Просмотреть файл

@ -9,6 +9,7 @@ using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.Phi.Module;
using TorchSharp;
using TorchSharp.Modules;
@ -66,6 +67,55 @@ public class Phi3ForCasualLM : nn.Module<CausalLMModelInput, CausalLMModelOutput
return phi;
}
public static Phi3ForCasualLM FromPretrained(
string modelFolder,
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
bool quantizeToInt8 = false,
bool quantizeToInt4 = false,
int layersOnTargetDevice = -1,
ScalarType torchDtype = ScalarType.BFloat16,
string targetDevice = "cuda")
{
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
{
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
}
var originalDefaultDevice = torch.get_default_device();
torch.set_default_device("meta");
var config = Path.Join(modelFolder, configName);
var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
modelConfig.DType = torchDtype;
var model = new Phi3ForCasualLM(modelConfig);
if (quantizeToInt8)
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
{
model.ToInt4QuantizeModule();
}
var deviceMap = model.InferDeviceMapForEachLayer(
[
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
KeyValuePair.Create("cpu", -1)
]);
torch.set_default_device("cpu");
model = new Phi3ForCasualLM(modelConfig);
model.LoadSafeTensors(modelFolder, checkPointName);
model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
torch.set_default_device(originalDefaultDevice);
return model;
}
public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
{
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);