[azopenai] A rather large, but needed, cleanup of the tests. (#22707)

It was getting difficult to tell what was and wasn't covered and also correlate tests to models, which can sometimes come with differing behavior.

So in this PR I:
- Moved things that don't need to be env vars (ie: non-secrets) to just be constants in the code - this includes model names and generic server identifiers for regions that we use to coordinate new feature development.
- Consolidated all of the setting of endpoint and models into one spot to make it simpler to double-check.
- Consolidated tests that tested the same thing into sub-tests with OpenAI or AzureOpenAI names.
  - If a function was only called by one test moved it into the test as an anonymous func

Also, I added in a test for logit_probs and logprobs/toplogprobs.
This commit is contained in:
Richard Park 2024-04-17 15:19:32 -07:00 коммит произвёл GitHub
Родитель a51db25793
Коммит a0f9b026ec
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 1004 добавлений и 1021 удалений

Просмотреть файл

@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/ai/azopenai",
"Tag": "go/ai/azopenai_a33cdad878"
"Tag": "go/ai/azopenai_a56e3e9e32"
}

Просмотреть файл

@ -19,64 +19,14 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_GetAudioTranscription_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranscriptionTests(t, client, azureOpenAI.Whisper.Model, true)
}
func TestClient_GetAudioTranscription(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
model := epm.Model
func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
client := newOpenAIClientForTest(t)
runTranscriptionTests(t, client, openAI.Whisper.Model, false)
}
func TestClient_GetAudioTranslation_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranslationTests(t, client, azureOpenAI.Whisper.Model, true)
}
func TestClient_GetAudioTranslation_OpenAI(t *testing.T) {
client := newOpenAIClientForTest(t)
runTranslationTests(t, client, openAI.Whisper.Model, false)
}
func TestClient_GetAudioSpeech(t *testing.T) {
client := newOpenAIClientForTest(t)
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
Input: to.Ptr("i am a computer"),
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
DeploymentName: to.Ptr("tts-1"),
}, nil)
require.NoError(t, err)
audioBytes, err := io.ReadAll(audioResp.Body)
require.NoError(t, err)
require.NotEmpty(t, audioBytes)
require.Equal(t, "fLaC", string(audioBytes[0:4]))
// now send _it_ back through the transcription API and see if we can get something useful.
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
Filename: to.Ptr("test.flac"),
File: audioBytes,
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
DeploymentName: &openAI.Whisper.Model,
Temperature: to.Ptr[float32](0.0),
}, nil)
require.NoError(t, err)
require.NotZero(t, *transcriptionResp.Duration)
// it occasionally comes back with different punctuation or makes a complete sentence but
// the major words always come through.
require.Contains(t, *transcriptionResp.Text, "computer")
}
func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if isAzure {
if epm.Endpoint.Azure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
@ -167,10 +117,23 @@ func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string,
}
}
func runTranslationTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Whisper)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Whisper)
})
}
func TestClient_GetAudioTranslation(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
model := epm.Model
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if isAzure {
if epm.Endpoint.Azure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
@ -262,6 +225,49 @@ func runTranslationTests(t *testing.T, client *azopenai.Client, model string, is
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Whisper)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Whisper)
})
}
func TestClient_GetAudioSpeech(t *testing.T) {
client := newTestClient(t, openAI.Speech.Endpoint)
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
Input: to.Ptr("i am a computer"),
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
DeploymentName: to.Ptr("tts-1"),
}, nil)
require.NoError(t, err)
audioBytes, err := io.ReadAll(audioResp.Body)
require.NoError(t, err)
require.NotEmpty(t, audioBytes)
require.Equal(t, "fLaC", string(audioBytes[0:4]))
// now send _it_ back through the transcription API and see if we can get something useful.
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
Filename: to.Ptr("test.flac"),
File: audioBytes,
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
DeploymentName: &openAI.Whisper.Model,
Temperature: to.Ptr[float32](0.0),
}, nil)
require.NoError(t, err)
require.NotZero(t, *transcriptionResp.Duration)
// it occasionally comes back with different punctuation or makes a complete sentence but
// the major words always come through.
require.Contains(t, *transcriptionResp.Text, "computer")
}
func newTranscriptionOptions(format azopenai.AudioTranscriptionFormat, model string, path string) azopenai.AudioTranscriptionOptions {
audioBytes, err := os.ReadFile(path)

Просмотреть файл

@ -42,34 +42,7 @@ var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
var expectedRole = azopenai.ChatRoleAssistant
func TestClient_GetChatCompletions(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetChatCompletions(t, client, azureOpenAI.ChatCompletionsRAI.Model, true)
}
func TestClient_GetChatCompletionsStream(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testGetChatCompletionsStream(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model)
}
func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient, openAI.ChatCompletions, false)
}
func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient, openAI.ChatCompletions)
}
func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment string, checkRAI bool) {
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedModel string, checkRAI bool) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
@ -88,9 +61,7 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
PromptTokens: to.Ptr(int32(42)),
TotalTokens: to.Ptr(int32(71)),
},
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
// by convention) but if this fails because they _don't_ match we can just adjust the test.
Model: &deployment,
Model: &returnedModel,
}
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
@ -112,10 +83,136 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
expected.ID = resp.ID
expected.Created = resp.Created
t.Logf("isAzure: %t, deployment: %s, returnedModel: %s", checkRAI, deployment, *resp.ChatCompletions.Model)
require.Equal(t, expected, resp.ChatCompletions)
}
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deployment string) {
t.Run("AzureOpenAI", func(t *testing.T) {
client := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testFn(t, client, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4", true)
})
t.Run("AzureOpenAI.DefaultAzureCredential", func(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skipf("Not running this test in playback (for now)")
}
if os.Getenv("USE_TOKEN_CREDS") != "true" {
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
}
recordingTransporter := newRecordingTransporter(t)
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: recordingTransporter,
},
})
require.NoError(t, err)
chatClient, err := azopenai.NewClient(azureOpenAI.ChatCompletions.Endpoint.URL, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
})
require.NoError(t, err)
testFn(t, chatClient, azureOpenAI.ChatCompletions.Model, "gpt-4", true)
})
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
testFn(t, chatClient, openAI.ChatCompletions.Model, "gpt-4-0613", false)
})
}
func TestClient_GetChatCompletions_LogProbs(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
opts := azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
LogProbs: to.Ptr(true),
TopLogProbs: to.Ptr(int32(5)),
}
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
require.NoError(t, err)
for _, choice := range resp.Choices {
require.NotEmpty(t, choice.LogProbs)
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}
func TestClient_GetChatCompletions_LogitBias(t *testing.T) {
// you can use LogitBias to constrain the answer to NOT contain
// certain tokens. More or less following the technique in this OpenAI article:
// https://help.openai.com/en/articles/5247780-using-logit-bias-to-alter-token-probability-with-the-openai-api
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
opts := azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent("Briefly, what are some common roles for people at a circus, names only, one per line?"),
},
},
MaxTokens: to.Ptr(int32(200)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
LogitBias: map[string]*int32{
// you can calculate these tokens using OpenAI's online tool:
// https://platform.openai.com/tokenizer?view=bpe
// These token IDs are all variations of "Clown", which I want to exclude from the response.
"25": to.Ptr(int32(-100)),
"220": to.Ptr(int32(-100)),
"1206": to.Ptr(int32(-100)),
"2493": to.Ptr(int32(-100)),
"5176": to.Ptr(int32(-100)),
"43456": to.Ptr(int32(-100)),
"99423": to.Ptr(int32(-100)),
},
}
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
require.NoError(t, err)
for _, choice := range resp.Choices {
if choice.Message == nil || choice.Message.Content == nil {
continue
}
require.NotContains(t, *choice.Message.Content, "clown")
require.NotContains(t, *choice.Message.Content, "Clown")
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}
func TestClient_GetChatCompletionsStream(t *testing.T) {
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedDeployment string) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
@ -141,7 +238,7 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
// by convention) but if this fails because they _don't_ match we can just adjust the test.
if deployment == *completion.Model {
if returnedDeployment == *completion.Model {
modelWasReturned = true
}
@ -178,34 +275,19 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
}
func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skipf("Not running this test in playback (for now)")
}
if os.Getenv("USE_TOKEN_CREDS") != "true" {
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
}
recordingTransporter := newRecordingTransporter(t)
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: recordingTransporter,
},
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testFn(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4")
})
require.NoError(t, err)
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint.URL, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
testFn(t, chatClient, openAI.ChatCompletions.Model, openAI.ChatCompletions.Model)
})
require.NoError(t, err)
testGetChatCompletions(t, chatClient, azureOpenAI.ChatCompletions, true)
}
func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
client := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
_, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
@ -230,14 +312,14 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions), nil)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions.Model), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions), nil)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions.Model), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
@ -276,15 +358,15 @@ func TestClient_GetChatCompletions_Vision(t *testing.T) {
t.Logf(*resp.Choices[0].Message.Content)
}
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, openAI.Vision.Model)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
testFn(t, chatClient, azureOpenAI.Vision.Model)
})
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.Vision.Endpoint)
testFn(t, chatClient, openAI.Vision.Model)
})
}
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
@ -313,13 +395,13 @@ func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
require.NotEmpty(t, v)
}
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, "gpt-3.5-turbo-1106")
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
testFn(t, chatClient, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Model)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
testFn(t, chatClient, "gpt-4-1106-preview")
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
testFn(t, chatClient, openAI.ChatCompletionsWithJSONResponseFormat.Model)
})
}

