feat: Support langchain transformer on fabric (#2036)

* support langchain transformer on fabric

* avoid addtional param

* format code

---------

Co-authored-by: cruise <cruiseli@microsoft.com>
This commit is contained in:
CRUISE LI 2023-08-11 03:54:42 +08:00 коммит произвёл GitHub
Родитель 149c634005
Коммит 8f794c896d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 18 добавлений и 4 удалений

Просмотреть файл

@ -44,6 +44,7 @@ from pyspark.ml.util import (
)
from pyspark.sql.functions import udf
from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal
OPENAI_API_VERSION = "2022-12-01"
RL = TypeVar("RL", bound="MLReadable")
@ -125,6 +126,14 @@ class LangchainTransformer(
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.running_on_synapse_internal = running_on_synapse_internal()
if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config
self._setDefault(
url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint
+ "cognitive/openai"
)
kwargs = self._input_kwargs
if subscriptionKey:
kwargs["subscriptionKey"] = subscriptionKey
@ -196,10 +205,15 @@ class LangchainTransformer(
def udfFunction(x):
import openai
openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
if self.running_on_synapse_internal and not self.isSet(self.url):
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun
OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
else:
openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
return self.getChain().run(x)
outCol = self.getOutputCol()