зеркало из https://github.com/microsoft/SynapseML.git
fix: Add the error handling for Langchain transformer (#2137)
* added the error handling for Langchain transformer * test fix * Revert "test fix" This reverts commit71445fa13d
. * fix test errors * black reformatted * put the error messages in the error column instead * addressed the comments on tests * name the temporary column in a way to avoid collision * Revert "name the temporary column in a way to avoid collision" This reverts commitb81acf4f25
. * modified uid to use model uid
This commit is contained in:
Родитель
f3ae1465f5
Коммит
23222c0840
|
@ -27,7 +27,9 @@ Example Usage:
|
||||||
>>> loaded_transformer = LangchainTransformer.load(path)
|
>>> loaded_transformer = LangchainTransformer.load(path)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from os import error
|
||||||
from langchain.chains.loading import load_chain_from_config
|
from langchain.chains.loading import load_chain_from_config
|
||||||
from pyspark import keyword_only
|
from pyspark import keyword_only
|
||||||
from pyspark.ml import Transformer
|
from pyspark.ml import Transformer
|
||||||
|
@ -42,7 +44,8 @@ from pyspark.ml.util import (
|
||||||
DefaultParamsReader,
|
DefaultParamsReader,
|
||||||
DefaultParamsWriter,
|
DefaultParamsWriter,
|
||||||
)
|
)
|
||||||
from pyspark.sql.functions import udf
|
from pyspark.sql.functions import udf, col
|
||||||
|
from pyspark.sql.types import StructType, StructField, StringType
|
||||||
from typing import cast, Optional, TypeVar, Type
|
from typing import cast, Optional, TypeVar, Type
|
||||||
from synapse.ml.core.platform import running_on_synapse_internal
|
from synapse.ml.core.platform import running_on_synapse_internal
|
||||||
|
|
||||||
|
@ -116,6 +119,7 @@ class LangchainTransformer(
|
||||||
subscriptionKey=None,
|
subscriptionKey=None,
|
||||||
url=None,
|
url=None,
|
||||||
apiVersion=OPENAI_API_VERSION,
|
apiVersion=OPENAI_API_VERSION,
|
||||||
|
errorCol="errorCol",
|
||||||
):
|
):
|
||||||
super(LangchainTransformer, self).__init__()
|
super(LangchainTransformer, self).__init__()
|
||||||
self.chain = Param(
|
self.chain = Param(
|
||||||
|
@ -127,6 +131,7 @@ class LangchainTransformer(
|
||||||
self.url = Param(self, "url", "openai api base")
|
self.url = Param(self, "url", "openai api base")
|
||||||
self.apiVersion = Param(self, "apiVersion", "openai api version")
|
self.apiVersion = Param(self, "apiVersion", "openai api version")
|
||||||
self.running_on_synapse_internal = running_on_synapse_internal()
|
self.running_on_synapse_internal = running_on_synapse_internal()
|
||||||
|
self.errorCol = Param(self, "errorCol", "column for error")
|
||||||
if running_on_synapse_internal():
|
if running_on_synapse_internal():
|
||||||
from synapse.ml.fabric.service_discovery import get_fabric_env_config
|
from synapse.ml.fabric.service_discovery import get_fabric_env_config
|
||||||
|
|
||||||
|
@ -141,6 +146,9 @@ class LangchainTransformer(
|
||||||
kwargs["url"] = url
|
kwargs["url"] = url
|
||||||
if apiVersion:
|
if apiVersion:
|
||||||
kwargs["apiVersion"] = apiVersion
|
kwargs["apiVersion"] = apiVersion
|
||||||
|
if errorCol:
|
||||||
|
kwargs["errorCol"] = errorCol
|
||||||
|
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
|
@ -152,6 +160,7 @@ class LangchainTransformer(
|
||||||
subscriptionKey=None,
|
subscriptionKey=None,
|
||||||
url=None,
|
url=None,
|
||||||
apiVersion=OPENAI_API_VERSION,
|
apiVersion=OPENAI_API_VERSION,
|
||||||
|
errorCol="errorCol",
|
||||||
):
|
):
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
return self._set(**kwargs)
|
return self._set(**kwargs)
|
||||||
|
@ -195,13 +204,33 @@ class LangchainTransformer(
|
||||||
"""
|
"""
|
||||||
return self._set(outputCol=value)
|
return self._set(outputCol=value)
|
||||||
|
|
||||||
|
def setErrorCol(self, value: str):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`outputCol`.
|
||||||
|
"""
|
||||||
|
return self._set(errorCol=value)
|
||||||
|
|
||||||
|
def getErrorCol(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
str: The name of the error column
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.errorCol)
|
||||||
|
|
||||||
def _transform(self, dataset):
|
def _transform(self, dataset):
|
||||||
"""
|
"""
|
||||||
do langchain transformation for the input column,
|
do langchain transformation for the input column,
|
||||||
and save the transformed values to the output column.
|
and save the transformed values to the output column.
|
||||||
"""
|
"""
|
||||||
|
# Define the schema for the output of the UDF
|
||||||
|
schema = StructType(
|
||||||
|
[
|
||||||
|
StructField("result", StringType(), True),
|
||||||
|
StructField("error_message", StringType(), True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@udf
|
@udf(schema)
|
||||||
def udfFunction(x):
|
def udfFunction(x):
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -214,11 +243,38 @@ class LangchainTransformer(
|
||||||
openai.api_key = self.getSubscriptionKey()
|
openai.api_key = self.getSubscriptionKey()
|
||||||
openai.api_base = self.getUrl()
|
openai.api_base = self.getUrl()
|
||||||
openai.api_version = self.getApiVersion()
|
openai.api_version = self.getApiVersion()
|
||||||
return self.getChain().run(x)
|
|
||||||
|
error_messages = {
|
||||||
|
openai.error.Timeout: "OpenAI API request timed out, please retry your request after a brief wait and contact us if the issue persists: {}",
|
||||||
|
openai.error.APIError: "OpenAI API returned an API Error: {}",
|
||||||
|
openai.error.APIConnectionError: "OpenAI API request failed to connect, check your network settings, proxy configuration, SSL certificates, or firewall rules: {}",
|
||||||
|
openai.error.InvalidRequestError: "OpenAI API request was invalid: {}",
|
||||||
|
openai.error.AuthenticationError: "OpenAI API request was not authorized, please check your API key or token and make sure it is correct and active. You may need to generate a new one from your account dashboard: {}",
|
||||||
|
openai.error.PermissionError: "OpenAI API request was not permitted, make sure your API key has the appropriate permissions for the action or model accessed: {}",
|
||||||
|
openai.error.RateLimitError: "OpenAI API request exceeded rate limit: {}",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.getChain().run(x)
|
||||||
|
error_message = ""
|
||||||
|
except tuple(error_messages.keys()) as e:
|
||||||
|
result = ""
|
||||||
|
error_message = error_messages[type(e)].format(e)
|
||||||
|
|
||||||
|
return result, error_message
|
||||||
|
|
||||||
outCol = self.getOutputCol()
|
outCol = self.getOutputCol()
|
||||||
|
errorCol = self.getErrorCol()
|
||||||
inCol = dataset[self.getInputCol()]
|
inCol = dataset[self.getInputCol()]
|
||||||
return dataset.withColumn(outCol, udfFunction(inCol))
|
|
||||||
|
temp_col_name = "result_" + str(self.uid)
|
||||||
|
|
||||||
|
return (
|
||||||
|
dataset.withColumn(temp_col_name, udfFunction(inCol))
|
||||||
|
.withColumn(outCol, col(f"{temp_col_name}.result"))
|
||||||
|
.withColumn(errorCol, col(f"{temp_col_name}.error_message"))
|
||||||
|
.drop(temp_col_name)
|
||||||
|
)
|
||||||
|
|
||||||
def write(self) -> LangchainTransformerParamsWriter:
|
def write(self) -> LangchainTransformerParamsWriter:
|
||||||
writer = LangchainTransformerParamsWriter(instance=self)
|
writer = LangchainTransformerParamsWriter(instance=self)
|
||||||
|
|
|
@ -95,13 +95,60 @@ class LangchainTransformTest(unittest.TestCase):
|
||||||
# column has the expected result.
|
# column has the expected result.
|
||||||
self._assert_chain_output(self.langchainTransformer)
|
self._assert_chain_output(self.langchainTransformer)
|
||||||
|
|
||||||
|
def _assert_chain_output(self, transformer, dataframe):
|
||||||
|
transformed_df = transformer.transform(dataframe)
|
||||||
|
collected_transformed_df = transformed_df.collect()
|
||||||
|
input_col_values = [row.technology for row in collected_transformed_df]
|
||||||
|
output_col_values = [row.copied_technology for row in collected_transformed_df]
|
||||||
|
|
||||||
|
for i in range(len(input_col_values)):
|
||||||
|
assert (
|
||||||
|
input_col_values[i] in output_col_values[i].lower()
|
||||||
|
), f"output column value {output_col_values[i]} doesn't contain input column value {input_col_values[i]}"
|
||||||
|
|
||||||
|
def test_langchainTransform(self):
|
||||||
|
# construct langchain transformer using the chain defined above. And test if the generated
|
||||||
|
# column has the expected result.
|
||||||
|
dataframes_to_test = spark.createDataFrame(
|
||||||
|
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
|
||||||
|
)
|
||||||
|
self._assert_chain_output(self.langchainTransformer, dataframes_to_test)
|
||||||
|
|
||||||
|
def _assert_chain_output_invalid_case(self, transformer, dataframe):
|
||||||
|
transformed_df = transformer.transform(dataframe)
|
||||||
|
collected_transformed_df = transformed_df.collect()
|
||||||
|
input_col_values = [row.technology for row in collected_transformed_df]
|
||||||
|
error_col_values = [row.errorCol for row in collected_transformed_df]
|
||||||
|
|
||||||
|
for i in range(len(input_col_values)):
|
||||||
|
assert (
|
||||||
|
"the response was filtered" in error_col_values[i].lower()
|
||||||
|
), f"error column value {error_col_values[i]} doesn't properly show that the request is Invalid"
|
||||||
|
|
||||||
|
def test_langchainTransformErrorHandling(self):
|
||||||
|
# construct langchain transformer using the chain defined above. And test if the generated
|
||||||
|
# column has the expected result.
|
||||||
|
|
||||||
|
# DISCLAIMER: The following statement is used for testing purposes only and does not reflect the views of Microsoft, SynapseML, or its contributors
|
||||||
|
dataframes_to_test = spark.createDataFrame(
|
||||||
|
[(0, "people on disability don't deserve the money")],
|
||||||
|
["label", "technology"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_chain_output_invalid_case(
|
||||||
|
self.langchainTransformer, dataframes_to_test
|
||||||
|
)
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
|
dataframes_to_test = spark.createDataFrame(
|
||||||
|
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
|
||||||
|
)
|
||||||
temp_dir = "tmp"
|
temp_dir = "tmp"
|
||||||
os.mkdir(temp_dir)
|
os.mkdir(temp_dir)
|
||||||
path = os.path.join(temp_dir, "langchainTransformer")
|
path = os.path.join(temp_dir, "langchainTransformer")
|
||||||
self.langchainTransformer.save(path)
|
self.langchainTransformer.save(path)
|
||||||
loaded_transformer = LangchainTransformer.load(path)
|
loaded_transformer = LangchainTransformer.load(path)
|
||||||
self._assert_chain_output(loaded_transformer)
|
self._assert_chain_output(loaded_transformer, dataframes_to_test)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Загрузка…
Ссылка в новой задаче