Просмотреть файл

@ -15,32 +15,15 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetCompletions(t, client, true)
}
func TestClient_GetCompletions_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
testGetCompletions(t, client, false)
}
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
deploymentID := openAI.Completions
if isAzure {
deploymentID = azureOpenAI.Completions
}
func TestClient_GetCompletions(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &deploymentID,
DeploymentName: &epm.Model,
}, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
@ -55,7 +38,7 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
if isAzure {
if epm.Endpoint.Azure {
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{
@ -65,3 +48,12 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Completions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Completions)
})
}

Просмотреть файл

@ -18,7 +18,7 @@ import (
)
func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
client := newTestClient(t, azureOpenAI.Embeddings.Endpoint)
_, err := client.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
DeploymentName: to.Ptr("thisdoesntexist"),
@ -29,80 +29,10 @@ func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
}
func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
testGetEmbeddings(t, client, openAI.Embeddings)
}
func TestClient_GetEmbeddings(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
}
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
testFn := func(t *testing.T, tv testVars, dimension int32) {
client := newTestClient(t, tv.Endpoint)
arg := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
DeploymentName: &tv.TextEmbedding3Small,
}
if dimension > 0 {
arg.Dimensions = &dimension
}
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
require.NoError(t, err)
require.NotEmpty(t, base64Resp.Data)
require.Empty(t, base64Resp.Data[0].Embedding)
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
// sanity checks - we deserialized everything and didn't create anything impossible.
for _, v := range embeddings {
require.True(t, v <= 1.0 && v >= -1.0)
}
arg2 := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
DeploymentName: &tv.TextEmbedding3Small,
}
if dimension > 0 {
arg2.Dimensions = &dimension
}
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
require.NoError(t, err)
require.NotEmpty(t, floatResp.Data)
require.NotEmpty(t, floatResp.Data[0].Embedding)
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
// This works "most of the time" but it's non-deterministic since two separate calls don't always
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
}
for _, dim := range []int32{0, 1, 10, 100} {
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, azureOpenAI, dim)
})
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, openAI, dim)
})
}
}
func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
type args struct {
ctx context.Context
deploymentID string
@ -122,10 +52,10 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
client: client,
args: args{
ctx: context.TODO(),
deploymentID: modelOrDeploymentID,
deploymentID: epm.Model,
body: azopenai.EmbeddingsOptions{
Input: []string{"\"Your text string goes here\""},
DeploymentName: &modelOrDeploymentID,
DeploymentName: &epm.Model,
},
options: nil,
},
@ -151,6 +81,74 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Embeddings)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Embeddings)
})
}
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel, dimension int32) {
client := newTestClient(t, epm.Endpoint)
arg := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
DeploymentName: &epm.Model,
}
if dimension > 0 {
arg.Dimensions = &dimension
}
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
require.NoError(t, err)
require.NotEmpty(t, base64Resp.Data)
require.Empty(t, base64Resp.Data[0].Embedding)
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
// sanity checks - we deserialized everything and didn't create anything impossible.
for _, v := range embeddings {
require.True(t, v <= 1.0 && v >= -1.0)
}
arg2 := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
DeploymentName: &epm.Model,
}
if dimension > 0 {
arg2.Dimensions = &dimension
}
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
require.NoError(t, err)
require.NotEmpty(t, floatResp.Data)
require.NotEmpty(t, floatResp.Data[0].Embedding)
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
// This works "most of the time" but it's non-deterministic since two separate calls don't always
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
}
for _, dim := range []int32{0, 1, 10, 100} {
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, azureOpenAI.TextEmbedding3Small, dim)
})
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, openAI.TextEmbedding3Small, dim)
})
}
}
func deserializeBase64Embeddings(t *testing.T, ei azopenai.EmbeddingItem) []float32 {
destBytes, err := base64.StdEncoding.DecodeString(ei.EmbeddingBase64)
require.NoError(t, err)

Просмотреть файл

@ -6,6 +6,7 @@ package azopenai_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
@ -28,130 +29,7 @@ type ParamProperty struct {
func TestGetChatCompletions_usingFunctions(t *testing.T) {
// https://platform.openai.com/docs/guides/gpt/function-calling
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
)
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testData := []struct {
Model string
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: openAI.ChatCompletions, ToolChoice: nil},
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
}
for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
}
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testData := []struct {
Model string
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
}
for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
}
})
}
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletionsLegacyFunctions)
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletions)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctionsOlderStyle(t, chatClient, azureOpenAI.ChatCompletionsLegacyFunctions)
})
}
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
// https://platform.openai.com/docs/guides/gpt/function-calling
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctionsStreaming(t, chatClient, openAI)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctionsStreaming(t, chatClient, azureOpenAI)
})
}
func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Client, deploymentName string) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
Value: to.Ptr("auto"),
},
Functions: []azopenai.FunctionDefinition{
{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
resp, err := client.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
@ -204,9 +82,124 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Client, tv testVars) {
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
)
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
testData := []struct {
Model string
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: nil},
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: useSpecificTool},
}
for _, td := range testData {
testFn(t, chatClient, td.Model, td.ToolChoice)
}
})
t.Run("OpenAI", func(t *testing.T) {
testData := []struct {
EPM endpointWithModel
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{EPM: openAI.ChatCompletions, ToolChoice: nil},
{EPM: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{EPM: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
}
for i, td := range testData {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
chatClient := newTestClient(t, td.EPM.Endpoint)
testFn(t, chatClient, td.EPM.Model, td.ToolChoice)
})
}
})
}
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
body := azopenai.ChatCompletionsOptions{
DeploymentName: &tv.ChatCompletions,
DeploymentName: &epm.Model,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
Value: to.Ptr("auto"),
},
Functions: []azopenai.FunctionDefinition{
{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
resp, err := client.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletionsLegacyFunctions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
t.Run("OpenAI.LegacyFunctions", func(t *testing.T) {
testFn(t, openAI.ChatCompletionsLegacyFunctions)
})
}
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &epm.Model,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
@ -237,6 +230,8 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
Temperature: to.Ptr[float32](0.0),
}
chatClient := newTestClient(t, epm.Endpoint)
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
require.NoError(t, err)
require.NotEmpty(t, resp)
@ -293,3 +288,14 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
// https://platform.openai.com/docs/guides/gpt/function-calling
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}

