зеркало из https://github.com/microsoft/SynapseML.git
feat: Support langchain transformer on fabric (#2036)
* support langchain transformer on fabric * avoid addtional param * format code --------- Co-authored-by: cruise <cruiseli@microsoft.com>
This commit is contained in:
Родитель
149c634005
Коммит
8f794c896d
|
@ -44,6 +44,7 @@ from pyspark.ml.util import (
|
|||
)
|
||||
from pyspark.sql.functions import udf
|
||||
from typing import cast, Optional, TypeVar, Type
|
||||
from synapse.ml.core.platform import running_on_synapse_internal
|
||||
|
||||
OPENAI_API_VERSION = "2022-12-01"
|
||||
RL = TypeVar("RL", bound="MLReadable")
|
||||
|
@ -125,6 +126,14 @@ class LangchainTransformer(
|
|||
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
|
||||
self.url = Param(self, "url", "openai api base")
|
||||
self.apiVersion = Param(self, "apiVersion", "openai api version")
|
||||
self.running_on_synapse_internal = running_on_synapse_internal()
|
||||
if running_on_synapse_internal():
|
||||
from synapse.ml.fabric.service_discovery import get_fabric_env_config
|
||||
|
||||
self._setDefault(
|
||||
url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint
|
||||
+ "cognitive/openai"
|
||||
)
|
||||
kwargs = self._input_kwargs
|
||||
if subscriptionKey:
|
||||
kwargs["subscriptionKey"] = subscriptionKey
|
||||
|
@ -196,10 +205,15 @@ class LangchainTransformer(
|
|||
def udfFunction(x):
|
||||
import openai
|
||||
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = self.getSubscriptionKey()
|
||||
openai.api_base = self.getUrl()
|
||||
openai.api_version = self.getApiVersion()
|
||||
if self.running_on_synapse_internal and not self.isSet(self.url):
|
||||
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun
|
||||
|
||||
OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
|
||||
else:
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = self.getSubscriptionKey()
|
||||
openai.api_base = self.getUrl()
|
||||
openai.api_version = self.getApiVersion()
|
||||
return self.getChain().run(x)
|
||||
|
||||
outCol = self.getOutputCol()
|
||||
|
|
Загрузка…
Ссылка в новой задаче