diff --git a/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py b/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py index cbf6b528b8..fffbf13cdc 100644 --- a/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py +++ b/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py @@ -44,6 +44,7 @@ from pyspark.ml.util import ( ) from pyspark.sql.functions import udf from typing import cast, Optional, TypeVar, Type +from synapse.ml.core.platform import running_on_synapse_internal OPENAI_API_VERSION = "2022-12-01" RL = TypeVar("RL", bound="MLReadable") @@ -125,6 +126,14 @@ class LangchainTransformer( self.subscriptionKey = Param(self, "subscriptionKey", "openai api key") self.url = Param(self, "url", "openai api base") self.apiVersion = Param(self, "apiVersion", "openai api version") + self.running_on_synapse_internal = running_on_synapse_internal() + if running_on_synapse_internal(): + from synapse.ml.fabric.service_discovery import get_fabric_env_config + + self._setDefault( + url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint + + "cognitive/openai" + ) kwargs = self._input_kwargs if subscriptionKey: kwargs["subscriptionKey"] = subscriptionKey @@ -196,10 +205,15 @@ class LangchainTransformer( def udfFunction(x): import openai - openai.api_type = "azure" - openai.api_key = self.getSubscriptionKey() - openai.api_base = self.getUrl() - openai.api_version = self.getApiVersion() + if self.running_on_synapse_internal and not self.isSet(self.url): + from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun + + OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None) + else: + openai.api_type = "azure" + openai.api_key = self.getSubscriptionKey() + openai.api_base = self.getUrl() + openai.api_version = self.getApiVersion() return self.getChain().run(x) outCol = self.getOutputCol()