fix: Add the error handling for Langchain transformer (#2137)

* added the error handling for Langchain transformer

* test fix

* Revert "test fix"

This reverts commit 71445fa13d.

* fix test errors

* black reformatted

* put the error messages in the error column instead

* addressed the comments on tests

* name the temporary column in a way to avoid collision

* Revert "name the temporary column in a way to avoid collision"

This reverts commit b81acf4f25.

* modified uid to use model uid
This commit is contained in:
sherylZhaoCode 2023-11-21 04:26:21 -08:00 коммит произвёл GitHub
Родитель f3ae1465f5
Коммит 23222c0840
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 108 добавлений и 5 удалений

Просмотреть файл

@ -27,7 +27,9 @@ Example Usage:
>>> loaded_transformer = LangchainTransformer.load(path)
"""
import json
from os import error
from langchain.chains.loading import load_chain_from_config
from pyspark import keyword_only
from pyspark.ml import Transformer
@ -42,7 +44,8 @@ from pyspark.ml.util import (
DefaultParamsReader,
DefaultParamsWriter,
)
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StructType, StructField, StringType
from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal
@ -116,6 +119,7 @@ class LangchainTransformer(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
):
super(LangchainTransformer, self).__init__()
self.chain = Param(
@ -127,6 +131,7 @@ class LangchainTransformer(
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.running_on_synapse_internal = running_on_synapse_internal()
self.errorCol = Param(self, "errorCol", "column for error")
if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config
@ -141,6 +146,9 @@ class LangchainTransformer(
kwargs["url"] = url
if apiVersion:
kwargs["apiVersion"] = apiVersion
if errorCol:
kwargs["errorCol"] = errorCol
self.setParams(**kwargs)
@keyword_only
@ -152,6 +160,7 @@ class LangchainTransformer(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
):
kwargs = self._input_kwargs
return self._set(**kwargs)
@ -195,13 +204,33 @@ class LangchainTransformer(
"""
return self._set(outputCol=value)
def setErrorCol(self, value: str):
"""
Sets the value of :py:attr:`outputCol`.
"""
return self._set(errorCol=value)
def getErrorCol(self):
"""
Returns:
str: The name of the error column
"""
return self.getOrDefault(self.errorCol)
def _transform(self, dataset):
"""
do langchain transformation for the input column,
and save the transformed values to the output column.
"""
# Define the schema for the output of the UDF
schema = StructType(
[
StructField("result", StringType(), True),
StructField("error_message", StringType(), True),
]
)
@udf
@udf(schema)
def udfFunction(x):
import openai
@ -214,11 +243,38 @@ class LangchainTransformer(
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
return self.getChain().run(x)
error_messages = {
openai.error.Timeout: "OpenAI API request timed out, please retry your request after a brief wait and contact us if the issue persists: {}",
openai.error.APIError: "OpenAI API returned an API Error: {}",
openai.error.APIConnectionError: "OpenAI API request failed to connect, check your network settings, proxy configuration, SSL certificates, or firewall rules: {}",
openai.error.InvalidRequestError: "OpenAI API request was invalid: {}",
openai.error.AuthenticationError: "OpenAI API request was not authorized, please check your API key or token and make sure it is correct and active. You may need to generate a new one from your account dashboard: {}",
openai.error.PermissionError: "OpenAI API request was not permitted, make sure your API key has the appropriate permissions for the action or model accessed: {}",
openai.error.RateLimitError: "OpenAI API request exceeded rate limit: {}",
}
try:
result = self.getChain().run(x)
error_message = ""
except tuple(error_messages.keys()) as e:
result = ""
error_message = error_messages[type(e)].format(e)
return result, error_message
outCol = self.getOutputCol()
errorCol = self.getErrorCol()
inCol = dataset[self.getInputCol()]
return dataset.withColumn(outCol, udfFunction(inCol))
temp_col_name = "result_" + str(self.uid)
return (
dataset.withColumn(temp_col_name, udfFunction(inCol))
.withColumn(outCol, col(f"{temp_col_name}.result"))
.withColumn(errorCol, col(f"{temp_col_name}.error_message"))
.drop(temp_col_name)
)
def write(self) -> LangchainTransformerParamsWriter:
writer = LangchainTransformerParamsWriter(instance=self)

Просмотреть файл

@ -95,13 +95,60 @@ class LangchainTransformTest(unittest.TestCase):
# column has the expected result.
self._assert_chain_output(self.langchainTransformer)
def _assert_chain_output(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
output_col_values = [row.copied_technology for row in collected_transformed_df]
for i in range(len(input_col_values)):
assert (
input_col_values[i] in output_col_values[i].lower()
), f"output column value {output_col_values[i]} doesn't contain input column value {input_col_values[i]}"
def test_langchainTransform(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
self._assert_chain_output(self.langchainTransformer, dataframes_to_test)
def _assert_chain_output_invalid_case(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
error_col_values = [row.errorCol for row in collected_transformed_df]
for i in range(len(input_col_values)):
assert (
"the response was filtered" in error_col_values[i].lower()
), f"error column value {error_col_values[i]} doesn't properly show that the request is Invalid"
def test_langchainTransformErrorHandling(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.
# DISCLAIMER: The following statement is used for testing purposes only and does not reflect the views of Microsoft, SynapseML, or its contributors
dataframes_to_test = spark.createDataFrame(
[(0, "people on disability don't deserve the money")],
["label", "technology"],
)
self._assert_chain_output_invalid_case(
self.langchainTransformer, dataframes_to_test
)
def test_save_load(self):
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
temp_dir = "tmp"
os.mkdir(temp_dir)
path = os.path.join(temp_dir, "langchainTransformer")
self.langchainTransformer.save(path)
loaded_transformer = LangchainTransformer.load(path)
self._assert_chain_output(loaded_transformer)
self._assert_chain_output(loaded_transformer, dataframes_to_test)
if __name__ == "__main__":