fix: Add the error handling for Langchain transformer (#2137)

* added the error handling for Langchain transformer

* test fix

* Revert "test fix"

This reverts commit 71445fa13d.

* fix test errors

* black reformatted

* put the error messages in the error column instead

* addressed the comments on tests

* name the temporary column in a way to avoid collision

* Revert "name the temporary column in a way to avoid collision"

This reverts commit b81acf4f25.

* modified uid to use model uid
This commit is contained in:
sherylZhaoCode 2023-11-21 04:26:21 -08:00 коммит произвёл GitHub
Родитель f3ae1465f5
Коммит 23222c0840
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 108 добавлений и 5 удалений

Просмотреть файл

@ -27,7 +27,9 @@ Example Usage:
>>> loaded_transformer = LangchainTransformer.load(path) >>> loaded_transformer = LangchainTransformer.load(path)
""" """
import json import json
from os import error
from langchain.chains.loading import load_chain_from_config from langchain.chains.loading import load_chain_from_config
from pyspark import keyword_only from pyspark import keyword_only
from pyspark.ml import Transformer from pyspark.ml import Transformer
@ -42,7 +44,8 @@ from pyspark.ml.util import (
DefaultParamsReader, DefaultParamsReader,
DefaultParamsWriter, DefaultParamsWriter,
) )
from pyspark.sql.functions import udf from pyspark.sql.functions import udf, col
from pyspark.sql.types import StructType, StructField, StringType
from typing import cast, Optional, TypeVar, Type from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal from synapse.ml.core.platform import running_on_synapse_internal
@ -116,6 +119,7 @@ class LangchainTransformer(
subscriptionKey=None, subscriptionKey=None,
url=None, url=None,
apiVersion=OPENAI_API_VERSION, apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
): ):
super(LangchainTransformer, self).__init__() super(LangchainTransformer, self).__init__()
self.chain = Param( self.chain = Param(
@ -127,6 +131,7 @@ class LangchainTransformer(
self.url = Param(self, "url", "openai api base") self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version") self.apiVersion = Param(self, "apiVersion", "openai api version")
self.running_on_synapse_internal = running_on_synapse_internal() self.running_on_synapse_internal = running_on_synapse_internal()
self.errorCol = Param(self, "errorCol", "column for error")
if running_on_synapse_internal(): if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config from synapse.ml.fabric.service_discovery import get_fabric_env_config
@ -141,6 +146,9 @@ class LangchainTransformer(
kwargs["url"] = url kwargs["url"] = url
if apiVersion: if apiVersion:
kwargs["apiVersion"] = apiVersion kwargs["apiVersion"] = apiVersion
if errorCol:
kwargs["errorCol"] = errorCol
self.setParams(**kwargs) self.setParams(**kwargs)
@keyword_only @keyword_only
@ -152,6 +160,7 @@ class LangchainTransformer(
subscriptionKey=None, subscriptionKey=None,
url=None, url=None,
apiVersion=OPENAI_API_VERSION, apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
): ):
kwargs = self._input_kwargs kwargs = self._input_kwargs
return self._set(**kwargs) return self._set(**kwargs)
@ -195,13 +204,33 @@ class LangchainTransformer(
""" """
return self._set(outputCol=value) return self._set(outputCol=value)
def setErrorCol(self, value: str):
"""
Sets the value of :py:attr:`outputCol`.
"""
return self._set(errorCol=value)
def getErrorCol(self):
"""
Returns:
str: The name of the error column
"""
return self.getOrDefault(self.errorCol)
def _transform(self, dataset): def _transform(self, dataset):
""" """
do langchain transformation for the input column, do langchain transformation for the input column,
and save the transformed values to the output column. and save the transformed values to the output column.
""" """
# Define the schema for the output of the UDF
schema = StructType(
[
StructField("result", StringType(), True),
StructField("error_message", StringType(), True),
]
)
@udf @udf(schema)
def udfFunction(x): def udfFunction(x):
import openai import openai
@ -214,11 +243,38 @@ class LangchainTransformer(
openai.api_key = self.getSubscriptionKey() openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl() openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion() openai.api_version = self.getApiVersion()
return self.getChain().run(x)
error_messages = {
openai.error.Timeout: "OpenAI API request timed out, please retry your request after a brief wait and contact us if the issue persists: {}",
openai.error.APIError: "OpenAI API returned an API Error: {}",
openai.error.APIConnectionError: "OpenAI API request failed to connect, check your network settings, proxy configuration, SSL certificates, or firewall rules: {}",
openai.error.InvalidRequestError: "OpenAI API request was invalid: {}",
openai.error.AuthenticationError: "OpenAI API request was not authorized, please check your API key or token and make sure it is correct and active. You may need to generate a new one from your account dashboard: {}",
openai.error.PermissionError: "OpenAI API request was not permitted, make sure your API key has the appropriate permissions for the action or model accessed: {}",
openai.error.RateLimitError: "OpenAI API request exceeded rate limit: {}",
}
try:
result = self.getChain().run(x)
error_message = ""
except tuple(error_messages.keys()) as e:
result = ""
error_message = error_messages[type(e)].format(e)
return result, error_message
outCol = self.getOutputCol() outCol = self.getOutputCol()
errorCol = self.getErrorCol()
inCol = dataset[self.getInputCol()] inCol = dataset[self.getInputCol()]
return dataset.withColumn(outCol, udfFunction(inCol))
temp_col_name = "result_" + str(self.uid)
return (
dataset.withColumn(temp_col_name, udfFunction(inCol))
.withColumn(outCol, col(f"{temp_col_name}.result"))
.withColumn(errorCol, col(f"{temp_col_name}.error_message"))
.drop(temp_col_name)
)
def write(self) -> LangchainTransformerParamsWriter: def write(self) -> LangchainTransformerParamsWriter:
writer = LangchainTransformerParamsWriter(instance=self) writer = LangchainTransformerParamsWriter(instance=self)

Просмотреть файл

@ -95,13 +95,60 @@ class LangchainTransformTest(unittest.TestCase):
# column has the expected result. # column has the expected result.
self._assert_chain_output(self.langchainTransformer) self._assert_chain_output(self.langchainTransformer)
def _assert_chain_output(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
output_col_values = [row.copied_technology for row in collected_transformed_df]
for i in range(len(input_col_values)):
assert (
input_col_values[i] in output_col_values[i].lower()
), f"output column value {output_col_values[i]} doesn't contain input column value {input_col_values[i]}"
def test_langchainTransform(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
self._assert_chain_output(self.langchainTransformer, dataframes_to_test)
def _assert_chain_output_invalid_case(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
error_col_values = [row.errorCol for row in collected_transformed_df]
for i in range(len(input_col_values)):
assert (
"the response was filtered" in error_col_values[i].lower()
), f"error column value {error_col_values[i]} doesn't properly show that the request is Invalid"
def test_langchainTransformErrorHandling(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.
# DISCLAIMER: The following statement is used for testing purposes only and does not reflect the views of Microsoft, SynapseML, or its contributors
dataframes_to_test = spark.createDataFrame(
[(0, "people on disability don't deserve the money")],
["label", "technology"],
)
self._assert_chain_output_invalid_case(
self.langchainTransformer, dataframes_to_test
)
def test_save_load(self): def test_save_load(self):
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
temp_dir = "tmp" temp_dir = "tmp"
os.mkdir(temp_dir) os.mkdir(temp_dir)
path = os.path.join(temp_dir, "langchainTransformer") path = os.path.join(temp_dir, "langchainTransformer")
self.langchainTransformer.save(path) self.langchainTransformer.save(path)
loaded_transformer = LangchainTransformer.load(path) loaded_transformer = LangchainTransformer.load(path)
self._assert_chain_output(loaded_transformer) self._assert_chain_output(loaded_transformer, dataframes_to_test)
if __name__ == "__main__": if __name__ == "__main__":