Refresh batch score v2 data schema to align with OpenAI public API (#2768)
* Refresh batch score v2 data schema to align with OpenAI public API * bump version * Styling fixes * Add missing file * Update git ignore * Fix styling * Fix
This commit is contained in:
Родитель
27b8fbbb20
Коммит
9a7fc2d2fa
|
@ -139,4 +139,3 @@ mlruns/
|
|||
|
||||
# ignore config files
|
||||
config.json
|
||||
out*
|
|
@ -2,7 +2,7 @@ $schema: http://azureml/sdk-2-0/ParallelComponent.json
|
|||
type: parallel
|
||||
|
||||
name: batch_score_llm
|
||||
version: 1.1.5
|
||||
version: 1.1.6
|
||||
display_name: Batch Score Large Language Models
|
||||
is_deterministic: False
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
|
|||
if response_status == 200:
|
||||
return self._create_scoring_result(
|
||||
status=ScoringResultStatus.SUCCESS,
|
||||
model_response_code=response_status,
|
||||
scoring_request=scoring_request,
|
||||
start=start,
|
||||
end=end,
|
||||
|
@ -78,6 +79,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
|
|||
|
||||
result = self._create_scoring_result(
|
||||
status=ScoringResultStatus.FAILURE,
|
||||
model_response_code=response_status,
|
||||
scoring_request=scoring_request,
|
||||
start=start,
|
||||
end=end,
|
||||
|
@ -130,6 +132,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
|
|||
except Exception:
|
||||
return self._create_scoring_result(
|
||||
status=ScoringResultStatus.FAILURE,
|
||||
model_response_code=http_response.status,
|
||||
scoring_request=scoring_request,
|
||||
start=start,
|
||||
end=end,
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from ...utils.common import convert_result_list
|
||||
from ...utils.output_formatter import OutputFormatter
|
||||
from ...utils.v1_output_formatter import V1OutputFormatter
|
||||
from ...utils.v2_output_formatter import V2OutputFormatter
|
||||
from ..configuration.configuration import Configuration
|
||||
from ..post_processing.mini_batch_context import MiniBatchContext
|
||||
from ..post_processing.result_utils import apply_input_transformer
|
||||
|
@ -48,7 +50,12 @@ class Parallel:
|
|||
|
||||
apply_input_transformer(self.__input_to_output_transformer, scoring_results)
|
||||
|
||||
results = convert_result_list(
|
||||
output_formatter: OutputFormatter
|
||||
if self._configuration.input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
results = output_formatter.format_output(
|
||||
results=scoring_results,
|
||||
batch_size_per_request=self._configuration.batch_size_per_request)
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
import traceback
|
||||
|
||||
from ...utils.common import convert_result_list
|
||||
from ...utils.output_formatter import OutputFormatter
|
||||
from ...utils.v1_output_formatter import V1OutputFormatter
|
||||
from ...utils.v2_output_formatter import V2OutputFormatter
|
||||
from ..configuration.configuration import Configuration
|
||||
from ..scoring.scoring_result import ScoringResult
|
||||
from ..telemetry import logging_utils as lu
|
||||
|
@ -16,7 +18,7 @@ from .result_utils import (
|
|||
apply_input_transformer,
|
||||
get_return_value,
|
||||
)
|
||||
from .output_handler import SingleFileOutputHandler, SeparateFileOutputHandler
|
||||
from .output_handler import OutputHandler
|
||||
|
||||
|
||||
def add_callback(callback, cur):
|
||||
|
@ -33,9 +35,11 @@ class CallbackFactory:
|
|||
|
||||
def __init__(self,
|
||||
configuration: Configuration,
|
||||
output_handler: OutputHandler,
|
||||
input_to_output_transformer):
|
||||
"""Initialize CallbackFactory."""
|
||||
self._configuration = configuration
|
||||
self._output_handler = output_handler
|
||||
self.__input_to_output_transformer = input_to_output_transformer
|
||||
|
||||
def generate_callback(self):
|
||||
|
@ -46,7 +50,12 @@ class CallbackFactory:
|
|||
return callback
|
||||
|
||||
def _convert_result_list(self, scoring_results: "list[ScoringResult]", mini_batch_context: MiniBatchContext):
|
||||
return convert_result_list(
|
||||
output_formatter: OutputFormatter
|
||||
if self._configuration.input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
return output_formatter.format_output(
|
||||
results=scoring_results,
|
||||
batch_size_per_request=self._configuration.batch_size_per_request)
|
||||
|
||||
|
@ -64,13 +73,7 @@ class CallbackFactory:
|
|||
if mini_batch_context.exception is None:
|
||||
if self._configuration.save_mini_batch_results == "enabled":
|
||||
lu.get_logger().info("save_mini_batch_results is enabled")
|
||||
if (self._configuration.split_output):
|
||||
output_handler = SeparateFileOutputHandler()
|
||||
lu.get_logger().info("Saving successful results and errors to separate files")
|
||||
else:
|
||||
output_handler = SingleFileOutputHandler()
|
||||
lu.get_logger().info("Saving results to single file")
|
||||
output_handler.save_mini_batch_results(
|
||||
self._output_handler.save_mini_batch_results(
|
||||
scoring_results,
|
||||
self._configuration.mini_batch_results_out_directory,
|
||||
mini_batch_context.raw_mini_batch_context
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""This file contains the data class for Azure OpenAI scoring error."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AoaiScoringError:
|
||||
"""Azure OpenAI scoring error."""
|
||||
|
||||
code: str = None
|
||||
message: str = None
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""This file contains the data class for Azure OpenAI scoring response."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AoaiScoringResponse:
|
||||
"""Azure OpenAI scoring response."""
|
||||
|
||||
body: any = None
|
||||
request_id: str = None
|
||||
status_code: int = None
|
|
@ -32,6 +32,7 @@ class HttpResponseHandler:
|
|||
def _create_scoring_result(
|
||||
self,
|
||||
status: ScoringResultStatus,
|
||||
model_response_code: int,
|
||||
scoring_request: ScoringRequest,
|
||||
start: float,
|
||||
end: float,
|
||||
|
@ -43,6 +44,7 @@ class HttpResponseHandler:
|
|||
status=status,
|
||||
start=start,
|
||||
end=end,
|
||||
model_response_code=model_response_code,
|
||||
request_obj=scoring_request.original_payload_obj,
|
||||
request_metadata=scoring_request.request_metadata,
|
||||
response_body=http_post_response.payload,
|
||||
|
|
|
@ -17,6 +17,7 @@ class ScoringRequest:
|
|||
|
||||
__BATCH_REQUEST_METADATA = "_batch_request_metadata"
|
||||
__REQUEST_METADATA = "request_metadata"
|
||||
__CUSTOM_ID = "custom_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -55,6 +56,8 @@ class ScoringRequest:
|
|||
# These properties do not need to be sent to the model & will be added to the output file directly
|
||||
self.__request_metadata = self.__cleaned_payload_obj.pop(self.__BATCH_REQUEST_METADATA, None)
|
||||
self.__request_metadata = self.__cleaned_payload_obj.pop(self.__REQUEST_METADATA, self.__request_metadata)
|
||||
# If custom_id exists (V2 input schema), make sure it is not sent to MIR endpoint
|
||||
self.__CUSTOM_ID = self.__cleaned_payload_obj.pop(self.__CUSTOM_ID, None)
|
||||
|
||||
self.__cleaned_payload = json.dumps(self.__cleaned_payload_obj, cls=BatchComponentJSONEncoder)
|
||||
self.__loggable_payload = json.dumps(self.__loggable_payload_obj, cls=BatchComponentJSONEncoder)
|
||||
|
@ -136,6 +139,12 @@ class ScoringRequest:
|
|||
"""Get the segment id."""
|
||||
return self.__segment_id
|
||||
|
||||
# read-only
|
||||
@property
|
||||
def custom_id(self) -> str:
|
||||
"""Get the custom id. Only valid for V2 input schema."""
|
||||
return self.__CUSTOM_ID
|
||||
|
||||
@estimated_cost.setter
|
||||
def estimated_cost(self, cost: int):
|
||||
"""Set the estimated cost."""
|
||||
|
|
|
@ -58,6 +58,7 @@ class ScoringResult:
|
|||
status: ScoringResultStatus,
|
||||
start: float,
|
||||
end: float,
|
||||
model_response_code: int,
|
||||
request_obj: any,
|
||||
request_metadata: any,
|
||||
response_body: any,
|
||||
|
@ -70,6 +71,7 @@ class ScoringResult:
|
|||
self.status = status
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.model_response_code = model_response_code
|
||||
self.request_obj = request_obj # Normalize to json
|
||||
self.request_metadata = request_metadata
|
||||
self.response_body = response_body
|
||||
|
@ -121,6 +123,7 @@ class ScoringResult:
|
|||
status=ScoringResultStatus.FAILURE,
|
||||
start=0,
|
||||
end=0,
|
||||
model_response_code=None,
|
||||
request_obj=scoring_request.original_payload_obj if scoring_request else None,
|
||||
request_metadata=scoring_request.request_metadata if scoring_request else None,
|
||||
response_body=None,
|
||||
|
@ -140,6 +143,7 @@ class ScoringResult:
|
|||
self.status,
|
||||
self.start,
|
||||
self.end,
|
||||
self.model_response_code,
|
||||
self.request_obj,
|
||||
self.request_metadata,
|
||||
deepcopy(self.response_body),
|
||||
|
|
|
@ -90,6 +90,7 @@ def init():
|
|||
global par
|
||||
global configuration
|
||||
global input_handler
|
||||
global output_handler
|
||||
|
||||
start = time.time()
|
||||
parser = ConfigurationParserFactory().get_parser()
|
||||
|
@ -103,6 +104,13 @@ def init():
|
|||
else:
|
||||
raise ValueError(f"Invalid input_schema_version: {configuration.input_schema_version}")
|
||||
|
||||
if (configuration.split_output):
|
||||
output_handler = SeparateFileOutputHandler()
|
||||
lu.get_logger().info("Will save successful results and errors to separate files")
|
||||
else:
|
||||
output_handler = SingleFileOutputHandler()
|
||||
lu.get_logger().info("Will save all results to a single file")
|
||||
|
||||
event_utils.setup_context_vars(configuration, metadata)
|
||||
setup_geneva_event_handlers()
|
||||
setup_job_log_event_handlers()
|
||||
|
@ -147,6 +155,7 @@ def init():
|
|||
if configuration.async_mode:
|
||||
callback_factory = CallbackFactory(
|
||||
configuration=configuration,
|
||||
output_handler=output_handler,
|
||||
input_to_output_transformer=input_to_output_transformer)
|
||||
finished_callback = callback_factory.generate_callback()
|
||||
|
||||
|
@ -201,13 +210,6 @@ def run(input_data: pd.DataFrame, mini_batch_context):
|
|||
try:
|
||||
ret = par.run(data_list, mini_batch_context)
|
||||
|
||||
if (configuration.split_output):
|
||||
output_handler = SeparateFileOutputHandler()
|
||||
lu.get_logger().info("Saving successful results and errors to separate files")
|
||||
else:
|
||||
output_handler = SingleFileOutputHandler()
|
||||
lu.get_logger().info("Saving results to single file")
|
||||
|
||||
if configuration.save_mini_batch_results == "enabled":
|
||||
lu.get_logger().info("save_mini_batch_results is enabled")
|
||||
output_handler.save_mini_batch_results(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"component_version": "1.1.5",
|
||||
"component_version": "1.1.6",
|
||||
"component_directory": "driver/batch_score_llm",
|
||||
"component_name": "batch_score_llm",
|
||||
"virtual_environment_name": null,
|
||||
"registry_name": "azureml"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,6 +104,7 @@ class MirHttpResponseHandler(HttpResponseHandler):
|
|||
if response_status == 200:
|
||||
result = self._create_scoring_result(
|
||||
status=ScoringResultStatus.SUCCESS,
|
||||
model_response_code=response_status,
|
||||
scoring_request=scoring_request,
|
||||
start=start,
|
||||
end=end,
|
||||
|
@ -124,6 +125,7 @@ class MirHttpResponseHandler(HttpResponseHandler):
|
|||
else: # Score failed
|
||||
result = self._create_scoring_result(
|
||||
status=ScoringResultStatus.FAILURE,
|
||||
model_response_code=response_status,
|
||||
scoring_request=scoring_request,
|
||||
start=start,
|
||||
end=end,
|
||||
|
|
|
@ -3,14 +3,9 @@
|
|||
|
||||
"""Common utilities."""
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..common.scoring.scoring_result import ScoringResult
|
||||
from . import embeddings_utils as embeddings
|
||||
from .json_encoder_extensions import BatchComponentJSONEncoder
|
||||
|
||||
|
||||
def get_base_url(url: str) -> str:
|
||||
"""Get base url."""
|
||||
|
@ -38,39 +33,7 @@ def str2bool(v):
|
|||
raise ArgumentParser.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
|
||||
def convert_result_list(results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
|
||||
"""Convert scoring results to the result list."""
|
||||
output_list: list[dict[str, str]] = []
|
||||
for scoringResult in results:
|
||||
output: dict[str, str] = {}
|
||||
output["status"] = scoringResult.status.name
|
||||
output["start"] = scoringResult.start
|
||||
output["end"] = scoringResult.end
|
||||
output["request"] = scoringResult.request_obj
|
||||
output["response"] = scoringResult.response_body
|
||||
|
||||
if scoringResult.segmented_response_bodies is not None and len(scoringResult.segmented_response_bodies) > 0:
|
||||
output["segmented_responses"] = scoringResult.segmented_response_bodies
|
||||
|
||||
if scoringResult.request_metadata is not None:
|
||||
output["request_metadata"] = scoringResult.request_metadata
|
||||
|
||||
if batch_size_per_request > 1:
|
||||
batch_output_list = embeddings._convert_to_list_of_output_items(
|
||||
output,
|
||||
scoringResult.estimated_token_counts)
|
||||
output_list.extend(batch_output_list)
|
||||
else:
|
||||
output_list.append(output)
|
||||
|
||||
return list(map(__stringify_output, output_list))
|
||||
|
||||
|
||||
def get_mini_batch_id(mini_batch_context: any):
|
||||
"""Get mini batch id from mini batch context."""
|
||||
if mini_batch_context:
|
||||
return mini_batch_context.mini_batch_id
|
||||
|
||||
|
||||
def __stringify_output(payload_obj: dict) -> str:
|
||||
return json.dumps(payload_obj, cls=BatchComponentJSONEncoder)
|
||||
|
|
|
@ -3,15 +3,8 @@
|
|||
|
||||
"""Embeddings utilities."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ..batch_pool.quota.estimators import EmbeddingsEstimator
|
||||
from ..common.telemetry import logging_utils as lu
|
||||
|
||||
estimator = None
|
||||
|
||||
|
||||
def _convert_to_list_of_input_batches(
|
||||
data: pd.DataFrame,
|
||||
|
@ -32,145 +25,3 @@ def _convert_to_list_of_input_batches(
|
|||
payload_obj = {"input": list_of_strings}
|
||||
list_of_input_batches.append(payload_obj)
|
||||
return list_of_input_batches
|
||||
|
||||
|
||||
def _convert_to_list_of_output_items(result: dict, token_count_estimates: "tuple[int]") -> "list[dict]":
|
||||
"""Convert results to a list of output items."""
|
||||
"""
|
||||
Only used by the Embeddings API with batched HTTP requests, this method takes
|
||||
a scoring result as dictionary of "request", "response", and (optional) "request_metadata".
|
||||
It returns the batch list within request and response mapped out into a list of dictionaries,
|
||||
each with the correlating request and response from the batch.
|
||||
If the scoring result has request metadata, this is persisted in each of the
|
||||
output dictionaries.
|
||||
|
||||
Args:
|
||||
result: The scoring result containing a batch of inputs and outputs.
|
||||
token_count_estimates: The tuple of tiktoken estimates for each input in the batch.
|
||||
|
||||
Returns:
|
||||
List of output objects, each with "request", "response", (optional) "request_metadata".
|
||||
"""
|
||||
output_list = []
|
||||
response_obj = result["response"]
|
||||
try:
|
||||
response_data = response_obj.pop('data', None)
|
||||
except AttributeError:
|
||||
response_data = None
|
||||
request = result["request"]
|
||||
numrequests = len(request["input"])
|
||||
|
||||
if response_data is not None:
|
||||
# Result has data; response_obj["data"]
|
||||
numresults = len(response_data)
|
||||
__validate_response_data_length(numrequests, numresults)
|
||||
output_index_to_embedding_info_map = __build_output_idx_to_embedding_mapping(response_data)
|
||||
update_prompt_tokens = __tiktoken_estimates_succeeded(token_count_estimates, numrequests)
|
||||
if not update_prompt_tokens:
|
||||
# Single online endpoints will not have computed token estimates, as this occurs in quota client.
|
||||
token_count_estimates = __tiktoken_estimates_retry(request)
|
||||
update_prompt_tokens = __tiktoken_estimates_succeeded(token_count_estimates, numrequests)
|
||||
else:
|
||||
# Result has error; response_obj["error"]. Copy this for each request in batch below.
|
||||
numresults = -1
|
||||
error_message = "The batch request resulted in an error. See job output for the error message."
|
||||
lu.get_logger().error(error_message)
|
||||
|
||||
# Input can be large. Pop and iterate through the batch to avoid copying repeatedly.
|
||||
input_batch = request.pop('input', None)
|
||||
|
||||
for i in range(numrequests):
|
||||
# The large "input" from request and "data" from response have been popped so copy is smaller.
|
||||
# "input" and "data" are set for each below.
|
||||
single_output = {"request": deepcopy(request), "response": deepcopy(response_obj)}
|
||||
single_output["request"]["input"] = input_batch[i]
|
||||
if numresults > -1:
|
||||
single_output["response"]["data"] = [output_index_to_embedding_info_map[i]]
|
||||
if update_prompt_tokens:
|
||||
__override_prompt_tokens(single_output, token_count_estimates[i])
|
||||
output_list.append(single_output)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
def __build_output_idx_to_embedding_mapping(response_data):
|
||||
"""Build a mapping from output index to embedding."""
|
||||
"""
|
||||
Given response data, return a dictionary of the index and embedding info for each element of the batch.
|
||||
Unsure if the responses are always in the correct order by input index, ensure output order by mapping out index.
|
||||
|
||||
Args:
|
||||
response_data: The list of outputs from the 'data' of API response.
|
||||
|
||||
Returns:
|
||||
Dict mapping index to embedding info.
|
||||
"""
|
||||
return {embedding_info['index']: embedding_info for embedding_info in response_data}
|
||||
|
||||
|
||||
def __override_prompt_tokens(output_obj, token_count):
|
||||
"""
|
||||
Set the token_count as the value for `prompt_tokens` in response's usage info.
|
||||
|
||||
Args:
|
||||
output_obj: The dictionary of info for response, request
|
||||
token_count: The tiktoken count for this input string
|
||||
"""
|
||||
try:
|
||||
output_obj["response"]["usage"]["prompt_tokens"] = token_count
|
||||
except Exception as exc:
|
||||
lu.get_logger().exception("Unable to set prompt token override.")
|
||||
raise exc
|
||||
|
||||
|
||||
def __tiktoken_estimates_succeeded(token_count_estimates: "tuple[int]", input_length: int) -> bool:
|
||||
"""
|
||||
Return True if the length of the batch of inputs matches the length of the tiktoken estimates.
|
||||
|
||||
Args:
|
||||
token_count_estimates: The tuple of tiktoken estimates for the inputs in this batch
|
||||
input_length: The length of inputs in this batch
|
||||
"""
|
||||
token_est_length = len(token_count_estimates)
|
||||
length_matches = token_est_length == input_length
|
||||
if not length_matches:
|
||||
lu.get_logger().warn(f"Input length {input_length} does not match token estimate length {token_est_length}. "
|
||||
"Skipping prompt_tokens count overrides.")
|
||||
return length_matches
|
||||
|
||||
|
||||
def __tiktoken_estimates_retry(request_obj: dict) -> "tuple[int]":
|
||||
"""
|
||||
Return token counts for the inputs within a batch.
|
||||
|
||||
Args:
|
||||
request_obj: The request dictionary.
|
||||
"""
|
||||
lu.get_logger().debug("Attempting to calculate tokens for the embedding input batch.")
|
||||
global estimator
|
||||
if estimator is None:
|
||||
estimator = EmbeddingsEstimator()
|
||||
token_counts = estimator.estimate_request_cost(request_obj)
|
||||
if token_counts == 1:
|
||||
# This occurs if tiktoken module fails. See DV3Estimator for more info on why this could fail.
|
||||
return ()
|
||||
else:
|
||||
return token_counts
|
||||
|
||||
|
||||
def __validate_response_data_length(numrequests, numresults):
|
||||
"""
|
||||
Validate the number of outputs from the API response matches the number of requests in the batch.
|
||||
|
||||
Args:
|
||||
numrequests: The number of requests in this batch.
|
||||
numresults: The number of results in the response data.
|
||||
|
||||
Raises:
|
||||
Exception if response length and request length do not match.
|
||||
"""
|
||||
if numresults != numrequests:
|
||||
error_message = f"Result data length {numresults} != " + \
|
||||
f"{numrequests} request batch length."
|
||||
lu.get_logger().error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""This file contains the definition for the output formatter."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..common.scoring.scoring_result import ScoringResult
|
||||
from .json_encoder_extensions import BatchComponentJSONEncoder
|
||||
from ..batch_pool.quota.estimators import EmbeddingsEstimator
|
||||
from ..common.telemetry import logging_utils as lu
|
||||
|
||||
|
||||
class OutputFormatter(ABC):
|
||||
"""An abstract class for formatting output."""
|
||||
|
||||
@abstractmethod
|
||||
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
|
||||
"""Abstract output formatting method."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_response_obj(self, result: dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_custom_id(self, result: dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_single_output(self, *args):
|
||||
pass
|
||||
|
||||
def _convert_to_list_of_output_items(self, result: dict, token_count_estimates: "tuple[int]") -> "list[dict]":
|
||||
"""Convert results to a list of output items."""
|
||||
"""
|
||||
Only used by the Embeddings API with batched HTTP requests, this method takes
|
||||
a scoring result as dictionary of "request", "response", and (optional) "request_metadata".
|
||||
It returns the batch list within request and response mapped out into a list of dictionaries,
|
||||
each with the correlating request and response from the batch.
|
||||
If the scoring result has request metadata, this is persisted in each of the
|
||||
output dictionaries.
|
||||
|
||||
Args:
|
||||
result: The scoring result containing a batch of inputs and outputs.
|
||||
token_count_estimates: The tuple of tiktoken estimates for each input in the batch.
|
||||
|
||||
Returns:
|
||||
List of output objects, each with "request", "response", (optional) "request_metadata".
|
||||
"""
|
||||
output_list = []
|
||||
|
||||
response_obj = self._get_response_obj(result)
|
||||
custom_id = self._get_custom_id(result)
|
||||
format_version = 1 if custom_id is None else 2
|
||||
|
||||
try:
|
||||
response_data = response_obj.pop('data', None)
|
||||
except AttributeError:
|
||||
response_data = None
|
||||
request = result["request"]
|
||||
numrequests = len(request["input"])
|
||||
|
||||
if response_data is not None:
|
||||
# Result has data; response_obj["data"]
|
||||
numresults = len(response_data)
|
||||
self.__validate_response_data_length(numrequests, numresults)
|
||||
output_index_to_embedding_info_map = self.__build_output_idx_to_embedding_mapping(response_data)
|
||||
update_prompt_tokens = self.__tiktoken_estimates_succeeded(token_count_estimates, numrequests)
|
||||
if not update_prompt_tokens:
|
||||
# Single online endpoints will not have computed token estimates, as this occurs in quota client.
|
||||
token_count_estimates = self.__tiktoken_estimates_retry(request)
|
||||
update_prompt_tokens = self.__tiktoken_estimates_succeeded(token_count_estimates, numrequests)
|
||||
else:
|
||||
# Result has error; response_obj["error"]. Copy this for each request in batch below.
|
||||
numresults = -1
|
||||
error_message = "The batch request resulted in an error. See job output for the error message."
|
||||
lu.get_logger().error(error_message)
|
||||
|
||||
# Input can be large. Pop and iterate through the batch to avoid copying repeatedly.
|
||||
input_batch = request.pop('input', None)
|
||||
|
||||
for i in range(numrequests):
|
||||
# The large "input" from request and "data" from response have been popped so copy is smaller.
|
||||
# "input" and "data" are set for each below.
|
||||
if format_version == 1:
|
||||
single_output = self._get_single_output(request, response_obj, input_batch[i])
|
||||
else:
|
||||
single_output = self._get_single_output(custom_id, result)
|
||||
|
||||
if numresults > -1:
|
||||
if format_version == 1:
|
||||
single_output["response"]["data"] = [output_index_to_embedding_info_map[i]]
|
||||
else:
|
||||
single_output["response"]["body"]["data"] = [output_index_to_embedding_info_map[i]]
|
||||
if update_prompt_tokens:
|
||||
self.__override_prompt_tokens(single_output, token_count_estimates[i], format_version)
|
||||
output_list.append(single_output)
|
||||
|
||||
return output_list
|
||||
|
||||
def __build_output_idx_to_embedding_mapping(self, response_data):
|
||||
"""Build a mapping from output index to embedding."""
|
||||
"""
|
||||
Given response data, return a dictionary of the index and embedding info for each element of the batch.
|
||||
Unsure if the responses are always in the correct order by input index,
|
||||
ensure output order by mapping out index.
|
||||
|
||||
Args:
|
||||
response_data: The list of outputs from the 'data' of API response.
|
||||
|
||||
Returns:
|
||||
Dict mapping index to embedding info.
|
||||
"""
|
||||
return {embedding_info['index']: embedding_info for embedding_info in response_data}
|
||||
|
||||
def __override_prompt_tokens(self, output_obj, token_count, format_version):
|
||||
"""
|
||||
Set the token_count as the value for `prompt_tokens` in response's usage info.
|
||||
|
||||
Args:
|
||||
output_obj: The dictionary of info for response, request
|
||||
token_count: The tiktoken count for this input string
|
||||
format_version: The output format version
|
||||
"""
|
||||
try:
|
||||
if format_version == 1:
|
||||
output_obj["response"]["usage"]["prompt_tokens"] = token_count
|
||||
else:
|
||||
output_obj["response"]["body"]["usage"]["prompt_tokens"] = token_count
|
||||
except Exception as exc:
|
||||
lu.get_logger().exception("Unable to set prompt token override.")
|
||||
raise exc
|
||||
|
||||
def __tiktoken_estimates_succeeded(self, token_count_estimates: "tuple[int]", input_length: int) -> bool:
|
||||
"""
|
||||
Return True if the length of the batch of inputs matches the length of the tiktoken estimates.
|
||||
|
||||
Args:
|
||||
token_count_estimates: The tuple of tiktoken estimates for the inputs in this batch
|
||||
input_length: The length of inputs in this batch
|
||||
"""
|
||||
token_est_length = len(token_count_estimates)
|
||||
length_matches = token_est_length == input_length
|
||||
if not length_matches:
|
||||
lu.get_logger().warn(f"Input length {input_length} does not match token estimate "
|
||||
"length {token_est_length}. Skipping prompt_tokens count overrides.")
|
||||
return length_matches
|
||||
|
||||
def __tiktoken_estimates_retry(self, request_obj: dict) -> "tuple[int]":
|
||||
"""
|
||||
Return token counts for the inputs within a batch.
|
||||
|
||||
Args:
|
||||
request_obj: The request dictionary.
|
||||
"""
|
||||
lu.get_logger().debug("Attempting to calculate tokens for the embedding input batch.")
|
||||
if self.estimator is None:
|
||||
self.estimator = EmbeddingsEstimator()
|
||||
token_counts = self.estimator.estimate_request_cost(request_obj)
|
||||
if token_counts == 1:
|
||||
# This occurs if tiktoken module fails. See DV3Estimator for more info on why this could fail.
|
||||
return ()
|
||||
else:
|
||||
return token_counts
|
||||
|
||||
def __validate_response_data_length(self, numrequests, numresults):
|
||||
"""
|
||||
Validate the number of outputs from the API response matches the number of requests in the batch.
|
||||
|
||||
Args:
|
||||
numrequests: The number of requests in this batch.
|
||||
numresults: The number of results in the response data.
|
||||
|
||||
Raises:
|
||||
Exception if response length and request length do not match.
|
||||
"""
|
||||
if numresults != numrequests:
|
||||
error_message = f"Result data length {numresults} != " + \
|
||||
f"{numrequests} request batch length."
|
||||
lu.get_logger().error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
def _stringify_output(self, payload_obj: dict) -> str:
|
||||
return json.dumps(payload_obj, cls=BatchComponentJSONEncoder)
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This file contains the definition for the original (V1) output formatter.
|
||||
|
||||
V1 Output format:
|
||||
{
|
||||
"status": ["SUCCESS" | "FAILURE"],
|
||||
"start": 1709584163.2691997,
|
||||
"end": 1709584165.2570084,
|
||||
"request": { <request_body> },
|
||||
"response": { <response_body> }
|
||||
}
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from .output_formatter import OutputFormatter
|
||||
from ..common.scoring.scoring_result import ScoringResult
|
||||
|
||||
|
||||
class V1OutputFormatter(OutputFormatter):
|
||||
"""Defines a class to format output in V1 format."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize V1OutputFormatter."""
|
||||
self.estimator = None
|
||||
|
||||
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
|
||||
"""Format output in the V1 format."""
|
||||
output_list: list[dict[str, str]] = []
|
||||
for scoringResult in results:
|
||||
output: dict[str, str] = {}
|
||||
output["status"] = scoringResult.status.name
|
||||
output["start"] = scoringResult.start
|
||||
output["end"] = scoringResult.end
|
||||
output["request"] = scoringResult.request_obj
|
||||
output["response"] = scoringResult.response_body
|
||||
|
||||
if scoringResult.segmented_response_bodies is not None and \
|
||||
len(scoringResult.segmented_response_bodies) > 0:
|
||||
output["segmented_responses"] = scoringResult.segmented_response_bodies
|
||||
|
||||
if scoringResult.request_metadata is not None:
|
||||
output["request_metadata"] = scoringResult.request_metadata
|
||||
|
||||
if batch_size_per_request > 1:
|
||||
batch_output_list = self._convert_to_list_of_output_items(
|
||||
output,
|
||||
scoringResult.estimated_token_counts)
|
||||
output_list.extend(batch_output_list)
|
||||
else:
|
||||
output_list.append(output)
|
||||
|
||||
return list(map(self._stringify_output, output_list))
|
||||
|
||||
def _get_response_obj(self, result: dict):
|
||||
return result["response"]
|
||||
|
||||
def _get_custom_id(self, result: dict):
|
||||
return None
|
||||
|
||||
def _get_request_id(self, result: dict):
|
||||
return None
|
||||
|
||||
def _get_status(self, result: dict):
|
||||
return result["status"]
|
||||
|
||||
def _get_single_output(self, request, response_obj, input_batch):
|
||||
single_output = {"request": deepcopy(request), "response": deepcopy(response_obj)}
|
||||
single_output["request"]["input"] = input_batch
|
||||
|
||||
return single_output
|
|
@ -4,7 +4,6 @@
|
|||
"""This file contains the definition for the new (V2) schema input handler."""
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
|
||||
from .input_handler import InputHandler
|
||||
|
||||
|
@ -20,8 +19,9 @@ class V2InputSchemaHandler(InputHandler):
|
|||
"""Convert the new schema input pandas DataFrame to a list of payload strings."""
|
||||
body_details = []
|
||||
for _, row in data.iterrows():
|
||||
body = json.loads(row['body'])
|
||||
body = row['body']
|
||||
del body['model']
|
||||
body['custom_id'] = row['custom_id']
|
||||
body_details.append(body)
|
||||
original_schema_df = pd.DataFrame(body_details)
|
||||
return self._convert_to_list(original_schema_df, additional_properties, batch_size_per_request)
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This file contains the definition for the new (V2) output formatter.
|
||||
|
||||
V2 Output format:
|
||||
{
|
||||
"custom_id": <custom_id>,
|
||||
"request_id": "", // MIR endpoint request id?
|
||||
"status": <HTTP response status code>,
|
||||
// If response is successful, "response" should have the response body and "error" should be null,
|
||||
// and vice versa for a failed response.
|
||||
"response": { <response_body> | null },
|
||||
"error": { null | <response_body> }
|
||||
}
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from .output_formatter import OutputFormatter
|
||||
from ..common.scoring.aoai_error import AoaiScoringError
|
||||
from ..common.scoring.aoai_response import AoaiScoringResponse
|
||||
from ..common.scoring.scoring_result import ScoringResult
|
||||
|
||||
|
||||
class V2OutputFormatter(OutputFormatter):
|
||||
"""Defines a class to format output in V2 format."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize V2OutputFormatter."""
|
||||
self.estimator = None
|
||||
|
||||
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
|
||||
"""Format output in the V2 format."""
|
||||
output_list: list[dict[str, str]] = []
|
||||
for scoringResult in results:
|
||||
output: dict[str, str] = {}
|
||||
|
||||
keys = scoringResult.request_obj.keys()
|
||||
if "custom_id" in keys:
|
||||
output["custom_id"] = scoringResult.request_obj["custom_id"]
|
||||
else:
|
||||
raise Exception("V2OutputFormatter called and custom_id not found"
|
||||
"in request object (original payload)")
|
||||
|
||||
if scoringResult.status.name == "SUCCESS":
|
||||
response = AoaiScoringResponse(request_id=self.__get_request_id(scoringResult),
|
||||
status_code=scoringResult.model_response_code,
|
||||
body=deepcopy(scoringResult.response_body))
|
||||
output["response"] = vars(response)
|
||||
output["error"] = None
|
||||
else:
|
||||
error = AoaiScoringError(message=deepcopy(scoringResult.response_body))
|
||||
output["response"] = None
|
||||
output["error"] = vars(error)
|
||||
|
||||
if batch_size_per_request > 1:
|
||||
# _convert_to_list_of_output_items() expects output["request"] to be set.
|
||||
output["request"] = scoringResult.request_obj
|
||||
batch_output_list = self._convert_to_list_of_output_items(
|
||||
output,
|
||||
scoringResult.estimated_token_counts)
|
||||
output_list.extend(batch_output_list)
|
||||
else:
|
||||
output_list.append(output)
|
||||
|
||||
return list(map(self._stringify_output, output_list))
|
||||
|
||||
def __get_request_id(self, scoring_request: ScoringResult):
|
||||
return scoring_request.response_headers.get("x-request-id", "")
|
||||
|
||||
def _get_response_obj(self, result: dict):
|
||||
if result.get("response") is None:
|
||||
return result.get("error")
|
||||
else:
|
||||
return result.get("response").get("body")
|
||||
|
||||
def _get_custom_id(self, result: dict) -> str:
|
||||
return result.get("custom_id", "")
|
||||
|
||||
def _get_single_output(self, custom_id, result):
|
||||
single_output = {
|
||||
"id": "", # TODO: populate this ID
|
||||
"custom_id": deepcopy(custom_id),
|
||||
"response": deepcopy(result["response"]),
|
||||
"error": deepcopy(result["error"])
|
||||
}
|
||||
return single_output
|
|
@ -68,6 +68,7 @@ YAML_DISALLOW_FAILED_REQUESTS = {"jobs": {JOB_NAME: {
|
|||
|
||||
# This test confirms that we can score an MIR endpoint using the scoring_url parameter and
|
||||
# the batch_score_llm.yml component.
|
||||
@pytest.mark.skip('Tempararily disabled until the test endpoint is created.')
|
||||
@pytest.mark.smoke
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.timeout(20 * 60)
|
||||
|
|
|
@ -74,6 +74,7 @@ def mock_run(monkeypatch):
|
|||
async def _run(self, requests: "list[ScoringRequest]") -> "list[ScoringResult]":
|
||||
passed_requests.extend(requests)
|
||||
return [ScoringResult(status=ScoringResultStatus.SUCCESS,
|
||||
model_response_code=200,
|
||||
response_body={"usage": {}},
|
||||
omit=False,
|
||||
start=0,
|
||||
|
|
|
@ -19,6 +19,7 @@ def make_scoring_result():
|
|||
"""Mock scoring result."""
|
||||
def make(
|
||||
status: ScoringResultStatus = ScoringResultStatus.SUCCESS,
|
||||
model_response_code: int = 200,
|
||||
start: float = time.time() - 10,
|
||||
end: float = time.time(),
|
||||
request_obj: any = None,
|
||||
|
@ -30,6 +31,7 @@ def make_scoring_result():
|
|||
"""Make a mock scoring result."""
|
||||
return ScoringResult(
|
||||
status=status,
|
||||
model_response_code=model_response_code,
|
||||
start=start,
|
||||
end=end,
|
||||
request_obj=request_obj,
|
||||
|
|
|
@ -12,6 +12,7 @@ from src.batch_score.common.post_processing.callback_factory import CallbackFact
|
|||
from src.batch_score.common.post_processing.mini_batch_context import MiniBatchContext
|
||||
from src.batch_score.common.scoring.scoring_result import ScoringResult
|
||||
from src.batch_score.common.telemetry.events import event_utils
|
||||
from src.batch_score.common.post_processing.output_handler import SingleFileOutputHandler, SeparateFileOutputHandler
|
||||
from tests.fixtures.input_transformer import FakeInputOutputModifier
|
||||
from tests.fixtures.scoring_result import get_test_request_obj
|
||||
from tests.fixtures.test_mini_batch_context import TestMiniBatchContext
|
||||
|
@ -34,6 +35,7 @@ def test_generate_callback_success(mock_get_logger,
|
|||
|
||||
callback_factory = CallbackFactory(
|
||||
configuration=_get_test_configuration(),
|
||||
output_handler=SingleFileOutputHandler(),
|
||||
input_to_output_transformer=mock_input_to_output_transformer)
|
||||
|
||||
callbacks = callback_factory.generate_callback()
|
||||
|
@ -69,6 +71,7 @@ def test_generate_callback_exception_with_mini_batch_id(mock_get_logger,
|
|||
|
||||
callback_factory = CallbackFactory(
|
||||
configuration=_get_test_configuration(),
|
||||
output_handler=SingleFileOutputHandler(),
|
||||
input_to_output_transformer=mock_input_to_output_transformer)
|
||||
|
||||
callbacks = callback_factory.generate_callback()
|
||||
|
@ -109,21 +112,26 @@ def test_output_handler(
|
|||
scoring_result = make_scoring_result(request_obj=get_test_request_obj())
|
||||
gathered_result: list[ScoringResult] = [scoring_result.copy(), scoring_result.copy()]
|
||||
|
||||
test_configuration = _get_test_configuration_for_output_handler(split_output)
|
||||
|
||||
callback_factory = CallbackFactory(
|
||||
configuration=test_configuration,
|
||||
input_to_output_transformer=mock_input_to_output_transformer)
|
||||
|
||||
with patch(
|
||||
"src.batch_score.common.post_processing.callback_factory.SeparateFileOutputHandler",
|
||||
"tests.unit.common.post_processing.test_callback_factory.SeparateFileOutputHandler",
|
||||
return_value=MagicMock()
|
||||
) as mock_separate_file_output_handler, \
|
||||
patch(
|
||||
"src.batch_score.common.post_processing.callback_factory.SingleFileOutputHandler",
|
||||
"tests.unit.common.post_processing.test_callback_factory.SingleFileOutputHandler",
|
||||
return_value=MagicMock()
|
||||
) as mock_single_file_output_handler:
|
||||
|
||||
test_configuration = _get_test_configuration_for_output_handler(split_output)
|
||||
if test_configuration.split_output:
|
||||
output_handler = SeparateFileOutputHandler()
|
||||
else:
|
||||
output_handler = SingleFileOutputHandler()
|
||||
|
||||
callback_factory = CallbackFactory(
|
||||
configuration=test_configuration,
|
||||
output_handler=output_handler,
|
||||
input_to_output_transformer=mock_input_to_output_transformer)
|
||||
|
||||
callbacks = callback_factory.generate_callback()
|
||||
|
||||
_ = callbacks(gathered_result, mini_batch_context)
|
||||
|
|
|
@ -40,7 +40,7 @@ async def test_score(response_status, response_body, exception_to_raise):
|
|||
http_response_handler=http_response_handler,
|
||||
scoring_url=None)
|
||||
|
||||
scoring_request = ScoringRequest(original_payload='{"prompt":"Test model"}')
|
||||
scoring_request = ScoringRequest(original_payload='{"custom_id": "task_123", "prompt":"Test model"}')
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
with patch.object(session, "post") as mock_post:
|
||||
|
@ -62,6 +62,8 @@ async def test_score(response_status, response_body, exception_to_raise):
|
|||
assert http_response_handler.handle_response.assert_called_once
|
||||
response_sent_to_handler = http_response_handler.handle_response.call_args.kwargs['http_response']
|
||||
|
||||
assert "custom_id" not in scoring_client._create_http_request(scoring_request).payload
|
||||
|
||||
if exception_to_raise:
|
||||
assert type(response_sent_to_handler.exception) is exception_to_raise
|
||||
elif response_status == 200:
|
||||
|
|
|
@ -18,6 +18,7 @@ def test_copy():
|
|||
ScoringResultStatus.SUCCESS,
|
||||
0,
|
||||
1,
|
||||
200,
|
||||
"request_obj",
|
||||
{},
|
||||
response_body={"usage": {"prompt_tokens": 2,
|
||||
|
@ -48,6 +49,7 @@ def test_copy():
|
|||
assert result2.response_body["usage"]["total_tokens"] == 16
|
||||
|
||||
assert result.status == result2.status
|
||||
assert result.model_response_code == result2.model_response_code
|
||||
assert result2.estimated_token_counts == (1, 2, 3)
|
||||
|
||||
|
||||
|
@ -79,6 +81,7 @@ def test_usage_statistics(
|
|||
ScoringResultStatus.SUCCESS,
|
||||
0,
|
||||
1,
|
||||
200,
|
||||
"request_obj",
|
||||
{},
|
||||
response_body=response_usage,
|
||||
|
|
|
@ -17,7 +17,9 @@ from src.batch_score.common.scoring.scoring_result import (
|
|||
ScoringResult,
|
||||
ScoringResultStatus,
|
||||
)
|
||||
from src.batch_score.utils import common
|
||||
from src.batch_score.batch_pool.quota.estimators import EmbeddingsEstimator
|
||||
from src.batch_score.utils.v1_output_formatter import V1OutputFormatter
|
||||
from src.batch_score.utils.v2_output_formatter import V2OutputFormatter
|
||||
from src.batch_score.utils.v1_input_schema_handler import V1InputSchemaHandler
|
||||
from src.batch_score.utils.v2_input_schema_handler import V2InputSchemaHandler
|
||||
|
||||
|
@ -42,29 +44,53 @@ VALID_DATAFRAMES = [
|
|||
NEW_SCHEMA_VALID_DATAFRAMES = [
|
||||
[
|
||||
[
|
||||
{"task_id": "task_123", "method": "POST", "url": "/v1/completions",
|
||||
"body": '{"model": "chat-sahara-4", "max_tokens": 1}'},
|
||||
{"task_id": "task_789", "method": "POST", "url": "/v1/completions",
|
||||
"body": '{"model": "chat-sahara-4", "max_tokens": 2}'}
|
||||
{"custom_id": "task_123", "method": "POST", "url": "/v1/completions",
|
||||
"body": {"model": "chat-sahara-4", "max_tokens": 1}},
|
||||
{"custom_id": "task_789", "method": "POST", "url": "/v1/completions",
|
||||
"body": {"model": "chat-sahara-4", "max_tokens": 2}}
|
||||
],
|
||||
[
|
||||
'{"max_tokens": 1}', '{"max_tokens": 2}'
|
||||
'{"max_tokens": 1, "custom_id": "task_123"}', '{"max_tokens": 2, "custom_id": "task_789"}'
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
{"task_id": "task_123", "method": "POST", "url": "/v1/completions", "body": '{"model": "chat-sahara-4", \
|
||||
"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, \
|
||||
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'},
|
||||
{"task_id": "task_456", "method": "POST", "url": "/v1/completions", "body": '{"model": "chat-sahara-4", \
|
||||
"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, \
|
||||
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'}
|
||||
{
|
||||
"custom_id": "task_123",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "chat-sahara-4",
|
||||
"temperature": 0,
|
||||
"max_tokens": 1024,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."
|
||||
}
|
||||
},
|
||||
{
|
||||
"custom_id": "task_456",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "chat-sahara-4",
|
||||
"temperature": 0,
|
||||
"max_tokens": 1024,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."
|
||||
}
|
||||
}
|
||||
],
|
||||
[
|
||||
('{"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0,'
|
||||
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'),
|
||||
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human...",'
|
||||
' "custom_id": "task_123"}'),
|
||||
('{"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0,'
|
||||
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}')
|
||||
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human...",'
|
||||
' "custom_id": "task_456"}')
|
||||
]
|
||||
]
|
||||
]
|
||||
|
@ -153,95 +179,119 @@ def test_new_schema_convert_input_to_requests_happy_path(obj_list: "list[any]",
|
|||
assert result == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tiktoken_failed",
|
||||
[True, False])
|
||||
def test_convert_result_list_batch_size_one(tiktoken_failed):
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
@pytest.mark.parametrize("tiktoken_failed", [True, False])
|
||||
def test_output_formatter_batch_size_one(input_schema_version, tiktoken_failed):
|
||||
"""Test convert result list batch size one case."""
|
||||
# Arrange
|
||||
batch_size_per_request = 1
|
||||
result_list = []
|
||||
inputstring = __get_input_batch(batch_size_per_request)[0]
|
||||
request_obj = {"input": inputstring, "custom_id": "task_123"}
|
||||
outputlist = __get_output_data(batch_size_per_request)
|
||||
result = __get_scoring_result_for_batch(batch_size_per_request,
|
||||
inputstring,
|
||||
request_obj,
|
||||
outputlist,
|
||||
tiktoken_failed=tiktoken_failed)
|
||||
result_list.append(result)
|
||||
|
||||
# Act
|
||||
actual = common.convert_result_list(result_list, batch_size_per_request)
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
actual = output_formatter.format_output(result_list, batch_size_per_request)
|
||||
actual_obj = json.loads(actual[0])
|
||||
|
||||
# Assert
|
||||
assert len(actual) == 1
|
||||
assert actual_obj["status"] == "SUCCESS"
|
||||
assert "start" in actual_obj
|
||||
assert "end" in actual_obj
|
||||
assert actual_obj["request"]["input"] == inputstring
|
||||
assert actual_obj["response"]["usage"]["prompt_tokens"] == 1
|
||||
assert actual_obj["response"]["usage"]["total_tokens"] == 1
|
||||
if input_schema_version == 1:
|
||||
assert actual_obj["status"] == "SUCCESS"
|
||||
assert "start" in actual_obj
|
||||
assert "end" in actual_obj
|
||||
assert actual_obj["request"]["input"] == inputstring
|
||||
assert actual_obj["response"]["usage"]["prompt_tokens"] == 1
|
||||
assert actual_obj["response"]["usage"]["total_tokens"] == 1
|
||||
elif input_schema_version == 2:
|
||||
assert actual_obj["response"]["status_code"] == 200
|
||||
assert actual_obj["response"]["body"]["usage"]["prompt_tokens"] == 1
|
||||
assert actual_obj["response"]["body"]["usage"]["total_tokens"] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tiktoken_failed",
|
||||
[True, False])
|
||||
def test_convert_result_list_failed_result(tiktoken_failed):
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
@pytest.mark.parametrize("tiktoken_failed", [True, False])
|
||||
def test_output_formatter_failed_result(input_schema_version, tiktoken_failed):
|
||||
"""Test convert result list failed result case."""
|
||||
# Arrange
|
||||
batch_size_per_request = 1
|
||||
result_list = []
|
||||
inputstring = __get_input_batch(batch_size_per_request)[0]
|
||||
result = __get_failed_scoring_result_for_batch(inputstring, tiktoken_failed=tiktoken_failed)
|
||||
request_obj = {"input": inputstring, "custom_id": "task_123"}
|
||||
result = __get_failed_scoring_result_for_batch(request_obj, tiktoken_failed=tiktoken_failed)
|
||||
result_list.append(result)
|
||||
|
||||
# Act
|
||||
actual = common.convert_result_list(result_list, batch_size_per_request)
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
actual = output_formatter.format_output(result_list, batch_size_per_request)
|
||||
actual_obj = json.loads(actual[0])
|
||||
|
||||
# Assert
|
||||
assert len(actual) == 1
|
||||
assert actual_obj["status"] == "FAILURE"
|
||||
assert "start" in actual_obj
|
||||
assert "end" in actual_obj
|
||||
assert actual_obj["request"]["input"] == inputstring
|
||||
assert actual_obj["response"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in actual_obj["response"]["error"]["message"]
|
||||
if input_schema_version == 1:
|
||||
assert actual_obj["status"] == "FAILURE"
|
||||
assert "start" in actual_obj
|
||||
assert "end" in actual_obj
|
||||
assert actual_obj["request"]["input"] == inputstring
|
||||
assert actual_obj["response"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in actual_obj["response"]["error"]["message"]
|
||||
elif input_schema_version == 2: # TODO: Confirm this is the actual format for errors
|
||||
assert actual_obj["error"]["message"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in actual_obj["error"]["message"]["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tiktoken_failed",
|
||||
[True, False])
|
||||
def test_convert_result_list_failed_result_batch(tiktoken_failed):
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
@pytest.mark.parametrize("tiktoken_failed", [True, False])
|
||||
def test_output_formatter_failed_result_batch(input_schema_version, tiktoken_failed):
|
||||
"""Test convert result list failed result batch case."""
|
||||
# Arrange
|
||||
batch_size_per_request = 2
|
||||
|
||||
inputlist = __get_input_batch(batch_size_per_request)
|
||||
result_list = [__get_failed_scoring_result_for_batch(inputlist, tiktoken_failed)]
|
||||
request_obj = {"input": inputlist, "custom_id": "task_123"}
|
||||
result_list = [__get_failed_scoring_result_for_batch(request_obj, tiktoken_failed)]
|
||||
|
||||
# Act
|
||||
actual = common.convert_result_list(result_list, batch_size_per_request)
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
actual = output_formatter.format_output(result_list, batch_size_per_request)
|
||||
|
||||
# Assert
|
||||
assert len(actual) == batch_size_per_request
|
||||
for idx, result in enumerate(actual):
|
||||
output_obj = json.loads(result)
|
||||
assert output_obj["request"]["input"] == inputlist[idx]
|
||||
|
||||
assert output_obj["response"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in output_obj["response"]["error"]["message"]
|
||||
if input_schema_version == 1:
|
||||
assert output_obj["request"]["input"] == inputlist[idx]
|
||||
assert output_obj["response"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in output_obj["response"]["error"]["message"]
|
||||
elif input_schema_version == 2:
|
||||
assert output_obj["response"] is None
|
||||
assert output_obj["error"]["message"]["error"]["type"] == "invalid_request_error"
|
||||
assert "maximum context length is 8190 tokens" in output_obj["error"]["message"]["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reorder_results, online_endpoint_url, tiktoken_fails",
|
||||
[(True, True, True),
|
||||
(True, True, False),
|
||||
(True, False, True),
|
||||
(True, False, False),
|
||||
(False, True, True),
|
||||
(False, True, False),
|
||||
(False, False, True),
|
||||
(False, False, False)])
|
||||
def test_convert_result_list_batch_20(
|
||||
@pytest.mark.parametrize("tiktoken_fails", [True, False])
|
||||
@pytest.mark.parametrize("online_endpoint_url", [True, False])
|
||||
@pytest.mark.parametrize("reorder_results", [True, False])
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
def test_output_formatter_batch_20(
|
||||
monkeypatch,
|
||||
mock_get_logger,
|
||||
input_schema_version,
|
||||
reorder_results,
|
||||
online_endpoint_url,
|
||||
tiktoken_fails):
|
||||
|
@ -256,22 +306,24 @@ def test_convert_result_list_batch_20(
|
|||
inputlists = []
|
||||
for n in range(full_batches):
|
||||
inputlist = __get_input_batch(batch_size_per_request)
|
||||
request_obj = {"input": inputlist, "custom_id": "task_123"}
|
||||
outputlist = __get_output_data(batch_size_per_request)
|
||||
result = __get_scoring_result_for_batch(
|
||||
batch_size_per_request,
|
||||
inputlist,
|
||||
request_obj,
|
||||
outputlist,
|
||||
reorder_results,
|
||||
online_endpoint_url or tiktoken_fails)
|
||||
inputlists.extend(inputlist)
|
||||
result_list.append(result)
|
||||
inputlist = __get_input_batch(additional_rows)
|
||||
request_obj = {"input": inputlist, "custom_id": "task_456"}
|
||||
outputlist = __get_output_data(additional_rows)
|
||||
|
||||
inputlists.extend(inputlist)
|
||||
result_list.append(__get_scoring_result_for_batch(
|
||||
additional_rows,
|
||||
inputlist,
|
||||
request_obj,
|
||||
outputlist,
|
||||
reorder_results,
|
||||
online_endpoint_url or tiktoken_fails))
|
||||
|
@ -283,13 +335,18 @@ def test_convert_result_list_batch_20(
|
|||
__mock_tiktoken_estimate(monkeypatch)
|
||||
|
||||
# Act
|
||||
actual = common.convert_result_list(result_list, batch_size_per_request)
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
actual = output_formatter.format_output(result_list, batch_size_per_request)
|
||||
|
||||
# Assert
|
||||
assert len(actual) == batch_size_per_request * full_batches + additional_rows
|
||||
for idx, result in enumerate(actual):
|
||||
output_obj = json.loads(result)
|
||||
assert output_obj["request"]["input"] == inputlists[idx]
|
||||
if input_schema_version == 1:
|
||||
assert output_obj["request"]["input"] == inputlists[idx]
|
||||
|
||||
# Assign valid_batch_len for this result. This is the expected total_tokens.
|
||||
if idx >= batch_size_per_request * full_batches:
|
||||
|
@ -303,38 +360,48 @@ def test_convert_result_list_batch_20(
|
|||
# Index values in `response.data` are from [0, batch_size_per_request -1]
|
||||
valid_batch_idx = idx % batch_size_per_request
|
||||
|
||||
assert output_obj["response"]["data"][0]["index"] == valid_batch_idx
|
||||
assert output_obj["response"]["usage"]["total_tokens"] == valid_batch_len
|
||||
response_obj = output_obj["response"] if input_schema_version == 1 else output_obj["response"]["body"]
|
||||
|
||||
assert response_obj["data"][0]["index"] == valid_batch_idx
|
||||
assert response_obj["usage"]["total_tokens"] == valid_batch_len
|
||||
if tiktoken_fails:
|
||||
# Prompt tokens will equal total tokens (equals batch length)
|
||||
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_len
|
||||
assert response_obj["usage"]["prompt_tokens"] == valid_batch_len
|
||||
elif online_endpoint_url:
|
||||
# Prompt tokens is batch index (see helper function: `__mock_tiktoken_estimate`)
|
||||
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_idx
|
||||
assert response_obj["usage"]["prompt_tokens"] == valid_batch_idx
|
||||
else:
|
||||
# Batch pool case; prompt tokens is 10 + batch index (see helper function: `__get_token_counts`)
|
||||
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_idx + 10
|
||||
assert response_obj["usage"]["prompt_tokens"] == valid_batch_idx + 10
|
||||
|
||||
|
||||
def test_incorrect_data_length_raises():
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
def test_incorrect_data_length_raises(input_schema_version):
|
||||
"""Test incorrect data length raises."""
|
||||
# Arrange
|
||||
batch_size_per_request = 2
|
||||
result_list = []
|
||||
inputstring = __get_input_batch(batch_size_per_request)
|
||||
request_obj = {"input": inputstring, "custom_id": "task_123"}
|
||||
outputlist = __get_output_data(0)
|
||||
result = __get_scoring_result_for_batch(batch_size_per_request, inputstring, outputlist)
|
||||
result = __get_scoring_result_for_batch(batch_size_per_request, request_obj, outputlist)
|
||||
result_list.append(result)
|
||||
|
||||
# Act
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
common.convert_result_list(result_list, batch_size_per_request)
|
||||
output_formatter.format_output(result_list, batch_size_per_request)
|
||||
|
||||
# Assert
|
||||
assert "Result data length 0 != 2 request batch length." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_endpoint_response_is_not_json(mock_get_logger):
|
||||
@pytest.mark.parametrize("input_schema_version", [1, 2])
|
||||
def test_endpoint_response_is_not_json(input_schema_version, mock_get_logger):
|
||||
"""Test endpoint response is not json."""
|
||||
# Arrange failed response payload as a string
|
||||
batch_size_per_request = 10
|
||||
|
@ -344,11 +411,12 @@ def test_endpoint_response_is_not_json(mock_get_logger):
|
|||
status=ScoringResultStatus.FAILURE,
|
||||
start=0,
|
||||
end=0,
|
||||
model_response_code=400,
|
||||
request_metadata="Not important",
|
||||
response_headers="Headers",
|
||||
num_retries=2,
|
||||
token_counts=(1,) * batch_size_per_request,
|
||||
request_obj={"input": inputs},
|
||||
request_obj={"input": inputs, "custom_id": "task_123"},
|
||||
response_body=json.dumps({"object": "list",
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 8190 tokens, "
|
||||
|
@ -360,28 +428,38 @@ def test_endpoint_response_is_not_json(mock_get_logger):
|
|||
"code": None
|
||||
}}))
|
||||
# Act
|
||||
actual = common.convert_result_list([result], batch_size_per_request)
|
||||
if input_schema_version == 1:
|
||||
output_formatter = V1OutputFormatter()
|
||||
else:
|
||||
output_formatter = V2OutputFormatter()
|
||||
actual = output_formatter.format_output([result], batch_size_per_request)
|
||||
|
||||
# Assert
|
||||
assert len(actual) == batch_size_per_request
|
||||
for idx, result in enumerate(actual):
|
||||
unit = json.loads(result)
|
||||
assert unit['request']['input'] == inputs[idx]
|
||||
assert type(unit['response']) is str
|
||||
assert 'maximum context length' in unit['response']
|
||||
if input_schema_version == 1:
|
||||
assert unit['request']['input'] == inputs[idx]
|
||||
assert type(unit['response']) is str
|
||||
assert 'maximum context length' in unit['response']
|
||||
elif input_schema_version == 2:
|
||||
assert unit['response'] is None
|
||||
assert type(unit['error']['message']) is str
|
||||
assert 'maximum context length' in unit['error']['message']
|
||||
|
||||
|
||||
def __get_scoring_result_for_batch(batch_size, inputlist, outputlist, reorder_results=False, tiktoken_failed=False):
|
||||
token_counts = __get_token_counts(tiktoken_failed, inputlist)
|
||||
def __get_scoring_result_for_batch(batch_size, request_obj, outputlist, reorder_results=False, tiktoken_failed=False):
|
||||
token_counts = __get_token_counts(tiktoken_failed, request_obj["input"])
|
||||
result = ScoringResult(
|
||||
status=ScoringResultStatus.SUCCESS,
|
||||
model_response_code=200,
|
||||
start=0,
|
||||
end=0,
|
||||
request_metadata="Not important",
|
||||
response_headers="Headers",
|
||||
response_headers={"header1": "value"},
|
||||
num_retries=2,
|
||||
token_counts=token_counts,
|
||||
request_obj={"input": inputlist},
|
||||
request_obj=request_obj,
|
||||
response_body={"object": "list",
|
||||
"data": outputlist,
|
||||
"model": "text-embedding-ada-002",
|
||||
|
@ -394,17 +472,18 @@ def __get_scoring_result_for_batch(batch_size, inputlist, outputlist, reorder_re
|
|||
return result
|
||||
|
||||
|
||||
def __get_failed_scoring_result_for_batch(inputlist, tiktoken_failed=False):
|
||||
token_counts = __get_token_counts(tiktoken_failed, inputlist)
|
||||
def __get_failed_scoring_result_for_batch(request_obj, tiktoken_failed=False):
|
||||
token_counts = __get_token_counts(tiktoken_failed, request_obj["input"])
|
||||
result = ScoringResult(
|
||||
status=ScoringResultStatus.FAILURE,
|
||||
start=0,
|
||||
end=0,
|
||||
model_response_code=400,
|
||||
request_metadata="Not important",
|
||||
response_headers="Headers",
|
||||
num_retries=2,
|
||||
token_counts=token_counts,
|
||||
request_obj={"input": inputlist},
|
||||
request_obj=request_obj,
|
||||
response_body={"object": "list",
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 8190 tokens, "
|
||||
|
@ -451,10 +530,10 @@ def __random_string(length: int = 10):
|
|||
def __mock_tiktoken_permanent_failure(monkeypatch):
|
||||
def mock_tiktoken_failure(*args):
|
||||
return 1
|
||||
monkeypatch.setattr(common.embeddings.EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_failure)
|
||||
monkeypatch.setattr(EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_failure)
|
||||
|
||||
|
||||
def __mock_tiktoken_estimate(monkeypatch):
|
||||
def mock_tiktoken_override(estimator, request_obj):
|
||||
return [i for i in range(len(request_obj['input']))]
|
||||
monkeypatch.setattr(common.embeddings.EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_override)
|
||||
monkeypatch.setattr(EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_override)
|
||||
|
|
|
@ -16,6 +16,7 @@ with patch('importlib.import_module', side_effect=mock_import):
|
|||
from src.batch_score.batch_pool.meds_client import MEDSClient
|
||||
from src.batch_score.common.common_enums import EndpointType
|
||||
from src.batch_score.common.configuration.configuration import Configuration
|
||||
from src.batch_score.common.post_processing.output_handler import SeparateFileOutputHandler, SingleFileOutputHandler
|
||||
from src.batch_score.common.telemetry.events import event_utils
|
||||
from src.batch_score.common.telemetry.trace_configs import (
|
||||
ConnectionCreateEndTrace,
|
||||
|
@ -103,20 +104,26 @@ def test_output_handler_interface(
|
|||
use_single_file_output_handler: bool,
|
||||
use_separate_file_output_handler: bool):
|
||||
"""Test output handler interface."""
|
||||
input_data, mini_batch_context = _setup_main()
|
||||
|
||||
with patch(
|
||||
"src.batch_score.main.SeparateFileOutputHandler",
|
||||
return_value=MagicMock()
|
||||
"tests.unit.test_main.SeparateFileOutputHandler"
|
||||
) as mock_separate_file_output_handler, \
|
||||
patch(
|
||||
"src.batch_score.main.SingleFileOutputHandler",
|
||||
return_value=MagicMock()
|
||||
"tests.unit.test_main.SingleFileOutputHandler"
|
||||
) as mock_single_file_output_handler:
|
||||
|
||||
input_data, mini_batch_context = _setup_main()
|
||||
|
||||
main.configuration.split_output = split_output
|
||||
main.configuration.save_mini_batch_results = "enabled"
|
||||
main.configuration.mini_batch_results_out_directory = "driver/tests/unit/unit_test_results/"
|
||||
if main.configuration.split_output:
|
||||
main.output_handler = SeparateFileOutputHandler(
|
||||
main.configuration.batch_size_per_request,
|
||||
main.configuration.input_schema_version)
|
||||
else:
|
||||
main.output_handler = SingleFileOutputHandler(
|
||||
main.configuration.batch_size_per_request,
|
||||
main.configuration.input_schema_version)
|
||||
main.run(input_data=input_data, mini_batch_context=mini_batch_context)
|
||||
|
||||
assert mock_separate_file_output_handler.called == use_separate_file_output_handler
|
||||
|
@ -267,6 +274,7 @@ def _setup_main(par_exception=None):
|
|||
configuration.scoring_url = "https://scoring_url"
|
||||
configuration.batch_pool = "batch_pool"
|
||||
configuration.quota_audience = "quota_audience"
|
||||
configuration.input_schema_version = 1
|
||||
main.configuration = configuration
|
||||
|
||||
main.input_handler = MagicMock()
|
||||
|
|
Загрузка…
Ссылка в новой задаче