Просмотреть файл

@ -22,13 +22,13 @@ import (
func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) {
// Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered
// https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered
client := newAzureOpenAIClientForTest(t, azureOpenAI)
client := newTestClient(t, azureOpenAI.Completions.Endpoint)
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"How do I rob a bank with violence?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &azureOpenAI.Completions,
DeploymentName: &azureOpenAI.Completions.Model,
}, nil)
require.Empty(t, resp)

Просмотреть файл

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
@ -26,11 +27,6 @@ import (
"github.com/stretchr/testify/require"
)
var (
azureOpenAI testVars
openAI testVars
)
type endpoint struct {
URL string
APIKey string
@ -38,22 +34,19 @@ type endpoint struct {
}
type testVars struct {
Endpoint endpoint
Completions string
ChatCompletions string
ChatCompletionsLegacyFunctions string
Embeddings string
TextEmbedding3Small string
ChatCompletions endpointWithModel
ChatCompletionsLegacyFunctions endpointWithModel
ChatCompletionsOYD endpointWithModel // azure only
ChatCompletionsRAI endpointWithModel // azure only
ChatCompletionsWithJSONResponseFormat endpointWithModel
Cognitive azopenai.AzureSearchChatExtensionConfiguration
Whisper endpointWithModel
Completions endpointWithModel
DallE endpointWithModel
Embeddings endpointWithModel
Speech endpointWithModel
TextEmbedding3Small endpointWithModel
Vision endpointWithModel
ChatCompletionsRAI endpointWithModel // at the moment this is Azure only
// "own your data" - bringing in Azure resources as part of a chat completions
// request.
ChatCompletionsOYD endpointWithModel
Whisper endpointWithModel
}
type endpointWithModel struct {
@ -61,16 +54,130 @@ type endpointWithModel struct {
Model string
}
type testClientOption func(opt *azopenai.ClientOptions)
func ifAzure[T string | endpoint](azure bool, forAzure T, forOpenAI T) T {
if azure {
return forAzure
}
return forOpenAI
}
func withForgivingRetryOption() testClientOption {
return func(opt *azopenai.ClientOptions) {
opt.Retry = policy.RetryOptions{
MaxRetries: 10,
var azureOpenAI, openAI, servers = func() (testVars, testVars, []string) {
if recording.GetRecordMode() != recording.PlaybackMode {
if err := godotenv.Load(); err != nil {
log.Fatalf("Failed to load .env file: %s\n", err)
}
}
servers := struct {
USEast endpoint
USNorthCentral endpoint
USEast2 endpoint
SWECentral endpoint
OpenAI endpoint
}{
OpenAI: endpoint{
URL: getEndpoint("OPENAI_ENDPOINT"), // ex: https://api.openai.com/v1/
APIKey: recording.GetEnvVariable("OPENAI_API_KEY", fakeAPIKey),
Azure: false,
},
USEast: endpoint{
URL: getEndpoint("ENDPOINT_USEAST"),
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST_API_KEY", fakeAPIKey),
Azure: true,
},
USEast2: endpoint{
URL: getEndpoint("ENDPOINT_USEAST2"),
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST2_API_KEY", fakeAPIKey),
Azure: true,
},
USNorthCentral: endpoint{
URL: getEndpoint("ENDPOINT_USNORTHCENTRAL"),
APIKey: recording.GetEnvVariable("ENDPOINT_USNORTHCENTRAL_API_KEY", fakeAPIKey),
Azure: true,
},
SWECentral: endpoint{
URL: getEndpoint("ENDPOINT_SWECENTRAL"),
APIKey: recording.GetEnvVariable("ENDPOINT_SWECENTRAL_API_KEY", fakeAPIKey),
Azure: true,
},
}
// used when we setup the recording policy
endpoints := []string{
servers.OpenAI.URL,
servers.USEast.URL,
servers.USEast2.URL,
servers.USNorthCentral.URL,
servers.SWECentral.URL,
}
newTestVarsFn := func(azure bool) testVars {
return testVars{
ChatCompletions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
},
ChatCompletionsLegacyFunctions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
},
ChatCompletionsOYD: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
},
ChatCompletionsRAI: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
},
ChatCompletionsWithJSONResponseFormat: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-1106-preview", "gpt-3.5-turbo-1106"),
},
Completions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"),
},
DallE: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "dall-e-3", "dall-e-3"),
},
Embeddings: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "text-embedding-ada-002", "text-embedding-ada-002"),
},
Speech: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "tts-1", "tts-1"),
},
TextEmbedding3Small: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "text-embedding-3-small", "text-embedding-3-small"),
},
Vision: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-vision-preview", "gpt-4-vision-preview"),
},
Whisper: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast2, servers.OpenAI),
Model: ifAzure(azure, "whisper-deployment", "whisper-1"),
},
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_ENDPOINT", fakeCognitiveEndpoint)),
IndexName: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_INDEX", fakeCognitiveIndexName)),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_KEY", fakeAPIKey)),
},
},
},
}
}
return newTestVarsFn(true), newTestVarsFn(false), endpoints
}()
type testClientOption func(opt *azopenai.ClientOptions)
// newTestClient creates a client enabled for HTTP recording, if needed.
// See [newRecordingTransporter] for sanitization code.
func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azopenai.Client {
@ -101,197 +208,11 @@ func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azop
}
}
// getEndpoint retrieves details for an endpoint and a model.
// - res - the resource type for a particular endpoint. Ex: "DALLE".
//
// For example, if azure is true we'll load these environment values based on res:
// - AOAI_DALLE_ENDPOINT
// - AOAI_DALLE_API_KEY
//
// if azure is false we'll load these environment values based on res:
// - OPENAI_ENDPOINT
// - OPENAI_API_KEY
func getEndpoint(res string, isAzure bool) endpointWithModel {
var ep endpointWithModel
if isAzure {
// during development resources are often shifted between different
// internal Azure OpenAI resources.
ep = endpointWithModel{
Endpoint: endpoint{
URL: getRequired("AOAI_" + res + "_ENDPOINT"),
APIKey: getRequired("AOAI_" + res + "_API_KEY"),
Azure: true,
},
}
} else {
ep = endpointWithModel{
Endpoint: endpoint{
URL: getRequired("OPENAI_ENDPOINT"),
APIKey: getRequired("OPENAI_API_KEY"),
Azure: false,
},
}
}
if !strings.HasSuffix(ep.Endpoint.URL, "/") {
// (this just makes recording replacement easier)
ep.Endpoint.URL += "/"
}
return ep
}
func model(azure bool, azureModel, openAIModel string) string {
if azure {
return azureModel
}
return openAIModel
}
func updateModels(azure bool, tv *testVars) {
// the models we use are basically their own API surface so it's good to know which
// specific models our tests were written against.
tv.Completions = model(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct")
tv.ChatCompletions = model(azure, "gpt-35-turbo-0613", "gpt-4-0613")
tv.ChatCompletionsLegacyFunctions = model(azure, "gpt-4-0613", "gpt-4-0613")
tv.Embeddings = model(azure, "text-embedding-ada-002", "text-embedding-ada-002")
tv.TextEmbedding3Small = model(azure, "text-embedding-3-small", "text-embedding-3-small")
tv.DallE.Model = model(azure, "dall-e-3", "dall-e-3")
tv.Whisper.Model = model(azure, "whisper-deployment", "whisper-1")
tv.Vision.Model = model(azure, "gpt-4-vision-preview", "gpt-4-vision-preview")
// these are Azure-only features
tv.ChatCompletionsOYD.Model = model(azure, "gpt-4", "")
tv.ChatCompletionsRAI.Model = model(azure, "gpt-4", "")
}
func newTestVars(prefix string) testVars {
azure := prefix == "AOAI"
tv := testVars{
Endpoint: endpoint{
URL: getRequired(prefix + "_ENDPOINT"),
APIKey: getRequired(prefix + "_API_KEY"),
Azure: azure,
},
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(getRequired("COGNITIVE_SEARCH_API_ENDPOINT")),
IndexName: to.Ptr(getRequired("COGNITIVE_SEARCH_API_INDEX")),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(getRequired("COGNITIVE_SEARCH_API_KEY")),
},
},
},
DallE: getEndpoint("DALLE", azure),
Whisper: getEndpoint("WHISPER", azure),
Vision: getEndpoint("VISION", azure),
}
if azure {
tv.ChatCompletionsRAI = getEndpoint("CHAT_COMPLETIONS_RAI", azure)
tv.ChatCompletionsOYD = getEndpoint("OYD", azure)
}
updateModels(azure, &tv)
if tv.Endpoint.URL != "" && !strings.HasSuffix(tv.Endpoint.URL, "/") {
// (this just makes recording replacement easier)
tv.Endpoint.URL += "/"
}
return tv
}
const fakeEndpoint = "https://fake-recorded-host.microsoft.com/"
const fakeAPIKey = "redacted"
const fakeCognitiveEndpoint = "https://fake-cognitive-endpoint.microsoft.com"
const fakeCognitiveIndexName = "index"
func initEnvVars() {
if recording.GetRecordMode() == recording.PlaybackMode {
// Setup our variables so our requests are consistent with what we recorded.
// Endpoints are sanitized using the recording policy
azureOpenAI.Endpoint = endpoint{
URL: fakeEndpoint,
APIKey: fakeAPIKey,
Azure: true,
}
azureOpenAI.Whisper = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.ChatCompletionsRAI = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.ChatCompletionsOYD = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.DallE = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.Vision = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
openAI.Endpoint = endpoint{
APIKey: fakeAPIKey,
URL: fakeEndpoint,
}
openAI.Whisper = endpointWithModel{
Endpoint: endpoint{
APIKey: fakeAPIKey,
URL: fakeEndpoint,
},
}
openAI.DallE = endpointWithModel{
Endpoint: openAI.Endpoint,
}
updateModels(true, &azureOpenAI)
updateModels(false, &openAI)
openAI.Vision = azureOpenAI.Vision
azureOpenAI.Completions = "gpt-35-turbo-instruct"
openAI.Completions = "gpt-3.5-turbo-instruct"
azureOpenAI.ChatCompletions = "gpt-35-turbo-0613"
azureOpenAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
openAI.ChatCompletions = "gpt-4-0613"
openAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
openAI.Embeddings = "text-embedding-ada-002"
azureOpenAI.Embeddings = "text-embedding-ada-002"
azureOpenAI.Cognitive = azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(fakeCognitiveEndpoint),
IndexName: to.Ptr(fakeCognitiveIndexName),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(fakeAPIKey),
},
},
}
} else {
if err := godotenv.Load(); err != nil {
fmt.Printf("Failed to load .env file: %s\n", err)
}
azureOpenAI = newTestVars("AOAI")
openAI = newTestVars("OPENAI")
}
}
type MultipartRecordingPolicy struct {
}
@ -315,17 +236,7 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", "", nil)
require.NoError(t, err)
endpoints := []string{
azureOpenAI.Endpoint.URL,
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
azureOpenAI.Whisper.Endpoint.URL,
azureOpenAI.DallE.Endpoint.URL,
azureOpenAI.Vision.Endpoint.URL,
azureOpenAI.ChatCompletionsOYD.Endpoint.URL,
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
}
for _, ep := range endpoints {
for _, ep := range servers {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(ep), nil)
require.NoError(t, err)
}
@ -333,8 +244,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil)
require.NoError(t, err)
if openAI.Endpoint.URL != "" {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.Endpoint.URL), nil)
// there's only one OpenAI endpoint
if openAI.ChatCompletions.Endpoint.URL != "" {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.ChatCompletions.Endpoint.URL), nil)
require.NoError(t, err)
}
@ -401,20 +313,12 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {
return co
}
func newAzureOpenAIClientForTest(t *testing.T, tv testVars, options ...testClientOption) *azopenai.Client {
return newTestClient(t, tv.Endpoint, options...)
}
func newOpenAIClientForTest(t *testing.T, options ...testClientOption) *azopenai.Client {
return newTestClient(t, openAI.Endpoint, options...)
}
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
// a failure.
func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
cred := azcore.NewKeyCredential("bogus-api-key")
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint.URL, cred, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
@ -425,7 +329,7 @@ func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
cred := azcore.NewKeyCredential("bogus-api-key")
client, err := azopenai.NewClientForOpenAI(openAI.Endpoint.URL, cred, newClientOptionsForTest(t))
client, err := azopenai.NewClientForOpenAI(openAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}
@ -440,11 +344,12 @@ func assertResponseIsError(t *testing.T, err error) {
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
}
func getRequired(name string) string {
v := os.Getenv(name)
func getEndpoint(ev string) string {
v := recording.GetEnvVariable(ev, fakeEndpoint)
if v == "" {
panic(fmt.Sprintf("Env variable %s is missing", name))
if !strings.HasSuffix(v, "/") {
// (this just makes recording replacement easier)
v += "/"
}
return v

Просмотреть файл

@ -23,7 +23,7 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) {
t.Skip()
}
chatClient := newOpenAIClientForTest(t)
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
_, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{

Просмотреть файл

@ -29,11 +29,7 @@ func TestImageGeneration_AzureOpenAI(t *testing.T) {
}
func TestImageGeneration_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
client := newTestClient(t, openAI.DallE.Endpoint)
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatURL)
}
@ -60,7 +56,7 @@ func TestImageGeneration_OpenAI_Base64(t *testing.T) {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
client := newTestClient(t, openAI.DallE.Endpoint)
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatBase64)
}

Просмотреть файл

@ -76,28 +76,17 @@ func TestNewClientWithKeyCredential(t *testing.T) {
}
}
func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetCompletionsStream(t, client, azureOpenAI)
}
func TestGetCompletionsStream_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
testGetCompletionsStream(t, client, openAI)
}
func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
func TestGetCompletionsStream(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
body := azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &tv.Completions,
DeploymentName: &epm.Model,
}
client := newTestClient(t, epm.Endpoint)
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
@ -106,6 +95,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
t.Errorf("Client.GetCompletionsStream() error = %v", err)
return
}
reader := response.CompletionsStream
defer reader.Close()
@ -146,12 +136,23 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
require.GreaterOrEqual(t, eventCount, 50)
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Completions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Completions)
})
}
func TestClient_GetCompletions_Error(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}
doTest := func(t *testing.T, client *azopenai.Client, model string) {
doTest := func(t *testing.T, model string) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
@ -163,12 +164,10 @@ func TestClient_GetCompletions_Error(t *testing.T) {
}
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
doTest(t, client, azureOpenAI.Completions)
doTest(t, azureOpenAI.Completions.Model)
})
t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client, openAI.Completions)
doTest(t, openAI.Completions.Model)
})
}

Просмотреть файл

@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
go 1.18
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0
github.com/joho/godotenv v1.3.0
@ -19,9 +19,9 @@ require (
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/net v0.22.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

Просмотреть файл

@ -1,5 +1,5 @@
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 h1:c4k2FIYIh4xtwqrQwV0Ct1v5+ehlNXj5NI/MWVsiTkQ=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2/go.mod h1:5FDJtLEO/GxwNgUxbwrY3LP0pEoThTQJtk2oysdXHxM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0 h1:sUFnFjzDUie80h24I7mrKtwCKgLY9L8h5Tp2x9+TWqk=
@ -25,13 +25,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

Просмотреть файл

@ -15,7 +15,6 @@ import (
const RecordingDirectory = "sdk/ai/azopenai/testdata"
func TestMain(m *testing.M) {
initEnvVars()
code := run(m)
os.Exit(code)
}