[azopenai] A rather large, but needed, cleanup of the tests. (#22707)

It was getting difficult to tell what was and wasn't covered and also correlate tests to models, which can sometimes come with differing behavior.

So in this PR I:
- Moved things that don't need to be env vars (ie: non-secrets) to just be constants in the code - this includes model names and generic server identifiers for regions that we use to coordinate new feature development.
- Consolidated all of the setting of endpoint and models into one spot to make it simpler to double-check.
- Consolidated tests that tested the same thing into sub-tests with OpenAI or AzureOpenAI names.
  - If a function was only called by one test moved it into the test as an anonymous func

Also, I added in a test for logit_probs and logprobs/toplogprobs.
This commit is contained in:
Richard Park 2024-04-17 15:19:32 -07:00 коммит произвёл GitHub
Родитель a51db25793
Коммит a0f9b026ec
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 1004 добавлений и 1021 удалений

Просмотреть файл

@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/ai/azopenai",
"Tag": "go/ai/azopenai_a33cdad878"
"Tag": "go/ai/azopenai_a56e3e9e32"
}

Просмотреть файл

@ -19,28 +19,223 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_GetAudioTranscription_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranscriptionTests(t, client, azureOpenAI.Whisper.Model, true)
func TestClient_GetAudioTranscription(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
model := epm.Model
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if epm.Endpoint.Azure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, "mp3"), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVerboseJSON, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3")
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
return
}
testFiles := []string{
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a`,
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`,
}
for _, audioFile := range testFiles {
ext := filepath.Ext(audioFile)
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatSrt, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatSrt, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVtt, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVtt, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVerboseJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatJSON, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Whisper)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Whisper)
})
}
func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
client := newOpenAIClientForTest(t)
runTranscriptionTests(t, client, openAI.Whisper.Model, false)
}
func TestClient_GetAudioTranslation(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
model := epm.Model
func TestClient_GetAudioTranslation_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranslationTests(t, client, azureOpenAI.Whisper.Model, true)
}
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if epm.Endpoint.Azure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
func TestClient_GetAudioTranslation_OpenAI(t *testing.T) {
client := newOpenAIClientForTest(t)
runTranslationTests(t, client, openAI.Whisper.Model, false)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, "mp3"), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVerboseJSON, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3")
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
return
}
testFiles := []string{
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a`,
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`,
}
for _, audioFile := range testFiles {
ext := filepath.Ext(audioFile)
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatSrt, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatSrt, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVtt, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVtt, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVerboseJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatJSON, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Whisper)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Whisper)
})
}
func TestClient_GetAudioSpeech(t *testing.T) {
client := newOpenAIClientForTest(t)
client := newTestClient(t, openAI.Speech.Endpoint)
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
Input: to.Ptr("i am a computer"),
@ -73,195 +268,6 @@ func TestClient_GetAudioSpeech(t *testing.T) {
require.Contains(t, *transcriptionResp.Text, "computer")
}
func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if isAzure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, "mp3"), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVerboseJSON, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3")
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
return
}
testFiles := []string{
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a`,
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`,
}
for _, audioFile := range testFiles {
ext := filepath.Ext(audioFile)
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatSrt, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatSrt, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVtt, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVtt, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatVerboseJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatJSON, ext), func(t *testing.T) {
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranscription(t, transcriptResp.AudioTranscription)
})
}
}
func runTranslationTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
// The bulk of the logic will test against OpenAI anyways.
if isAzure {
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, "mp3"), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVerboseJSON, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3")
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
return
}
testFiles := []string{
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a`,
`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`,
}
for _, audioFile := range testFiles {
ext := filepath.Ext(audioFile)
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatSrt, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatSrt, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVtt, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVtt, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatVerboseJSON, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatVerboseJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
require.Greater(t, *transcriptResp.Duration, float32(0.0))
require.NotEmpty(t, *transcriptResp.Language)
require.NotEmpty(t, transcriptResp.Segments)
require.NotEmpty(t, transcriptResp.Segments[0])
require.NotEmpty(t, transcriptResp.Task)
})
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatJSON, ext), func(t *testing.T) {
args := newTranslationOptions(azopenai.AudioTranslationFormatJSON, model, audioFile)
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
require.NoError(t, err)
require.NotEmpty(t, transcriptResp)
require.NotEmpty(t, *transcriptResp.Text)
requireEmptyAudioTranslation(t, transcriptResp.AudioTranslation)
})
}
}
func newTranscriptionOptions(format azopenai.AudioTranscriptionFormat, model string, path string) azopenai.AudioTranscriptionOptions {
audioBytes, err := os.ReadFile(path)

Просмотреть файл

@ -42,170 +42,252 @@ var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
var expectedRole = azopenai.ChatRoleAssistant
func TestClient_GetChatCompletions(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetChatCompletions(t, client, azureOpenAI.ChatCompletionsRAI.Model, true)
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedModel string, checkRAI bool) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Message: &azopenai.ChatResponseMessage{
Role: &expectedRole,
Content: &expectedContent,
},
Index: to.Ptr(int32(0)),
FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")),
},
},
Usage: &azopenai.CompletionsUsage{
// these change depending on which model you use. These #'s work for gpt-4, which is
// what I'm using for these tests.
CompletionTokens: to.Ptr(int32(29)),
PromptTokens: to.Ptr(int32(42)),
TotalTokens: to.Ptr(int32(71)),
},
Model: &returnedModel,
}
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
if checkRAI {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptFilterResults = []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}
expected.Choices[0].ContentFilterResults = safeContentFilter
}
require.NotEmpty(t, resp.ID)
require.NotEmpty(t, resp.Created)
expected.ID = resp.ID
expected.Created = resp.Created
t.Logf("isAzure: %t, deployment: %s, returnedModel: %s", checkRAI, deployment, *resp.ChatCompletions.Model)
require.Equal(t, expected, resp.ChatCompletions)
}
t.Run("AzureOpenAI", func(t *testing.T) {
client := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testFn(t, client, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4", true)
})
t.Run("AzureOpenAI.DefaultAzureCredential", func(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skipf("Not running this test in playback (for now)")
}
if os.Getenv("USE_TOKEN_CREDS") != "true" {
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
}
recordingTransporter := newRecordingTransporter(t)
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: recordingTransporter,
},
})
require.NoError(t, err)
chatClient, err := azopenai.NewClient(azureOpenAI.ChatCompletions.Endpoint.URL, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
})
require.NoError(t, err)
testFn(t, chatClient, azureOpenAI.ChatCompletions.Model, "gpt-4", true)
})
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
testFn(t, chatClient, openAI.ChatCompletions.Model, "gpt-4-0613", false)
})
}
func TestClient_GetChatCompletions_LogProbs(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
opts := azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
LogProbs: to.Ptr(true),
TopLogProbs: to.Ptr(int32(5)),
}
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
require.NoError(t, err)
for _, choice := range resp.Choices {
require.NotEmpty(t, choice.LogProbs)
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}
func TestClient_GetChatCompletions_LogitBias(t *testing.T) {
// you can use LogitBias to constrain the answer to NOT contain
// certain tokens. More or less following the technique in this OpenAI article:
// https://help.openai.com/en/articles/5247780-using-logit-bias-to-alter-token-probability-with-the-openai-api
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
opts := azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent("Briefly, what are some common roles for people at a circus, names only, one per line?"),
},
},
MaxTokens: to.Ptr(int32(200)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
LogitBias: map[string]*int32{
// you can calculate these tokens using OpenAI's online tool:
// https://platform.openai.com/tokenizer?view=bpe
// These token IDs are all variations of "Clown", which I want to exclude from the response.
"25": to.Ptr(int32(-100)),
"220": to.Ptr(int32(-100)),
"1206": to.Ptr(int32(-100)),
"2493": to.Ptr(int32(-100)),
"5176": to.Ptr(int32(-100)),
"43456": to.Ptr(int32(-100)),
"99423": to.Ptr(int32(-100)),
},
}
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
require.NoError(t, err)
for _, choice := range resp.Choices {
if choice.Message == nil || choice.Message.Content == nil {
continue
}
require.NotContains(t, *choice.Message.Content, "clown")
require.NotContains(t, *choice.Message.Content, "Clown")
}
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}
func TestClient_GetChatCompletionsStream(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testGetChatCompletionsStream(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model)
}
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedDeployment string) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient, openAI.ChatCompletions, false)
}
func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient, openAI.ChatCompletions)
}
func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment string, checkRAI bool) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Message: &azopenai.ChatResponseMessage{
Role: &expectedRole,
Content: &expectedContent,
},
Index: to.Ptr(int32(0)),
FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")),
},
},
Usage: &azopenai.CompletionsUsage{
// these change depending on which model you use. These #'s work for gpt-4, which is
// what I'm using for these tests.
CompletionTokens: to.Ptr(int32(29)),
PromptTokens: to.Ptr(int32(42)),
TotalTokens: to.Ptr(int32(71)),
},
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
// by convention) but if this fails because they _don't_ match we can just adjust the test.
Model: &deployment,
}
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
if checkRAI {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptFilterResults = []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}
expected.Choices[0].ContentFilterResults = safeContentFilter
}
require.NotEmpty(t, resp.ID)
require.NotEmpty(t, resp.Created)
expected.ID = resp.ID
expected.Created = resp.Created
require.Equal(t, expected, resp.ChatCompletions)
}
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deployment string) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
t.Skipf("OpenAI resource overloaded, skipping this test")
}
require.NoError(t, err)
// the data comes back differently for streaming
// 1. the text comes back in the ChatCompletion.Delta field
// 2. the role is only sent on the first streamed ChatCompletion
// check that the role came back as well.
var choices []azopenai.ChatChoice
modelWasReturned := false
for {
completion, err := streamResp.ChatCompletionsStream.Read()
if errors.Is(err, io.EOF) {
break
}
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
// by convention) but if this fails because they _don't_ match we can just adjust the test.
if deployment == *completion.Model {
modelWasReturned = true
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
t.Skipf("OpenAI resource overloaded, skipping this test")
}
require.NoError(t, err)
if completion.PromptFilterResults != nil {
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}, completion.PromptFilterResults)
// the data comes back differently for streaming
// 1. the text comes back in the ChatCompletion.Delta field
// 2. the role is only sent on the first streamed ChatCompletion
// check that the role came back as well.
var choices []azopenai.ChatChoice
modelWasReturned := false
for {
completion, err := streamResp.ChatCompletionsStream.Read()
if errors.Is(err, io.EOF) {
break
}
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
// by convention) but if this fails because they _don't_ match we can just adjust the test.
if returnedDeployment == *completion.Model {
modelWasReturned = true
}
require.NoError(t, err)
if completion.PromptFilterResults != nil {
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}, completion.PromptFilterResults)
}
if len(completion.Choices) == 0 {
// you can get empty entries that contain just metadata (ie, prompt annotations)
continue
}
require.Equal(t, 1, len(completion.Choices))
choices = append(choices, completion.Choices[0])
}
if len(completion.Choices) == 0 {
// you can get empty entries that contain just metadata (ie, prompt annotations)
continue
require.True(t, modelWasReturned)
var message string
for _, choice := range choices {
if choice.Delta.Content == nil {
continue
}
message += *choice.Delta.Content
}
require.Equal(t, 1, len(completion.Choices))
choices = append(choices, completion.Choices[0])
require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead")
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
}
require.True(t, modelWasReturned)
var message string
for _, choice := range choices {
if choice.Delta.Content == nil {
continue
}
message += *choice.Delta.Content
}
require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead")
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
}
func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skipf("Not running this test in playback (for now)")
}
if os.Getenv("USE_TOKEN_CREDS") != "true" {
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
}
recordingTransporter := newRecordingTransporter(t)
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: recordingTransporter,
},
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
testFn(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4")
})
require.NoError(t, err)
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint.URL, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
testFn(t, chatClient, openAI.ChatCompletions.Model, openAI.ChatCompletions.Model)
})
require.NoError(t, err)
testGetChatCompletions(t, chatClient, azureOpenAI.ChatCompletions, true)
}
func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
client := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
_, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
@ -230,14 +312,14 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions), nil)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions.Model), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions), nil)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions.Model), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
@ -276,15 +358,15 @@ func TestClient_GetChatCompletions_Vision(t *testing.T) {
t.Logf(*resp.Choices[0].Message.Content)
}
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, openAI.Vision.Model)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
testFn(t, chatClient, azureOpenAI.Vision.Model)
})
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.Vision.Endpoint)
testFn(t, chatClient, openAI.Vision.Model)
})
}
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
@ -313,13 +395,13 @@ func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
require.NotEmpty(t, v)
}
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, "gpt-3.5-turbo-1106")
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
testFn(t, chatClient, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Model)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
testFn(t, chatClient, "gpt-4-1106-preview")
t.Run("OpenAI", func(t *testing.T) {
chatClient := newTestClient(t, openAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
testFn(t, chatClient, openAI.ChatCompletionsWithJSONResponseFormat.Model)
})
}

Просмотреть файл

@ -15,53 +15,45 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetCompletions(t, client, true)
}
func TestClient_GetCompletions(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
}, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
// we'll do a general check here - as models change the answers can also change, token usages are different,
// etc... So we'll just make sure data is coming back and is reasonable.
require.NotZero(t, *resp.Completions.Usage.PromptTokens)
require.NotZero(t, *resp.Completions.Usage.CompletionTokens)
require.NotZero(t, *resp.Completions.Usage.TotalTokens)
require.Equal(t, int32(0), *resp.Completions.Choices[0].Index)
require.Equal(t, azopenai.CompletionsFinishReasonStopped, *resp.Completions.Choices[0].FinishReason)
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
if epm.Endpoint.Azure {
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{
PromptIndex: to.Ptr[int32](0),
ContentFilterResults: safeContentFilterResultDetailsForPrompt,
}}, resp.PromptFilterResults)
}
func TestClient_GetCompletions_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
testGetCompletions(t, client, false)
}
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
deploymentID := openAI.Completions
if isAzure {
deploymentID = azureOpenAI.Completions
}
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &deploymentID,
}, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
// we'll do a general check here - as models change the answers can also change, token usages are different,
// etc... So we'll just make sure data is coming back and is reasonable.
require.NotZero(t, *resp.Completions.Usage.PromptTokens)
require.NotZero(t, *resp.Completions.Usage.CompletionTokens)
require.NotZero(t, *resp.Completions.Usage.TotalTokens)
require.Equal(t, int32(0), *resp.Completions.Choices[0].Index)
require.Equal(t, azopenai.CompletionsFinishReasonStopped, *resp.Completions.Choices[0].FinishReason)
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
if isAzure {
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{
PromptIndex: to.Ptr[int32](0),
ContentFilterResults: safeContentFilterResultDetailsForPrompt,
}}, resp.PromptFilterResults)
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Completions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Completions)
})
}

Просмотреть файл

@ -18,7 +18,7 @@ import (
)
func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
client := newTestClient(t, azureOpenAI.Embeddings.Endpoint)
_, err := client.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
DeploymentName: to.Ptr("thisdoesntexist"),
@ -29,28 +29,75 @@ func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
}
func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
func TestClient_GetEmbeddings(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
type args struct {
ctx context.Context
deploymentID string
body azopenai.EmbeddingsOptions
options *azopenai.GetEmbeddingsOptions
}
tests := []struct {
name string
client *azopenai.Client
args args
want azopenai.GetEmbeddingsResponse
wantErr bool
}{
{
name: "Embeddings",
client: client,
args: args{
ctx: context.TODO(),
deploymentID: epm.Model,
body: azopenai.EmbeddingsOptions{
Input: []string{"\"Your text string goes here\""},
DeploymentName: &epm.Model,
},
options: nil,
},
want: azopenai.GetEmbeddingsResponse{
azopenai.Embeddings{
Data: []azopenai.EmbeddingItem{},
Usage: &azopenai.EmbeddingsUsage{},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr)
return
}
require.NotEmpty(t, got.Embeddings.Data[0].Embedding)
})
}
}
client := newOpenAIClientForTest(t)
testGetEmbeddings(t, client, openAI.Embeddings)
}
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Embeddings)
})
func TestClient_GetEmbeddings(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Embeddings)
})
}
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
testFn := func(t *testing.T, tv testVars, dimension int32) {
client := newTestClient(t, tv.Endpoint)
testFn := func(t *testing.T, epm endpointWithModel, dimension int32) {
client := newTestClient(t, epm.Endpoint)
arg := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
DeploymentName: &tv.TextEmbedding3Small,
DeploymentName: &epm.Model,
}
if dimension > 0 {
@ -71,7 +118,7 @@ func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
arg2 := azopenai.EmbeddingsOptions{
Input: []string{"hello"},
DeploymentName: &tv.TextEmbedding3Small,
DeploymentName: &epm.Model,
}
if dimension > 0 {
@ -93,60 +140,11 @@ func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
for _, dim := range []int32{0, 1, 10, 100} {
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, azureOpenAI, dim)
testFn(t, azureOpenAI.TextEmbedding3Small, dim)
})
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
testFn(t, openAI, dim)
})
}
}
func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
type args struct {
ctx context.Context
deploymentID string
body azopenai.EmbeddingsOptions
options *azopenai.GetEmbeddingsOptions
}
tests := []struct {
name string
client *azopenai.Client
args args
want azopenai.GetEmbeddingsResponse
wantErr bool
}{
{
name: "Embeddings",
client: client,
args: args{
ctx: context.TODO(),
deploymentID: modelOrDeploymentID,
body: azopenai.EmbeddingsOptions{
Input: []string{"\"Your text string goes here\""},
DeploymentName: &modelOrDeploymentID,
},
options: nil,
},
want: azopenai.GetEmbeddingsResponse{
azopenai.Embeddings{
Data: []azopenai.EmbeddingItem{},
Usage: &azopenai.EmbeddingsUsage{},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr)
return
}
require.NotEmpty(t, got.Embeddings.Data[0].Embedding)
testFn(t, openAI.TextEmbedding3Small, dim)
})
}
}

Просмотреть файл

@ -6,6 +6,7 @@ package azopenai_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
@ -28,12 +29,65 @@ type ParamProperty struct {
func TestGetChatCompletions_usingFunctions(t *testing.T) {
// https://platform.openai.com/docs/guides/gpt/function-calling
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
Tools: []azopenai.ChatCompletionsToolDefinitionClassification{
&azopenai.ChatCompletionsFunctionToolDefinition{
Function: &azopenai.FunctionDefinition{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
ToolChoice: toolChoice,
Temperature: to.Ptr[float32](0.0),
}
resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.Choices[0].Message.ToolCalls[0].(*azopenai.ChatCompletionsFunctionToolCall).Function
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
)
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
testData := []struct {
Model string
@ -41,255 +95,207 @@ func TestGetChatCompletions_usingFunctions(t *testing.T) {
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: openAI.ChatCompletions, ToolChoice: nil},
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: nil},
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: useSpecificTool},
}
for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
testFn(t, chatClient, td.Model, td.ToolChoice)
}
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
t.Run("OpenAI", func(t *testing.T) {
testData := []struct {
Model string
EPM endpointWithModel
ToolChoice *azopenai.ChatCompletionsToolChoice
}{
// all of these variants use the tool provided - auto just also works since we did provide
// a tool reference and ask a question to use it.
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
{EPM: openAI.ChatCompletions, ToolChoice: nil},
{EPM: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
{EPM: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
}
for _, td := range testData {
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
for i, td := range testData {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
chatClient := newTestClient(t, td.EPM.Endpoint)
testFn(t, chatClient, td.EPM.Model, td.ToolChoice)
})
}
})
}
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletionsLegacyFunctions)
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletions)
})
testFn := func(t *testing.T, epm endpointWithModel) {
client := newTestClient(t, epm.Endpoint)
body := azopenai.ChatCompletionsOptions{
DeploymentName: &epm.Model,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
Value: to.Ptr("auto"),
},
Functions: []azopenai.FunctionDefinition{
{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
resp, err := client.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctionsOlderStyle(t, chatClient, azureOpenAI.ChatCompletionsLegacyFunctions)
testFn(t, azureOpenAI.ChatCompletionsLegacyFunctions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
t.Run("OpenAI.LegacyFunctions", func(t *testing.T) {
testFn(t, openAI.ChatCompletionsLegacyFunctions)
})
}
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &epm.Model,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
Tools: []azopenai.ChatCompletionsToolDefinitionClassification{
&azopenai.ChatCompletionsFunctionToolDefinition{
Function: &azopenai.FunctionDefinition{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
chatClient := newTestClient(t, epm.Endpoint)
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
require.NoError(t, err)
require.NotEmpty(t, resp)
defer func() {
err := resp.ChatCompletionsStream.Close()
require.NoError(t, err)
}()
// these results are way trickier than they should be, but we have to accumulate across
// multiple fields to get a full result.
funcCall := &azopenai.FunctionCall{
Arguments: to.Ptr(""),
Name: to.Ptr(""),
}
for {
streamResp, err := resp.ChatCompletionsStream.Read()
require.NoError(t, err)
if len(streamResp.Choices) == 0 {
// there are prompt filter results.
require.NotEmpty(t, streamResp.PromptFilterResults)
continue
}
if streamResp.Choices[0].FinishReason != nil {
break
}
var functionToolCall *azopenai.ChatCompletionsFunctionToolCall = streamResp.Choices[0].Delta.ToolCalls[0].(*azopenai.ChatCompletionsFunctionToolCall)
require.NotEmpty(t, functionToolCall.Function)
if functionToolCall.Function.Arguments != nil {
*funcCall.Arguments += *functionToolCall.Function.Arguments
}
if functionToolCall.Function.Name != nil {
*funcCall.Name += *functionToolCall.Function.Name
}
}
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
// https://platform.openai.com/docs/guides/gpt/function-calling
t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctionsStreaming(t, chatClient, openAI)
})
t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctionsStreaming(t, chatClient, azureOpenAI)
testFn(t, azureOpenAI.ChatCompletions)
})
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.ChatCompletions)
})
}
func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Client, deploymentName string) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
Value: to.Ptr("auto"),
},
Functions: []azopenai.FunctionDefinition{
{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
resp, err := client.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &deploymentName,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
Tools: []azopenai.ChatCompletionsToolDefinitionClassification{
&azopenai.ChatCompletionsFunctionToolDefinition{
Function: &azopenai.FunctionDefinition{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
ToolChoice: toolChoice,
Temperature: to.Ptr[float32](0.0),
}
resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)
funcCall := resp.Choices[0].Message.ToolCalls[0].(*azopenai.ChatCompletionsFunctionToolCall).Function
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}
func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Client, tv testVars) {
body := azopenai.ChatCompletionsOptions{
DeploymentName: &tv.ChatCompletions,
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
},
},
Tools: []azopenai.ChatCompletionsToolDefinitionClassification{
&azopenai.ChatCompletionsFunctionToolDefinition{
Function: &azopenai.FunctionDefinition{
Name: to.Ptr("get_current_weather"),
Description: to.Ptr("Get the current weather in a given location"),
Parameters: Params{
Required: []string{"location"},
Type: "object",
Properties: map[string]ParamProperty{
"location": {
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
Temperature: to.Ptr[float32](0.0),
}
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
require.NoError(t, err)
require.NotEmpty(t, resp)
defer func() {
err := resp.ChatCompletionsStream.Close()
require.NoError(t, err)
}()
// these results are way trickier than they should be, but we have to accumulate across
// multiple fields to get a full result.
funcCall := &azopenai.FunctionCall{
Arguments: to.Ptr(""),
Name: to.Ptr(""),
}
for {
streamResp, err := resp.ChatCompletionsStream.Read()
require.NoError(t, err)
if len(streamResp.Choices) == 0 {
// there are prompt filter results.
require.NotEmpty(t, streamResp.PromptFilterResults)
continue
}
if streamResp.Choices[0].FinishReason != nil {
break
}
var functionToolCall *azopenai.ChatCompletionsFunctionToolCall = streamResp.Choices[0].Delta.ToolCalls[0].(*azopenai.ChatCompletionsFunctionToolCall)
require.NotEmpty(t, functionToolCall.Function)
if functionToolCall.Function.Arguments != nil {
*funcCall.Arguments += *functionToolCall.Function.Arguments
}
if functionToolCall.Function.Name != nil {
*funcCall.Name += *functionToolCall.Function.Name
}
}
require.Equal(t, "get_current_weather", *funcCall.Name)
type location struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
var funcParams *location
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
require.NoError(t, err)
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
}

Просмотреть файл

@ -22,13 +22,13 @@ import (
func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) {
// Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered
// https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered
client := newAzureOpenAIClientForTest(t, azureOpenAI)
client := newTestClient(t, azureOpenAI.Completions.Endpoint)
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"How do I rob a bank with violence?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &azureOpenAI.Completions,
DeploymentName: &azureOpenAI.Completions.Model,
}, nil)
require.Empty(t, resp)

Просмотреть файл

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
@ -26,11 +27,6 @@ import (
"github.com/stretchr/testify/require"
)
var (
azureOpenAI testVars
openAI testVars
)
type endpoint struct {
URL string
APIKey string
@ -38,22 +34,19 @@ type endpoint struct {
}
type testVars struct {
Endpoint endpoint
Completions string
ChatCompletions string
ChatCompletionsLegacyFunctions string
Embeddings string
TextEmbedding3Small string
Cognitive azopenai.AzureSearchChatExtensionConfiguration
Whisper endpointWithModel
DallE endpointWithModel
Vision endpointWithModel
ChatCompletionsRAI endpointWithModel // at the moment this is Azure only
// "own your data" - bringing in Azure resources as part of a chat completions
// request.
ChatCompletionsOYD endpointWithModel
ChatCompletions endpointWithModel
ChatCompletionsLegacyFunctions endpointWithModel
ChatCompletionsOYD endpointWithModel // azure only
ChatCompletionsRAI endpointWithModel // azure only
ChatCompletionsWithJSONResponseFormat endpointWithModel
Cognitive azopenai.AzureSearchChatExtensionConfiguration
Completions endpointWithModel
DallE endpointWithModel
Embeddings endpointWithModel
Speech endpointWithModel
TextEmbedding3Small endpointWithModel
Vision endpointWithModel
Whisper endpointWithModel
}
type endpointWithModel struct {
@ -61,15 +54,129 @@ type endpointWithModel struct {
Model string
}
type testClientOption func(opt *azopenai.ClientOptions)
func ifAzure[T string | endpoint](azure bool, forAzure T, forOpenAI T) T {
if azure {
return forAzure
}
return forOpenAI
}
func withForgivingRetryOption() testClientOption {
return func(opt *azopenai.ClientOptions) {
opt.Retry = policy.RetryOptions{
MaxRetries: 10,
var azureOpenAI, openAI, servers = func() (testVars, testVars, []string) {
if recording.GetRecordMode() != recording.PlaybackMode {
if err := godotenv.Load(); err != nil {
log.Fatalf("Failed to load .env file: %s\n", err)
}
}
}
servers := struct {
USEast endpoint
USNorthCentral endpoint
USEast2 endpoint
SWECentral endpoint
OpenAI endpoint
}{
OpenAI: endpoint{
URL: getEndpoint("OPENAI_ENDPOINT"), // ex: https://api.openai.com/v1/
APIKey: recording.GetEnvVariable("OPENAI_API_KEY", fakeAPIKey),
Azure: false,
},
USEast: endpoint{
URL: getEndpoint("ENDPOINT_USEAST"),
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST_API_KEY", fakeAPIKey),
Azure: true,
},
USEast2: endpoint{
URL: getEndpoint("ENDPOINT_USEAST2"),
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST2_API_KEY", fakeAPIKey),
Azure: true,
},
USNorthCentral: endpoint{
URL: getEndpoint("ENDPOINT_USNORTHCENTRAL"),
APIKey: recording.GetEnvVariable("ENDPOINT_USNORTHCENTRAL_API_KEY", fakeAPIKey),
Azure: true,
},
SWECentral: endpoint{
URL: getEndpoint("ENDPOINT_SWECENTRAL"),
APIKey: recording.GetEnvVariable("ENDPOINT_SWECENTRAL_API_KEY", fakeAPIKey),
Azure: true,
},
}
// used when we setup the recording policy
endpoints := []string{
servers.OpenAI.URL,
servers.USEast.URL,
servers.USEast2.URL,
servers.USNorthCentral.URL,
servers.SWECentral.URL,
}
newTestVarsFn := func(azure bool) testVars {
return testVars{
ChatCompletions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
},
ChatCompletionsLegacyFunctions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
},
ChatCompletionsOYD: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
},
ChatCompletionsRAI: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
},
ChatCompletionsWithJSONResponseFormat: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-1106-preview", "gpt-3.5-turbo-1106"),
},
Completions: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"),
},
DallE: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "dall-e-3", "dall-e-3"),
},
Embeddings: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "text-embedding-ada-002", "text-embedding-ada-002"),
},
Speech: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "tts-1", "tts-1"),
},
TextEmbedding3Small: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
Model: ifAzure(azure, "text-embedding-3-small", "text-embedding-3-small"),
},
Vision: endpointWithModel{
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
Model: ifAzure(azure, "gpt-4-vision-preview", "gpt-4-vision-preview"),
},
Whisper: endpointWithModel{
Endpoint: ifAzure(azure, servers.USEast2, servers.OpenAI),
Model: ifAzure(azure, "whisper-deployment", "whisper-1"),
},
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_ENDPOINT", fakeCognitiveEndpoint)),
IndexName: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_INDEX", fakeCognitiveIndexName)),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_KEY", fakeAPIKey)),
},
},
},
}
}
return newTestVarsFn(true), newTestVarsFn(false), endpoints
}()
type testClientOption func(opt *azopenai.ClientOptions)
// newTestClient creates a client enabled for HTTP recording, if needed.
// See [newRecordingTransporter] for sanitization code.
@ -101,197 +208,11 @@ func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azop
}
}
// getEndpoint retrieves details for an endpoint and a model.
// - res - the resource type for a particular endpoint. Ex: "DALLE".
//
// For example, if azure is true we'll load these environment values based on res:
// - AOAI_DALLE_ENDPOINT
// - AOAI_DALLE_API_KEY
//
// if azure is false we'll load these environment values based on res:
// - OPENAI_ENDPOINT
// - OPENAI_API_KEY
func getEndpoint(res string, isAzure bool) endpointWithModel {
var ep endpointWithModel
if isAzure {
// during development resources are often shifted between different
// internal Azure OpenAI resources.
ep = endpointWithModel{
Endpoint: endpoint{
URL: getRequired("AOAI_" + res + "_ENDPOINT"),
APIKey: getRequired("AOAI_" + res + "_API_KEY"),
Azure: true,
},
}
} else {
ep = endpointWithModel{
Endpoint: endpoint{
URL: getRequired("OPENAI_ENDPOINT"),
APIKey: getRequired("OPENAI_API_KEY"),
Azure: false,
},
}
}
if !strings.HasSuffix(ep.Endpoint.URL, "/") {
// (this just makes recording replacement easier)
ep.Endpoint.URL += "/"
}
return ep
}
func model(azure bool, azureModel, openAIModel string) string {
if azure {
return azureModel
}
return openAIModel
}
func updateModels(azure bool, tv *testVars) {
// the models we use are basically their own API surface so it's good to know which
// specific models our tests were written against.
tv.Completions = model(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct")
tv.ChatCompletions = model(azure, "gpt-35-turbo-0613", "gpt-4-0613")
tv.ChatCompletionsLegacyFunctions = model(azure, "gpt-4-0613", "gpt-4-0613")
tv.Embeddings = model(azure, "text-embedding-ada-002", "text-embedding-ada-002")
tv.TextEmbedding3Small = model(azure, "text-embedding-3-small", "text-embedding-3-small")
tv.DallE.Model = model(azure, "dall-e-3", "dall-e-3")
tv.Whisper.Model = model(azure, "whisper-deployment", "whisper-1")
tv.Vision.Model = model(azure, "gpt-4-vision-preview", "gpt-4-vision-preview")
// these are Azure-only features
tv.ChatCompletionsOYD.Model = model(azure, "gpt-4", "")
tv.ChatCompletionsRAI.Model = model(azure, "gpt-4", "")
}
func newTestVars(prefix string) testVars {
azure := prefix == "AOAI"
tv := testVars{
Endpoint: endpoint{
URL: getRequired(prefix + "_ENDPOINT"),
APIKey: getRequired(prefix + "_API_KEY"),
Azure: azure,
},
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(getRequired("COGNITIVE_SEARCH_API_ENDPOINT")),
IndexName: to.Ptr(getRequired("COGNITIVE_SEARCH_API_INDEX")),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(getRequired("COGNITIVE_SEARCH_API_KEY")),
},
},
},
DallE: getEndpoint("DALLE", azure),
Whisper: getEndpoint("WHISPER", azure),
Vision: getEndpoint("VISION", azure),
}
if azure {
tv.ChatCompletionsRAI = getEndpoint("CHAT_COMPLETIONS_RAI", azure)
tv.ChatCompletionsOYD = getEndpoint("OYD", azure)
}
updateModels(azure, &tv)
if tv.Endpoint.URL != "" && !strings.HasSuffix(tv.Endpoint.URL, "/") {
// (this just makes recording replacement easier)
tv.Endpoint.URL += "/"
}
return tv
}
const fakeEndpoint = "https://fake-recorded-host.microsoft.com/"
const fakeAPIKey = "redacted"
const fakeCognitiveEndpoint = "https://fake-cognitive-endpoint.microsoft.com"
const fakeCognitiveIndexName = "index"
func initEnvVars() {
if recording.GetRecordMode() == recording.PlaybackMode {
// Setup our variables so our requests are consistent with what we recorded.
// Endpoints are sanitized using the recording policy
azureOpenAI.Endpoint = endpoint{
URL: fakeEndpoint,
APIKey: fakeAPIKey,
Azure: true,
}
azureOpenAI.Whisper = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.ChatCompletionsRAI = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.ChatCompletionsOYD = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.DallE = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
azureOpenAI.Vision = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
}
openAI.Endpoint = endpoint{
APIKey: fakeAPIKey,
URL: fakeEndpoint,
}
openAI.Whisper = endpointWithModel{
Endpoint: endpoint{
APIKey: fakeAPIKey,
URL: fakeEndpoint,
},
}
openAI.DallE = endpointWithModel{
Endpoint: openAI.Endpoint,
}
updateModels(true, &azureOpenAI)
updateModels(false, &openAI)
openAI.Vision = azureOpenAI.Vision
azureOpenAI.Completions = "gpt-35-turbo-instruct"
openAI.Completions = "gpt-3.5-turbo-instruct"
azureOpenAI.ChatCompletions = "gpt-35-turbo-0613"
azureOpenAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
openAI.ChatCompletions = "gpt-4-0613"
openAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
openAI.Embeddings = "text-embedding-ada-002"
azureOpenAI.Embeddings = "text-embedding-ada-002"
azureOpenAI.Cognitive = azopenai.AzureSearchChatExtensionConfiguration{
Parameters: &azopenai.AzureSearchChatExtensionParameters{
Endpoint: to.Ptr(fakeCognitiveEndpoint),
IndexName: to.Ptr(fakeCognitiveIndexName),
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
Key: to.Ptr(fakeAPIKey),
},
},
}
} else {
if err := godotenv.Load(); err != nil {
fmt.Printf("Failed to load .env file: %s\n", err)
}
azureOpenAI = newTestVars("AOAI")
openAI = newTestVars("OPENAI")
}
}
type MultipartRecordingPolicy struct {
}
@ -315,17 +236,7 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", "", nil)
require.NoError(t, err)
endpoints := []string{
azureOpenAI.Endpoint.URL,
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
azureOpenAI.Whisper.Endpoint.URL,
azureOpenAI.DallE.Endpoint.URL,
azureOpenAI.Vision.Endpoint.URL,
azureOpenAI.ChatCompletionsOYD.Endpoint.URL,
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
}
for _, ep := range endpoints {
for _, ep := range servers {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(ep), nil)
require.NoError(t, err)
}
@ -333,8 +244,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil)
require.NoError(t, err)
if openAI.Endpoint.URL != "" {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.Endpoint.URL), nil)
// there's only one OpenAI endpoint
if openAI.ChatCompletions.Endpoint.URL != "" {
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.ChatCompletions.Endpoint.URL), nil)
require.NoError(t, err)
}
@ -401,20 +313,12 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {
return co
}
func newAzureOpenAIClientForTest(t *testing.T, tv testVars, options ...testClientOption) *azopenai.Client {
return newTestClient(t, tv.Endpoint, options...)
}
func newOpenAIClientForTest(t *testing.T, options ...testClientOption) *azopenai.Client {
return newTestClient(t, openAI.Endpoint, options...)
}
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
// a failure.
func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
cred := azcore.NewKeyCredential("bogus-api-key")
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint.URL, cred, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
@ -425,7 +329,7 @@ func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
cred := azcore.NewKeyCredential("bogus-api-key")
client, err := azopenai.NewClientForOpenAI(openAI.Endpoint.URL, cred, newClientOptionsForTest(t))
client, err := azopenai.NewClientForOpenAI(openAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}
@ -440,11 +344,12 @@ func assertResponseIsError(t *testing.T, err error) {
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
}
func getRequired(name string) string {
v := os.Getenv(name)
func getEndpoint(ev string) string {
v := recording.GetEnvVariable(ev, fakeEndpoint)
if v == "" {
panic(fmt.Sprintf("Env variable %s is missing", name))
if !strings.HasSuffix(v, "/") {
// (this just makes recording replacement easier)
v += "/"
}
return v

Просмотреть файл

@ -23,7 +23,7 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) {
t.Skip()
}
chatClient := newOpenAIClientForTest(t)
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
_, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{

Просмотреть файл

@ -29,11 +29,7 @@ func TestImageGeneration_AzureOpenAI(t *testing.T) {
}
func TestImageGeneration_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
client := newTestClient(t, openAI.DallE.Endpoint)
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatURL)
}
@ -60,7 +56,7 @@ func TestImageGeneration_OpenAI_Base64(t *testing.T) {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
client := newTestClient(t, openAI.DallE.Endpoint)
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatBase64)
}

Просмотреть файл

@ -76,74 +76,73 @@ func TestNewClientWithKeyCredential(t *testing.T) {
}
}
func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
client := newTestClient(t, azureOpenAI.Endpoint)
testGetCompletionsStream(t, client, azureOpenAI)
}
func TestGetCompletionsStream_OpenAI(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}
client := newOpenAIClientForTest(t)
testGetCompletionsStream(t, client, openAI)
}
func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
body := azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &tv.Completions,
}
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
if err != nil {
t.Errorf("Client.GetCompletionsStream() error = %v", err)
return
}
reader := response.CompletionsStream
defer reader.Close()
var sb strings.Builder
var eventCount int
for {
completion, err := reader.Read()
if err == io.EOF {
break
func TestGetCompletionsStream(t *testing.T) {
testFn := func(t *testing.T, epm endpointWithModel) {
body := azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
DeploymentName: &epm.Model,
}
if completion.PromptFilterResults != nil {
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}, completion.PromptFilterResults)
}
client := newTestClient(t, epm.Endpoint)
eventCount++
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
skipNowIfThrottled(t, err)
require.NoError(t, err)
if err != nil {
t.Errorf("reader.Read() error = %v", err)
t.Errorf("Client.GetCompletionsStream() error = %v", err)
return
}
if len(completion.Choices) > 0 {
sb.WriteString(*completion.Choices[0].Text)
reader := response.CompletionsStream
defer reader.Close()
var sb strings.Builder
var eventCount int
for {
completion, err := reader.Read()
if err == io.EOF {
break
}
if completion.PromptFilterResults != nil {
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: safeContentFilterResultDetailsForPrompt},
}, completion.PromptFilterResults)
}
eventCount++
if err != nil {
t.Errorf("reader.Read() error = %v", err)
return
}
if len(completion.Choices) > 0 {
sb.WriteString(*completion.Choices[0].Text)
}
}
got := sb.String()
require.NotEmpty(t, got)
// there's no strict requirement of how the response is streamed so just
// choosing something that's reasonable but will be lower than typical usage
// (which is usually somewhere around the 80s).
require.GreaterOrEqual(t, eventCount, 50)
}
got := sb.String()
require.NotEmpty(t, got)
t.Run("AzureOpenAI", func(t *testing.T) {
testFn(t, azureOpenAI.Completions)
})
// there's no strict requirement of how the response is streamed so just
// choosing something that's reasonable but will be lower than typical usage
// (which is usually somewhere around the 80s).
require.GreaterOrEqual(t, eventCount, 50)
t.Run("OpenAI", func(t *testing.T) {
testFn(t, openAI.Completions)
})
}
func TestClient_GetCompletions_Error(t *testing.T) {
@ -151,7 +150,9 @@ func TestClient_GetCompletions_Error(t *testing.T) {
t.Skip()
}
doTest := func(t *testing.T, client *azopenai.Client, model string) {
doTest := func(t *testing.T, model string) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
@ -163,12 +164,10 @@ func TestClient_GetCompletions_Error(t *testing.T) {
}
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
doTest(t, client, azureOpenAI.Completions)
doTest(t, azureOpenAI.Completions.Model)
})
t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client, openAI.Completions)
doTest(t, openAI.Completions.Model)
})
}

Просмотреть файл

@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
go 1.18
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0
github.com/joho/godotenv v1.3.0
@ -19,9 +19,9 @@ require (
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/net v0.22.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

Просмотреть файл

@ -1,5 +1,5 @@
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 h1:c4k2FIYIh4xtwqrQwV0Ct1v5+ehlNXj5NI/MWVsiTkQ=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2/go.mod h1:5FDJtLEO/GxwNgUxbwrY3LP0pEoThTQJtk2oysdXHxM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0 h1:sUFnFjzDUie80h24I7mrKtwCKgLY9L8h5Tp2x9+TWqk=
@ -25,13 +25,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

Просмотреть файл

@ -15,7 +15,6 @@ import (
const RecordingDirectory = "sdk/ai/azopenai/testdata"
func TestMain(m *testing.M) {
initEnvVars()
code := run(m)
os.Exit(code)
}