[azopenai] A rather large, but needed, cleanup of the tests. (#22707)
It was getting difficult to tell what was and wasn't covered and also correlate tests to models, which can sometimes come with differing behavior. So in this PR I: - Moved things that don't need to be env vars (ie: non-secrets) to just be constants in the code - this includes model names and generic server identifiers for regions that we use to coordinate new feature development. - Consolidated all of the setting of endpoint and models into one spot to make it simpler to double-check. - Consolidated tests that tested the same thing into sub-tests with OpenAI or AzureOpenAI names. - If a function was only called by one test moved it into the test as an anonymous func Also, I added in a test for logit_probs and logprobs/toplogprobs.
This commit is contained in:
Родитель
a51db25793
Коммит
a0f9b026ec
|
@ -2,5 +2,5 @@
|
||||||
"AssetsRepo": "Azure/azure-sdk-assets",
|
"AssetsRepo": "Azure/azure-sdk-assets",
|
||||||
"AssetsRepoPrefixPath": "go",
|
"AssetsRepoPrefixPath": "go",
|
||||||
"TagPrefix": "go/ai/azopenai",
|
"TagPrefix": "go/ai/azopenai",
|
||||||
"Tag": "go/ai/azopenai_a33cdad878"
|
"Tag": "go/ai/azopenai_a56e3e9e32"
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,64 +19,14 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClient_GetAudioTranscription_AzureOpenAI(t *testing.T) {
|
func TestClient_GetAudioTranscription(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
runTranscriptionTests(t, client, azureOpenAI.Whisper.Model, true)
|
client := newTestClient(t, epm.Endpoint)
|
||||||
}
|
model := epm.Model
|
||||||
|
|
||||||
func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
runTranscriptionTests(t, client, openAI.Whisper.Model, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetAudioTranslation_AzureOpenAI(t *testing.T) {
|
|
||||||
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
|
|
||||||
runTranslationTests(t, client, azureOpenAI.Whisper.Model, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetAudioTranslation_OpenAI(t *testing.T) {
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
runTranslationTests(t, client, openAI.Whisper.Model, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetAudioSpeech(t *testing.T) {
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
|
|
||||||
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
|
|
||||||
Input: to.Ptr("i am a computer"),
|
|
||||||
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
|
|
||||||
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
|
|
||||||
DeploymentName: to.Ptr("tts-1"),
|
|
||||||
}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
audioBytes, err := io.ReadAll(audioResp.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NotEmpty(t, audioBytes)
|
|
||||||
require.Equal(t, "fLaC", string(audioBytes[0:4]))
|
|
||||||
|
|
||||||
// now send _it_ back through the transcription API and see if we can get something useful.
|
|
||||||
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
|
|
||||||
Filename: to.Ptr("test.flac"),
|
|
||||||
File: audioBytes,
|
|
||||||
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
|
|
||||||
DeploymentName: &openAI.Whisper.Model,
|
|
||||||
Temperature: to.Ptr[float32](0.0),
|
|
||||||
}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NotZero(t, *transcriptionResp.Duration)
|
|
||||||
|
|
||||||
// it occasionally comes back with different punctuation or makes a complete sentence but
|
|
||||||
// the major words always come through.
|
|
||||||
require.Contains(t, *transcriptionResp.Text, "computer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
|
|
||||||
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
||||||
// The bulk of the logic will test against OpenAI anyways.
|
// The bulk of the logic will test against OpenAI anyways.
|
||||||
if isAzure {
|
if epm.Endpoint.Azure {
|
||||||
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
||||||
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
||||||
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
|
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
|
||||||
|
@ -167,10 +117,23 @@ func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTranslationTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.Whisper)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.Whisper)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetAudioTranslation(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
model := epm.Model
|
||||||
|
|
||||||
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
||||||
// The bulk of the logic will test against OpenAI anyways.
|
// The bulk of the logic will test against OpenAI anyways.
|
||||||
if isAzure {
|
if epm.Endpoint.Azure {
|
||||||
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
||||||
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
||||||
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
|
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
|
||||||
|
@ -262,6 +225,49 @@ func runTranslationTests(t *testing.T, client *azopenai.Client, model string, is
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.Whisper)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.Whisper)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetAudioSpeech(t *testing.T) {
|
||||||
|
client := newTestClient(t, openAI.Speech.Endpoint)
|
||||||
|
|
||||||
|
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
|
||||||
|
Input: to.Ptr("i am a computer"),
|
||||||
|
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
|
||||||
|
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
|
||||||
|
DeploymentName: to.Ptr("tts-1"),
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
audioBytes, err := io.ReadAll(audioResp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEmpty(t, audioBytes)
|
||||||
|
require.Equal(t, "fLaC", string(audioBytes[0:4]))
|
||||||
|
|
||||||
|
// now send _it_ back through the transcription API and see if we can get something useful.
|
||||||
|
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
|
||||||
|
Filename: to.Ptr("test.flac"),
|
||||||
|
File: audioBytes,
|
||||||
|
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
|
||||||
|
DeploymentName: &openAI.Whisper.Model,
|
||||||
|
Temperature: to.Ptr[float32](0.0),
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotZero(t, *transcriptionResp.Duration)
|
||||||
|
|
||||||
|
// it occasionally comes back with different punctuation or makes a complete sentence but
|
||||||
|
// the major words always come through.
|
||||||
|
require.Contains(t, *transcriptionResp.Text, "computer")
|
||||||
|
}
|
||||||
|
|
||||||
func newTranscriptionOptions(format azopenai.AudioTranscriptionFormat, model string, path string) azopenai.AudioTranscriptionOptions {
|
func newTranscriptionOptions(format azopenai.AudioTranscriptionFormat, model string, path string) azopenai.AudioTranscriptionOptions {
|
||||||
audioBytes, err := os.ReadFile(path)
|
audioBytes, err := os.ReadFile(path)
|
||||||
|
|
||||||
|
|
|
@ -42,34 +42,7 @@ var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
|
||||||
var expectedRole = azopenai.ChatRoleAssistant
|
var expectedRole = azopenai.ChatRoleAssistant
|
||||||
|
|
||||||
func TestClient_GetChatCompletions(t *testing.T) {
|
func TestClient_GetChatCompletions(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedModel string, checkRAI bool) {
|
||||||
testGetChatCompletions(t, client, azureOpenAI.ChatCompletionsRAI.Model, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetChatCompletionsStream(t *testing.T) {
|
|
||||||
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
|
||||||
testGetChatCompletionsStream(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
testGetChatCompletions(t, chatClient, openAI.ChatCompletions, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
testGetChatCompletionsStream(t, chatClient, openAI.ChatCompletions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment string, checkRAI bool) {
|
|
||||||
expected := azopenai.ChatCompletions{
|
expected := azopenai.ChatCompletions{
|
||||||
Choices: []azopenai.ChatChoice{
|
Choices: []azopenai.ChatChoice{
|
||||||
{
|
{
|
||||||
|
@ -88,9 +61,7 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
|
||||||
PromptTokens: to.Ptr(int32(42)),
|
PromptTokens: to.Ptr(int32(42)),
|
||||||
TotalTokens: to.Ptr(int32(71)),
|
TotalTokens: to.Ptr(int32(71)),
|
||||||
},
|
},
|
||||||
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
|
Model: &returnedModel,
|
||||||
// by convention) but if this fails because they _don't_ match we can just adjust the test.
|
|
||||||
Model: &deployment,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
||||||
|
@ -112,10 +83,136 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
|
||||||
expected.ID = resp.ID
|
expected.ID = resp.ID
|
||||||
expected.Created = resp.Created
|
expected.Created = resp.Created
|
||||||
|
|
||||||
|
t.Logf("isAzure: %t, deployment: %s, returnedModel: %s", checkRAI, deployment, *resp.ChatCompletions.Model)
|
||||||
require.Equal(t, expected, resp.ChatCompletions)
|
require.Equal(t, expected, resp.ChatCompletions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deployment string) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
client := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
||||||
|
testFn(t, client, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4", true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI.DefaultAzureCredential", func(t *testing.T) {
|
||||||
|
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||||
|
t.Skipf("Not running this test in playback (for now)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Getenv("USE_TOKEN_CREDS") != "true" {
|
||||||
|
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
recordingTransporter := newRecordingTransporter(t)
|
||||||
|
|
||||||
|
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
|
||||||
|
ClientOptions: policy.ClientOptions{
|
||||||
|
Transport: recordingTransporter,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
chatClient, err := azopenai.NewClient(azureOpenAI.ChatCompletions.Endpoint.URL, dac, &azopenai.ClientOptions{
|
||||||
|
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testFn(t, chatClient, azureOpenAI.ChatCompletions.Model, "gpt-4", true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||||
|
testFn(t, chatClient, openAI.ChatCompletions.Model, "gpt-4-0613", false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetChatCompletions_LogProbs(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
|
opts := azopenai.ChatCompletionsOptions{
|
||||||
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
&azopenai.ChatRequestUserMessage{
|
||||||
|
Content: azopenai.NewChatRequestUserMessageContent("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MaxTokens: to.Ptr(int32(1024)),
|
||||||
|
Temperature: to.Ptr(float32(0.0)),
|
||||||
|
DeploymentName: &epm.Model,
|
||||||
|
LogProbs: to.Ptr(true),
|
||||||
|
TopLogProbs: to.Ptr(int32(5)),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, choice := range resp.Choices {
|
||||||
|
require.NotEmpty(t, choice.LogProbs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetChatCompletions_LogitBias(t *testing.T) {
|
||||||
|
// you can use LogitBias to constrain the answer to NOT contain
|
||||||
|
// certain tokens. More or less following the technique in this OpenAI article:
|
||||||
|
// https://help.openai.com/en/articles/5247780-using-logit-bias-to-alter-token-probability-with-the-openai-api
|
||||||
|
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
|
opts := azopenai.ChatCompletionsOptions{
|
||||||
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
&azopenai.ChatRequestUserMessage{
|
||||||
|
Content: azopenai.NewChatRequestUserMessageContent("Briefly, what are some common roles for people at a circus, names only, one per line?"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MaxTokens: to.Ptr(int32(200)),
|
||||||
|
Temperature: to.Ptr(float32(0.0)),
|
||||||
|
DeploymentName: &epm.Model,
|
||||||
|
LogitBias: map[string]*int32{
|
||||||
|
// you can calculate these tokens using OpenAI's online tool:
|
||||||
|
// https://platform.openai.com/tokenizer?view=bpe
|
||||||
|
// These token IDs are all variations of "Clown", which I want to exclude from the response.
|
||||||
|
"25": to.Ptr(int32(-100)),
|
||||||
|
"220": to.Ptr(int32(-100)),
|
||||||
|
"1206": to.Ptr(int32(-100)),
|
||||||
|
"2493": to.Ptr(int32(-100)),
|
||||||
|
"5176": to.Ptr(int32(-100)),
|
||||||
|
"43456": to.Ptr(int32(-100)),
|
||||||
|
"99423": to.Ptr(int32(-100)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, choice := range resp.Choices {
|
||||||
|
if choice.Message == nil || choice.Message.Content == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotContains(t, *choice.Message.Content, "clown")
|
||||||
|
require.NotContains(t, *choice.Message.Content, "Clown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetChatCompletionsStream(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedDeployment string) {
|
||||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
||||||
|
|
||||||
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
|
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
@ -141,7 +238,7 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
|
||||||
|
|
||||||
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
|
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
|
||||||
// by convention) but if this fails because they _don't_ match we can just adjust the test.
|
// by convention) but if this fails because they _don't_ match we can just adjust the test.
|
||||||
if deployment == *completion.Model {
|
if returnedDeployment == *completion.Model {
|
||||||
modelWasReturned = true
|
modelWasReturned = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,34 +275,19 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
|
||||||
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
|
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
||||||
t.Skipf("Not running this test in playback (for now)")
|
testFn(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4")
|
||||||
}
|
|
||||||
|
|
||||||
if os.Getenv("USE_TOKEN_CREDS") != "true" {
|
|
||||||
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
recordingTransporter := newRecordingTransporter(t)
|
|
||||||
|
|
||||||
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
|
|
||||||
ClientOptions: policy.ClientOptions{
|
|
||||||
Transport: recordingTransporter,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint.URL, dac, &azopenai.ClientOptions{
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
|
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||||
|
testFn(t, chatClient, openAI.ChatCompletions.Model, openAI.ChatCompletions.Model)
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
testGetChatCompletions(t, chatClient, azureOpenAI.ChatCompletions, true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
|
func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
client := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
|
||||||
|
|
||||||
_, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
_, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
||||||
Messages: []azopenai.ChatRequestMessageClassification{
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
@ -230,14 +312,14 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
client := newBogusAzureOpenAIClient(t)
|
client := newBogusAzureOpenAIClient(t)
|
||||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions), nil)
|
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions.Model), nil)
|
||||||
require.Empty(t, streamResp)
|
require.Empty(t, streamResp)
|
||||||
assertResponseIsError(t, err)
|
assertResponseIsError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
client := newBogusOpenAIClient(t)
|
client := newBogusOpenAIClient(t)
|
||||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions), nil)
|
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions.Model), nil)
|
||||||
require.Empty(t, streamResp)
|
require.Empty(t, streamResp)
|
||||||
assertResponseIsError(t, err)
|
assertResponseIsError(t, err)
|
||||||
})
|
})
|
||||||
|
@ -276,15 +358,15 @@ func TestClient_GetChatCompletions_Vision(t *testing.T) {
|
||||||
t.Logf(*resp.Choices[0].Message.Content)
|
t.Logf(*resp.Choices[0].Message.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
testFn(t, chatClient, openAI.Vision.Model)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
|
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
|
||||||
testFn(t, chatClient, azureOpenAI.Vision.Model)
|
testFn(t, chatClient, azureOpenAI.Vision.Model)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
chatClient := newTestClient(t, openAI.Vision.Endpoint)
|
||||||
|
testFn(t, chatClient, openAI.Vision.Model)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
|
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
|
||||||
|
@ -313,13 +395,13 @@ func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
|
||||||
require.NotEmpty(t, v)
|
require.NotEmpty(t, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
chatClient := newOpenAIClientForTest(t)
|
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
|
||||||
testFn(t, chatClient, "gpt-3.5-turbo-1106")
|
testFn(t, chatClient, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Model)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
|
chatClient := newTestClient(t, openAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
|
||||||
testFn(t, chatClient, "gpt-4-1106-preview")
|
testFn(t, chatClient, openAI.ChatCompletionsWithJSONResponseFormat.Model)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,32 +15,15 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
|
func TestClient_GetCompletions(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
testGetCompletions(t, client, true)
|
client := newTestClient(t, epm.Endpoint)
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetCompletions_OpenAI(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
testGetCompletions(t, client, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
|
||||||
deploymentID := openAI.Completions
|
|
||||||
|
|
||||||
if isAzure {
|
|
||||||
deploymentID = azureOpenAI.Completions
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
||||||
Prompt: []string{"What is Azure OpenAI?"},
|
Prompt: []string{"What is Azure OpenAI?"},
|
||||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||||
Temperature: to.Ptr(float32(0.0)),
|
Temperature: to.Ptr(float32(0.0)),
|
||||||
DeploymentName: &deploymentID,
|
DeploymentName: &epm.Model,
|
||||||
}, nil)
|
}, nil)
|
||||||
skipNowIfThrottled(t, err)
|
skipNowIfThrottled(t, err)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -55,7 +38,7 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
||||||
|
|
||||||
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
|
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
|
||||||
|
|
||||||
if isAzure {
|
if epm.Endpoint.Azure {
|
||||||
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
|
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
|
||||||
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
|
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
|
||||||
{
|
{
|
||||||
|
@ -65,3 +48,12 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.Completions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.Completions)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
|
func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
client := newTestClient(t, azureOpenAI.Embeddings.Endpoint)
|
||||||
|
|
||||||
_, err := client.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
|
_, err := client.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
|
||||||
DeploymentName: to.Ptr("thisdoesntexist"),
|
DeploymentName: to.Ptr("thisdoesntexist"),
|
||||||
|
@ -29,80 +29,10 @@ func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
|
||||||
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
|
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
testGetEmbeddings(t, client, openAI.Embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetEmbeddings(t *testing.T) {
|
func TestClient_GetEmbeddings(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
|
client := newTestClient(t, epm.Endpoint)
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
|
|
||||||
testFn := func(t *testing.T, tv testVars, dimension int32) {
|
|
||||||
client := newTestClient(t, tv.Endpoint)
|
|
||||||
|
|
||||||
arg := azopenai.EmbeddingsOptions{
|
|
||||||
Input: []string{"hello"},
|
|
||||||
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
|
|
||||||
DeploymentName: &tv.TextEmbedding3Small,
|
|
||||||
}
|
|
||||||
|
|
||||||
if dimension > 0 {
|
|
||||||
arg.Dimensions = &dimension
|
|
||||||
}
|
|
||||||
|
|
||||||
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NotEmpty(t, base64Resp.Data)
|
|
||||||
require.Empty(t, base64Resp.Data[0].Embedding)
|
|
||||||
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
|
|
||||||
|
|
||||||
// sanity checks - we deserialized everything and didn't create anything impossible.
|
|
||||||
for _, v := range embeddings {
|
|
||||||
require.True(t, v <= 1.0 && v >= -1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
arg2 := azopenai.EmbeddingsOptions{
|
|
||||||
Input: []string{"hello"},
|
|
||||||
DeploymentName: &tv.TextEmbedding3Small,
|
|
||||||
}
|
|
||||||
|
|
||||||
if dimension > 0 {
|
|
||||||
arg2.Dimensions = &dimension
|
|
||||||
}
|
|
||||||
|
|
||||||
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NotEmpty(t, floatResp.Data)
|
|
||||||
require.NotEmpty(t, floatResp.Data[0].Embedding)
|
|
||||||
|
|
||||||
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
|
|
||||||
|
|
||||||
// This works "most of the time" but it's non-deterministic since two separate calls don't always
|
|
||||||
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
|
|
||||||
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dim := range []int32{0, 1, 10, 100} {
|
|
||||||
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
|
||||||
testFn(t, azureOpenAI, dim)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
|
||||||
testFn(t, openAI, dim)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
|
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
deploymentID string
|
deploymentID string
|
||||||
|
@ -122,10 +52,10 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
|
||||||
client: client,
|
client: client,
|
||||||
args: args{
|
args: args{
|
||||||
ctx: context.TODO(),
|
ctx: context.TODO(),
|
||||||
deploymentID: modelOrDeploymentID,
|
deploymentID: epm.Model,
|
||||||
body: azopenai.EmbeddingsOptions{
|
body: azopenai.EmbeddingsOptions{
|
||||||
Input: []string{"\"Your text string goes here\""},
|
Input: []string{"\"Your text string goes here\""},
|
||||||
DeploymentName: &modelOrDeploymentID,
|
DeploymentName: &epm.Model,
|
||||||
},
|
},
|
||||||
options: nil,
|
options: nil,
|
||||||
},
|
},
|
||||||
|
@ -151,6 +81,74 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.Embeddings)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.Embeddings)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel, dimension int32) {
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
|
arg := azopenai.EmbeddingsOptions{
|
||||||
|
Input: []string{"hello"},
|
||||||
|
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
|
||||||
|
DeploymentName: &epm.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dimension > 0 {
|
||||||
|
arg.Dimensions = &dimension
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEmpty(t, base64Resp.Data)
|
||||||
|
require.Empty(t, base64Resp.Data[0].Embedding)
|
||||||
|
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
|
||||||
|
|
||||||
|
// sanity checks - we deserialized everything and didn't create anything impossible.
|
||||||
|
for _, v := range embeddings {
|
||||||
|
require.True(t, v <= 1.0 && v >= -1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg2 := azopenai.EmbeddingsOptions{
|
||||||
|
Input: []string{"hello"},
|
||||||
|
DeploymentName: &epm.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dimension > 0 {
|
||||||
|
arg2.Dimensions = &dimension
|
||||||
|
}
|
||||||
|
|
||||||
|
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEmpty(t, floatResp.Data)
|
||||||
|
require.NotEmpty(t, floatResp.Data[0].Embedding)
|
||||||
|
|
||||||
|
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
|
||||||
|
|
||||||
|
// This works "most of the time" but it's non-deterministic since two separate calls don't always
|
||||||
|
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
|
||||||
|
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dim := range []int32{0, 1, 10, 100} {
|
||||||
|
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.TextEmbedding3Small, dim)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||||
|
testFn(t, openAI.TextEmbedding3Small, dim)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func deserializeBase64Embeddings(t *testing.T, ei azopenai.EmbeddingItem) []float32 {
|
func deserializeBase64Embeddings(t *testing.T, ei azopenai.EmbeddingItem) []float32 {
|
||||||
destBytes, err := base64.StdEncoding.DecodeString(ei.EmbeddingBase64)
|
destBytes, err := base64.StdEncoding.DecodeString(ei.EmbeddingBase64)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -6,6 +6,7 @@ package azopenai_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
|
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
|
||||||
|
@ -28,130 +29,7 @@ type ParamProperty struct {
|
||||||
func TestGetChatCompletions_usingFunctions(t *testing.T) {
|
func TestGetChatCompletions_usingFunctions(t *testing.T) {
|
||||||
// https://platform.openai.com/docs/guides/gpt/function-calling
|
// https://platform.openai.com/docs/guides/gpt/function-calling
|
||||||
|
|
||||||
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
|
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
|
||||||
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
|
|
||||||
)
|
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
|
|
||||||
testData := []struct {
|
|
||||||
Model string
|
|
||||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
|
||||||
}{
|
|
||||||
// all of these variants use the tool provided - auto just also works since we did provide
|
|
||||||
// a tool reference and ask a question to use it.
|
|
||||||
{Model: openAI.ChatCompletions, ToolChoice: nil},
|
|
||||||
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
|
||||||
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, td := range testData {
|
|
||||||
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
|
||||||
|
|
||||||
testData := []struct {
|
|
||||||
Model string
|
|
||||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
|
||||||
}{
|
|
||||||
// all of these variants use the tool provided - auto just also works since we did provide
|
|
||||||
// a tool reference and ask a question to use it.
|
|
||||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
|
|
||||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
|
||||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, td := range testData {
|
|
||||||
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletionsLegacyFunctions)
|
|
||||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletions)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
|
||||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, azureOpenAI.ChatCompletionsLegacyFunctions)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
|
|
||||||
// https://platform.openai.com/docs/guides/gpt/function-calling
|
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
|
||||||
testChatCompletionsFunctionsStreaming(t, chatClient, openAI)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
|
||||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
|
||||||
testChatCompletionsFunctionsStreaming(t, chatClient, azureOpenAI)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Client, deploymentName string) {
|
|
||||||
body := azopenai.ChatCompletionsOptions{
|
|
||||||
DeploymentName: &deploymentName,
|
|
||||||
Messages: []azopenai.ChatRequestMessageClassification{
|
|
||||||
&azopenai.ChatRequestAssistantMessage{
|
|
||||||
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
|
|
||||||
Value: to.Ptr("auto"),
|
|
||||||
},
|
|
||||||
Functions: []azopenai.FunctionDefinition{
|
|
||||||
{
|
|
||||||
Name: to.Ptr("get_current_weather"),
|
|
||||||
Description: to.Ptr("Get the current weather in a given location"),
|
|
||||||
Parameters: Params{
|
|
||||||
Required: []string{"location"},
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]ParamProperty{
|
|
||||||
"location": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "The city and state, e.g. San Francisco, CA",
|
|
||||||
},
|
|
||||||
"unit": {
|
|
||||||
Type: "string",
|
|
||||||
Enum: []string{"celsius", "fahrenheit"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Temperature: to.Ptr[float32](0.0),
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.GetChatCompletions(context.Background(), body, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
|
|
||||||
|
|
||||||
require.Equal(t, "get_current_weather", *funcCall.Name)
|
|
||||||
|
|
||||||
type location struct {
|
|
||||||
Location string `json:"location"`
|
|
||||||
Unit string `json:"unit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var funcParams *location
|
|
||||||
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
|
|
||||||
body := azopenai.ChatCompletionsOptions{
|
body := azopenai.ChatCompletionsOptions{
|
||||||
DeploymentName: &deploymentName,
|
DeploymentName: &deploymentName,
|
||||||
Messages: []azopenai.ChatRequestMessageClassification{
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
@ -204,9 +82,124 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep
|
||||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Client, tv testVars) {
|
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
|
||||||
|
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
chatClient := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
|
||||||
|
|
||||||
|
testData := []struct {
|
||||||
|
Model string
|
||||||
|
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||||
|
}{
|
||||||
|
// all of these variants use the tool provided - auto just also works since we did provide
|
||||||
|
// a tool reference and ask a question to use it.
|
||||||
|
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: nil},
|
||||||
|
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||||
|
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: useSpecificTool},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, td := range testData {
|
||||||
|
testFn(t, chatClient, td.Model, td.ToolChoice)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testData := []struct {
|
||||||
|
EPM endpointWithModel
|
||||||
|
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||||
|
}{
|
||||||
|
// all of these variants use the tool provided - auto just also works since we did provide
|
||||||
|
// a tool reference and ask a question to use it.
|
||||||
|
{EPM: openAI.ChatCompletions, ToolChoice: nil},
|
||||||
|
{EPM: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||||
|
{EPM: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, td := range testData {
|
||||||
|
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||||
|
chatClient := newTestClient(t, td.EPM.Endpoint)
|
||||||
|
testFn(t, chatClient, td.EPM.Model, td.ToolChoice)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
body := azopenai.ChatCompletionsOptions{
|
body := azopenai.ChatCompletionsOptions{
|
||||||
DeploymentName: &tv.ChatCompletions,
|
DeploymentName: &epm.Model,
|
||||||
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
&azopenai.ChatRequestAssistantMessage{
|
||||||
|
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
|
||||||
|
Value: to.Ptr("auto"),
|
||||||
|
},
|
||||||
|
Functions: []azopenai.FunctionDefinition{
|
||||||
|
{
|
||||||
|
Name: to.Ptr("get_current_weather"),
|
||||||
|
Description: to.Ptr("Get the current weather in a given location"),
|
||||||
|
Parameters: Params{
|
||||||
|
Required: []string{"location"},
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]ParamProperty{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: to.Ptr[float32](0.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.GetChatCompletions(context.Background(), body, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
|
||||||
|
|
||||||
|
require.Equal(t, "get_current_weather", *funcCall.Name)
|
||||||
|
|
||||||
|
type location struct {
|
||||||
|
Location string `json:"location"`
|
||||||
|
Unit string `json:"unit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var funcParams *location
|
||||||
|
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.ChatCompletionsLegacyFunctions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI.LegacyFunctions", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.ChatCompletionsLegacyFunctions)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
|
||||||
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
|
body := azopenai.ChatCompletionsOptions{
|
||||||
|
DeploymentName: &epm.Model,
|
||||||
Messages: []azopenai.ChatRequestMessageClassification{
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
&azopenai.ChatRequestAssistantMessage{
|
&azopenai.ChatRequestAssistantMessage{
|
||||||
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
||||||
|
@ -237,6 +230,8 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
|
||||||
Temperature: to.Ptr[float32](0.0),
|
Temperature: to.Ptr[float32](0.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chatClient := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
|
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, resp)
|
require.NotEmpty(t, resp)
|
||||||
|
@ -293,3 +288,14 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
|
||||||
|
|
||||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/guides/gpt/function-calling
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.ChatCompletions)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -22,13 +22,13 @@ import (
|
||||||
func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) {
|
func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) {
|
||||||
// Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered
|
// Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered
|
||||||
// https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered
|
// https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered
|
||||||
client := newAzureOpenAIClientForTest(t, azureOpenAI)
|
client := newTestClient(t, azureOpenAI.Completions.Endpoint)
|
||||||
|
|
||||||
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
||||||
Prompt: []string{"How do I rob a bank with violence?"},
|
Prompt: []string{"How do I rob a bank with violence?"},
|
||||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||||
Temperature: to.Ptr(float32(0.0)),
|
Temperature: to.Ptr(float32(0.0)),
|
||||||
DeploymentName: &azureOpenAI.Completions,
|
DeploymentName: &azureOpenAI.Completions.Model,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
require.Empty(t, resp)
|
require.Empty(t, resp)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -26,11 +27,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
azureOpenAI testVars
|
|
||||||
openAI testVars
|
|
||||||
)
|
|
||||||
|
|
||||||
type endpoint struct {
|
type endpoint struct {
|
||||||
URL string
|
URL string
|
||||||
APIKey string
|
APIKey string
|
||||||
|
@ -38,22 +34,19 @@ type endpoint struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type testVars struct {
|
type testVars struct {
|
||||||
Endpoint endpoint
|
ChatCompletions endpointWithModel
|
||||||
Completions string
|
ChatCompletionsLegacyFunctions endpointWithModel
|
||||||
ChatCompletions string
|
ChatCompletionsOYD endpointWithModel // azure only
|
||||||
ChatCompletionsLegacyFunctions string
|
ChatCompletionsRAI endpointWithModel // azure only
|
||||||
Embeddings string
|
ChatCompletionsWithJSONResponseFormat endpointWithModel
|
||||||
TextEmbedding3Small string
|
|
||||||
Cognitive azopenai.AzureSearchChatExtensionConfiguration
|
Cognitive azopenai.AzureSearchChatExtensionConfiguration
|
||||||
Whisper endpointWithModel
|
Completions endpointWithModel
|
||||||
DallE endpointWithModel
|
DallE endpointWithModel
|
||||||
|
Embeddings endpointWithModel
|
||||||
|
Speech endpointWithModel
|
||||||
|
TextEmbedding3Small endpointWithModel
|
||||||
Vision endpointWithModel
|
Vision endpointWithModel
|
||||||
|
Whisper endpointWithModel
|
||||||
ChatCompletionsRAI endpointWithModel // at the moment this is Azure only
|
|
||||||
|
|
||||||
// "own your data" - bringing in Azure resources as part of a chat completions
|
|
||||||
// request.
|
|
||||||
ChatCompletionsOYD endpointWithModel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type endpointWithModel struct {
|
type endpointWithModel struct {
|
||||||
|
@ -61,16 +54,130 @@ type endpointWithModel struct {
|
||||||
Model string
|
Model string
|
||||||
}
|
}
|
||||||
|
|
||||||
type testClientOption func(opt *azopenai.ClientOptions)
|
func ifAzure[T string | endpoint](azure bool, forAzure T, forOpenAI T) T {
|
||||||
|
if azure {
|
||||||
|
return forAzure
|
||||||
|
}
|
||||||
|
return forOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
func withForgivingRetryOption() testClientOption {
|
var azureOpenAI, openAI, servers = func() (testVars, testVars, []string) {
|
||||||
return func(opt *azopenai.ClientOptions) {
|
if recording.GetRecordMode() != recording.PlaybackMode {
|
||||||
opt.Retry = policy.RetryOptions{
|
if err := godotenv.Load(); err != nil {
|
||||||
MaxRetries: 10,
|
log.Fatalf("Failed to load .env file: %s\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
servers := struct {
|
||||||
|
USEast endpoint
|
||||||
|
USNorthCentral endpoint
|
||||||
|
USEast2 endpoint
|
||||||
|
SWECentral endpoint
|
||||||
|
OpenAI endpoint
|
||||||
|
}{
|
||||||
|
OpenAI: endpoint{
|
||||||
|
URL: getEndpoint("OPENAI_ENDPOINT"), // ex: https://api.openai.com/v1/
|
||||||
|
APIKey: recording.GetEnvVariable("OPENAI_API_KEY", fakeAPIKey),
|
||||||
|
Azure: false,
|
||||||
|
},
|
||||||
|
USEast: endpoint{
|
||||||
|
URL: getEndpoint("ENDPOINT_USEAST"),
|
||||||
|
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST_API_KEY", fakeAPIKey),
|
||||||
|
Azure: true,
|
||||||
|
},
|
||||||
|
USEast2: endpoint{
|
||||||
|
URL: getEndpoint("ENDPOINT_USEAST2"),
|
||||||
|
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST2_API_KEY", fakeAPIKey),
|
||||||
|
Azure: true,
|
||||||
|
},
|
||||||
|
USNorthCentral: endpoint{
|
||||||
|
URL: getEndpoint("ENDPOINT_USNORTHCENTRAL"),
|
||||||
|
APIKey: recording.GetEnvVariable("ENDPOINT_USNORTHCENTRAL_API_KEY", fakeAPIKey),
|
||||||
|
Azure: true,
|
||||||
|
},
|
||||||
|
SWECentral: endpoint{
|
||||||
|
URL: getEndpoint("ENDPOINT_SWECENTRAL"),
|
||||||
|
APIKey: recording.GetEnvVariable("ENDPOINT_SWECENTRAL_API_KEY", fakeAPIKey),
|
||||||
|
Azure: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// used when we setup the recording policy
|
||||||
|
endpoints := []string{
|
||||||
|
servers.OpenAI.URL,
|
||||||
|
servers.USEast.URL,
|
||||||
|
servers.USEast2.URL,
|
||||||
|
servers.USNorthCentral.URL,
|
||||||
|
servers.SWECentral.URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
newTestVarsFn := func(azure bool) testVars {
|
||||||
|
return testVars{
|
||||||
|
ChatCompletions: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
|
||||||
|
},
|
||||||
|
ChatCompletionsLegacyFunctions: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
|
||||||
|
},
|
||||||
|
ChatCompletionsOYD: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
|
||||||
|
},
|
||||||
|
ChatCompletionsRAI: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
|
||||||
|
},
|
||||||
|
ChatCompletionsWithJSONResponseFormat: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-1106-preview", "gpt-3.5-turbo-1106"),
|
||||||
|
},
|
||||||
|
Completions: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"),
|
||||||
|
},
|
||||||
|
DallE: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "dall-e-3", "dall-e-3"),
|
||||||
|
},
|
||||||
|
Embeddings: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "text-embedding-ada-002", "text-embedding-ada-002"),
|
||||||
|
},
|
||||||
|
Speech: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "tts-1", "tts-1"),
|
||||||
|
},
|
||||||
|
TextEmbedding3Small: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "text-embedding-3-small", "text-embedding-3-small"),
|
||||||
|
},
|
||||||
|
Vision: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "gpt-4-vision-preview", "gpt-4-vision-preview"),
|
||||||
|
},
|
||||||
|
Whisper: endpointWithModel{
|
||||||
|
Endpoint: ifAzure(azure, servers.USEast2, servers.OpenAI),
|
||||||
|
Model: ifAzure(azure, "whisper-deployment", "whisper-1"),
|
||||||
|
},
|
||||||
|
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
|
||||||
|
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
||||||
|
Endpoint: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_ENDPOINT", fakeCognitiveEndpoint)),
|
||||||
|
IndexName: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_INDEX", fakeCognitiveIndexName)),
|
||||||
|
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
||||||
|
Key: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_KEY", fakeAPIKey)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newTestVarsFn(true), newTestVarsFn(false), endpoints
|
||||||
|
}()
|
||||||
|
|
||||||
|
type testClientOption func(opt *azopenai.ClientOptions)
|
||||||
|
|
||||||
// newTestClient creates a client enabled for HTTP recording, if needed.
|
// newTestClient creates a client enabled for HTTP recording, if needed.
|
||||||
// See [newRecordingTransporter] for sanitization code.
|
// See [newRecordingTransporter] for sanitization code.
|
||||||
func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azopenai.Client {
|
func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azopenai.Client {
|
||||||
|
@ -101,197 +208,11 @@ func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getEndpoint retrieves details for an endpoint and a model.
|
|
||||||
// - res - the resource type for a particular endpoint. Ex: "DALLE".
|
|
||||||
//
|
|
||||||
// For example, if azure is true we'll load these environment values based on res:
|
|
||||||
// - AOAI_DALLE_ENDPOINT
|
|
||||||
// - AOAI_DALLE_API_KEY
|
|
||||||
//
|
|
||||||
// if azure is false we'll load these environment values based on res:
|
|
||||||
// - OPENAI_ENDPOINT
|
|
||||||
// - OPENAI_API_KEY
|
|
||||||
func getEndpoint(res string, isAzure bool) endpointWithModel {
|
|
||||||
var ep endpointWithModel
|
|
||||||
if isAzure {
|
|
||||||
// during development resources are often shifted between different
|
|
||||||
// internal Azure OpenAI resources.
|
|
||||||
ep = endpointWithModel{
|
|
||||||
Endpoint: endpoint{
|
|
||||||
URL: getRequired("AOAI_" + res + "_ENDPOINT"),
|
|
||||||
APIKey: getRequired("AOAI_" + res + "_API_KEY"),
|
|
||||||
Azure: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ep = endpointWithModel{
|
|
||||||
Endpoint: endpoint{
|
|
||||||
URL: getRequired("OPENAI_ENDPOINT"),
|
|
||||||
APIKey: getRequired("OPENAI_API_KEY"),
|
|
||||||
Azure: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(ep.Endpoint.URL, "/") {
|
|
||||||
// (this just makes recording replacement easier)
|
|
||||||
ep.Endpoint.URL += "/"
|
|
||||||
}
|
|
||||||
|
|
||||||
return ep
|
|
||||||
}
|
|
||||||
|
|
||||||
func model(azure bool, azureModel, openAIModel string) string {
|
|
||||||
if azure {
|
|
||||||
return azureModel
|
|
||||||
}
|
|
||||||
|
|
||||||
return openAIModel
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateModels(azure bool, tv *testVars) {
|
|
||||||
// the models we use are basically their own API surface so it's good to know which
|
|
||||||
// specific models our tests were written against.
|
|
||||||
tv.Completions = model(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct")
|
|
||||||
tv.ChatCompletions = model(azure, "gpt-35-turbo-0613", "gpt-4-0613")
|
|
||||||
tv.ChatCompletionsLegacyFunctions = model(azure, "gpt-4-0613", "gpt-4-0613")
|
|
||||||
tv.Embeddings = model(azure, "text-embedding-ada-002", "text-embedding-ada-002")
|
|
||||||
tv.TextEmbedding3Small = model(azure, "text-embedding-3-small", "text-embedding-3-small")
|
|
||||||
|
|
||||||
tv.DallE.Model = model(azure, "dall-e-3", "dall-e-3")
|
|
||||||
tv.Whisper.Model = model(azure, "whisper-deployment", "whisper-1")
|
|
||||||
tv.Vision.Model = model(azure, "gpt-4-vision-preview", "gpt-4-vision-preview")
|
|
||||||
|
|
||||||
// these are Azure-only features
|
|
||||||
tv.ChatCompletionsOYD.Model = model(azure, "gpt-4", "")
|
|
||||||
tv.ChatCompletionsRAI.Model = model(azure, "gpt-4", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestVars(prefix string) testVars {
|
|
||||||
azure := prefix == "AOAI"
|
|
||||||
|
|
||||||
tv := testVars{
|
|
||||||
Endpoint: endpoint{
|
|
||||||
URL: getRequired(prefix + "_ENDPOINT"),
|
|
||||||
APIKey: getRequired(prefix + "_API_KEY"),
|
|
||||||
Azure: azure,
|
|
||||||
},
|
|
||||||
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
|
|
||||||
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
|
||||||
Endpoint: to.Ptr(getRequired("COGNITIVE_SEARCH_API_ENDPOINT")),
|
|
||||||
IndexName: to.Ptr(getRequired("COGNITIVE_SEARCH_API_INDEX")),
|
|
||||||
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
|
||||||
Key: to.Ptr(getRequired("COGNITIVE_SEARCH_API_KEY")),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
DallE: getEndpoint("DALLE", azure),
|
|
||||||
Whisper: getEndpoint("WHISPER", azure),
|
|
||||||
Vision: getEndpoint("VISION", azure),
|
|
||||||
}
|
|
||||||
|
|
||||||
if azure {
|
|
||||||
tv.ChatCompletionsRAI = getEndpoint("CHAT_COMPLETIONS_RAI", azure)
|
|
||||||
tv.ChatCompletionsOYD = getEndpoint("OYD", azure)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateModels(azure, &tv)
|
|
||||||
|
|
||||||
if tv.Endpoint.URL != "" && !strings.HasSuffix(tv.Endpoint.URL, "/") {
|
|
||||||
// (this just makes recording replacement easier)
|
|
||||||
tv.Endpoint.URL += "/"
|
|
||||||
}
|
|
||||||
|
|
||||||
return tv
|
|
||||||
}
|
|
||||||
|
|
||||||
const fakeEndpoint = "https://fake-recorded-host.microsoft.com/"
|
const fakeEndpoint = "https://fake-recorded-host.microsoft.com/"
|
||||||
const fakeAPIKey = "redacted"
|
const fakeAPIKey = "redacted"
|
||||||
const fakeCognitiveEndpoint = "https://fake-cognitive-endpoint.microsoft.com"
|
const fakeCognitiveEndpoint = "https://fake-cognitive-endpoint.microsoft.com"
|
||||||
const fakeCognitiveIndexName = "index"
|
const fakeCognitiveIndexName = "index"
|
||||||
|
|
||||||
func initEnvVars() {
|
|
||||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
|
||||||
// Setup our variables so our requests are consistent with what we recorded.
|
|
||||||
// Endpoints are sanitized using the recording policy
|
|
||||||
azureOpenAI.Endpoint = endpoint{
|
|
||||||
URL: fakeEndpoint,
|
|
||||||
APIKey: fakeAPIKey,
|
|
||||||
Azure: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI.Whisper = endpointWithModel{
|
|
||||||
Endpoint: azureOpenAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI.ChatCompletionsRAI = endpointWithModel{
|
|
||||||
Endpoint: azureOpenAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI.ChatCompletionsOYD = endpointWithModel{
|
|
||||||
Endpoint: azureOpenAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI.DallE = endpointWithModel{
|
|
||||||
Endpoint: azureOpenAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI.Vision = endpointWithModel{
|
|
||||||
Endpoint: azureOpenAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
openAI.Endpoint = endpoint{
|
|
||||||
APIKey: fakeAPIKey,
|
|
||||||
URL: fakeEndpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
openAI.Whisper = endpointWithModel{
|
|
||||||
Endpoint: endpoint{
|
|
||||||
APIKey: fakeAPIKey,
|
|
||||||
URL: fakeEndpoint,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
openAI.DallE = endpointWithModel{
|
|
||||||
Endpoint: openAI.Endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
updateModels(true, &azureOpenAI)
|
|
||||||
updateModels(false, &openAI)
|
|
||||||
|
|
||||||
openAI.Vision = azureOpenAI.Vision
|
|
||||||
|
|
||||||
azureOpenAI.Completions = "gpt-35-turbo-instruct"
|
|
||||||
openAI.Completions = "gpt-3.5-turbo-instruct"
|
|
||||||
|
|
||||||
azureOpenAI.ChatCompletions = "gpt-35-turbo-0613"
|
|
||||||
azureOpenAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
|
|
||||||
openAI.ChatCompletions = "gpt-4-0613"
|
|
||||||
openAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
|
|
||||||
|
|
||||||
openAI.Embeddings = "text-embedding-ada-002"
|
|
||||||
azureOpenAI.Embeddings = "text-embedding-ada-002"
|
|
||||||
|
|
||||||
azureOpenAI.Cognitive = azopenai.AzureSearchChatExtensionConfiguration{
|
|
||||||
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
|
||||||
Endpoint: to.Ptr(fakeCognitiveEndpoint),
|
|
||||||
IndexName: to.Ptr(fakeCognitiveIndexName),
|
|
||||||
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
|
||||||
Key: to.Ptr(fakeAPIKey),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := godotenv.Load(); err != nil {
|
|
||||||
fmt.Printf("Failed to load .env file: %s\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
azureOpenAI = newTestVars("AOAI")
|
|
||||||
openAI = newTestVars("OPENAI")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type MultipartRecordingPolicy struct {
|
type MultipartRecordingPolicy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -315,17 +236,7 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
|
||||||
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", "", nil)
|
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", "", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
endpoints := []string{
|
for _, ep := range servers {
|
||||||
azureOpenAI.Endpoint.URL,
|
|
||||||
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
|
|
||||||
azureOpenAI.Whisper.Endpoint.URL,
|
|
||||||
azureOpenAI.DallE.Endpoint.URL,
|
|
||||||
azureOpenAI.Vision.Endpoint.URL,
|
|
||||||
azureOpenAI.ChatCompletionsOYD.Endpoint.URL,
|
|
||||||
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ep := range endpoints {
|
|
||||||
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(ep), nil)
|
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(ep), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
@ -333,8 +244,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
|
||||||
err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil)
|
err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if openAI.Endpoint.URL != "" {
|
// there's only one OpenAI endpoint
|
||||||
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.Endpoint.URL), nil)
|
if openAI.ChatCompletions.Endpoint.URL != "" {
|
||||||
|
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.ChatCompletions.Endpoint.URL), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,20 +313,12 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {
|
||||||
return co
|
return co
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAzureOpenAIClientForTest(t *testing.T, tv testVars, options ...testClientOption) *azopenai.Client {
|
|
||||||
return newTestClient(t, tv.Endpoint, options...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOpenAIClientForTest(t *testing.T, options ...testClientOption) *azopenai.Client {
|
|
||||||
return newTestClient(t, openAI.Endpoint, options...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
|
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
|
||||||
// a failure.
|
// a failure.
|
||||||
func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
|
func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
|
||||||
cred := azcore.NewKeyCredential("bogus-api-key")
|
cred := azcore.NewKeyCredential("bogus-api-key")
|
||||||
|
|
||||||
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint.URL, cred, newClientOptionsForTest(t))
|
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
@ -425,7 +329,7 @@ func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
|
||||||
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
|
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
|
||||||
cred := azcore.NewKeyCredential("bogus-api-key")
|
cred := azcore.NewKeyCredential("bogus-api-key")
|
||||||
|
|
||||||
client, err := azopenai.NewClientForOpenAI(openAI.Endpoint.URL, cred, newClientOptionsForTest(t))
|
client, err := azopenai.NewClientForOpenAI(openAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
@ -440,11 +344,12 @@ func assertResponseIsError(t *testing.T, err error) {
|
||||||
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
|
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRequired(name string) string {
|
func getEndpoint(ev string) string {
|
||||||
v := os.Getenv(name)
|
v := recording.GetEnvVariable(ev, fakeEndpoint)
|
||||||
|
|
||||||
if v == "" {
|
if !strings.HasSuffix(v, "/") {
|
||||||
panic(fmt.Sprintf("Env variable %s is missing", name))
|
// (this just makes recording replacement easier)
|
||||||
|
v += "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
|
@ -23,7 +23,7 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) {
|
||||||
t.Skip()
|
t.Skip()
|
||||||
}
|
}
|
||||||
|
|
||||||
chatClient := newOpenAIClientForTest(t)
|
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||||
|
|
||||||
_, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
_, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
||||||
Messages: []azopenai.ChatRequestMessageClassification{
|
Messages: []azopenai.ChatRequestMessageClassification{
|
||||||
|
|
|
@ -29,11 +29,7 @@ func TestImageGeneration_AzureOpenAI(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImageGeneration_OpenAI(t *testing.T) {
|
func TestImageGeneration_OpenAI(t *testing.T) {
|
||||||
if testing.Short() {
|
client := newTestClient(t, openAI.DallE.Endpoint)
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatURL)
|
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,7 +56,7 @@ func TestImageGeneration_OpenAI_Base64(t *testing.T) {
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||||
}
|
}
|
||||||
|
|
||||||
client := newOpenAIClientForTest(t)
|
client := newTestClient(t, openAI.DallE.Endpoint)
|
||||||
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatBase64)
|
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatBase64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,28 +76,17 @@ func TestNewClientWithKeyCredential(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
|
func TestGetCompletionsStream(t *testing.T) {
|
||||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||||
testGetCompletionsStream(t, client, azureOpenAI)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetCompletionsStream_OpenAI(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := newOpenAIClientForTest(t)
|
|
||||||
testGetCompletionsStream(t, client, openAI)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
|
|
||||||
body := azopenai.CompletionsOptions{
|
body := azopenai.CompletionsOptions{
|
||||||
Prompt: []string{"What is Azure OpenAI?"},
|
Prompt: []string{"What is Azure OpenAI?"},
|
||||||
MaxTokens: to.Ptr(int32(2048)),
|
MaxTokens: to.Ptr(int32(2048)),
|
||||||
Temperature: to.Ptr(float32(0.0)),
|
Temperature: to.Ptr(float32(0.0)),
|
||||||
DeploymentName: &tv.Completions,
|
DeploymentName: &epm.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := newTestClient(t, epm.Endpoint)
|
||||||
|
|
||||||
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
|
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
|
||||||
skipNowIfThrottled(t, err)
|
skipNowIfThrottled(t, err)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -106,6 +95,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
|
||||||
t.Errorf("Client.GetCompletionsStream() error = %v", err)
|
t.Errorf("Client.GetCompletionsStream() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reader := response.CompletionsStream
|
reader := response.CompletionsStream
|
||||||
defer reader.Close()
|
defer reader.Close()
|
||||||
|
|
||||||
|
@ -146,12 +136,23 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
|
||||||
require.GreaterOrEqual(t, eventCount, 50)
|
require.GreaterOrEqual(t, eventCount, 50)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, azureOpenAI.Completions)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
|
testFn(t, openAI.Completions)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_GetCompletions_Error(t *testing.T) {
|
func TestClient_GetCompletions_Error(t *testing.T) {
|
||||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||||
t.Skip()
|
t.Skip()
|
||||||
}
|
}
|
||||||
|
|
||||||
doTest := func(t *testing.T, client *azopenai.Client, model string) {
|
doTest := func(t *testing.T, model string) {
|
||||||
|
client := newBogusAzureOpenAIClient(t)
|
||||||
|
|
||||||
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
|
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
|
||||||
Prompt: []string{"What is Azure OpenAI?"},
|
Prompt: []string{"What is Azure OpenAI?"},
|
||||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||||
|
@ -163,12 +164,10 @@ func TestClient_GetCompletions_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||||
client := newBogusAzureOpenAIClient(t)
|
doTest(t, azureOpenAI.Completions.Model)
|
||||||
doTest(t, client, azureOpenAI.Completions)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OpenAI", func(t *testing.T) {
|
t.Run("OpenAI", func(t *testing.T) {
|
||||||
client := newBogusOpenAIClient(t)
|
doTest(t, openAI.Completions.Model)
|
||||||
doTest(t, client, openAI.Completions)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
|
||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0
|
||||||
github.com/joho/godotenv v1.3.0
|
github.com/joho/godotenv v1.3.0
|
||||||
|
@ -19,9 +19,9 @@ require (
|
||||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
golang.org/x/crypto v0.21.0 // indirect
|
golang.org/x/crypto v0.22.0 // indirect
|
||||||
golang.org/x/net v0.22.0 // indirect
|
golang.org/x/net v0.24.0 // indirect
|
||||||
golang.org/x/sys v0.18.0 // indirect
|
golang.org/x/sys v0.19.0 // indirect
|
||||||
golang.org/x/text v0.14.0 // indirect
|
golang.org/x/text v0.14.0 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 h1:c4k2FIYIh4xtwqrQwV0Ct1v5+ehlNXj5NI/MWVsiTkQ=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2/go.mod h1:5FDJtLEO/GxwNgUxbwrY3LP0pEoThTQJtk2oysdXHxM=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0 h1:sUFnFjzDUie80h24I7mrKtwCKgLY9L8h5Tp2x9+TWqk=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0 h1:sUFnFjzDUie80h24I7mrKtwCKgLY9L8h5Tp2x9+TWqk=
|
||||||
|
@ -25,13 +25,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||||
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
const RecordingDirectory = "sdk/ai/azopenai/testdata"
|
const RecordingDirectory = "sdk/ai/azopenai/testdata"
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
initEnvVars()
|
|
||||||
code := run(m)
|
code := run(m)
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче