fix simulator calling (#34678)
* Reverting models to make sure calls to the simulator work * quotes * Spellcheck fixes * ignore the models for doc generation * Fixed the quotes on f strings * pylint skip file
This commit is contained in:
Родитель
35461714e3
Коммит
bda301f49e
|
@ -107,7 +107,7 @@ release = '2.0.0'
|
|||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
exclude_patterns = ['_build']
|
||||
exclude_patterns = ['_build', '*/synthetic/simulator/_model_tools/models.py']
|
||||
|
||||
# The reST default role (used for this markup: `text`) to use for all
|
||||
# documents.
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# ---------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# ---------------------------------------------------------
|
||||
|
||||
from ast import literal_eval
|
||||
# pylint: skip-file
|
||||
import copy
|
||||
import time
|
||||
import asyncio
|
||||
|
@ -10,12 +9,12 @@ import uuid
|
|||
import logging
|
||||
from urllib.parse import urlparse
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Deque, Dict, List, Optional, Union
|
||||
from typing import Deque, Dict, List, Optional, Union, Sized
|
||||
from collections import deque
|
||||
|
||||
from aiohttp import TraceConfig # pylint: disable=networking-import-outside-azure-core-transport
|
||||
from aiohttp.web import HTTPException # pylint: disable=networking-import-outside-azure-core-transport
|
||||
from aiohttp_retry import RetryClient, RandomRetry # pylint: disable=networking-import-outside-azure-core-transport
|
||||
from aiohttp import TraceConfig
|
||||
from aiohttp.web import HTTPException
|
||||
from aiohttp_retry import RetryClient, RandomRetry
|
||||
|
||||
from .identity_manager import APITokenManager
|
||||
from .images import replace_prompt_captions, format_multimodal_prompt
|
||||
|
@ -25,24 +24,18 @@ MIN_ERRORS_TO_FAIL = 3
|
|||
MAX_TIME_TAKEN_RECORDS = 20_000
|
||||
|
||||
|
||||
def get_model_class_from_url(endpoint_url: str) -> type:
|
||||
"""
|
||||
Convert an endpoint URL to the appropriate model class.
|
||||
|
||||
:param endpoint_url: The URL of the endpoint.
|
||||
:type endpoint_url: str
|
||||
:return: The model class corresponding to the endpoint URL.
|
||||
:rtype: type
|
||||
"""
|
||||
def get_model_class_from_url(endpoint_url: str):
|
||||
"""Convert an endpoint URL to the appropriate model class."""
|
||||
endpoint_path = urlparse(endpoint_url).path # remove query params
|
||||
|
||||
if endpoint_path.endswith("chat/completions"):
|
||||
return OpenAIChatCompletionsModel
|
||||
if "/rainbow" in endpoint_path:
|
||||
elif "/rainbow" in endpoint_path:
|
||||
return OpenAIMultiModalCompletionsModel
|
||||
if endpoint_path.endswith("completions"):
|
||||
elif endpoint_path.endswith("completions"):
|
||||
return OpenAICompletionsModel
|
||||
raise ValueError(f"Unknown API type for endpoint {endpoint_url}")
|
||||
else:
|
||||
raise ValueError(f"Unknown API type for endpoint {endpoint_url}")
|
||||
|
||||
|
||||
# ===================== HTTP Retry ======================
|
||||
|
@ -58,44 +51,43 @@ class AsyncHTTPClientWithRetry:
|
|||
trace_config.on_request_end.append(self.on_request_end)
|
||||
if retry_options is None:
|
||||
retry_options = RandomRetry( # set up retry configuration
|
||||
statuses=[104, 408, 409, 424, 429, 500, 502, 503, 504], # on which statuses to retry
|
||||
statuses=[104, 408, 409, 424, 429, 500, 502,
|
||||
503, 504], # on which statuses to retry
|
||||
attempts=n_retry,
|
||||
min_timeout=retry_timeout,
|
||||
max_timeout=retry_timeout,
|
||||
)
|
||||
|
||||
self.client = RetryClient(trace_configs=[trace_config], retry_options=retry_options)
|
||||
self.client = RetryClient(
|
||||
trace_configs=[trace_config], retry_options=retry_options)
|
||||
|
||||
async def on_request_start(self, trace_config_ctx, params):
|
||||
async def on_request_start(self, session, trace_config_ctx, params):
|
||||
current_attempt = trace_config_ctx.trace_request_ctx["current_attempt"]
|
||||
self.logger.info("[ATTEMPT %s] Sending %s request to %s" % (current_attempt, params.method, params.url))
|
||||
self.logger.info("[ATTEMPT %s] Sending %s request to %s" % (
|
||||
current_attempt, params.method, params.url
|
||||
))
|
||||
|
||||
async def on_request_end(self, trace_config_ctx, params):
|
||||
async def on_request_end(self, session, trace_config_ctx, params):
|
||||
current_attempt = trace_config_ctx.trace_request_ctx["current_attempt"]
|
||||
request_headers = dict(params.response.request_info.headers)
|
||||
if "Authorization" in request_headers:
|
||||
del request_headers["Authorization"] # hide auth token from logs
|
||||
if "api-key" in request_headers:
|
||||
del request_headers["api-key"]
|
||||
self.logger.info(
|
||||
"[ATTEMPT %s] For %s request to %s, received response with status %s and request headers: %s"
|
||||
% (current_attempt, params.method, params.url, params.response.status, request_headers)
|
||||
)
|
||||
|
||||
self.logger.info("[ATTEMPT %s] For %s request to %s, received response with status %s and request headers: %s" % (
|
||||
current_attempt, params.method, params.url, params.response.status, request_headers
|
||||
))
|
||||
|
||||
# ===========================================================
|
||||
# ===================== LLMBase Class =======================
|
||||
# ===========================================================
|
||||
|
||||
|
||||
class LLMBase(ABC):
|
||||
"""
|
||||
Base class for all LLM models.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = None):
|
||||
if additional_headers is None:
|
||||
additional_headers = {}
|
||||
def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = {}):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.name = name
|
||||
self.additional_headers = additional_headers
|
||||
|
@ -103,7 +95,7 @@ class LLMBase(ABC):
|
|||
|
||||
# Metric tracking
|
||||
self.lock = asyncio.Lock()
|
||||
self.response_times: Deque[Union[int, float]] = deque(maxlen=MAX_TIME_TAKEN_RECORDS)
|
||||
self.response_times: Deque[Union[int, float]] = deque(maxlen=MAX_TIME_TAKEN_RECORDS)
|
||||
self.step = 0
|
||||
self.error_count = 0
|
||||
|
||||
|
@ -124,13 +116,11 @@ class LLMBase(ABC):
|
|||
"""
|
||||
Query the model a single time with a prompt.
|
||||
|
||||
:param prompt: Prompt str to query model with.
|
||||
:type prompt: str
|
||||
:param session: aiohttp RetryClient object to use for the request.
|
||||
:type session: RetryClient
|
||||
:keyword **request_params: Additional parameters to pass to the request.
|
||||
:return: Dictionary containing the completion response from the model.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
prompt: Prompt str to query model with.
|
||||
session: aiohttp RetryClient object to use for the request.
|
||||
**request_params: Additional parameters to pass to the request.
|
||||
"""
|
||||
request_data = self.format_request_data(prompt, **request_params)
|
||||
return await self.request_api(
|
||||
|
@ -180,7 +170,7 @@ class LLMBase(ABC):
|
|||
pass
|
||||
|
||||
def _log_request(self, request: dict) -> None:
|
||||
self.logger.info("Request: %s", request)
|
||||
self.logger.info(f"Request: {request}")
|
||||
|
||||
async def _add_successful_response(self, time_taken: Union[int, float]) -> None:
|
||||
async with self.lock:
|
||||
|
@ -220,37 +210,28 @@ class LLMBase(ABC):
|
|||
# ================== OpenAICompletions ======================
|
||||
# ===========================================================
|
||||
|
||||
|
||||
class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attributes
|
||||
class OpenAICompletionsModel(LLMBase):
|
||||
"""
|
||||
Object for calling a Completions-style API for OpenAI models.
|
||||
"""
|
||||
|
||||
prompt_idx_key = "__prompt_idx__"
|
||||
|
||||
max_stop_tokens = 4
|
||||
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
|
||||
|
||||
model_param_names = [
|
||||
"model",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"n",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"model", "temperature", "max_tokens", "top_p", "n",
|
||||
"frequency_penalty", "presence_penalty", "stop"
|
||||
]
|
||||
|
||||
CHAT_START_TOKEN = "<|im_start|>"
|
||||
CHAT_END_TOKEN = "<|im_end|>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
self, *,
|
||||
endpoint_url: str,
|
||||
name: str = "OpenAICompletionsModel",
|
||||
additional_headers: Optional[dict] = None,
|
||||
additional_headers: Optional[dict] = {},
|
||||
api_version: Optional[str] = "2023-03-15-preview",
|
||||
token_manager: APITokenManager,
|
||||
azureml_model_deployment: Optional[str] = None,
|
||||
|
@ -262,12 +243,9 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
frequency_penalty: Optional[float] = 0,
|
||||
presence_penalty: Optional[float] = 0,
|
||||
stop: Optional[Union[List[str], str]] = None,
|
||||
image_captions: Optional[Dict[str, str]] = None,
|
||||
# pylint: disable=unused-argument
|
||||
image_captions: Dict[str, str] = {},
|
||||
images_dir: Optional[str] = None, # Note: unused, kept for class compatibility
|
||||
):
|
||||
if additional_headers is None:
|
||||
additional_headers = {}
|
||||
super().__init__(endpoint_url=endpoint_url, name=name, additional_headers=additional_headers)
|
||||
self.api_version = api_version
|
||||
self.token_manager = token_manager
|
||||
|
@ -279,15 +257,15 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
self.n = n
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.image_captions = image_captions if image_captions is not None else {}
|
||||
self.image_captions = image_captions
|
||||
|
||||
# Default stop to end token if not provided
|
||||
if not stop:
|
||||
stop = []
|
||||
# Else if stop sequence is given as a string (Ex: "["\n", "<im_end>"]"), convert
|
||||
elif isinstance(stop, str) and stop.startswith("[") and stop.endswith("]"):
|
||||
stop = literal_eval(stop)
|
||||
elif isinstance(stop, str):
|
||||
elif type(stop) is str and stop.startswith("[") and stop.endswith("]"):
|
||||
stop = eval(stop)
|
||||
elif type(stop) is str:
|
||||
stop = [stop]
|
||||
self.stop: List = stop # type: ignore[assignment]
|
||||
|
||||
|
@ -299,24 +277,18 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
self.stop.append(token)
|
||||
|
||||
if top_p not in [None, 1.0] and temperature is not None:
|
||||
self.logger.warning(
|
||||
"Both top_p and temperature are set. OpenAI advises against using both at the same time."
|
||||
)
|
||||
self.logger.warning("Both top_p and temperature are set. OpenAI advises against using both at the same time.")
|
||||
|
||||
self.logger.info(f"Default model settings: {self.get_model_params()}")
|
||||
|
||||
self.logger.info("Default model settings: %s", self.get_model_params())
|
||||
|
||||
def get_model_params(self):
|
||||
return {param: getattr(self, param) for param in self.model_param_names if getattr(self, param) is not None}
|
||||
|
||||
|
||||
def format_request_data(self, prompt: str, **request_params) -> Dict[str, str]:
|
||||
"""
|
||||
Format the request data for the OpenAI API.
|
||||
|
||||
:param prompt: The prompt string.
|
||||
:type prompt: str
|
||||
:keyword request_params: Additional parameters to pass to the model.
|
||||
:return: The formatted request data.
|
||||
:rtype: Dict[str, str]
|
||||
"""
|
||||
# Caption images if available
|
||||
if len(self.image_captions.keys()):
|
||||
|
@ -329,6 +301,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
request_data.update(request_params)
|
||||
return request_data
|
||||
|
||||
|
||||
async def get_conversation_completion(
|
||||
self,
|
||||
messages: List[dict],
|
||||
|
@ -339,16 +312,12 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
"""
|
||||
Query the model a single time with a message.
|
||||
|
||||
:param messages: List of messages to query the model with.
|
||||
Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
:type messages: List[dict]
|
||||
:param session: aiohttp RetryClient object to query the model with.
|
||||
:type session: RetryClient
|
||||
:param role: Role of the user sending the message.
|
||||
:type role: str
|
||||
:keyword request_params: Additional parameters to pass to the model.
|
||||
:return: Dictionary containing the completion response from the model.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
session: aiohttp RetryClient object to query the model with.
|
||||
role: Role of the user sending the message.
|
||||
request_params: Additional parameters to pass to the model.
|
||||
"""
|
||||
prompt = []
|
||||
for message in messages:
|
||||
|
@ -362,6 +331,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
**request_params,
|
||||
)
|
||||
|
||||
|
||||
async def get_all_completions( # type: ignore[override]
|
||||
self,
|
||||
prompts: List[Dict[str, str]],
|
||||
|
@ -374,29 +344,22 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
"""
|
||||
Run a batch of prompts through the model and return the results in the order given.
|
||||
|
||||
:param prompts: List of prompts to query the model with.
|
||||
:type prompts: List[Dict[str, str]]
|
||||
:param session: aiohttp RetryClient to use for the request.
|
||||
:type session: RetryClient
|
||||
:param api_call_max_parallel_count: Number of parallel requests to make to the API.
|
||||
:type api_call_max_parallel_count: int
|
||||
:param api_call_delay_seconds: Number of seconds to wait between API requests.
|
||||
:type api_call_delay_seconds: float
|
||||
:param request_error_rate_threshold: Maximum error rate allowed before raising an error.
|
||||
:type request_error_rate_threshold: float
|
||||
:keyword request_params: Additional parameters to pass to the API.
|
||||
:return: List of completion results.
|
||||
:rtype: List[dict]
|
||||
Parameters
|
||||
----------
|
||||
prompts: List of prompts to query the model with.
|
||||
session: aiohttp RetryClient to use for the request.
|
||||
api_call_max_parallel_count: Number of parallel requests to make to the API.
|
||||
api_call_delay_seconds: Number of seconds to wait between API requests.
|
||||
request_error_rate_threshold: Maximum error rate allowed before raising an error.
|
||||
request_params: Additional parameters to pass to the API.
|
||||
"""
|
||||
if api_call_max_parallel_count > 1:
|
||||
self.logger.info("Using %s parallel workers to query the API..", api_call_max_parallel_count)
|
||||
self.logger.info(f"Using {api_call_max_parallel_count} parallel workers to query the API..")
|
||||
|
||||
# Format prompts and tag with index
|
||||
request_datas: List[Dict] = []
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt: Dict[str, str] = self.format_request_data( # type: ignore[no-redef]
|
||||
prompt, **request_params # type: ignore[arg-type]
|
||||
)
|
||||
prompt: Dict[str, str] = self.format_request_data(prompt, **request_params) # type: ignore[no-redef,arg-type]
|
||||
prompt[self.prompt_idx_key] = idx # type: ignore[assignment]
|
||||
request_datas.append(prompt)
|
||||
|
||||
|
@ -406,22 +369,21 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
|
||||
output_collector: List = []
|
||||
tasks = [ # create a set of worker-tasks to query inference endpoint in parallel
|
||||
asyncio.create_task(
|
||||
self.request_api_parallel(
|
||||
request_datas=request_datas,
|
||||
output_collector=output_collector,
|
||||
session=session,
|
||||
api_call_delay_seconds=api_call_delay_seconds,
|
||||
request_error_rate_threshold=request_error_rate_threshold,
|
||||
)
|
||||
)
|
||||
asyncio.create_task(self.request_api_parallel(
|
||||
request_datas=request_datas,
|
||||
output_collector=output_collector,
|
||||
session=session,
|
||||
api_call_delay_seconds=api_call_delay_seconds,
|
||||
request_error_rate_threshold=request_error_rate_threshold,
|
||||
))
|
||||
for _ in range(api_call_max_parallel_count)
|
||||
]
|
||||
|
||||
# Await the completion of all tasks, and propagate any exceptions
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
if request_datas:
|
||||
raise RuntimeError("All inference tasks were finished, but the queue is not empty")
|
||||
if len(request_datas):
|
||||
raise RuntimeError(
|
||||
"All inference tasks were finished, but the queue is not empty")
|
||||
|
||||
# Output results back to the caller
|
||||
output_collector.sort(key=lambda x: x[self.prompt_idx_key])
|
||||
|
@ -429,6 +391,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
output.pop(self.prompt_idx_key)
|
||||
return output_collector
|
||||
|
||||
|
||||
async def request_api_parallel(
|
||||
self,
|
||||
request_datas: List[dict],
|
||||
|
@ -439,21 +402,11 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
) -> None:
|
||||
"""
|
||||
Query the model for all prompts given as a list and append the output to output_collector.
|
||||
|
||||
:param request_datas: List of request data dictionaries.
|
||||
:type request_datas: List[dict]
|
||||
:param output_collector: List to store the output.
|
||||
:type output_collector: List
|
||||
:param session: RetryClient session.
|
||||
:type session: RetryClient
|
||||
:param api_call_delay_seconds: Delay between consecutive API calls in seconds.
|
||||
:type api_call_delay_seconds: float, optional
|
||||
:param request_error_rate_threshold: Threshold for request error rate.
|
||||
:type request_error_rate_threshold: float, optional
|
||||
No return value, output_collector is modified in place.
|
||||
"""
|
||||
logger_tasks: List = [] # to await for logging to finish
|
||||
|
||||
while True: # process data from queue until it's empty
|
||||
while True: # process data from queue until it"s empty
|
||||
try:
|
||||
request_data = request_datas.pop()
|
||||
prompt_idx = request_data.pop(self.prompt_idx_key)
|
||||
|
@ -464,26 +417,24 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
request_data=request_data,
|
||||
)
|
||||
await self._add_successful_response(response["time_taken"])
|
||||
except HTTPException as e:
|
||||
except Exception as e:
|
||||
response = {
|
||||
"request": request_data,
|
||||
"response": {
|
||||
"finish_reason": "error",
|
||||
"error": str(e),
|
||||
},
|
||||
}
|
||||
}
|
||||
await self._add_error()
|
||||
|
||||
self.logger.exception("Errored on prompt #%s", str(prompt_idx))
|
||||
self.logger.exception(f"Errored on prompt #{prompt_idx}")
|
||||
|
||||
# if we count too many errors, we stop and raise an exception
|
||||
response_count = await self.get_response_count()
|
||||
error_rate = await self.get_error_rate()
|
||||
if response_count >= MIN_ERRORS_TO_FAIL and error_rate >= request_error_rate_threshold:
|
||||
error_msg = (
|
||||
f"Error rate is more than {request_error_rate_threshold:.0%} -- something is broken!"
|
||||
)
|
||||
raise Exception(error_msg) from e
|
||||
error_msg = f"Error rate is more than {request_error_rate_threshold:.0%} -- something is broken!"
|
||||
raise Exception(error_msg)
|
||||
|
||||
response[self.prompt_idx_key] = prompt_idx
|
||||
output_collector.append(response)
|
||||
|
@ -496,6 +447,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
await asyncio.gather(*logger_tasks)
|
||||
return
|
||||
|
||||
|
||||
async def request_api(
|
||||
self,
|
||||
session: RetryClient,
|
||||
|
@ -504,18 +456,16 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
"""
|
||||
Request the model with a body of data.
|
||||
|
||||
:param session: HTTPS Session for invoking the endpoint.
|
||||
:type session: RetryClient
|
||||
:param request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.)
|
||||
:type request_data: dict
|
||||
:return: Response from the model.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
session: HTTPS Session for invoking the endpoint.
|
||||
request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.)
|
||||
"""
|
||||
|
||||
self._log_request(request_data)
|
||||
|
||||
token = await self.token_manager.get_token()
|
||||
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-CV": f"{uuid.uuid4()}",
|
||||
|
@ -542,21 +492,24 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
|
||||
time_start = time.time()
|
||||
full_response = None
|
||||
async with session.post(url=self.endpoint_url, headers=headers, json=request_data, params=params) as response:
|
||||
async with session.post(
|
||||
url=self.endpoint_url,
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
params=params
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
self.logger.info("Response: %s", response_data)
|
||||
self.logger.info(f"Response: {response_data}")
|
||||
|
||||
# Copy the full response and return it to be saved in jsonl.
|
||||
full_response = copy.copy(response_data)
|
||||
|
||||
time_taken = time.time() - time_start
|
||||
|
||||
parsed_response = self._parse_response(response_data)
|
||||
parsed_response = self._parse_response(response_data, request_data=request_data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
reason="Received unexpected HTTP status: {} {}".format(response.status, await response.text())
|
||||
)
|
||||
raise HTTPException(reason=f"Received unexpected HTTP status: {response.status} {await response.text()}")
|
||||
|
||||
return {
|
||||
"request": request_data,
|
||||
|
@ -565,7 +518,7 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
"full_response": full_response,
|
||||
}
|
||||
|
||||
def _parse_response(self, response_data: dict) -> dict:
|
||||
def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
|
||||
# https://platform.openai.com/docs/api-reference/completions
|
||||
samples = []
|
||||
finish_reason = []
|
||||
|
@ -575,36 +528,40 @@ class OpenAICompletionsModel(LLMBase): # pylint: disable=too-many-instance-attr
|
|||
if "finish_reason" in choice:
|
||||
finish_reason.append(choice["finish_reason"])
|
||||
|
||||
return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
|
||||
|
||||
return {
|
||||
"samples": samples,
|
||||
"finish_reason": finish_reason,
|
||||
"id": response_data["id"]
|
||||
}
|
||||
|
||||
# ===========================================================
|
||||
# ============== OpenAIChatCompletionsModel =================
|
||||
# ===========================================================
|
||||
|
||||
|
||||
class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
||||
"""
|
||||
OpenAIChatCompletionsModel is a wrapper around OpenAICompletionsModel that
|
||||
formats the prompt for chat completion.
|
||||
"""
|
||||
# pylint: disable=keyword-arg-before-vararg
|
||||
|
||||
def __init__(self, name="OpenAIChatCompletionsModel", *args, **kwargs):
|
||||
super().__init__(name=name, *args, **kwargs)
|
||||
|
||||
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
|
||||
|
||||
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
|
||||
# Caption images if available
|
||||
if len(self.image_captions.keys()):
|
||||
for message in prompt:
|
||||
for message in messages:
|
||||
message["content"] = replace_prompt_captions(
|
||||
message["content"],
|
||||
captions=self.image_captions,
|
||||
)
|
||||
|
||||
request_data = {"messages": prompt, **self.get_model_params()}
|
||||
request_data = {"messages": messages, **self.get_model_params()}
|
||||
request_data.update(request_params)
|
||||
return request_data
|
||||
|
||||
|
||||
async def get_conversation_completion(
|
||||
self,
|
||||
messages: List[dict],
|
||||
|
@ -615,16 +572,12 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
|||
"""
|
||||
Query the model a single time with a message.
|
||||
|
||||
:param messages: List of messages to query the model with.
|
||||
Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
:type messages: List[dict]
|
||||
:param session: aiohttp RetryClient object to query the model with.
|
||||
:type session: RetryClient
|
||||
:param role: Not used for this model, since it is a chat model.
|
||||
:type role: str
|
||||
:keyword **request_params: Additional parameters to pass to the model.
|
||||
:return: Dictionary containing the completion response.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
session: aiohttp RetryClient object to query the model with.
|
||||
role: Not used for this model, since it is a chat model.
|
||||
request_params: Additional parameters to pass to the model.
|
||||
"""
|
||||
request_data = self.format_request_data(
|
||||
messages=messages,
|
||||
|
@ -635,6 +588,7 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
|||
request_data=request_data,
|
||||
)
|
||||
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -642,24 +596,26 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
|||
**request_params,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a ChatCompletions model with a single prompt.
|
||||
Query a ChatCompletions model with a single prompt. Note: entire message will be inserted into a "system" call.
|
||||
|
||||
:param prompt: Prompt str to query model with.
|
||||
:type prompt: str
|
||||
:param session: aiohttp RetryClient object to use for the request.
|
||||
:type session: RetryClient
|
||||
:keyword **request_params: Additional parameters to pass to the request.
|
||||
:return: Dictionary containing the completion response.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
prompt: Prompt str to query model with.
|
||||
session: aiohttp RetryClient object to use for the request.
|
||||
**request_params: Additional parameters to pass to the request.
|
||||
"""
|
||||
messages = [{"role": "system", "content": prompt}]
|
||||
|
||||
request_data = self.format_request_data(messages=messages, **request_params)
|
||||
request_data = self.format_request_data(
|
||||
messages=messages,
|
||||
**request_params
|
||||
)
|
||||
return await self.request_api(
|
||||
session=session,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
|
||||
async def get_all_completions(
|
||||
self,
|
||||
prompts: List[str], # type: ignore[override]
|
||||
|
@ -680,7 +636,8 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
|||
**request_params,
|
||||
)
|
||||
|
||||
def _parse_response(self, response_data: dict) -> dict:
|
||||
|
||||
def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
|
||||
# https://platform.openai.com/docs/api-reference/chat
|
||||
samples = []
|
||||
finish_reason = []
|
||||
|
@ -691,22 +648,23 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
|
|||
if "message" in choice and "finish_reason" in choice["message"]:
|
||||
finish_reason.append(choice["message"]["finish_reason"])
|
||||
|
||||
return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
|
||||
|
||||
return {
|
||||
"samples": samples,
|
||||
"finish_reason": finish_reason,
|
||||
"id": response_data["id"]
|
||||
}
|
||||
|
||||
# ===========================================================
|
||||
# =========== OpenAIMultiModalCompletionsModel ==============
|
||||
# ===========================================================
|
||||
|
||||
|
||||
class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
|
||||
"""
|
||||
Wrapper around OpenAICompletionsModel that formats the prompt for multimodal
|
||||
completions containing images.
|
||||
"""
|
||||
|
||||
model_param_names = ["temperature", "max_tokens", "top_p", "n", "stop"]
|
||||
# pylint: disable=keyword-arg-before-vararg
|
||||
|
||||
def __init__(self, name="OpenAIMultiModalCompletionsModel", images_dir: Optional[str] = None, *args, **kwargs):
|
||||
self.images_dir = images_dir
|
||||
|
||||
|
@ -723,18 +681,15 @@ class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
|
|||
request.update(request_params)
|
||||
return request
|
||||
|
||||
def _log_request(self, request: dict) -> None:
|
||||
"""
|
||||
Log prompt, ignoring image data if multimodal.
|
||||
|
||||
:param request: The request dictionary.
|
||||
:type request: dict
|
||||
"""
|
||||
def _log_request(self, request: dict) -> None:
|
||||
"""Log prompt, ignoring image data if multimodal."""
|
||||
loggable_prompt_transcript = {
|
||||
"transcript": [
|
||||
(c if c["type"] != "image" else {"type": "image", "data": "..."}) for c in request["transcript"]
|
||||
(c if c["type"] != "image" else {"type": "image", "data": "..."})
|
||||
for c in request["transcript"]
|
||||
],
|
||||
**{k: v for k, v in request.items() if k != "transcript"},
|
||||
**{k: v for k, v in request.items() if k != "transcript"}
|
||||
}
|
||||
super()._log_request(loggable_prompt_transcript)
|
||||
|
||||
|
@ -743,13 +698,13 @@ class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
|
|||
# ============== LLAMA CompletionsModel =====================
|
||||
# ===========================================================
|
||||
|
||||
|
||||
class LLAMACompletionsModel(OpenAICompletionsModel):
|
||||
"""
|
||||
Object for calling a Completions-style API for LLAMA models.
|
||||
"""
|
||||
# pylint: disable=keyword-arg-before-vararg
|
||||
def __init__(self, name: str = "LLAMACompletionsModel", *args, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self, name: str = "LLAMACompletionsModel", *args, **kwargs):
|
||||
super().__init__(name=name, *args, **kwargs)
|
||||
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
|
||||
self.token_manager.auth_header = "Bearer"
|
||||
|
@ -757,12 +712,6 @@ class LLAMACompletionsModel(OpenAICompletionsModel):
|
|||
def format_request_data(self, prompt: str, **request_params):
|
||||
"""
|
||||
Format the request data for the OpenAI API.
|
||||
|
||||
:param prompt: The prompt string.
|
||||
:type prompt: str
|
||||
:keyword request_params: Additional request parameters.
|
||||
:return: The formatted request data.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Caption images if available
|
||||
if len(self.image_captions.keys()):
|
||||
|
@ -774,20 +723,19 @@ class LLAMACompletionsModel(OpenAICompletionsModel):
|
|||
request_data = {
|
||||
"input_data": {
|
||||
"input_string": [prompt],
|
||||
"parameters": {"temperature": self.temperature, "max_gen_len": self.max_tokens},
|
||||
"parameters": {"temperature": self.temperature, "max_gen_len": self.max_tokens}
|
||||
}
|
||||
}
|
||||
|
||||
request_data.update(request_params)
|
||||
return request_data
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def _parse_response(self, response_data: dict, request_data: dict) -> dict: # type: ignore[override]
|
||||
prompt = request_data["input_data"]["input_string"][0]
|
||||
|
||||
# remove prompt text from each response as llama model returns prompt + completion instead of only completion
|
||||
# remove any text after the stop tokens, since llama doesn't support stop token
|
||||
for idx, _ in enumerate(response_data["samples"]):
|
||||
# remove any text after the stop tokens, since llama does not support stop token
|
||||
for idx, response in enumerate(response_data["samples"]):
|
||||
response_data["samples"][idx] = response_data["samples"][idx].replace(prompt, "").strip()
|
||||
for stop_token in self.stop:
|
||||
if stop_token in response_data["samples"][idx]:
|
||||
|
@ -813,71 +761,65 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
|
|||
"""
|
||||
LLaMa ChatCompletionsModel is a wrapper around LLaMaCompletionsModel that
|
||||
formats the prompt for chat completion.
|
||||
This chat completion model should be only used as assistant,
|
||||
and shouldn't be used to simulate user. It is not possible
|
||||
to pass a system prompt do describe how the model would behave,
|
||||
So we only use the model as assistant to reply for questions made by GPT simulated users.
|
||||
This chat completion model should be only used as assistant, and should not be used to simulate user. It is not possible
|
||||
to pass a system prompt do describe how the model would behave, So we only use the model as assistant to reply for questions
|
||||
made by GPT simulated users.
|
||||
"""
|
||||
# pylint: disable=keyword-arg-before-vararg
|
||||
|
||||
def __init__(self, name="LLAMAChatCompletionsModel", *args, **kwargs):
|
||||
super().__init__(name=name, *args, **kwargs)
|
||||
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
|
||||
self.token_manager.auth_header = "Bearer"
|
||||
|
||||
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
|
||||
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
|
||||
# Caption images if available
|
||||
if len(self.image_captions.keys()):
|
||||
for message in prompt:
|
||||
for message in messages:
|
||||
message["content"] = replace_prompt_captions(
|
||||
message["content"],
|
||||
captions=self.image_captions,
|
||||
)
|
||||
|
||||
# For LLaMa we don't pass the prompt (user persona) as a system message
|
||||
# since LLama doesn't support system message
|
||||
# LLama only supports user, and assistant messages.
|
||||
# The messages sequence has to start with User message/ It can't have two user or
|
||||
# For LLaMa we do not pass the prompt (user persona) as a system message since LLama does not support system message
|
||||
# LLama only supports user, and assistant messages.
|
||||
# The messages sequence has to start with User message/ It can not have two user or
|
||||
# two assistant consecutive messages.
|
||||
# so if we set the system meta prompt as a user message,
|
||||
# and if we have the first two messages made by user then we
|
||||
# so if we set the system meta prompt as a user message, and if we have the first two messages made by user then we
|
||||
# combine the two messages in one message.
|
||||
for _, x in enumerate(prompt):
|
||||
for idx, x in enumerate(messages):
|
||||
if x["role"] == "system":
|
||||
x["role"] = "user"
|
||||
if len(prompt) > 1 and prompt[0]["role"] == "user" and prompt[1]["role"] == "user":
|
||||
prompt[0] = {"role": "user", "content": prompt[0]["content"] + "\n" + prompt[1]["content"]}
|
||||
del prompt[1]
|
||||
if len(messages) > 1 and messages[0]["role"] == "user" and messages[1]["role"] == "user":
|
||||
messages[0] = {"role": "user", "content": messages[0]["content"] + "\n" + messages[1]["content"]}
|
||||
del messages[1]
|
||||
|
||||
# request_data = {"messages": messages, **self.get_model_params()}
|
||||
request_data = {
|
||||
"input_data": {
|
||||
"input_string": prompt,
|
||||
"parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens},
|
||||
},
|
||||
"input_data":
|
||||
{
|
||||
"input_string": messages,
|
||||
"parameters": {"temperature": self.temperature, "max_new_tokens": self.max_tokens}
|
||||
},
|
||||
}
|
||||
request_data.update(request_params)
|
||||
return request_data
|
||||
|
||||
async def get_conversation_completion(
|
||||
self,
|
||||
messages: List[dict],
|
||||
session: RetryClient,
|
||||
role: str = "assistant",
|
||||
**request_params,
|
||||
self,
|
||||
messages: List[dict],
|
||||
session: RetryClient,
|
||||
role: str = "assistant",
|
||||
**request_params,
|
||||
) -> dict:
|
||||
"""
|
||||
Query the model a single time with a message.
|
||||
|
||||
:param messages: List of messages to query the model with.
|
||||
Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
:type messages: List[dict]
|
||||
:param session: aiohttp RetryClient object to query the model with.
|
||||
:type session: RetryClient
|
||||
:param role: Not used for this model, since it is a chat model.
|
||||
:type role: str
|
||||
:keyword request_params: Additional parameters to pass to the model.
|
||||
:return: Dictionary containing the response from the model.
|
||||
:rtype: dict
|
||||
Parameters
|
||||
----------
|
||||
messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...]
|
||||
session: aiohttp RetryClient object to query the model with.
|
||||
role: Not used for this model, since it is a chat model.
|
||||
request_params: Additional parameters to pass to the model.
|
||||
"""
|
||||
|
||||
request_data = self.format_request_data(
|
||||
|
@ -889,7 +831,6 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
|
|||
request_data=request_data,
|
||||
)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def _parse_response(self, response_data: dict) -> dict: # type: ignore[override]
|
||||
# https://platform.openai.com/docs/api-reference/chat
|
||||
samples = []
|
||||
|
@ -903,4 +844,4 @@ class LLAMAChatCompletionsModel(LLAMACompletionsModel):
|
|||
"samples": samples,
|
||||
"finish_reason": finish_reason,
|
||||
# "id": response_data["id"]
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
# ---------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# ---------------------------------------------------------
|
||||
#pylint: skip-file
|
||||
from typing import List
|
||||
import uuid
|
||||
import time
|
||||
|
@ -30,8 +31,8 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
|
|||
|
||||
super().__init__(name=name, *args, **kwargs)
|
||||
|
||||
def format_request_data(self, prompt: List[dict], **request_params): # type: ignore[override]
|
||||
request_data = {"messages": prompt, **self.get_model_params()}
|
||||
def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
|
||||
request_data = {"messages": messages, **self.get_model_params()}
|
||||
request_data.update(request_params)
|
||||
return request_data
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# ---------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# ---------------------------------------------------------
|
||||
|
||||
# pylint: disable=E0401
|
||||
# needed for 'list' type annotations on 3.8
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче