Distillation and Synthetic data gen notebook (#3307)
* Distillation and Synthetic data gen notebook * Black
This commit is contained in:
Родитель
ef4f57a7fa
Коммит
08542cec5e
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Synthetic Data Generation with Large Language Models\n",
|
||||
"## Notebook details\n",
|
||||
"This notebook generates synthetic data with an LLM on a sample NLI dataset.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Step 1: Install the dependencies in your environment\n",
|
||||
"\n",
|
||||
"Install the libraries/dependencies required to run the python code."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"%pip install azure-ai-ml\n",
|
||||
"%pip install azure-identity\n",
|
||||
"%pip install datasets\n",
|
||||
"%pip install mlflow\n",
|
||||
"%pip install azureml-mlflow\n",
|
||||
"%pip install fsspec"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# TASK : NLI Synthetic Data generation"
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Natural Language Inference (NLI)\n",
|
||||
"\n",
|
||||
"Synthetic data generation is targeted towards cases where user does not have labeled data, so teacher LLM is used to create high quality, synthetic labels for the data.\n",
|
||||
"\n",
|
||||
"This notebook assumes the data to have the above three fields: 'premise', 'hypothesis'. The 'label' can optionally be used to compute metrics based on original ground truth. However, the purpose of synthetic data generation is to replace the labels with the high quality labels generated by a large, capable LLM.\n",
|
||||
"\n",
|
||||
"Natural Language Inference or Recognizing Textual Entailment (RTE) is the task of classifying a pair of premise and hypothesis sentences into three classes: **contradiction, neutral, and entailment**. For example:\n",
|
||||
"\n",
|
||||
"| premise | hypothesis | label |\n",
|
||||
"|---------------------------------------------------|--------------------------------------------------------|---------------|\n",
|
||||
"| A man inspects the uniform of a figure in some East Asian country. | The man is sleeping. | contradiction |\n",
|
||||
"| An older and younger man smiling. | Two men are smiling and laughing at the cats playing on the floor. | neutral |\n",
|
||||
"| A soccer game with multiple males playing. | Some men are playing a sport. | entailment |\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Step 2: Consume input dataset\n",
|
||||
"\n",
|
||||
"The classes in this cell handle the responsibility of ingesting the input dataset. Dataset can be anything, HuggingFace, Locally hosted, JSON, string etc. For our NLI example, we have written a `NLIHuggingFaceInputDataset` class to ingests input from HuggingFace datasets.\n",
|
||||
"\n",
|
||||
"Example NLI Dataset looks like the following:\n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"premise\": \"Aside from the Indigenous population, nearly all Argentines or their ancestors immigrated within the past five centuries.\",\n",
|
||||
" \"hypothesis\": \"Aside from the Indigenous population, some Argentines or their ancestors immigrated within the past five centuries.\",\n",
|
||||
" \"label\": 0\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"Labels 0, 1, 2 correspond to entailment, neutral and contradiction respectively."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from utils import NLIHuggingFaceInputDataset\n",
|
||||
"\n",
|
||||
"# We can define train and test sample sizes here.\n",
|
||||
"train_sample_size = 2\n",
|
||||
"val_sample_size = 2\n",
|
||||
"test_sample_size = 2\n",
|
||||
"\n",
|
||||
"# Sample notebook using the dataset: https://huggingface.co/datasets/cestwc/conjnli\n",
|
||||
"dataset_name = \"cestwc/conjnli\"\n",
|
||||
"input_dataset = NLIHuggingFaceInputDataset()\n",
|
||||
"\n",
|
||||
"# Note: train_split_name and test_split_name can vary by dataset. They are passed as arguments in load_hf_dataset.\n",
|
||||
"# If val_split_name is None, the below function will split the train set to create the specified sized validation set.\n",
|
||||
"train, val, test = input_dataset.load_hf_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" train_sample_size=train_sample_size,\n",
|
||||
" val_sample_size=val_sample_size,\n",
|
||||
" test_sample_size=test_sample_size,\n",
|
||||
" train_split_name=\"adversarial\",\n",
|
||||
" val_split_name=None,\n",
|
||||
" test_split_name=\"dev\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Len of train data sample is \" + str(len(train)))\n",
|
||||
"print(\"Len of validation data sample is \" + str(len(val)))\n",
|
||||
"print(\"Len of test data sample is \" + str(len(test)))"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1721738954273
|
||||
},
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#### Check format of data"
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"train[0]"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1721738970221
|
||||
},
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Step 3: Generate prompt for inference\n",
|
||||
"\n",
|
||||
"We generate the prompts in the required format to be able to output a desired answer.\n",
|
||||
"\n",
|
||||
"So the previous cell prompt \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"premise\": \"Aside from the Indigenous population, nearly all Argentines or their ancestors immigrated within the past five centuries.\",\n",
|
||||
" \"hypothesis\": \"Aside from the Indigenous population, some Argentines or their ancestors immigrated within the past five centuries.\",\n",
|
||||
" \"label\": 0\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"**transforms to**\n",
|
||||
"\n",
|
||||
"```json\n",
|
||||
"\n",
|
||||
"{\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You are a helpful assistant. Write out in a step by step manner your reasoning about the answer using no more than 80 words. Based on the reasoning, produce the final answer. Your response should be in JSON format without using any backticks. The JSON is a dictionary whose keys are 'reason' and 'answer_choice'.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: Aside from the Indigenous population, nearly all Argentines or their ancestors immigrated within the past five centuries.\\nHypothesis:Aside from the Indigenous population, some Argentines or their ancestors immigrated within the past five centuries.\\n\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"}\n"
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
" #### We have abstracted out this functionality in a separate class which you can use as follows."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# An example of how a final NLI prompt looks like\n",
|
||||
"from utils import NLIPromptGenerator\n",
|
||||
"\n",
|
||||
"# You can set the enable chain of thought flag to True to enable CoT prompting\n",
|
||||
"\n",
|
||||
"nli_prompt_generator = NLIPromptGenerator(enable_chain_of_thought=True)\n",
|
||||
"nli_prompt_generator.generate_prompt(train[0])"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1721739016231
|
||||
},
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Step 4: Setup inference with Azure ML endpoints\n",
|
||||
"\n",
|
||||
"### First deploy the teacher model in Azure AI Studio\n",
|
||||
"* Go to Azure AI Studio (ai.azure.com)\n",
|
||||
"* Select Meta-Llama-3.1-405B-Instruct model from Model catalog.\n",
|
||||
"* Deploy with \"Pay-as-you-go\"\n",
|
||||
"* Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.\n",
|
||||
"\n",
|
||||
"The following cell builds the Azure ML endpoints to be able to get outputs from the LLama endpoint set up in Azure. You can directly use the `AzureInference` class that handles this."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from utils import AzureInference\n",
|
||||
"\n",
|
||||
"url = \"<Chat completion teacher model endpoint URL>\"\n",
|
||||
"key = \"<API key>\"\n",
|
||||
"\n",
|
||||
"az_llama_405b_model_inf = AzureInference(url=url, key=key)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1721739068307
|
||||
},
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Step 5: Build the final dataset with synthetic labels\n",
|
||||
"\n",
|
||||
"In the following cell, we utilize the previously built classes to get input dataset, prompt engineer it, call the LLM from Azure ML endpoints, generate the output and write it to a file.\n",
|
||||
"Sample final output: \n",
|
||||
"\n",
|
||||
"```json\n",
|
||||
"\n",
|
||||
"{\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You are a helpful assistant. Write out in a step by step manner your reasoning about the answer using no more than 80 words. Based on the reasoning, produce the final answer. Your response should be in JSON format without using any backticks. The JSON is a dictionary whose keys are 'reason' and 'answer_choice'.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: None but Jake managed to win their game.\\nHypothesis: Jake managed to win their game.\",\n",
|
||||
" \"role\": \"user\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"entailment\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The answer \"entailment\" in the above sample JSON is generated as a response by the LLM. We wrap it as a response generated by the \"assistant\"."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##### We have abstracted out the above functionality in `NLISyntheticDatasetBuilder` which builds prompts, calls Llama endpoint, and then writes the final dataset in your local directory."
|
||||
],
|
||||
"metadata": {
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from utils import NLISyntheticDatasetBuilder\n",
|
||||
"\n",
|
||||
"nli_dataset_builder = NLISyntheticDatasetBuilder(\n",
|
||||
" nli_prompt_generator, inference_pointer=az_llama_405b_model_inf\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Write synthetic training and validation data to local directory.\n",
|
||||
"nli_dataset_builder.build_dataset(train, file_name=\"train_nli\")\n",
|
||||
"nli_dataset_builder.build_dataset(val, file_name=\"valid_nli\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1721739108156
|
||||
},
|
||||
"jupyter": {
|
||||
"outputs_hidden": false,
|
||||
"source_hidden": false
|
||||
},
|
||||
"nteract": {
|
||||
"transient": {
|
||||
"deleting": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernel_info": {
|
||||
"name": "python310-sdkv2"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python310-sdkv2",
|
||||
"language": "python",
|
||||
"display_name": "Python 3.10 - SDK v2"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.14",
|
||||
"mimetype": "text/x-python",
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"pygments_lexer": "ipython3",
|
||||
"nbconvert_exporter": "python",
|
||||
"file_extension": ".py"
|
||||
},
|
||||
"microsoft": {
|
||||
"host": {
|
||||
"AzureML": {
|
||||
"notebookHasBeenCompleted": true
|
||||
}
|
||||
},
|
||||
"ms_spell_check": {
|
||||
"ms_spell_check_language": "en"
|
||||
}
|
||||
},
|
||||
"nteract": {
|
||||
"version": "nteract-front-end@1.0.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,443 @@
|
|||
"""
|
||||
THIS FILE CONTAINS.
|
||||
|
||||
1. DATASETS INGESTION FUCNTIONS
|
||||
2. PROMPT GENERATOR CLASSES
|
||||
3. AZURE MODEL INFERENCE CONNECTOR HELPER CLASSES
|
||||
4. FINAL DATASET PUBLISHING CLASSES
|
||||
|
||||
"""
|
||||
from abc import ABC
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
import requests
|
||||
|
||||
|
||||
"""
|
||||
DATASETS INGESTION FUNCTIONS
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InputDataset(ABC):
|
||||
"""Input Dataset class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
(
|
||||
self.train_data_file_name,
|
||||
self.test_data_file_name,
|
||||
self.eval_data_file_name,
|
||||
) = (None, None, None)
|
||||
|
||||
|
||||
class QALocalInputDataset(InputDataset):
|
||||
"""
|
||||
Loads the input dataset if its in local.
|
||||
|
||||
The directory is left blank if your dataset is in the same directory as your notebook.
|
||||
The input dataset is divided as train, eval and test based on availaibility
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir="",
|
||||
dataset_name="cqa",
|
||||
train_samples="512",
|
||||
test_samples="256",
|
||||
eval_samples=None,
|
||||
):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
self.dir = dir
|
||||
if train_samples is not None:
|
||||
self.train_data_file_name = (
|
||||
dir + dataset_name + "_train_" + str(train_samples) + ".jsonl"
|
||||
)
|
||||
if test_samples is not None:
|
||||
self.test_data_file_name = (
|
||||
dir + dataset_name + "_test_" + str(test_samples) + ".jsonl"
|
||||
)
|
||||
if eval_samples is not None:
|
||||
self.eval_data_file_name = (
|
||||
dir + dataset_name + "_eval_" + str(eval_samples) + ".jsonl"
|
||||
)
|
||||
|
||||
def load_local_dataset(self, sample_size=10):
|
||||
"""Load the local dataset."""
|
||||
train_data, val_data, test_data = [], [], []
|
||||
|
||||
if self.train_data_file_name:
|
||||
with open(self.train_data_file_name, "r") as f:
|
||||
for line in f:
|
||||
train_data.append(json.loads(line))
|
||||
|
||||
if self.test_data_file_name:
|
||||
with open(self.test_data_file_name, "r") as f:
|
||||
for line in f:
|
||||
test_data.append(json.loads(line))
|
||||
|
||||
if self.eval_data_file_name:
|
||||
with open(self.eval_data_file_name, "r") as f:
|
||||
for line in f:
|
||||
val_data.append(json.loads(line))
|
||||
return train_data[:sample_size], val_data[:sample_size], test_data[:sample_size]
|
||||
|
||||
|
||||
class QAHuggingFaceInputDataset(InputDataset):
|
||||
"""Loads the HuggingFace dataset."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
|
||||
def load_hf_dataset(
|
||||
self,
|
||||
dataset_name,
|
||||
train_sample_size=10,
|
||||
val_sample_size=10,
|
||||
test_sample_size=10,
|
||||
train_split_name="train",
|
||||
val_split_name="validation",
|
||||
test_split_name="test",
|
||||
):
|
||||
"""Load the HuggingFace dataset."""
|
||||
full_dataset = load_dataset(dataset_name)
|
||||
|
||||
if val_split_name is not None:
|
||||
train_data = full_dataset[train_split_name].select(range(train_sample_size))
|
||||
val_data = full_dataset[val_split_name].select(range(val_sample_size))
|
||||
test_data = full_dataset[test_split_name].select(range(test_sample_size))
|
||||
else:
|
||||
train_val_data = full_dataset[train_split_name].select(
|
||||
range(train_sample_size + val_sample_size)
|
||||
)
|
||||
train_data = train_val_data.select(range(train_sample_size))
|
||||
val_data = train_val_data.select(
|
||||
range(train_sample_size, train_sample_size + val_sample_size)
|
||||
)
|
||||
test_data = full_dataset[test_split_name].select(range(test_sample_size))
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
class NLIHuggingFaceInputDataset(InputDataset):
|
||||
"""Loads the HuggingFace dataset."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
|
||||
def load_hf_dataset(
|
||||
self,
|
||||
dataset_name,
|
||||
train_sample_size=10,
|
||||
val_sample_size=10,
|
||||
test_sample_size=10,
|
||||
train_split_name="train",
|
||||
val_split_name="validation",
|
||||
test_split_name="test",
|
||||
):
|
||||
"""Load the HuggingFace dataset."""
|
||||
full_dataset = load_dataset(dataset_name)
|
||||
|
||||
if val_split_name is not None:
|
||||
train_data = full_dataset[train_split_name].select(range(train_sample_size))
|
||||
val_data = full_dataset[val_split_name].select(range(val_sample_size))
|
||||
test_data = full_dataset[test_split_name].select(range(test_sample_size))
|
||||
else:
|
||||
train_val_data = full_dataset[train_split_name].select(
|
||||
range(train_sample_size + val_sample_size)
|
||||
)
|
||||
train_data = train_val_data.select(range(train_sample_size))
|
||||
val_data = train_val_data.select(
|
||||
range(train_sample_size, train_sample_size + val_sample_size)
|
||||
)
|
||||
test_data = full_dataset[test_split_name].select(range(test_sample_size))
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
"""
|
||||
PROMPT GENERATOR CLASSES
|
||||
"""
|
||||
|
||||
|
||||
class PromptGenerator(ABC):
|
||||
"""Prompt Generator class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
|
||||
def generate_prompt(self):
|
||||
"""Generate the prompt."""
|
||||
pass
|
||||
|
||||
|
||||
class QAPromptGenerator(PromptGenerator):
|
||||
"""Prompt format each data for inference."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
self.qa_system_prompt = (
|
||||
"You are a helpful assistant. Write out in a step by step manner "
|
||||
"your reasoning about the answer using no more than 80 words. "
|
||||
"Based on the reasoning, produce the final answer. "
|
||||
"Your response should be in JSON format without using any backticks. "
|
||||
"The JSON is a dictionary whose keys are 'reason' and 'answer_choice'."
|
||||
)
|
||||
|
||||
self.qa_user_prompt_template = (
|
||||
"Answer the following multiple-choice question by selecting the correct option.\n\n"
|
||||
"Question: {question}\n"
|
||||
"Answer Choices:\n"
|
||||
"{answer_choices}"
|
||||
)
|
||||
|
||||
def generate_prompt(self, qa_input):
|
||||
"""Generate the prompt."""
|
||||
_, choices, _ = qa_input["question"], qa_input["choices"], qa_input["answerKey"]
|
||||
|
||||
labels, choice_list = choices["label"], choices["text"]
|
||||
answer_choices = [
|
||||
"({}) {}".format(labels[i], choice_list[i]) for i in range(len(labels))
|
||||
]
|
||||
answer_choices = "\n".join(answer_choices)
|
||||
|
||||
self.qa_user_prompt = self.qa_user_prompt_template.format(
|
||||
question=qa_input["question"], answer_choices=answer_choices
|
||||
)
|
||||
|
||||
final_prompt = {
|
||||
"messages": [
|
||||
{"role": "system", "content": self.qa_system_prompt},
|
||||
{"role": "user", "content": self.qa_user_prompt},
|
||||
]
|
||||
}
|
||||
|
||||
return final_prompt
|
||||
|
||||
|
||||
class NLIPromptGenerator(PromptGenerator):
|
||||
"""Prompt format each data for inference."""
|
||||
|
||||
def __init__(self, enable_chain_of_thought=False):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
self.nli_user_prompt_template = (
|
||||
"Given the following two texts, your task is to determine the logical "
|
||||
"relationship between them. The first text is the 'premise' and the second "
|
||||
"text is the 'hypothesis'. The relationship should be labeled as one of the "
|
||||
"following: 'entailment' if the premise entails the hypothesis, 'contradiction' "
|
||||
"if the premise contradicts the hypothesis, or 'neutral' if the premise neither "
|
||||
"entails nor contradicts the hypothesis.\n\n"
|
||||
"Premise: {premise}\n"
|
||||
"Hypothesis: {hypothesis}"
|
||||
)
|
||||
if enable_chain_of_thought:
|
||||
self.nli_system_prompt = (
|
||||
"You are a helpful assistant. Write out in a step by step manner "
|
||||
"your reasoning about the answer using no more than 80 words. "
|
||||
"Based on the reasoning, produce the final answer. "
|
||||
"Your response should be in JSON format without using any backticks. "
|
||||
"The JSON is a dictionary whose keys are 'reason' and 'answer_choice'."
|
||||
)
|
||||
else:
|
||||
self.nli_system_prompt = (
|
||||
"You are a helpful assistant. "
|
||||
"Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'."
|
||||
)
|
||||
|
||||
def generate_prompt(self, nli_input):
|
||||
"""Generate the prompt."""
|
||||
premise, hypothesis = nli_input["premise"], nli_input["hypothesis"]
|
||||
|
||||
self.nli_user_prompt = self.nli_user_prompt_template.format(
|
||||
premise=premise, hypothesis=hypothesis
|
||||
)
|
||||
|
||||
final_prompt = {
|
||||
"messages": [
|
||||
{"role": "system", "content": self.nli_system_prompt},
|
||||
{"role": "user", "content": self.nli_user_prompt},
|
||||
]
|
||||
}
|
||||
|
||||
return final_prompt
|
||||
|
||||
|
||||
"""
|
||||
CONNECTION TO AZURE MODEL INFERENCING ENDPOINT CODE
|
||||
"""
|
||||
|
||||
|
||||
class AzureInference(ABC):
|
||||
"""Azure Inference class."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
|
||||
self.url = kwargs["url"]
|
||||
self.key = kwargs["key"]
|
||||
|
||||
def _invoke_endpoint(self, data):
|
||||
"""Invoke the endpoint."""
|
||||
print(f"inferencing: {self.url}")
|
||||
|
||||
response = requests.post(
|
||||
self.url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": ("Bearer " + self.key),
|
||||
},
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_inference(self, prompt):
|
||||
"""Invoke the inference."""
|
||||
response = self._invoke_endpoint(prompt)
|
||||
try:
|
||||
response_dict = json.loads(response.text)
|
||||
label = response_dict["choices"][0]["message"]["content"].strip().upper()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
label = "error"
|
||||
return label
|
||||
|
||||
|
||||
"""
|
||||
|
||||
SYNTHETIC DATASET BUILDER CLASSES
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SyntheticDatasetBuilder(ABC):
|
||||
"""Synthetic Dataset Builder class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
|
||||
def _write_to_file(self, data, fname):
|
||||
"""Write to file."""
|
||||
with open(fname, "w") as f:
|
||||
for sample in data:
|
||||
f.write(json.dumps(sample) + "\n")
|
||||
|
||||
def _is_json(self, json_str):
|
||||
"""Check if the string is a valid JSON."""
|
||||
try:
|
||||
json.loads(json_str)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class QASyntheticDatasetBuilder(SyntheticDatasetBuilder):
|
||||
"""Builds dataset with Predicted labels by LLM."""
|
||||
|
||||
def __init__(self, qa_prompt_builder, inference_pointer=None):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
self.valid_answer_choices = ["A", "B", "C", "D", "E"]
|
||||
self.inference_pointer = inference_pointer
|
||||
self.qa_prompt_builder = qa_prompt_builder
|
||||
|
||||
def build_dataset(
|
||||
self, dataset, file_name=None, write_labels_to_separate_file=False
|
||||
):
|
||||
"""Build the dataset."""
|
||||
if self.inference_pointer is not None or self.prompt_builder is not None:
|
||||
final_dataset = []
|
||||
no_label_dataset = []
|
||||
for row in dataset:
|
||||
prompt = self.qa_prompt_builder.generate_prompt(row)
|
||||
label = self.inference_pointer.invoke_inference(prompt)
|
||||
if self._is_json(label):
|
||||
label = json.loads(label.strip()).get("ANSWER_CHOICE")
|
||||
if label not in self.valid_answer_choices:
|
||||
continue
|
||||
new_content = {"role": "assistant", "content": label}
|
||||
no_label_dataset.append(prompt.copy())
|
||||
prompt["messages"].append(new_content)
|
||||
if write_labels_to_separate_file:
|
||||
final_dataset.append(new_content)
|
||||
else:
|
||||
final_dataset.append(prompt)
|
||||
if file_name is not None:
|
||||
if write_labels_to_separate_file:
|
||||
self._write_to_file(
|
||||
data=final_dataset, fname=file_name + "_label_" + ".jsonl"
|
||||
)
|
||||
self._write_to_file(
|
||||
data=no_label_dataset, fname=file_name + ".jsonl"
|
||||
)
|
||||
else:
|
||||
self._write_to_file(data=final_dataset, fname=file_name + ".jsonl")
|
||||
|
||||
print("Write to file complete")
|
||||
else:
|
||||
print("Please specify a valid endpoint first")
|
||||
|
||||
|
||||
class NLISyntheticDatasetBuilder(SyntheticDatasetBuilder):
|
||||
"""Builds dataset with Predicted labels by LLM."""
|
||||
|
||||
def __init__(self, nli_prompt_builder, inference_pointer=None):
|
||||
"""Initialize the class."""
|
||||
super().__init__()
|
||||
self.valid_labels = ["entailment", "contradiction", "neutral"]
|
||||
self.inference_pointer = inference_pointer
|
||||
self.nli_prompt_builder = nli_prompt_builder
|
||||
|
||||
def build_dataset(
|
||||
self, dataset, file_name=None, write_labels_to_separate_file=False
|
||||
):
|
||||
"""Build the dataset."""
|
||||
if self.inference_pointer is not None or self.nli_prompt_builder is not None:
|
||||
final_dataset = []
|
||||
no_label_dataset = []
|
||||
for row in dataset:
|
||||
llm_body, llm_output = self._get_output_from_model(row)
|
||||
llm_body_copy = llm_body.copy()
|
||||
llm_body["messages"].append(llm_output)
|
||||
if write_labels_to_separate_file:
|
||||
final_dataset.append(llm_body_copy)
|
||||
no_label_dataset.append(llm_output)
|
||||
else:
|
||||
final_dataset.append(llm_body)
|
||||
if file_name is not None:
|
||||
if write_labels_to_separate_file:
|
||||
self._write_to_file(
|
||||
data=no_label_dataset, fname=file_name + "_label_" + ".jsonl"
|
||||
)
|
||||
self._write_to_file(data=final_dataset, fname=file_name + ".jsonl")
|
||||
else:
|
||||
self._write_to_file(data=final_dataset, fname=file_name + ".jsonl")
|
||||
print("Write to file complete")
|
||||
else:
|
||||
print("Please specify a valid endpoint first")
|
||||
|
||||
def _get_output_from_model(self, data):
|
||||
"""Get the output from the model."""
|
||||
prompt = self.nli_prompt_builder.generate_prompt(data)
|
||||
# Invoke the LLama endpoint
|
||||
label = self.inference_pointer.invoke_inference(prompt).lower()
|
||||
# Note that the following code block should be commented out if you are disabling CoT
|
||||
# ---- Start ----
|
||||
if self._is_json(label):
|
||||
label = json.loads(label.strip()).get("answer_choice")
|
||||
# --- End ---
|
||||
if label not in self.valid_labels:
|
||||
print("Invalid label generated by model")
|
||||
return
|
||||
new_content = {"role": "assistant", "content": label}
|
||||
|
||||
return prompt, new_content
|
|
@ -0,0 +1,586 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Distillation with Large Language Models\n",
|
||||
" \n",
|
||||
"### Notebook details\n",
|
||||
" \n",
|
||||
"This sample demonstrates how to train the selected student model using the teacher model, resulting in the creation of the distilled model.\n",
|
||||
" \n",
|
||||
"We will use the Meta Llama 3.1 405B Instruct as the teacher model and the Meta Llama 3.1 8B Instruct as the student model.\n",
|
||||
" \n",
|
||||
"**Note :**\n",
|
||||
" \n",
|
||||
"- Distillation offering is only available in **West US 3** regions.\n",
|
||||
"- Distillation should only be used for single turn chat completion format.\n",
|
||||
"- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.\n",
|
||||
"- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.\n",
|
||||
"- Distllation is currently supported only for Natural Language Inference (NLI) task, which is a standard task in benchmarking for Natural Language Understanding."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install the SDK v2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install azure-ai-ml\n",
|
||||
"# %pip install azure-identity\n",
|
||||
"\n",
|
||||
"# %pip install mlflow\n",
|
||||
"# %pip install azureml-mlflow"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Import the required libraries"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import required libraries\n",
|
||||
"\n",
|
||||
"import base64\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n",
|
||||
"\n",
|
||||
"from azure.ai.ml import MLClient, Input\n",
|
||||
"from azure.ai.ml.constants import AssetTypes\n",
|
||||
"from azure.ai.ml.dsl import pipeline\n",
|
||||
"from azure.ai.ml.entities import Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prerequisites\n",
|
||||
"\n",
|
||||
"An AI Studio project in **West US 3** is required. Please follow [this](https://learn.microsoft.com/azure/ai-studio/how-to/fine-tune-model-llama?tabs=llama-two%2Cchatcompletion#prerequisites) document to setup your AI Studio project"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## AI Studio project settings\n",
|
||||
"\n",
|
||||
"Update following cell with the information of the AI Studio project just created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SUBSCRIPTION_ID = \"<SUBSCRIPTION>\"\n",
|
||||
"RESOURCE_GROUP = \"<RESOURCE_GROUP>\"\n",
|
||||
"AI_PROJECT_NAME = \"<AI_PROJECT_NAME>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure credential\n",
|
||||
"\n",
|
||||
"We are using `DefaultAzureCredential` to get access to workspace. \n",
|
||||
"`DefaultAzureCredential` should be capable of handling most Azure SDK authentication scenarios. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" credential = DefaultAzureCredential()\n",
|
||||
" # Check if given credential can get token successfully.\n",
|
||||
" credential.get_token(\"https://management.azure.com/.default\")\n",
|
||||
"except Exception as ex:\n",
|
||||
" # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work\n",
|
||||
" credential = InteractiveBrowserCredential()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Get handle to AI Studio project"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ml_client = MLClient(credential, SUBSCRIPTION_ID, RESOURCE_GROUP, AI_PROJECT_NAME)\n",
|
||||
"\n",
|
||||
"ai_project = ml_client._workspaces.get(ml_client.workspace_name)\n",
|
||||
"ai_project._workspace_id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Pick a teacher model\n",
|
||||
"\n",
|
||||
"We support **Meta-Llama-3.1-405B-Instruct** as the teacher model. \n",
|
||||
"### First deploy the teacher model in Azure AI Studio\n",
|
||||
"* Go to Azure AI Studio (ai.azure.com)\n",
|
||||
"* Select Meta-Llama-3.1-405B-Instruct model from Model catalog.\n",
|
||||
"* Deploy with \"Pay-as-you-go\"\n",
|
||||
"* Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.\n",
|
||||
"\n",
|
||||
"Update the following cell with the information of the deployment you just created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Llama-3-405B Teacher model endpoint name\n",
|
||||
"TEACHER_MODEL_NAME = \"Meta-Llama-3.1-405B-Instruct\"\n",
|
||||
"TEACHER_MODEL_ENDPOINT_NAME = \"<Please provide Meta Llama 3.1 405B endpoint name>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Pick a student model\n",
|
||||
"\n",
|
||||
"We will use **Meta-Llama-3.1-8B-Instruct** as student model. We only support chat completion models that are available for PayGo finetuning in Azure AI Studio."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"STUDENT_MODEL_NAME = \"Meta-Llama-3.1-8B-Instruct\"\n",
|
||||
"\n",
|
||||
"# retrieve student model from model registry\n",
|
||||
"mlclient_azureml_meta = MLClient(credential, registry_name=\"azureml-meta\")\n",
|
||||
"student_model = mlclient_azureml_meta.models.get(STUDENT_MODEL_NAME)\n",
|
||||
"\n",
|
||||
"print(\n",
|
||||
" \"\\n\\nUsing model name: {0}, version: {1}, id: {2} for fine tuning\".format(\n",
|
||||
" student_model.name, student_model.version, student_model.id\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Download the dataset from HuggingFace repo\n",
|
||||
"\n",
|
||||
"For our example, we download and use the ConjNLI dataset (https://huggingface.co/datasets/cestwc/conjnli) from HuggingFace."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"from abc import ABC\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class InputDataset(ABC):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" (\n",
|
||||
" self.train_data_file_name,\n",
|
||||
" self.test_data_file_name,\n",
|
||||
" self.eval_data_file_name,\n",
|
||||
" ) = (None, None, None)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class NLIHuggingFaceInputDataset(InputDataset):\n",
|
||||
" \"\"\"\n",
|
||||
" Loads the HuggingFace dataset\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" def load_hf_dataset(\n",
|
||||
" self,\n",
|
||||
" dataset_name,\n",
|
||||
" train_sample_size=10,\n",
|
||||
" val_sample_size=10,\n",
|
||||
" test_sample_size=10,\n",
|
||||
" train_split_name=\"train\",\n",
|
||||
" val_split_name=\"validation\",\n",
|
||||
" test_split_name=\"test\",\n",
|
||||
" ):\n",
|
||||
" full_dataset = load_dataset(dataset_name)\n",
|
||||
"\n",
|
||||
" if val_split_name is not None:\n",
|
||||
" train_data = full_dataset[train_split_name].select(range(train_sample_size))\n",
|
||||
" val_data = full_dataset[val_split_name].select(range(val_sample_size))\n",
|
||||
" test_data = full_dataset[test_split_name].select(range(test_sample_size))\n",
|
||||
" else:\n",
|
||||
" train_val_data = full_dataset[train_split_name].select(\n",
|
||||
" range(train_sample_size + val_sample_size)\n",
|
||||
" )\n",
|
||||
" train_data = train_val_data.select(range(train_sample_size))\n",
|
||||
" val_data = train_val_data.select(\n",
|
||||
" range(train_sample_size, train_sample_size + val_sample_size)\n",
|
||||
" )\n",
|
||||
" test_data = full_dataset[test_split_name].select(range(test_sample_size))\n",
|
||||
"\n",
|
||||
" return train_data, val_data, test_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We can define train and test sample sizes here. Validation size is kept same as test sample size\n",
|
||||
"train_sample_size = 512\n",
|
||||
"val_sample_size = 256\n",
|
||||
"\n",
|
||||
"# Sample notebook using the dataset: https://huggingface.co/datasets/cestwc/conjnli\n",
|
||||
"dataset_name = \"cestwc/conjnli\"\n",
|
||||
"input_dataset = NLIHuggingFaceInputDataset()\n",
|
||||
"\n",
|
||||
"# Note: train_split_name and test_split_name can vary by dataset. They are passed as arguments in load_hf_dataset.\n",
|
||||
"# If val_split_name is None, the below function will split the train set to create the specified sized validation set.\n",
|
||||
"train, val, _ = input_dataset.load_hf_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" train_sample_size=train_sample_size,\n",
|
||||
" val_sample_size=val_sample_size,\n",
|
||||
" train_split_name=\"adversarial\",\n",
|
||||
" val_split_name=None,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Len of train data sample is \" + str(len(train)))\n",
|
||||
"print(\"Len of validation data sample is \" + str(len(val)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data_path = \"data/train_conjnli_512.jsonl\"\n",
|
||||
"valid_data_path = \"data/valid_conjnli_256.jsonl\"\n",
|
||||
"\n",
|
||||
"for row in train:\n",
|
||||
" data = {\"messages\": []}\n",
|
||||
" data[\"messages\"].append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" data[\"messages\"].append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
|
||||
" + row[\"premise\"]\n",
|
||||
" + \"\\nHypothesis: \"\n",
|
||||
" + row[\"hypothesis\"],\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" with open(train_data_path, \"w\") as f:\n",
|
||||
" f.write(json.dumps(data) + \"\\n\")\n",
|
||||
"\n",
|
||||
"for row in val:\n",
|
||||
" data = {\"messages\": []}\n",
|
||||
" data[\"messages\"].append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.\",\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" data[\"messages\"].append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\\n\\nPremise: \"\n",
|
||||
" + row[\"premise\"]\n",
|
||||
" + \"\\nHypothesis: \"\n",
|
||||
" + row[\"hypothesis\"],\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" with open(valid_data_path, \"w\") as f:\n",
|
||||
" f.write(json.dumps(data) + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare data inputs\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = None\n",
|
||||
"train_data_name = \"nli_train_70-70\"\n",
|
||||
"\n",
|
||||
"train_data = ml_client.data.create_or_update(\n",
|
||||
" Data(\n",
|
||||
" path=train_data_path,\n",
|
||||
" type=AssetTypes.URI_FILE,\n",
|
||||
" description=\"Training dataset\",\n",
|
||||
" name=train_data_name,\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"train_data_asset_id = f\"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{train_data.name}/versions/{train_data.version}\"\n",
|
||||
"train_data_asset_id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"valid_data = None\n",
|
||||
"valid_data_name = \"nli_valid_70\"\n",
|
||||
"\n",
|
||||
"valid_data = ml_client.data.create_or_update(\n",
|
||||
" Data(\n",
|
||||
" path=valid_data_path,\n",
|
||||
" type=AssetTypes.URI_FILE,\n",
|
||||
" description=\"validation dataset\",\n",
|
||||
" name=valid_data_name,\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"valid_data_asset_id = f\"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{valid_data.name}/versions/{valid_data.version}\"\n",
|
||||
"valid_data_asset_id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Distillation strategy settings\n",
|
||||
"\n",
|
||||
"We provide the option to leverage Chain of Thought (CoT) reasoning for distillation. CoT leverages step by step reasoning ability of the teacher model to generate more accurate labels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ENABLE_CHAIN_OF_THOUGHT = \"true\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure distillation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mlclient_azureml = MLClient(credential, registry_name=\"azureml\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"distillation_pipeline_name = \"oss_distillation_pipeline\"\n",
|
||||
"distillation_pipeline_component = mlclient_azureml.components.get(\n",
|
||||
" name=distillation_pipeline_name\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@pipeline\n",
|
||||
"def distillation_pipeline(\n",
|
||||
" teacher_model_endpoint_name: str,\n",
|
||||
" enable_chain_of_thought: str,\n",
|
||||
" system_properties: str,\n",
|
||||
" input_finetune_model: Input,\n",
|
||||
" train_file_path: Input,\n",
|
||||
" validation_file_path: Input = None,\n",
|
||||
"):\n",
|
||||
" oss_distillation = distillation_pipeline_component(\n",
|
||||
" teacher_model_endpoint_name=teacher_model_endpoint_name,\n",
|
||||
" enable_chain_of_thought=enable_chain_of_thought,\n",
|
||||
" train_file_path=train_file_path,\n",
|
||||
" validation_file_path=validation_file_path,\n",
|
||||
" # Finetune\n",
|
||||
" mlflow_model_path=input_finetune_model,\n",
|
||||
" model_asset_id=student_model.id,\n",
|
||||
" system_properties=system_properties,\n",
|
||||
" ## hyperparams\n",
|
||||
" learning_rate=0.00002,\n",
|
||||
" per_device_train_batch_size=1,\n",
|
||||
" num_train_epochs=3,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\"output_model\": oss_distillation.outputs.output_model}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_properties = {\n",
|
||||
" \"finetune_oss\": \"True\",\n",
|
||||
" \"model_asset_id\": student_model.id,\n",
|
||||
" \"PipelineType\": \"Finetune\",\n",
|
||||
" \"azureml.PipelineType\": \"Finetune\",\n",
|
||||
" \"azureml.ModelName\": student_model.name,\n",
|
||||
" \"azureml.original_model_id\": student_model.id,\n",
|
||||
" \"azureml.trainingData.assetId\": train_data_asset_id,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"json_str = json.dumps(system_properties).replace(\" \", \"\")\n",
|
||||
"\n",
|
||||
"system_properties_b64_encoded = base64.b64encode(json_str.encode(\"utf-8\")).decode(\n",
|
||||
" \"utf-8\"\n",
|
||||
")\n",
|
||||
"print(f\"System properties => {system_properties_b64_encoded}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_file_path_input = Input(type=\"uri_file\", path=train_data.path)\n",
|
||||
"validation_file_path_input = Input(type=\"uri_file\", path=valid_data.path)\n",
|
||||
"input_finetune_model = Input(type=\"mlflow_model\", path=student_model.id)\n",
|
||||
"\n",
|
||||
"finetuning_job = distillation_pipeline(\n",
|
||||
" teacher_model_endpoint_name=TEACHER_MODEL_ENDPOINT_NAME,\n",
|
||||
" enable_chain_of_thought=ENABLE_CHAIN_OF_THOUGHT,\n",
|
||||
" system_properties=system_properties_b64_encoded,\n",
|
||||
" input_finetune_model=input_finetune_model,\n",
|
||||
" train_file_path=train_file_path_input,\n",
|
||||
" validation_file_path=validation_file_path_input,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"finetuning_job.properties.update(system_properties)\n",
|
||||
"print(f\"job property: {finetuning_job.properties}\")\n",
|
||||
"\n",
|
||||
"# pipeline_job.identity = UserIdentityConfiguration()\n",
|
||||
"finetuning_job.display_name = f\"finetune-{student_model.name}\"\n",
|
||||
"finetuning_job.experiment_name = f\"distillation-{TEACHER_MODEL_NAME}\"\n",
|
||||
"finetuning_job.settings.default_compute_type = \"serverless\"\n",
|
||||
"finetuning_job.continue_on_step_failure = False\n",
|
||||
"# pipeline_job.settings.force_rerun = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Submit pipeline job"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Submit pipeline job to workspace\n",
|
||||
"ft_job = ml_client.jobs.create_or_update(finetuning_job)\n",
|
||||
"# ft_job.studio_url\n",
|
||||
"\n",
|
||||
"# build link to ai studio fine-tuning tab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Consuming the distilled model\n",
|
||||
"\n",
|
||||
"Once the above job completes, you should be able to deploy the model and use it for inferencing. To deploy this model, do the following:\n",
|
||||
"\n",
|
||||
"* Go to AI Studio\n",
|
||||
"* Navigate to the Fine-tuning tab on the left menu\n",
|
||||
"* In the list of models you see, click on the model which got created from the distillation\n",
|
||||
"* This should take you to the details page where you can see the model attributes and other details\n",
|
||||
"* Click on the Deploy button on top of the page\n",
|
||||
"* Follow the steps to deploy the model"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
Загрузка…
Ссылка в новой задаче