Refresh batch score v2 data schema to align with OpenAI public API (#2768)

* Refresh batch score v2 data schema to align with OpenAI public API

* bump version

* Styling fixes

* Add missing file

* Update git ignore

* Fix styling

* Fix
This commit is contained in:
Ye Tao 2024-04-24 13:35:10 -07:00 коммит произвёл GitHub
Родитель 27b8fbbb20
Коммит 9a7fc2d2fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
27 изменённых файлов: 635 добавлений и 308 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -139,4 +139,3 @@ mlruns/
# ignore config files
config.json
out*

Просмотреть файл

@ -2,7 +2,7 @@ $schema: http://azureml/sdk-2-0/ParallelComponent.json
type: parallel
name: batch_score_llm
version: 1.1.5
version: 1.1.6
display_name: Batch Score Large Language Models
is_deterministic: False

Просмотреть файл

@ -65,6 +65,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
if response_status == 200:
return self._create_scoring_result(
status=ScoringResultStatus.SUCCESS,
model_response_code=response_status,
scoring_request=scoring_request,
start=start,
end=end,
@ -78,6 +79,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
result = self._create_scoring_result(
status=ScoringResultStatus.FAILURE,
model_response_code=response_status,
scoring_request=scoring_request,
start=start,
end=end,
@ -130,6 +132,7 @@ class AoaiHttpResponseHandler(HttpResponseHandler):
except Exception:
return self._create_scoring_result(
status=ScoringResultStatus.FAILURE,
model_response_code=http_response.status,
scoring_request=scoring_request,
start=start,
end=end,

Просмотреть файл

@ -5,7 +5,9 @@
import asyncio
from ...utils.common import convert_result_list
from ...utils.output_formatter import OutputFormatter
from ...utils.v1_output_formatter import V1OutputFormatter
from ...utils.v2_output_formatter import V2OutputFormatter
from ..configuration.configuration import Configuration
from ..post_processing.mini_batch_context import MiniBatchContext
from ..post_processing.result_utils import apply_input_transformer
@ -48,7 +50,12 @@ class Parallel:
apply_input_transformer(self.__input_to_output_transformer, scoring_results)
results = convert_result_list(
output_formatter: OutputFormatter
if self._configuration.input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
results = output_formatter.format_output(
results=scoring_results,
batch_size_per_request=self._configuration.batch_size_per_request)

Просмотреть файл

@ -5,7 +5,9 @@
import traceback
from ...utils.common import convert_result_list
from ...utils.output_formatter import OutputFormatter
from ...utils.v1_output_formatter import V1OutputFormatter
from ...utils.v2_output_formatter import V2OutputFormatter
from ..configuration.configuration import Configuration
from ..scoring.scoring_result import ScoringResult
from ..telemetry import logging_utils as lu
@ -16,7 +18,7 @@ from .result_utils import (
apply_input_transformer,
get_return_value,
)
from .output_handler import SingleFileOutputHandler, SeparateFileOutputHandler
from .output_handler import OutputHandler
def add_callback(callback, cur):
@ -33,9 +35,11 @@ class CallbackFactory:
def __init__(self,
configuration: Configuration,
output_handler: OutputHandler,
input_to_output_transformer):
"""Initialize CallbackFactory."""
self._configuration = configuration
self._output_handler = output_handler
self.__input_to_output_transformer = input_to_output_transformer
def generate_callback(self):
@ -46,7 +50,12 @@ class CallbackFactory:
return callback
def _convert_result_list(self, scoring_results: "list[ScoringResult]", mini_batch_context: MiniBatchContext):
return convert_result_list(
output_formatter: OutputFormatter
if self._configuration.input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
return output_formatter.format_output(
results=scoring_results,
batch_size_per_request=self._configuration.batch_size_per_request)
@ -64,13 +73,7 @@ class CallbackFactory:
if mini_batch_context.exception is None:
if self._configuration.save_mini_batch_results == "enabled":
lu.get_logger().info("save_mini_batch_results is enabled")
if (self._configuration.split_output):
output_handler = SeparateFileOutputHandler()
lu.get_logger().info("Saving successful results and errors to separate files")
else:
output_handler = SingleFileOutputHandler()
lu.get_logger().info("Saving results to single file")
output_handler.save_mini_batch_results(
self._output_handler.save_mini_batch_results(
scoring_results,
self._configuration.mini_batch_results_out_directory,
mini_batch_context.raw_mini_batch_context

Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""This file contains the data class for Azure OpenAI scoring error."""
from dataclasses import dataclass
@dataclass
class AoaiScoringError:
"""Azure OpenAI scoring error."""
code: str = None
message: str = None

Просмотреть файл

@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""This file contains the data class for Azure OpenAI scoring response."""
from dataclasses import dataclass
@dataclass
class AoaiScoringResponse:
"""Azure OpenAI scoring response."""
body: any = None
request_id: str = None
status_code: int = None

Просмотреть файл

@ -32,6 +32,7 @@ class HttpResponseHandler:
def _create_scoring_result(
self,
status: ScoringResultStatus,
model_response_code: int,
scoring_request: ScoringRequest,
start: float,
end: float,
@ -43,6 +44,7 @@ class HttpResponseHandler:
status=status,
start=start,
end=end,
model_response_code=model_response_code,
request_obj=scoring_request.original_payload_obj,
request_metadata=scoring_request.request_metadata,
response_body=http_post_response.payload,

Просмотреть файл

@ -17,6 +17,7 @@ class ScoringRequest:
__BATCH_REQUEST_METADATA = "_batch_request_metadata"
__REQUEST_METADATA = "request_metadata"
__CUSTOM_ID = "custom_id"
def __init__(
self,
@ -55,6 +56,8 @@ class ScoringRequest:
# These properties do not need to be sent to the model & will be added to the output file directly
self.__request_metadata = self.__cleaned_payload_obj.pop(self.__BATCH_REQUEST_METADATA, None)
self.__request_metadata = self.__cleaned_payload_obj.pop(self.__REQUEST_METADATA, self.__request_metadata)
# If custom_id exists (V2 input schema), make sure it is not sent to MIR endpoint
self.__CUSTOM_ID = self.__cleaned_payload_obj.pop(self.__CUSTOM_ID, None)
self.__cleaned_payload = json.dumps(self.__cleaned_payload_obj, cls=BatchComponentJSONEncoder)
self.__loggable_payload = json.dumps(self.__loggable_payload_obj, cls=BatchComponentJSONEncoder)
@ -136,6 +139,12 @@ class ScoringRequest:
"""Get the segment id."""
return self.__segment_id
# read-only
@property
def custom_id(self) -> str:
"""Get the custom id. Only valid for V2 input schema."""
return self.__CUSTOM_ID
@estimated_cost.setter
def estimated_cost(self, cost: int):
"""Set the estimated cost."""

Просмотреть файл

@ -58,6 +58,7 @@ class ScoringResult:
status: ScoringResultStatus,
start: float,
end: float,
model_response_code: int,
request_obj: any,
request_metadata: any,
response_body: any,
@ -70,6 +71,7 @@ class ScoringResult:
self.status = status
self.start = start
self.end = end
self.model_response_code = model_response_code
self.request_obj = request_obj # Normalize to json
self.request_metadata = request_metadata
self.response_body = response_body
@ -121,6 +123,7 @@ class ScoringResult:
status=ScoringResultStatus.FAILURE,
start=0,
end=0,
model_response_code=None,
request_obj=scoring_request.original_payload_obj if scoring_request else None,
request_metadata=scoring_request.request_metadata if scoring_request else None,
response_body=None,
@ -140,6 +143,7 @@ class ScoringResult:
self.status,
self.start,
self.end,
self.model_response_code,
self.request_obj,
self.request_metadata,
deepcopy(self.response_body),

Просмотреть файл

@ -90,6 +90,7 @@ def init():
global par
global configuration
global input_handler
global output_handler
start = time.time()
parser = ConfigurationParserFactory().get_parser()
@ -103,6 +104,13 @@ def init():
else:
raise ValueError(f"Invalid input_schema_version: {configuration.input_schema_version}")
if (configuration.split_output):
output_handler = SeparateFileOutputHandler()
lu.get_logger().info("Will save successful results and errors to separate files")
else:
output_handler = SingleFileOutputHandler()
lu.get_logger().info("Will save all results to a single file")
event_utils.setup_context_vars(configuration, metadata)
setup_geneva_event_handlers()
setup_job_log_event_handlers()
@ -147,6 +155,7 @@ def init():
if configuration.async_mode:
callback_factory = CallbackFactory(
configuration=configuration,
output_handler=output_handler,
input_to_output_transformer=input_to_output_transformer)
finished_callback = callback_factory.generate_callback()
@ -201,13 +210,6 @@ def run(input_data: pd.DataFrame, mini_batch_context):
try:
ret = par.run(data_list, mini_batch_context)
if (configuration.split_output):
output_handler = SeparateFileOutputHandler()
lu.get_logger().info("Saving successful results and errors to separate files")
else:
output_handler = SingleFileOutputHandler()
lu.get_logger().info("Saving results to single file")
if configuration.save_mini_batch_results == "enabled":
lu.get_logger().info("save_mini_batch_results is enabled")
output_handler.save_mini_batch_results(

Просмотреть файл

@ -1,7 +1,7 @@
{
"component_version": "1.1.5",
"component_version": "1.1.6",
"component_directory": "driver/batch_score_llm",
"component_name": "batch_score_llm",
"virtual_environment_name": null,
"registry_name": "azureml"
}
}

Просмотреть файл

@ -104,6 +104,7 @@ class MirHttpResponseHandler(HttpResponseHandler):
if response_status == 200:
result = self._create_scoring_result(
status=ScoringResultStatus.SUCCESS,
model_response_code=response_status,
scoring_request=scoring_request,
start=start,
end=end,
@ -124,6 +125,7 @@ class MirHttpResponseHandler(HttpResponseHandler):
else: # Score failed
result = self._create_scoring_result(
status=ScoringResultStatus.FAILURE,
model_response_code=response_status,
scoring_request=scoring_request,
start=start,
end=end,

Просмотреть файл

@ -3,14 +3,9 @@
"""Common utilities."""
import json
from argparse import ArgumentParser
from urllib.parse import urlparse
from ..common.scoring.scoring_result import ScoringResult
from . import embeddings_utils as embeddings
from .json_encoder_extensions import BatchComponentJSONEncoder
def get_base_url(url: str) -> str:
"""Get base url."""
@ -38,39 +33,7 @@ def str2bool(v):
raise ArgumentParser.ArgumentTypeError('Boolean value expected.')
def convert_result_list(results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
"""Convert scoring results to the result list."""
output_list: list[dict[str, str]] = []
for scoringResult in results:
output: dict[str, str] = {}
output["status"] = scoringResult.status.name
output["start"] = scoringResult.start
output["end"] = scoringResult.end
output["request"] = scoringResult.request_obj
output["response"] = scoringResult.response_body
if scoringResult.segmented_response_bodies is not None and len(scoringResult.segmented_response_bodies) > 0:
output["segmented_responses"] = scoringResult.segmented_response_bodies
if scoringResult.request_metadata is not None:
output["request_metadata"] = scoringResult.request_metadata
if batch_size_per_request > 1:
batch_output_list = embeddings._convert_to_list_of_output_items(
output,
scoringResult.estimated_token_counts)
output_list.extend(batch_output_list)
else:
output_list.append(output)
return list(map(__stringify_output, output_list))
def get_mini_batch_id(mini_batch_context: any):
"""Get mini batch id from mini batch context."""
if mini_batch_context:
return mini_batch_context.mini_batch_id
def __stringify_output(payload_obj: dict) -> str:
return json.dumps(payload_obj, cls=BatchComponentJSONEncoder)

Просмотреть файл

@ -3,15 +3,8 @@
"""Embeddings utilities."""
from copy import deepcopy
import pandas as pd
from ..batch_pool.quota.estimators import EmbeddingsEstimator
from ..common.telemetry import logging_utils as lu
estimator = None
def _convert_to_list_of_input_batches(
data: pd.DataFrame,
@ -32,145 +25,3 @@ def _convert_to_list_of_input_batches(
payload_obj = {"input": list_of_strings}
list_of_input_batches.append(payload_obj)
return list_of_input_batches
def _convert_to_list_of_output_items(result: dict, token_count_estimates: "tuple[int]") -> "list[dict]":
"""Convert results to a list of output items."""
"""
Only used by the Embeddings API with batched HTTP requests, this method takes
a scoring result as dictionary of "request", "response", and (optional) "request_metadata".
It returns the batch list within request and response mapped out into a list of dictionaries,
each with the correlating request and response from the batch.
If the scoring result has request metadata, this is persisted in each of the
output dictionaries.
Args:
result: The scoring result containing a batch of inputs and outputs.
token_count_estimates: The tuple of tiktoken estimates for each input in the batch.
Returns:
List of output objects, each with "request", "response", (optional) "request_metadata".
"""
output_list = []
response_obj = result["response"]
try:
response_data = response_obj.pop('data', None)
except AttributeError:
response_data = None
request = result["request"]
numrequests = len(request["input"])
if response_data is not None:
# Result has data; response_obj["data"]
numresults = len(response_data)
__validate_response_data_length(numrequests, numresults)
output_index_to_embedding_info_map = __build_output_idx_to_embedding_mapping(response_data)
update_prompt_tokens = __tiktoken_estimates_succeeded(token_count_estimates, numrequests)
if not update_prompt_tokens:
# Single online endpoints will not have computed token estimates, as this occurs in quota client.
token_count_estimates = __tiktoken_estimates_retry(request)
update_prompt_tokens = __tiktoken_estimates_succeeded(token_count_estimates, numrequests)
else:
# Result has error; response_obj["error"]. Copy this for each request in batch below.
numresults = -1
error_message = "The batch request resulted in an error. See job output for the error message."
lu.get_logger().error(error_message)
# Input can be large. Pop and iterate through the batch to avoid copying repeatedly.
input_batch = request.pop('input', None)
for i in range(numrequests):
# The large "input" from request and "data" from response have been popped so copy is smaller.
# "input" and "data" are set for each below.
single_output = {"request": deepcopy(request), "response": deepcopy(response_obj)}
single_output["request"]["input"] = input_batch[i]
if numresults > -1:
single_output["response"]["data"] = [output_index_to_embedding_info_map[i]]
if update_prompt_tokens:
__override_prompt_tokens(single_output, token_count_estimates[i])
output_list.append(single_output)
return output_list
def __build_output_idx_to_embedding_mapping(response_data):
"""Build a mapping from output index to embedding."""
"""
Given response data, return a dictionary of the index and embedding info for each element of the batch.
Unsure if the responses are always in the correct order by input index, ensure output order by mapping out index.
Args:
response_data: The list of outputs from the 'data' of API response.
Returns:
Dict mapping index to embedding info.
"""
return {embedding_info['index']: embedding_info for embedding_info in response_data}
def __override_prompt_tokens(output_obj, token_count):
"""
Set the token_count as the value for `prompt_tokens` in response's usage info.
Args:
output_obj: The dictionary of info for response, request
token_count: The tiktoken count for this input string
"""
try:
output_obj["response"]["usage"]["prompt_tokens"] = token_count
except Exception as exc:
lu.get_logger().exception("Unable to set prompt token override.")
raise exc
def __tiktoken_estimates_succeeded(token_count_estimates: "tuple[int]", input_length: int) -> bool:
"""
Return True if the length of the batch of inputs matches the length of the tiktoken estimates.
Args:
token_count_estimates: The tuple of tiktoken estimates for the inputs in this batch
input_length: The length of inputs in this batch
"""
token_est_length = len(token_count_estimates)
length_matches = token_est_length == input_length
if not length_matches:
lu.get_logger().warn(f"Input length {input_length} does not match token estimate length {token_est_length}. "
"Skipping prompt_tokens count overrides.")
return length_matches
def __tiktoken_estimates_retry(request_obj: dict) -> "tuple[int]":
"""
Return token counts for the inputs within a batch.
Args:
request_obj: The request dictionary.
"""
lu.get_logger().debug("Attempting to calculate tokens for the embedding input batch.")
global estimator
if estimator is None:
estimator = EmbeddingsEstimator()
token_counts = estimator.estimate_request_cost(request_obj)
if token_counts == 1:
# This occurs if tiktoken module fails. See DV3Estimator for more info on why this could fail.
return ()
else:
return token_counts
def __validate_response_data_length(numrequests, numresults):
"""
Validate the number of outputs from the API response matches the number of requests in the batch.
Args:
numrequests: The number of requests in this batch.
numresults: The number of results in the response data.
Raises:
Exception if response length and request length do not match.
"""
if numresults != numrequests:
error_message = f"Result data length {numresults} != " + \
f"{numrequests} request batch length."
lu.get_logger().error(error_message)
raise Exception(error_message)

Просмотреть файл

@ -0,0 +1,186 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""This file contains the definition for the output formatter."""
import json
from abc import ABC, abstractmethod
from ..common.scoring.scoring_result import ScoringResult
from .json_encoder_extensions import BatchComponentJSONEncoder
from ..batch_pool.quota.estimators import EmbeddingsEstimator
from ..common.telemetry import logging_utils as lu
class OutputFormatter(ABC):
"""An abstract class for formatting output."""
@abstractmethod
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
"""Abstract output formatting method."""
pass
@abstractmethod
def _get_response_obj(self, result: dict):
pass
@abstractmethod
def _get_custom_id(self, result: dict):
pass
@abstractmethod
def _get_single_output(self, *args):
pass
def _convert_to_list_of_output_items(self, result: dict, token_count_estimates: "tuple[int]") -> "list[dict]":
"""Convert results to a list of output items."""
"""
Only used by the Embeddings API with batched HTTP requests, this method takes
a scoring result as dictionary of "request", "response", and (optional) "request_metadata".
It returns the batch list within request and response mapped out into a list of dictionaries,
each with the correlating request and response from the batch.
If the scoring result has request metadata, this is persisted in each of the
output dictionaries.
Args:
result: The scoring result containing a batch of inputs and outputs.
token_count_estimates: The tuple of tiktoken estimates for each input in the batch.
Returns:
List of output objects, each with "request", "response", (optional) "request_metadata".
"""
output_list = []
response_obj = self._get_response_obj(result)
custom_id = self._get_custom_id(result)
format_version = 1 if custom_id is None else 2
try:
response_data = response_obj.pop('data', None)
except AttributeError:
response_data = None
request = result["request"]
numrequests = len(request["input"])
if response_data is not None:
# Result has data; response_obj["data"]
numresults = len(response_data)
self.__validate_response_data_length(numrequests, numresults)
output_index_to_embedding_info_map = self.__build_output_idx_to_embedding_mapping(response_data)
update_prompt_tokens = self.__tiktoken_estimates_succeeded(token_count_estimates, numrequests)
if not update_prompt_tokens:
# Single online endpoints will not have computed token estimates, as this occurs in quota client.
token_count_estimates = self.__tiktoken_estimates_retry(request)
update_prompt_tokens = self.__tiktoken_estimates_succeeded(token_count_estimates, numrequests)
else:
# Result has error; response_obj["error"]. Copy this for each request in batch below.
numresults = -1
error_message = "The batch request resulted in an error. See job output for the error message."
lu.get_logger().error(error_message)
# Input can be large. Pop and iterate through the batch to avoid copying repeatedly.
input_batch = request.pop('input', None)
for i in range(numrequests):
# The large "input" from request and "data" from response have been popped so copy is smaller.
# "input" and "data" are set for each below.
if format_version == 1:
single_output = self._get_single_output(request, response_obj, input_batch[i])
else:
single_output = self._get_single_output(custom_id, result)
if numresults > -1:
if format_version == 1:
single_output["response"]["data"] = [output_index_to_embedding_info_map[i]]
else:
single_output["response"]["body"]["data"] = [output_index_to_embedding_info_map[i]]
if update_prompt_tokens:
self.__override_prompt_tokens(single_output, token_count_estimates[i], format_version)
output_list.append(single_output)
return output_list
def __build_output_idx_to_embedding_mapping(self, response_data):
"""Build a mapping from output index to embedding."""
"""
Given response data, return a dictionary of the index and embedding info for each element of the batch.
Unsure if the responses are always in the correct order by input index,
ensure output order by mapping out index.
Args:
response_data: The list of outputs from the 'data' of API response.
Returns:
Dict mapping index to embedding info.
"""
return {embedding_info['index']: embedding_info for embedding_info in response_data}
def __override_prompt_tokens(self, output_obj, token_count, format_version):
"""
Set the token_count as the value for `prompt_tokens` in response's usage info.
Args:
output_obj: The dictionary of info for response, request
token_count: The tiktoken count for this input string
format_version: The output format version
"""
try:
if format_version == 1:
output_obj["response"]["usage"]["prompt_tokens"] = token_count
else:
output_obj["response"]["body"]["usage"]["prompt_tokens"] = token_count
except Exception as exc:
lu.get_logger().exception("Unable to set prompt token override.")
raise exc
def __tiktoken_estimates_succeeded(self, token_count_estimates: "tuple[int]", input_length: int) -> bool:
"""
Return True if the length of the batch of inputs matches the length of the tiktoken estimates.
Args:
token_count_estimates: The tuple of tiktoken estimates for the inputs in this batch
input_length: The length of inputs in this batch
"""
token_est_length = len(token_count_estimates)
length_matches = token_est_length == input_length
if not length_matches:
lu.get_logger().warn(f"Input length {input_length} does not match token estimate "
"length {token_est_length}. Skipping prompt_tokens count overrides.")
return length_matches
def __tiktoken_estimates_retry(self, request_obj: dict) -> "tuple[int]":
"""
Return token counts for the inputs within a batch.
Args:
request_obj: The request dictionary.
"""
lu.get_logger().debug("Attempting to calculate tokens for the embedding input batch.")
if self.estimator is None:
self.estimator = EmbeddingsEstimator()
token_counts = self.estimator.estimate_request_cost(request_obj)
if token_counts == 1:
# This occurs if tiktoken module fails. See DV3Estimator for more info on why this could fail.
return ()
else:
return token_counts
def __validate_response_data_length(self, numrequests, numresults):
"""
Validate the number of outputs from the API response matches the number of requests in the batch.
Args:
numrequests: The number of requests in this batch.
numresults: The number of results in the response data.
Raises:
Exception if response length and request length do not match.
"""
if numresults != numrequests:
error_message = f"Result data length {numresults} != " + \
f"{numrequests} request batch length."
lu.get_logger().error(error_message)
raise Exception(error_message)
def _stringify_output(self, payload_obj: dict) -> str:
return json.dumps(payload_obj, cls=BatchComponentJSONEncoder)

Просмотреть файл

@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This file contains the definition for the original (V1) output formatter.
V1 Output format:
{
"status": ["SUCCESS" | "FAILURE"],
"start": 1709584163.2691997,
"end": 1709584165.2570084,
"request": { <request_body> },
"response": { <response_body> }
}
"""
from copy import deepcopy
from .output_formatter import OutputFormatter
from ..common.scoring.scoring_result import ScoringResult
class V1OutputFormatter(OutputFormatter):
"""Defines a class to format output in V1 format."""
def __init__(self):
"""Initialize V1OutputFormatter."""
self.estimator = None
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
"""Format output in the V1 format."""
output_list: list[dict[str, str]] = []
for scoringResult in results:
output: dict[str, str] = {}
output["status"] = scoringResult.status.name
output["start"] = scoringResult.start
output["end"] = scoringResult.end
output["request"] = scoringResult.request_obj
output["response"] = scoringResult.response_body
if scoringResult.segmented_response_bodies is not None and \
len(scoringResult.segmented_response_bodies) > 0:
output["segmented_responses"] = scoringResult.segmented_response_bodies
if scoringResult.request_metadata is not None:
output["request_metadata"] = scoringResult.request_metadata
if batch_size_per_request > 1:
batch_output_list = self._convert_to_list_of_output_items(
output,
scoringResult.estimated_token_counts)
output_list.extend(batch_output_list)
else:
output_list.append(output)
return list(map(self._stringify_output, output_list))
def _get_response_obj(self, result: dict):
return result["response"]
def _get_custom_id(self, result: dict):
return None
def _get_request_id(self, result: dict):
return None
def _get_status(self, result: dict):
return result["status"]
def _get_single_output(self, request, response_obj, input_batch):
single_output = {"request": deepcopy(request), "response": deepcopy(response_obj)}
single_output["request"]["input"] = input_batch
return single_output

Просмотреть файл

@ -4,7 +4,6 @@
"""This file contains the definition for the new (V2) schema input handler."""
import pandas as pd
import json
from .input_handler import InputHandler
@ -20,8 +19,9 @@ class V2InputSchemaHandler(InputHandler):
"""Convert the new schema input pandas DataFrame to a list of payload strings."""
body_details = []
for _, row in data.iterrows():
body = json.loads(row['body'])
body = row['body']
del body['model']
body['custom_id'] = row['custom_id']
body_details.append(body)
original_schema_df = pd.DataFrame(body_details)
return self._convert_to_list(original_schema_df, additional_properties, batch_size_per_request)

Просмотреть файл

@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This file contains the definition for the new (V2) output formatter.
V2 Output format:
{
"custom_id": <custom_id>,
"request_id": "", // MIR endpoint request id?
"status": <HTTP response status code>,
// If response is successful, "response" should have the response body and "error" should be null,
// and vice versa for a failed response.
"response": { <response_body> | null },
"error": { null | <response_body> }
}
"""
from copy import deepcopy
from .output_formatter import OutputFormatter
from ..common.scoring.aoai_error import AoaiScoringError
from ..common.scoring.aoai_response import AoaiScoringResponse
from ..common.scoring.scoring_result import ScoringResult
class V2OutputFormatter(OutputFormatter):
"""Defines a class to format output in V2 format."""
def __init__(self):
"""Initialize V2OutputFormatter."""
self.estimator = None
def format_output(self, results: "list[ScoringResult]", batch_size_per_request: int) -> "list[str]":
"""Format output in the V2 format."""
output_list: list[dict[str, str]] = []
for scoringResult in results:
output: dict[str, str] = {}
keys = scoringResult.request_obj.keys()
if "custom_id" in keys:
output["custom_id"] = scoringResult.request_obj["custom_id"]
else:
raise Exception("V2OutputFormatter called and custom_id not found"
"in request object (original payload)")
if scoringResult.status.name == "SUCCESS":
response = AoaiScoringResponse(request_id=self.__get_request_id(scoringResult),
status_code=scoringResult.model_response_code,
body=deepcopy(scoringResult.response_body))
output["response"] = vars(response)
output["error"] = None
else:
error = AoaiScoringError(message=deepcopy(scoringResult.response_body))
output["response"] = None
output["error"] = vars(error)
if batch_size_per_request > 1:
# _convert_to_list_of_output_items() expects output["request"] to be set.
output["request"] = scoringResult.request_obj
batch_output_list = self._convert_to_list_of_output_items(
output,
scoringResult.estimated_token_counts)
output_list.extend(batch_output_list)
else:
output_list.append(output)
return list(map(self._stringify_output, output_list))
def __get_request_id(self, scoring_request: ScoringResult):
return scoring_request.response_headers.get("x-request-id", "")
def _get_response_obj(self, result: dict):
if result.get("response") is None:
return result.get("error")
else:
return result.get("response").get("body")
def _get_custom_id(self, result: dict) -> str:
return result.get("custom_id", "")
def _get_single_output(self, custom_id, result):
single_output = {
"id": "", # TODO: populate this ID
"custom_id": deepcopy(custom_id),
"response": deepcopy(result["response"]),
"error": deepcopy(result["error"])
}
return single_output

Просмотреть файл

@ -68,6 +68,7 @@ YAML_DISALLOW_FAILED_REQUESTS = {"jobs": {JOB_NAME: {
# This test confirms that we can score an MIR endpoint using the scoring_url parameter and
# the batch_score_llm.yml component.
@pytest.mark.skip('Tempararily disabled until the test endpoint is created.')
@pytest.mark.smoke
@pytest.mark.e2e
@pytest.mark.timeout(20 * 60)

Просмотреть файл

@ -74,6 +74,7 @@ def mock_run(monkeypatch):
async def _run(self, requests: "list[ScoringRequest]") -> "list[ScoringResult]":
passed_requests.extend(requests)
return [ScoringResult(status=ScoringResultStatus.SUCCESS,
model_response_code=200,
response_body={"usage": {}},
omit=False,
start=0,

Просмотреть файл

@ -19,6 +19,7 @@ def make_scoring_result():
"""Mock scoring result."""
def make(
status: ScoringResultStatus = ScoringResultStatus.SUCCESS,
model_response_code: int = 200,
start: float = time.time() - 10,
end: float = time.time(),
request_obj: any = None,
@ -30,6 +31,7 @@ def make_scoring_result():
"""Make a mock scoring result."""
return ScoringResult(
status=status,
model_response_code=model_response_code,
start=start,
end=end,
request_obj=request_obj,

Просмотреть файл

@ -12,6 +12,7 @@ from src.batch_score.common.post_processing.callback_factory import CallbackFact
from src.batch_score.common.post_processing.mini_batch_context import MiniBatchContext
from src.batch_score.common.scoring.scoring_result import ScoringResult
from src.batch_score.common.telemetry.events import event_utils
from src.batch_score.common.post_processing.output_handler import SingleFileOutputHandler, SeparateFileOutputHandler
from tests.fixtures.input_transformer import FakeInputOutputModifier
from tests.fixtures.scoring_result import get_test_request_obj
from tests.fixtures.test_mini_batch_context import TestMiniBatchContext
@ -34,6 +35,7 @@ def test_generate_callback_success(mock_get_logger,
callback_factory = CallbackFactory(
configuration=_get_test_configuration(),
output_handler=SingleFileOutputHandler(),
input_to_output_transformer=mock_input_to_output_transformer)
callbacks = callback_factory.generate_callback()
@ -69,6 +71,7 @@ def test_generate_callback_exception_with_mini_batch_id(mock_get_logger,
callback_factory = CallbackFactory(
configuration=_get_test_configuration(),
output_handler=SingleFileOutputHandler(),
input_to_output_transformer=mock_input_to_output_transformer)
callbacks = callback_factory.generate_callback()
@ -109,21 +112,26 @@ def test_output_handler(
scoring_result = make_scoring_result(request_obj=get_test_request_obj())
gathered_result: list[ScoringResult] = [scoring_result.copy(), scoring_result.copy()]
test_configuration = _get_test_configuration_for_output_handler(split_output)
callback_factory = CallbackFactory(
configuration=test_configuration,
input_to_output_transformer=mock_input_to_output_transformer)
with patch(
"src.batch_score.common.post_processing.callback_factory.SeparateFileOutputHandler",
"tests.unit.common.post_processing.test_callback_factory.SeparateFileOutputHandler",
return_value=MagicMock()
) as mock_separate_file_output_handler, \
patch(
"src.batch_score.common.post_processing.callback_factory.SingleFileOutputHandler",
"tests.unit.common.post_processing.test_callback_factory.SingleFileOutputHandler",
return_value=MagicMock()
) as mock_single_file_output_handler:
test_configuration = _get_test_configuration_for_output_handler(split_output)
if test_configuration.split_output:
output_handler = SeparateFileOutputHandler()
else:
output_handler = SingleFileOutputHandler()
callback_factory = CallbackFactory(
configuration=test_configuration,
output_handler=output_handler,
input_to_output_transformer=mock_input_to_output_transformer)
callbacks = callback_factory.generate_callback()
_ = callbacks(gathered_result, mini_batch_context)

Просмотреть файл

@ -40,7 +40,7 @@ async def test_score(response_status, response_body, exception_to_raise):
http_response_handler=http_response_handler,
scoring_url=None)
scoring_request = ScoringRequest(original_payload='{"prompt":"Test model"}')
scoring_request = ScoringRequest(original_payload='{"custom_id": "task_123", "prompt":"Test model"}')
async with aiohttp.ClientSession() as session:
with patch.object(session, "post") as mock_post:
@ -62,6 +62,8 @@ async def test_score(response_status, response_body, exception_to_raise):
assert http_response_handler.handle_response.assert_called_once
response_sent_to_handler = http_response_handler.handle_response.call_args.kwargs['http_response']
assert "custom_id" not in scoring_client._create_http_request(scoring_request).payload
if exception_to_raise:
assert type(response_sent_to_handler.exception) is exception_to_raise
elif response_status == 200:

Просмотреть файл

@ -18,6 +18,7 @@ def test_copy():
ScoringResultStatus.SUCCESS,
0,
1,
200,
"request_obj",
{},
response_body={"usage": {"prompt_tokens": 2,
@ -48,6 +49,7 @@ def test_copy():
assert result2.response_body["usage"]["total_tokens"] == 16
assert result.status == result2.status
assert result.model_response_code == result2.model_response_code
assert result2.estimated_token_counts == (1, 2, 3)
@ -79,6 +81,7 @@ def test_usage_statistics(
ScoringResultStatus.SUCCESS,
0,
1,
200,
"request_obj",
{},
response_body=response_usage,

Просмотреть файл

@ -17,7 +17,9 @@ from src.batch_score.common.scoring.scoring_result import (
ScoringResult,
ScoringResultStatus,
)
from src.batch_score.utils import common
from src.batch_score.batch_pool.quota.estimators import EmbeddingsEstimator
from src.batch_score.utils.v1_output_formatter import V1OutputFormatter
from src.batch_score.utils.v2_output_formatter import V2OutputFormatter
from src.batch_score.utils.v1_input_schema_handler import V1InputSchemaHandler
from src.batch_score.utils.v2_input_schema_handler import V2InputSchemaHandler
@ -42,29 +44,53 @@ VALID_DATAFRAMES = [
NEW_SCHEMA_VALID_DATAFRAMES = [
[
[
{"task_id": "task_123", "method": "POST", "url": "/v1/completions",
"body": '{"model": "chat-sahara-4", "max_tokens": 1}'},
{"task_id": "task_789", "method": "POST", "url": "/v1/completions",
"body": '{"model": "chat-sahara-4", "max_tokens": 2}'}
{"custom_id": "task_123", "method": "POST", "url": "/v1/completions",
"body": {"model": "chat-sahara-4", "max_tokens": 1}},
{"custom_id": "task_789", "method": "POST", "url": "/v1/completions",
"body": {"model": "chat-sahara-4", "max_tokens": 2}}
],
[
'{"max_tokens": 1}', '{"max_tokens": 2}'
'{"max_tokens": 1, "custom_id": "task_123"}', '{"max_tokens": 2, "custom_id": "task_789"}'
]
],
[
[
{"task_id": "task_123", "method": "POST", "url": "/v1/completions", "body": '{"model": "chat-sahara-4", \
"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, \
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'},
{"task_id": "task_456", "method": "POST", "url": "/v1/completions", "body": '{"model": "chat-sahara-4", \
"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, \
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'}
{
"custom_id": "task_123",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "chat-sahara-4",
"temperature": 0,
"max_tokens": 1024,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."
}
},
{
"custom_id": "task_456",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "chat-sahara-4",
"temperature": 0,
"max_tokens": 1024,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."
}
}
],
[
('{"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0,'
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}'),
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human...",'
' "custom_id": "task_123"}'),
('{"temperature": 0, "max_tokens": 1024, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0,'
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human..."}')
' "prompt": "# You will be given a conversation between a chatbot called Sydney and Human...",'
' "custom_id": "task_456"}')
]
]
]
@ -153,95 +179,119 @@ def test_new_schema_convert_input_to_requests_happy_path(obj_list: "list[any]",
assert result == expected_result
@pytest.mark.parametrize("tiktoken_failed",
[True, False])
def test_convert_result_list_batch_size_one(tiktoken_failed):
@pytest.mark.parametrize("input_schema_version", [1, 2])
@pytest.mark.parametrize("tiktoken_failed", [True, False])
def test_output_formatter_batch_size_one(input_schema_version, tiktoken_failed):
"""Test convert result list batch size one case."""
# Arrange
batch_size_per_request = 1
result_list = []
inputstring = __get_input_batch(batch_size_per_request)[0]
request_obj = {"input": inputstring, "custom_id": "task_123"}
outputlist = __get_output_data(batch_size_per_request)
result = __get_scoring_result_for_batch(batch_size_per_request,
inputstring,
request_obj,
outputlist,
tiktoken_failed=tiktoken_failed)
result_list.append(result)
# Act
actual = common.convert_result_list(result_list, batch_size_per_request)
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
actual = output_formatter.format_output(result_list, batch_size_per_request)
actual_obj = json.loads(actual[0])
# Assert
assert len(actual) == 1
assert actual_obj["status"] == "SUCCESS"
assert "start" in actual_obj
assert "end" in actual_obj
assert actual_obj["request"]["input"] == inputstring
assert actual_obj["response"]["usage"]["prompt_tokens"] == 1
assert actual_obj["response"]["usage"]["total_tokens"] == 1
if input_schema_version == 1:
assert actual_obj["status"] == "SUCCESS"
assert "start" in actual_obj
assert "end" in actual_obj
assert actual_obj["request"]["input"] == inputstring
assert actual_obj["response"]["usage"]["prompt_tokens"] == 1
assert actual_obj["response"]["usage"]["total_tokens"] == 1
elif input_schema_version == 2:
assert actual_obj["response"]["status_code"] == 200
assert actual_obj["response"]["body"]["usage"]["prompt_tokens"] == 1
assert actual_obj["response"]["body"]["usage"]["total_tokens"] == 1
@pytest.mark.parametrize("tiktoken_failed",
[True, False])
def test_convert_result_list_failed_result(tiktoken_failed):
@pytest.mark.parametrize("input_schema_version", [1, 2])
@pytest.mark.parametrize("tiktoken_failed", [True, False])
def test_output_formatter_failed_result(input_schema_version, tiktoken_failed):
"""Test convert result list failed result case."""
# Arrange
batch_size_per_request = 1
result_list = []
inputstring = __get_input_batch(batch_size_per_request)[0]
result = __get_failed_scoring_result_for_batch(inputstring, tiktoken_failed=tiktoken_failed)
request_obj = {"input": inputstring, "custom_id": "task_123"}
result = __get_failed_scoring_result_for_batch(request_obj, tiktoken_failed=tiktoken_failed)
result_list.append(result)
# Act
actual = common.convert_result_list(result_list, batch_size_per_request)
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
actual = output_formatter.format_output(result_list, batch_size_per_request)
actual_obj = json.loads(actual[0])
# Assert
assert len(actual) == 1
assert actual_obj["status"] == "FAILURE"
assert "start" in actual_obj
assert "end" in actual_obj
assert actual_obj["request"]["input"] == inputstring
assert actual_obj["response"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in actual_obj["response"]["error"]["message"]
if input_schema_version == 1:
assert actual_obj["status"] == "FAILURE"
assert "start" in actual_obj
assert "end" in actual_obj
assert actual_obj["request"]["input"] == inputstring
assert actual_obj["response"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in actual_obj["response"]["error"]["message"]
elif input_schema_version == 2: # TODO: Confirm this is the actual format for errors
assert actual_obj["error"]["message"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in actual_obj["error"]["message"]["error"]["message"]
@pytest.mark.parametrize("tiktoken_failed",
[True, False])
def test_convert_result_list_failed_result_batch(tiktoken_failed):
@pytest.mark.parametrize("input_schema_version", [1, 2])
@pytest.mark.parametrize("tiktoken_failed", [True, False])
def test_output_formatter_failed_result_batch(input_schema_version, tiktoken_failed):
"""Test convert result list failed result batch case."""
# Arrange
batch_size_per_request = 2
inputlist = __get_input_batch(batch_size_per_request)
result_list = [__get_failed_scoring_result_for_batch(inputlist, tiktoken_failed)]
request_obj = {"input": inputlist, "custom_id": "task_123"}
result_list = [__get_failed_scoring_result_for_batch(request_obj, tiktoken_failed)]
# Act
actual = common.convert_result_list(result_list, batch_size_per_request)
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
actual = output_formatter.format_output(result_list, batch_size_per_request)
# Assert
assert len(actual) == batch_size_per_request
for idx, result in enumerate(actual):
output_obj = json.loads(result)
assert output_obj["request"]["input"] == inputlist[idx]
assert output_obj["response"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in output_obj["response"]["error"]["message"]
if input_schema_version == 1:
assert output_obj["request"]["input"] == inputlist[idx]
assert output_obj["response"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in output_obj["response"]["error"]["message"]
elif input_schema_version == 2:
assert output_obj["response"] is None
assert output_obj["error"]["message"]["error"]["type"] == "invalid_request_error"
assert "maximum context length is 8190 tokens" in output_obj["error"]["message"]["error"]["message"]
@pytest.mark.parametrize("reorder_results, online_endpoint_url, tiktoken_fails",
[(True, True, True),
(True, True, False),
(True, False, True),
(True, False, False),
(False, True, True),
(False, True, False),
(False, False, True),
(False, False, False)])
def test_convert_result_list_batch_20(
@pytest.mark.parametrize("tiktoken_fails", [True, False])
@pytest.mark.parametrize("online_endpoint_url", [True, False])
@pytest.mark.parametrize("reorder_results", [True, False])
@pytest.mark.parametrize("input_schema_version", [1, 2])
def test_output_formatter_batch_20(
monkeypatch,
mock_get_logger,
input_schema_version,
reorder_results,
online_endpoint_url,
tiktoken_fails):
@ -256,22 +306,24 @@ def test_convert_result_list_batch_20(
inputlists = []
for n in range(full_batches):
inputlist = __get_input_batch(batch_size_per_request)
request_obj = {"input": inputlist, "custom_id": "task_123"}
outputlist = __get_output_data(batch_size_per_request)
result = __get_scoring_result_for_batch(
batch_size_per_request,
inputlist,
request_obj,
outputlist,
reorder_results,
online_endpoint_url or tiktoken_fails)
inputlists.extend(inputlist)
result_list.append(result)
inputlist = __get_input_batch(additional_rows)
request_obj = {"input": inputlist, "custom_id": "task_456"}
outputlist = __get_output_data(additional_rows)
inputlists.extend(inputlist)
result_list.append(__get_scoring_result_for_batch(
additional_rows,
inputlist,
request_obj,
outputlist,
reorder_results,
online_endpoint_url or tiktoken_fails))
@ -283,13 +335,18 @@ def test_convert_result_list_batch_20(
__mock_tiktoken_estimate(monkeypatch)
# Act
actual = common.convert_result_list(result_list, batch_size_per_request)
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
actual = output_formatter.format_output(result_list, batch_size_per_request)
# Assert
assert len(actual) == batch_size_per_request * full_batches + additional_rows
for idx, result in enumerate(actual):
output_obj = json.loads(result)
assert output_obj["request"]["input"] == inputlists[idx]
if input_schema_version == 1:
assert output_obj["request"]["input"] == inputlists[idx]
# Assign valid_batch_len for this result. This is the expected total_tokens.
if idx >= batch_size_per_request * full_batches:
@ -303,38 +360,48 @@ def test_convert_result_list_batch_20(
# Index values in `response.data` are from [0, batch_size_per_request -1]
valid_batch_idx = idx % batch_size_per_request
assert output_obj["response"]["data"][0]["index"] == valid_batch_idx
assert output_obj["response"]["usage"]["total_tokens"] == valid_batch_len
response_obj = output_obj["response"] if input_schema_version == 1 else output_obj["response"]["body"]
assert response_obj["data"][0]["index"] == valid_batch_idx
assert response_obj["usage"]["total_tokens"] == valid_batch_len
if tiktoken_fails:
# Prompt tokens will equal total tokens (equals batch length)
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_len
assert response_obj["usage"]["prompt_tokens"] == valid_batch_len
elif online_endpoint_url:
# Prompt tokens is batch index (see helper function: `__mock_tiktoken_estimate`)
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_idx
assert response_obj["usage"]["prompt_tokens"] == valid_batch_idx
else:
# Batch pool case; prompt tokens is 10 + batch index (see helper function: `__get_token_counts`)
assert output_obj["response"]["usage"]["prompt_tokens"] == valid_batch_idx + 10
assert response_obj["usage"]["prompt_tokens"] == valid_batch_idx + 10
def test_incorrect_data_length_raises():
@pytest.mark.parametrize("input_schema_version", [1, 2])
def test_incorrect_data_length_raises(input_schema_version):
"""Test incorrect data length raises."""
# Arrange
batch_size_per_request = 2
result_list = []
inputstring = __get_input_batch(batch_size_per_request)
request_obj = {"input": inputstring, "custom_id": "task_123"}
outputlist = __get_output_data(0)
result = __get_scoring_result_for_batch(batch_size_per_request, inputstring, outputlist)
result = __get_scoring_result_for_batch(batch_size_per_request, request_obj, outputlist)
result_list.append(result)
# Act
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
with pytest.raises(Exception) as excinfo:
common.convert_result_list(result_list, batch_size_per_request)
output_formatter.format_output(result_list, batch_size_per_request)
# Assert
assert "Result data length 0 != 2 request batch length." in str(excinfo.value)
def test_endpoint_response_is_not_json(mock_get_logger):
@pytest.mark.parametrize("input_schema_version", [1, 2])
def test_endpoint_response_is_not_json(input_schema_version, mock_get_logger):
"""Test endpoint response is not json."""
# Arrange failed response payload as a string
batch_size_per_request = 10
@ -344,11 +411,12 @@ def test_endpoint_response_is_not_json(mock_get_logger):
status=ScoringResultStatus.FAILURE,
start=0,
end=0,
model_response_code=400,
request_metadata="Not important",
response_headers="Headers",
num_retries=2,
token_counts=(1,) * batch_size_per_request,
request_obj={"input": inputs},
request_obj={"input": inputs, "custom_id": "task_123"},
response_body=json.dumps({"object": "list",
"error": {
"message": "This model's maximum context length is 8190 tokens, "
@ -360,28 +428,38 @@ def test_endpoint_response_is_not_json(mock_get_logger):
"code": None
}}))
# Act
actual = common.convert_result_list([result], batch_size_per_request)
if input_schema_version == 1:
output_formatter = V1OutputFormatter()
else:
output_formatter = V2OutputFormatter()
actual = output_formatter.format_output([result], batch_size_per_request)
# Assert
assert len(actual) == batch_size_per_request
for idx, result in enumerate(actual):
unit = json.loads(result)
assert unit['request']['input'] == inputs[idx]
assert type(unit['response']) is str
assert 'maximum context length' in unit['response']
if input_schema_version == 1:
assert unit['request']['input'] == inputs[idx]
assert type(unit['response']) is str
assert 'maximum context length' in unit['response']
elif input_schema_version == 2:
assert unit['response'] is None
assert type(unit['error']['message']) is str
assert 'maximum context length' in unit['error']['message']
def __get_scoring_result_for_batch(batch_size, inputlist, outputlist, reorder_results=False, tiktoken_failed=False):
token_counts = __get_token_counts(tiktoken_failed, inputlist)
def __get_scoring_result_for_batch(batch_size, request_obj, outputlist, reorder_results=False, tiktoken_failed=False):
token_counts = __get_token_counts(tiktoken_failed, request_obj["input"])
result = ScoringResult(
status=ScoringResultStatus.SUCCESS,
model_response_code=200,
start=0,
end=0,
request_metadata="Not important",
response_headers="Headers",
response_headers={"header1": "value"},
num_retries=2,
token_counts=token_counts,
request_obj={"input": inputlist},
request_obj=request_obj,
response_body={"object": "list",
"data": outputlist,
"model": "text-embedding-ada-002",
@ -394,17 +472,18 @@ def __get_scoring_result_for_batch(batch_size, inputlist, outputlist, reorder_re
return result
def __get_failed_scoring_result_for_batch(inputlist, tiktoken_failed=False):
token_counts = __get_token_counts(tiktoken_failed, inputlist)
def __get_failed_scoring_result_for_batch(request_obj, tiktoken_failed=False):
token_counts = __get_token_counts(tiktoken_failed, request_obj["input"])
result = ScoringResult(
status=ScoringResultStatus.FAILURE,
start=0,
end=0,
model_response_code=400,
request_metadata="Not important",
response_headers="Headers",
num_retries=2,
token_counts=token_counts,
request_obj={"input": inputlist},
request_obj=request_obj,
response_body={"object": "list",
"error": {
"message": "This model's maximum context length is 8190 tokens, "
@ -451,10 +530,10 @@ def __random_string(length: int = 10):
def __mock_tiktoken_permanent_failure(monkeypatch):
def mock_tiktoken_failure(*args):
return 1
monkeypatch.setattr(common.embeddings.EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_failure)
monkeypatch.setattr(EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_failure)
def __mock_tiktoken_estimate(monkeypatch):
def mock_tiktoken_override(estimator, request_obj):
return [i for i in range(len(request_obj['input']))]
monkeypatch.setattr(common.embeddings.EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_override)
monkeypatch.setattr(EmbeddingsEstimator, "estimate_request_cost", mock_tiktoken_override)

Просмотреть файл

@ -16,6 +16,7 @@ with patch('importlib.import_module', side_effect=mock_import):
from src.batch_score.batch_pool.meds_client import MEDSClient
from src.batch_score.common.common_enums import EndpointType
from src.batch_score.common.configuration.configuration import Configuration
from src.batch_score.common.post_processing.output_handler import SeparateFileOutputHandler, SingleFileOutputHandler
from src.batch_score.common.telemetry.events import event_utils
from src.batch_score.common.telemetry.trace_configs import (
ConnectionCreateEndTrace,
@ -103,20 +104,26 @@ def test_output_handler_interface(
use_single_file_output_handler: bool,
use_separate_file_output_handler: bool):
"""Test output handler interface."""
input_data, mini_batch_context = _setup_main()
with patch(
"src.batch_score.main.SeparateFileOutputHandler",
return_value=MagicMock()
"tests.unit.test_main.SeparateFileOutputHandler"
) as mock_separate_file_output_handler, \
patch(
"src.batch_score.main.SingleFileOutputHandler",
return_value=MagicMock()
"tests.unit.test_main.SingleFileOutputHandler"
) as mock_single_file_output_handler:
input_data, mini_batch_context = _setup_main()
main.configuration.split_output = split_output
main.configuration.save_mini_batch_results = "enabled"
main.configuration.mini_batch_results_out_directory = "driver/tests/unit/unit_test_results/"
if main.configuration.split_output:
main.output_handler = SeparateFileOutputHandler(
main.configuration.batch_size_per_request,
main.configuration.input_schema_version)
else:
main.output_handler = SingleFileOutputHandler(
main.configuration.batch_size_per_request,
main.configuration.input_schema_version)
main.run(input_data=input_data, mini_batch_context=mini_batch_context)
assert mock_separate_file_output_handler.called == use_separate_file_output_handler
@ -267,6 +274,7 @@ def _setup_main(par_exception=None):
configuration.scoring_url = "https://scoring_url"
configuration.batch_pool = "batch_pool"
configuration.quota_audience = "quota_audience"
configuration.input_schema_version = 1
main.configuration = configuration
main.input_handler = MagicMock()