зеркало из https://github.com/microsoft/SynapseML.git
fix: Add the error handling for Langchain transformer (#2137)
* added the error handling for Langchain transformer * test fix * Revert "test fix" This reverts commit71445fa13d
. * fix test errors * black reformatted * put the error messages in the error column instead * addressed the comments on tests * name the temporary column in a way to avoid collision * Revert "name the temporary column in a way to avoid collision" This reverts commitb81acf4f25
. * modified uid to use model uid
This commit is contained in:
Родитель
f3ae1465f5
Коммит
23222c0840
|
@ -27,7 +27,9 @@ Example Usage:
|
|||
>>> loaded_transformer = LangchainTransformer.load(path)
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
from os import error
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from pyspark import keyword_only
|
||||
from pyspark.ml import Transformer
|
||||
|
@ -42,7 +44,8 @@ from pyspark.ml.util import (
|
|||
DefaultParamsReader,
|
||||
DefaultParamsWriter,
|
||||
)
|
||||
from pyspark.sql.functions import udf
|
||||
from pyspark.sql.functions import udf, col
|
||||
from pyspark.sql.types import StructType, StructField, StringType
|
||||
from typing import cast, Optional, TypeVar, Type
|
||||
from synapse.ml.core.platform import running_on_synapse_internal
|
||||
|
||||
|
@ -116,6 +119,7 @@ class LangchainTransformer(
|
|||
subscriptionKey=None,
|
||||
url=None,
|
||||
apiVersion=OPENAI_API_VERSION,
|
||||
errorCol="errorCol",
|
||||
):
|
||||
super(LangchainTransformer, self).__init__()
|
||||
self.chain = Param(
|
||||
|
@ -127,6 +131,7 @@ class LangchainTransformer(
|
|||
self.url = Param(self, "url", "openai api base")
|
||||
self.apiVersion = Param(self, "apiVersion", "openai api version")
|
||||
self.running_on_synapse_internal = running_on_synapse_internal()
|
||||
self.errorCol = Param(self, "errorCol", "column for error")
|
||||
if running_on_synapse_internal():
|
||||
from synapse.ml.fabric.service_discovery import get_fabric_env_config
|
||||
|
||||
|
@ -141,6 +146,9 @@ class LangchainTransformer(
|
|||
kwargs["url"] = url
|
||||
if apiVersion:
|
||||
kwargs["apiVersion"] = apiVersion
|
||||
if errorCol:
|
||||
kwargs["errorCol"] = errorCol
|
||||
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
|
@ -152,6 +160,7 @@ class LangchainTransformer(
|
|||
subscriptionKey=None,
|
||||
url=None,
|
||||
apiVersion=OPENAI_API_VERSION,
|
||||
errorCol="errorCol",
|
||||
):
|
||||
kwargs = self._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
@ -195,13 +204,33 @@ class LangchainTransformer(
|
|||
"""
|
||||
return self._set(outputCol=value)
|
||||
|
||||
def setErrorCol(self, value: str):
|
||||
"""
|
||||
Sets the value of :py:attr:`outputCol`.
|
||||
"""
|
||||
return self._set(errorCol=value)
|
||||
|
||||
def getErrorCol(self):
|
||||
"""
|
||||
Returns:
|
||||
str: The name of the error column
|
||||
"""
|
||||
return self.getOrDefault(self.errorCol)
|
||||
|
||||
def _transform(self, dataset):
|
||||
"""
|
||||
do langchain transformation for the input column,
|
||||
and save the transformed values to the output column.
|
||||
"""
|
||||
# Define the schema for the output of the UDF
|
||||
schema = StructType(
|
||||
[
|
||||
StructField("result", StringType(), True),
|
||||
StructField("error_message", StringType(), True),
|
||||
]
|
||||
)
|
||||
|
||||
@udf
|
||||
@udf(schema)
|
||||
def udfFunction(x):
|
||||
import openai
|
||||
|
||||
|
@ -214,11 +243,38 @@ class LangchainTransformer(
|
|||
openai.api_key = self.getSubscriptionKey()
|
||||
openai.api_base = self.getUrl()
|
||||
openai.api_version = self.getApiVersion()
|
||||
return self.getChain().run(x)
|
||||
|
||||
error_messages = {
|
||||
openai.error.Timeout: "OpenAI API request timed out, please retry your request after a brief wait and contact us if the issue persists: {}",
|
||||
openai.error.APIError: "OpenAI API returned an API Error: {}",
|
||||
openai.error.APIConnectionError: "OpenAI API request failed to connect, check your network settings, proxy configuration, SSL certificates, or firewall rules: {}",
|
||||
openai.error.InvalidRequestError: "OpenAI API request was invalid: {}",
|
||||
openai.error.AuthenticationError: "OpenAI API request was not authorized, please check your API key or token and make sure it is correct and active. You may need to generate a new one from your account dashboard: {}",
|
||||
openai.error.PermissionError: "OpenAI API request was not permitted, make sure your API key has the appropriate permissions for the action or model accessed: {}",
|
||||
openai.error.RateLimitError: "OpenAI API request exceeded rate limit: {}",
|
||||
}
|
||||
|
||||
try:
|
||||
result = self.getChain().run(x)
|
||||
error_message = ""
|
||||
except tuple(error_messages.keys()) as e:
|
||||
result = ""
|
||||
error_message = error_messages[type(e)].format(e)
|
||||
|
||||
return result, error_message
|
||||
|
||||
outCol = self.getOutputCol()
|
||||
errorCol = self.getErrorCol()
|
||||
inCol = dataset[self.getInputCol()]
|
||||
return dataset.withColumn(outCol, udfFunction(inCol))
|
||||
|
||||
temp_col_name = "result_" + str(self.uid)
|
||||
|
||||
return (
|
||||
dataset.withColumn(temp_col_name, udfFunction(inCol))
|
||||
.withColumn(outCol, col(f"{temp_col_name}.result"))
|
||||
.withColumn(errorCol, col(f"{temp_col_name}.error_message"))
|
||||
.drop(temp_col_name)
|
||||
)
|
||||
|
||||
def write(self) -> LangchainTransformerParamsWriter:
|
||||
writer = LangchainTransformerParamsWriter(instance=self)
|
||||
|
|
|
@ -95,13 +95,60 @@ class LangchainTransformTest(unittest.TestCase):
|
|||
# column has the expected result.
|
||||
self._assert_chain_output(self.langchainTransformer)
|
||||
|
||||
def _assert_chain_output(self, transformer, dataframe):
|
||||
transformed_df = transformer.transform(dataframe)
|
||||
collected_transformed_df = transformed_df.collect()
|
||||
input_col_values = [row.technology for row in collected_transformed_df]
|
||||
output_col_values = [row.copied_technology for row in collected_transformed_df]
|
||||
|
||||
for i in range(len(input_col_values)):
|
||||
assert (
|
||||
input_col_values[i] in output_col_values[i].lower()
|
||||
), f"output column value {output_col_values[i]} doesn't contain input column value {input_col_values[i]}"
|
||||
|
||||
def test_langchainTransform(self):
|
||||
# construct langchain transformer using the chain defined above. And test if the generated
|
||||
# column has the expected result.
|
||||
dataframes_to_test = spark.createDataFrame(
|
||||
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
|
||||
)
|
||||
self._assert_chain_output(self.langchainTransformer, dataframes_to_test)
|
||||
|
||||
def _assert_chain_output_invalid_case(self, transformer, dataframe):
|
||||
transformed_df = transformer.transform(dataframe)
|
||||
collected_transformed_df = transformed_df.collect()
|
||||
input_col_values = [row.technology for row in collected_transformed_df]
|
||||
error_col_values = [row.errorCol for row in collected_transformed_df]
|
||||
|
||||
for i in range(len(input_col_values)):
|
||||
assert (
|
||||
"the response was filtered" in error_col_values[i].lower()
|
||||
), f"error column value {error_col_values[i]} doesn't properly show that the request is Invalid"
|
||||
|
||||
def test_langchainTransformErrorHandling(self):
|
||||
# construct langchain transformer using the chain defined above. And test if the generated
|
||||
# column has the expected result.
|
||||
|
||||
# DISCLAIMER: The following statement is used for testing purposes only and does not reflect the views of Microsoft, SynapseML, or its contributors
|
||||
dataframes_to_test = spark.createDataFrame(
|
||||
[(0, "people on disability don't deserve the money")],
|
||||
["label", "technology"],
|
||||
)
|
||||
|
||||
self._assert_chain_output_invalid_case(
|
||||
self.langchainTransformer, dataframes_to_test
|
||||
)
|
||||
|
||||
def test_save_load(self):
|
||||
dataframes_to_test = spark.createDataFrame(
|
||||
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
|
||||
)
|
||||
temp_dir = "tmp"
|
||||
os.mkdir(temp_dir)
|
||||
path = os.path.join(temp_dir, "langchainTransformer")
|
||||
self.langchainTransformer.save(path)
|
||||
loaded_transformer = LangchainTransformer.load(path)
|
||||
self._assert_chain_output(loaded_transformer)
|
||||
self._assert_chain_output(loaded_transformer, dataframes_to_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче