зеркало из
1
0
Форкнуть 0
* Reverting models to make sure calls to the simulator work

* quotes

* Spellcheck fixes

* ignore the models for doc generation

* Fixed the quotes on f strings

* pylint skip file
This commit is contained in:
Nagkumar Arkalgud 2024-03-07 11:56:47 -08:00 коммит произвёл GitHub
Родитель 35461714e3
Коммит bda301f49e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 181 добавлений и 239 удалений

Просмотреть файл

@ -107,7 +107,7 @@ release = '2.0.0'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
exclude_patterns = ['_build', '*/synthetic/simulator/_model_tools/models.py']
# The reST default role (used for this markup: `text`) to use for all
# documents.

Просмотреть файл

@ -1,8 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from ast import literal_eval
# pylint: skip-file
import copy
import time
import asyncio
@ -10,12 +9,12 @@ import uuid
import logging
from urllib.parse import urlparse
from abc import ABC, abstractmethod
from typing import Deque, Dict, List, Optional, Union
from typing import Deque, Dict, List, Optional, Union, Sized
from collections import deque
from aiohttp import TraceConfig # pylint: disable=networking-import-outside-azure-core-transport
from aiohttp.web import HTTPException # pylint: disable=networking-import-outside-azure-core-transport
from aiohttp_retry import RetryClient, RandomRetry # pylint: disable=networking-import-outside-azure-core-transport
from aiohttp import TraceConfig
from aiohttp.web import HTTPException
from aiohttp_retry import RetryClient, RandomRetry
from .identity_manager import APITokenManager
from .images import replace_prompt_captions, format_multimodal_prompt
@ -25,24 +24,18 @@ MIN_ERRORS_TO_FAIL = 3
MAX_TIME_TAKEN_RECORDS = 20_000
def get_model_class_from_url(endpoint_url: str) -> type:
"""
Convert an endpoint URL to the appropriate model class.
:param endpoint_url: The URL of the endpoint.
:type endpoint_url: str
:return: The model class corresponding to the endpoint URL.
:rtype: type
"""
def get_model_class_from_url(endpoint_url: str):
"""Convert an endpoint URL to the appropriate model class."""
endpoint_path = urlparse(endpoint_url).path # remove query params
if endpoint_path.endswith("chat/completions"):
return OpenAIChatCompletionsModel
if "/rainbow" in endpoint_path:
elif "/rainbow" in endpoint_path:
return OpenAIMultiModalCompletionsModel
if endpoint_path.endswith("completions"):
elif endpoint_path.endswith("completions"):
return OpenAICompletionsModel
raise ValueError(f"Unknown API type for endpoint {endpoint_url}")
else:
raise ValueError(f"Unknown API type for endpoint {endpoint_url}")
# ===================== HTTP Retry ======================
@ -58,44 +51,43 @@ class AsyncHTTPClientWithRetry:
trace_config.on_request_end.append(self.on_request_end)
if retry_options is None:
retry_options = RandomRetry( # set up retry configuration
statuses=[104, 408, 409, 424, 429, 500, 502, 503, 504], # on which statuses to retry
statuses=[104, 408, 409, 424, 429, 500, 502,
503, 504], # on which statuses to retry
attempts=n_retry,
min_timeout=retry_timeout,
max_timeout=retry_timeout,
)
self.client = RetryClient(trace_configs=[trace_config], retry_options=retry_options)
self.client = RetryClient(
trace_configs=[trace_config], retry_options=retry_options)
async def on_request_start(self, trace_config_ctx, params):
async def on_request_start(self, session, trace_config_ctx, params):
current_attempt = trace_config_ctx.trace_request_ctx["current_attempt"]
self.logger.info("[ATTEMPT %s] Sending %s request to %s" % (current_attempt, params.method, params.url))
self.logger.info("[ATTEMPT %s] Sending %s request to %s" % (
current_attempt, params.method, params.url
))
async def on_request_end(self, trace_config_ctx, params):
async def on_request_end(self, session, trace_config_ctx, params):
current_attempt = trace_config_ctx.trace_request_ctx["current_attempt"]
request_headers = dict(params.response.request_info.headers)
if "Authorization" in request_headers:
del request_headers["Authorization"] # hide auth token from logs
if "api-key" in request_headers:
del request_headers["api-key"]
self.logger.info(
"[ATTEMPT %s] For %s request to %s, received response with status %s and request headers: %s"
% (current_attempt, params.method, params.url, params.response.status, request_headers)
)
self.logger.info("[ATTEMPT %s] For %s request to %s, received response with status %s and request headers: %s" % (
current_attempt, params.method, params.url, params.response.status, request_headers
))
# ===========================================================
# ===================== LLMBase Class =======================
# ===========================================================
class LLMBase(ABC):
"""
Base class for all LLM models.
"""
def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = None):
if additional_headers is None:
additional_headers = {}
def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = {}):
self.endpoint_url = endpoint_url
self.name = name
self.additional_headers = additional_headers
@ -103,7 +95,7 @@ class LLMBase(ABC):
# Metric tracking
self.lock = asyncio.Lock()
self.response_times: Deque[Union[int, float]] = deque(maxlen=MAX_TIME_TAKEN_RECORDS)
self.response_times: Deque[Union[int, float]] = deque(maxlen=MAX_TIME_TAKEN_RECORDS)
self.step = 0
self.error_count = 0
@ -124,13 +116,11 @@ class LLMBase(ABC):
"""
Query the model a single time with a prompt.
:param prompt: Prompt str to query model with.
:type prompt: str
:param session: aiohttp RetryClient object to use for the request.
:type session: RetryClient
:keyword **request_params: Additional parameters to pass to the request.
:return: Dictionary containing the completion response from the model.
:rtype: dict
Parameters
----------
prompt: Prompt str to query model with.
session: aiohttp RetryClient object to use for the request.
**request_params: Additional parameters to pass to the request.
"""
request_data = self.format_request_data(prompt, **request_params)
return await self.request_api(
@ -180,7 +170,7 @@ class LLMBase(ABC):
pass
def _log_request(self, request: dict) -> None:
self.logger.info("Request: %s", request)
self.logger.info(f"Request: {request}")
async def _add_successful_response(self, time_taken: Union[int, float]) -> None:
async with self.lock:
@ -220,37 +210,28 @@ class LLMBase(ABC):
# ================== OpenAICompletions ======================
# ===========================================================
class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attributes
class OpenAICompletionsModel(LLMBase):
"""
Object for calling a Completions-style API for OpenAI models.
"""
prompt_idx_key = "__prompt_idx__"
max_stop_tokens = 4
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
model_param_names = [
"model",
"temperature",
"max_tokens",
"top_p",
"n",
"frequency_penalty",
"presence_penalty",
"stop",
"model", "temperature", "max_tokens", "top_p", "n",
"frequency_penalty", "presence_penalty", "stop"
]
CHAT_START_TOKEN = "<|im_start|>"
CHAT_END_TOKEN = "<|im_end|>"
def __init__(
self,
*,
self, *,
endpoint_url: str,
name: str = "OpenAICompletionsModel",
additional_headers: Optional[dict] = None,
additional_headers: Optional[dict] = {},
api_version: Optional[str] = "2023-03-15-preview",
token_manager: APITokenManager,
azureml_model_deployment: Optional[str] = None,
@ -262,12 +243,9 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
frequency_penalty: Optional[float] = 0,
presence_penalty: Optional[float] = 0,
stop: Optional[Union[List[str], str]] = None,
image_captions: Optional[Dict[str, str]] = None,
# pylint: disable=unused-argument
image_captions: Dict[str, str] = {},
images_dir: Optional[str] = None, # Note: unused, kept for class compatibility
):
if additional_headers is None:
additional_headers = {}
super().__init__(endpoint_url=endpoint_url, name=name, additional_headers=additional_headers)
self.api_version = api_version
self.token_manager = token_manager
@ -279,15 +257,15 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
self.n = n
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.image_captions = image_captions if image_captions is not None else {}
self.image_captions = image_captions
# Default stop to end token if not provided
if not stop:
stop = []
# Else if stop sequence is given as a string (Ex: "["\n", "<im_end>"]"), convert
elif isinstance(stop, str) and stop.startswith("[") and stop.endswith("]"):
stop = literal_eval(stop)
elif isinstance(stop, str):
elif type(stop) is str and stop.startswith("[") and stop.endswith("]"):
stop = eval(stop)
elif type(stop) is str:
stop = [stop]
self.stop: List = stop # type: ignore[assignment]
@ -299,24 +277,18 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
self.stop.append(token)
if top_p not in [None, 1.0] and temperature is not None:
self.logger.warning(
"Both top_p and temperature are set. OpenAI advises against using both at the same time."
)
self.logger.warning("Both top_p and temperature are set. OpenAI advises against using both at the same time.")
self.logger.info(f"Default model settings: {self.get_model_params()}")
self.logger.info("Default model settings: %s", self.get_model_params())
def get_model_params(self):
return {param: getattr(self, param) for param in self.model_param_names if getattr(self, param) is not None}
def format_request_data(self, prompt: str, **request_params) -> Dict[str, str]:
"""
Format the request data for the OpenAI API.
:param prompt: The prompt string.
:type prompt: str
:keyword request_params: Additional parameters to pass to the model.
:return: The formatted request data.
:rtype: Dict[str, str]
"""
# Caption images if available
if len(self.image_captions.keys()):
@ -329,6 +301,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
request_data.update(request_params)
return request_data
async def get_conversation_completion(
self,
messages: List[dict],
@ -339,16 +312,12 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
"""
Query the model a single time with a message.
:param messages: List of messages to query the model with.
Expected format: [{"role": "user", "content": "Hello!"}, ...]
:type messages: List[dict]
:param session: aiohttp RetryClient object to query the model with.
:type session: RetryClient
:param role: Role of the user sending the message.
:type role: str
:keyword request_params: Additional parameters to pass to the model.
:return: Dictionary containing the completion response from the model.
:rtype: dict
Parameters
----------
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
session: aiohttp RetryClient object to query the model with.
role: Role of the user sending the message.
request_params: Additional parameters to pass to the model.
"""
prompt = []
for message in messages:
@ -362,6 +331,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
**request_params,
)
async def get_all_completions( # type: ignore[override]
self,
prompts: List[Dict[str, str]],
@ -374,29 +344,22 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
"""
Run a batch of prompts through the model and return the results in the order given.
:param prompts: List of prompts to query the model with.
:type prompts: List[Dict[str, str]]
:param session: aiohttp RetryClient to use for the request.
:type session: RetryClient
:param api_call_max_parallel_count: Number of parallel requests to make to the API.
:type api_call_max_parallel_count: int
:param api_call_delay_seconds: Number of seconds to wait between API requests.
:type api_call_delay_seconds: float
:param request_error_rate_threshold: Maximum error rate allowed before raising an error.
:type request_error_rate_threshold: float
:keyword request_params: Additional parameters to pass to the API.
:return: List of completion results.
:rtype: List[dict]
Parameters
----------
prompts: List of prompts to query the model with.
session: aiohttp RetryClient to use for the request.
api_call_max_parallel_count: Number of parallel requests to make to the API.
api_call_delay_seconds: Number of seconds to wait between API requests.
request_error_rate_threshold: Maximum error rate allowed before raising an error.
request_params: Additional parameters to pass to the API.
"""
if api_call_max_parallel_count > 1:
self.logger.info("Using %s parallel workers to query the API..", api_call_max_parallel_count)
self.logger.info(f"Using {api_call_max_parallel_count} parallel workers to query the API..")
# Format prompts and tag with index
request_datas: List[Dict] = []
for idx, prompt in enumerate(prompts):
prompt: Dict[str, str] = self.format_request_data( # type: ignore[no-redef]
prompt, **request_params # type: ignore[arg-type]
)
prompt: Dict[str, str] = self.format_request_data(prompt, **request_params) # type: ignore[no-redef,arg-type]
prompt[self.prompt_idx_key] = idx # type: ignore[assignment]
request_datas.append(prompt)
@ -406,22 +369,21 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
output_collector: List = []
tasks = [ # create a set of worker-tasks to query inference endpoint in parallel
asyncio.create_task(
self.request_api_parallel(
request_datas=request_datas,
output_collector=output_collector,
session=session,
api_call_delay_seconds=api_call_delay_seconds,
request_error_rate_threshold=request_error_rate_threshold,
)
)
asyncio.create_task(self.request_api_parallel(
request_datas=request_datas,
output_collector=output_collector,
session=session,
api_call_delay_seconds=api_call_delay_seconds,
request_error_rate_threshold=request_error_rate_threshold,
))
for _ in range(api_call_max_parallel_count)
]
# Await the completion of all tasks, and propagate any exceptions
await asyncio.gather(*tasks, return_exceptions=False)
if request_datas:
raise RuntimeError("All inference tasks were finished, but the queue is not empty")
if len(request_datas):
raise RuntimeError(
"All inference tasks were finished, but the queue is not empty")
# Output results back to the caller
output_collector.sort(key=lambda x: x[self.prompt_idx_key])
@ -429,6 +391,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
output.pop(self.prompt_idx_key)
return output_collector
async def request_api_parallel(
self,
request_datas: List[dict],
@ -439,21 +402,11 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
) -> None:
"""
Query the model for all prompts given as a list and append the output to output_collector.
:param request_datas: List of request data dictionaries.
:type request_datas: List[dict]
:param output_collector: List to store the output.
:type output_collector: List
:param session: RetryClient session.
:type session: RetryClient
:param api_call_delay_seconds: Delay between consecutive API calls in seconds.
:type api_call_delay_seconds: float, optional
:param request_error_rate_threshold: Threshold for request error rate.
:type request_error_rate_threshold: float, optional
No return value, output_collector is modified in place.
"""
logger_tasks: List = [] # to await for logging to finish
while True: # process data from queue until it's empty
while True: # process data from queue until it"s empty
try:
request_data = request_datas.pop()
prompt_idx = request_data.pop(self.prompt_idx_key)
@ -464,26 +417,24 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
request_data=request_data,
)
await self._add_successful_response(response["time_taken"])
except HTTPException as e:
except Exception as e:
response = {
"request": request_data,
"response": {
"finish_reason": "error",
"error": str(e),
},
}
}
await self._add_error()
self.logger.exception("Errored on prompt #%s", str(prompt_idx))
self.logger.exception(f"Errored on prompt #{prompt_idx}")
# if we count too many errors, we stop and raise an exception
response_count = await self.get_response_count()
error_rate = await self.get_error_rate()
if response_count >= MIN_ERRORS_TO_FAIL and error_rate >= request_error_rate_threshold:
error_msg = (
f"Error rate is more than {request_error_rate_threshold:.0%} -- something is broken!"
)
raise Exception(error_msg) from e
error_msg = f"Error rate is more than {request_error_rate_threshold:.0%} -- something is broken!"
raise Exception(error_msg)
response[self.prompt_idx_key] = prompt_idx
output_collector.append(response)
@ -496,6 +447,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
await asyncio.gather(*logger_tasks)
return
async def request_api(
self,
session: RetryClient,
@ -504,18 +456,16 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
"""
Request the model with a body of data.
:param session: HTTPS Session for invoking the endpoint.
:type session: RetryClient
:param request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.)
:type request_data: dict
:return: Response from the model.
:rtype: dict
Parameters
----------
session: HTTPS Session for invoking the endpoint.
request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.)
"""
self._log_request(request_data)
token = await self.token_manager.get_token()
headers = {
"Content-Type": "application/json",
"X-CV": f"{uuid.uuid4()}",
@ -542,21 +492,24 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
time_start = time.time()
full_response = None
async with session.post(url=self.endpoint_url, headers=headers, json=request_data, params=params) as response:
async with session.post(
url=self.endpoint_url,
headers=headers,
json=request_data,
params=params
) as response:
if response.status == 200:
response_data = await response.json()
self.logger.info("Response: %s", response_data)
self.logger.info(f"Response: {response_data}")
# Copy the full response and return it to be saved in jsonl.
full_response = copy.copy(response_data)
time_taken = time.time() - time_start
parsed_response = self._parse_response(response_data)
parsed_response = self._parse_response(response_data, request_data=request_data)
else:
raise HTTPException(
reason="Received unexpected HTTP status: {} {}".format(response.status, await response.text())
)
raise HTTPException(reason=f"Received unexpected HTTP status: {response.status} {await response.text()}")
return {
"request": request_data,
@ -565,7 +518,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
"full_response": full_response,
}
def _parse_response(self, response_data: dict) -> dict:
def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
# https://platform.openai.com/docs/api-reference/completions
samples = []
finish_reason = []
@ -575,36 +528,40 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
if "finish_reason" in choice:
finish_reason.append(choice["finish_reason"])
return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
return {
"samples": samples,
"finish_reason": finish_reason,
"id": response_data["id"]
}
# ===========================================================
# ============== OpenAIChatCompletionsModel =================
# ===========================================================
class OpenAIChatCompletionsModel(OpenAICompletionsModel):
"""
OpenAIChatCompletionsModel is a wrapper around OpenAICompletionsModel that
formats the prompt for chat completion.
"""
# pylint: disable=keyword-arg-before-vararg
def __init__(self, name="OpenAIChatCompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
# Caption images if available
if len(self.image_captions.keys()):
for message in prompt:
for message in messages:
message["content"] = replace_prompt_captions(
message["content"],
captions=self.image_captions,
)
request_data = {"messages": prompt, **self.get_model_params()}
request_data = {"messages": messages, **self.get_model_params()}
request_data.update(request_params)
return request_data
async def get_conversation_completion(
self,
messages: List[dict],
@ -615,16 +572,12 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
"""
Query the model a single time with a message.
:param messages: List of messages to query the model with.
Expected format: [{"role": "user", "content": "Hello!"}, ...]
:type messages: List[dict]
:param session: aiohttp RetryClient object to query the model with.
:type session: RetryClient
:param role: Not used for this model, since it is a chat model.
:type role: str
:keyword **request_params: Additional parameters to pass to the model.
:return: Dictionary containing the completion response.
:rtype: dict
Parameters
----------
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
session: aiohttp RetryClient object to query the model with.
role: Not used for this model, since it is a chat model.
request_params: Additional parameters to pass to the model.
"""
request_data = self.format_request_data(
messages=messages,
@ -635,6 +588,7 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
request_data=request_data,
)
async def get_completion(
self,
prompt: str,
@ -642,24 +596,26 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
**request_params,
) -> dict:
"""
Query a ChatCompletions model with a single prompt.
Query a ChatCompletions model with a single prompt. Note: entire message will be inserted into a "system" call.
:param prompt: Prompt str to query model with.
:type prompt: str
:param session: aiohttp RetryClient object to use for the request.
:type session: RetryClient
:keyword **request_params: Additional parameters to pass to the request.
:return: Dictionary containing the completion response.
:rtype: dict
Parameters
----------
prompt: Prompt str to query model with.
session: aiohttp RetryClient object to use for the request.
**request_params: Additional parameters to pass to the request.
"""
messages = [{"role": "system", "content": prompt}]
request_data = self.format_request_data(messages=messages, **request_params)
request_data = self.format_request_data(
messages=messages,
**request_params
)
return await self.request_api(
session=session,
request_data=request_data,
)
async def get_all_completions(
self,
prompts: List[str], # type: ignore[override]
@ -680,7 +636,8 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
**request_params,
)
def _parse_response(self, response_data: dict) -> dict:
def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
# https://platform.openai.com/docs/api-reference/chat
samples = []
finish_reason = []
@ -691,22 +648,23 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
if "message" in choice and "finish_reason" in choice["message"]:
finish_reason.append(choice["message"]["finish_reason"])
return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
return {
"samples": samples,
"finish_reason": finish_reason,
"id": response_data["id"]
}
# ===========================================================
# =========== OpenAIMultiModalCompletionsModel ==============
# ===========================================================
class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
"""
Wrapper around OpenAICompletionsModel that formats the prompt for multimodal
completions containing images.
"""
model_param_names = ["temperature", "max_tokens", "top_p", "n", "stop"]
# pylint: disable=keyword-arg-before-vararg
def __init__(self, name="OpenAIMultiModalCompletionsModel", images_dir: Optional[str] = None, *args, **kwargs):
self.images_dir = images_dir
@ -723,18 +681,15 @@ class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
request.update(request_params)
return request
def _log_request(self, request: dict) -> None:
"""
Log prompt, ignoring image data if multimodal.
:param request: The request dictionary.
:type request: dict
"""
def _log_request(self, request: dict) -> None:
"""Log prompt, ignoring image data if multimodal."""
loggable_prompt_transcript = {
"transcript": [
(c if c["type"] != "image" else {"type": "image", "data": "..."}) for c in request["transcript"]
(c if c["type"] != "image" else {"type": "image", "data": "..."})
for c in request["transcript"]
],
**{k: v for k, v in request.items() if k != "transcript"},
**{k: v for k, v in request.items() if k != "transcript"}
}
super()._log_request(loggable_prompt_transcript)
@ -743,13 +698,13 @@ class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
# ============== LLAMA CompletionsModel =====================
# ===========================================================
class LLAMACompletionsModel(OpenAICompletionsModel):
"""
Object for calling a Completions-style API for LLAMA models.
"""
# pylint: disable=keyword-arg-before-vararg
def __init__(self, name: str = "LLAMACompletionsModel", *args, **kwargs):
def __init__(
self, name: str = "LLAMACompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
self.token_manager.auth_header = "Bearer"
@ -757,12 +712,6 @@ class LLAMACompletionsModel(OpenAICompletionsModel):
def format_request_data(self, prompt: str, **request_params):
"""
Format the request data for the OpenAI API.
:param prompt: The prompt string.
:type prompt: str
:keyword request_params: Additional request parameters.
:return: The formatted request data.
:rtype: dict
"""
# Caption images if available
if len(self.image_captions.keys()):
@ -774,20 +723,19 @@ class LLAMACompletionsModel(OpenAICompletionsModel):
request_data = {
"input_data": {
"input_string": [prompt],
"parameters": {"temperature": self.temperature, "max_gen_len": self.max_tokens},
"parameters": {"temperature": self.temperature, "max_gen_len": self.max_tokens}
}
}
request_data.update(request_params)
return request_data
# pylint: disable=arguments-differ
def _parse_response(self, response_data: dict, request_data: dict) -> dict: # type: ignore[override]
prompt = request_data["input_data"]["input_string"][0]
# remove prompt text from each response as llama model returns prompt + completion instead of only completion
# remove any text after the stop tokens, since llama doesn't support stop token
for idx, _ in enumerate(response_data["samples"]):
# remove any text after the stop tokens, since llama does not support stop token
for idx, response in enumerate(response_data["samples"]):
response_data["samples"][idx] = response_data["samples"][idx].replace(prompt, "").strip()
for stop_token in self.stop:
if stop_token in response_data["samples"][idx]:
@ -813,71 +761,65 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
"""
LLaMa ChatCompletionsModel is a wrapper around LLaMaCompletionsModel that
formats the prompt for chat completion.
This chat completion model should be only used as assistant,
and shouldn't be used to simulate user. It is not possible
to pass a system prompt do describe how the model would behave,
So we only use the model as assistant to reply for questions made by GPT simulated users.
This chat completion model should be only used as assistant, and should not be used to simulate user. It is not possible
to pass a system prompt do describe how the model would behave, So we only use the model as assistant to reply for questions
made by GPT simulated users.
"""
# pylint: disable=keyword-arg-before-vararg
def __init__(self, name="LLAMAChatCompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
self.token_manager.auth_header = "Bearer"
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
# Caption images if available
if len(self.image_captions.keys()):
for message in prompt:
for message in messages:
message["content"] = replace_prompt_captions(
message["content"],
captions=self.image_captions,
)
# For LLaMa we don't pass the prompt (user persona) as a system message
# since LLama doesn't support system message
# LLama only supports user, and assistant messages.
# The messages sequence has to start with User message/ It can't have two user or
# For LLaMa we do not pass the prompt (user persona) as a system message since LLama does not support system message
# LLama only supports user, and assistant messages.
# The messages sequence has to start with User message/ It can not have two user or
# two assistant consecutive messages.
# so if we set the system meta prompt as a user message,
# and if we have the first two messages made by user then we
# so if we set the system meta prompt as a user message, and if we have the first two messages made by user then we
# combine the two messages in one message.
for _, x in enumerate(prompt):
for idx, x in enumerate(messages):
if x["role"] == "system":
x["role"] = "user"
if len(prompt) > 1 and prompt[0]["role"] == "user" and prompt[1]["role"] == "user":
prompt[0] = {"role": "user", "content": prompt[0]["content"] + "\n" + prompt[1]["content"]}
del prompt[1]
if len(messages) > 1 and messages[0]["role"] == "user" and messages[1]["role"] == "user":
messages[0] = {"role": "user", "content": messages[0]["content"] + "\n" + messages[1]["content"]}
del messages[1]
# request_data = {"messages": messages, **self.get_model_params()}
request_data = {
"input_data": {
"input_string": prompt,
"parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens},
},
"input_data":
{
"input_string": messages,
"parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens}
},
}
request_data.update(request_params)
return request_data
async def get_conversation_completion(
self,
messages: List[dict],
session: RetryClient,
role: str = "assistant",
**request_params,
self,
messages: List[dict],
session: RetryClient,
role: str = "assistant",
**request_params,
) -> dict:
"""
Query the model a single time with a message.
:param messages: List of messages to query the model with.
Expected format: [{"role": "user", "content": "Hello!"}, ...]
:type messages: List[dict]
:param session: aiohttp RetryClient object to query the model with.
:type session: RetryClient
:param role: Not used for this model, since it is a chat model.
:type role: str
:keyword request_params: Additional parameters to pass to the model.
:return: Dictionary containing the response from the model.
:rtype: dict
Parameters
----------
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
session: aiohttp RetryClient object to query the model with.
role: Not used for this model, since it is a chat model.
request_params: Additional parameters to pass to the model.
"""
request_data = self.format_request_data(
@ -889,7 +831,6 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
request_data=request_data,
)
# pylint: disable=arguments-differ
def _parse_response(self, response_data: dict) -> dict: # type: ignore[override]
# https://platform.openai.com/docs/api-reference/chat
samples = []
@ -903,4 +844,4 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
"samples": samples,
"finish_reason": finish_reason,
# "id": response_data["id"]
}
}

Просмотреть файл

@ -1,6 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
#pylint: skip-file
from typing import List
import uuid
import time
@ -30,8 +31,8 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
super().__init__(name=name, *args, **kwargs)
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
request_data = {"messages": prompt, **self.get_model_params()}
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
request_data = {"messages": messages, **self.get_model_params()}
request_data.update(request_params)
return request_data

Просмотреть файл

@ -1,7 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=E0401
# needed for 'list' type annotations on 3.8
from __future__ import annotations