[azopenai] A rather large, but needed, cleanup of the tests. (#22707)
It was getting difficult to tell what was and wasn't covered and also correlate tests to models, which can sometimes come with differing behavior. So in this PR I: - Moved things that don't need to be env vars (ie: non-secrets) to just be constants in the code - this includes model names and generic server identifiers for regions that we use to coordinate new feature development. - Consolidated all of the setting of endpoint and models into one spot to make it simpler to double-check. - Consolidated tests that tested the same thing into sub-tests with OpenAI or AzureOpenAI names. - If a function was only called by one test moved it into the test as an anonymous func Also, I added in a test for logit_probs and logprobs/toplogprobs.
This commit is contained in:
Родитель
a51db25793
Коммит
a0f9b026ec
|
@ -2,5 +2,5 @@
|
|||
"AssetsRepo": "Azure/azure-sdk-assets",
|
||||
"AssetsRepoPrefixPath": "go",
|
||||
"TagPrefix": "go/ai/azopenai",
|
||||
"Tag": "go/ai/azopenai_a33cdad878"
|
||||
"Tag": "go/ai/azopenai_a56e3e9e32"
|
||||
}
|
||||
|
|
|
@ -19,64 +19,14 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClient_GetAudioTranscription_AzureOpenAI(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
|
||||
runTranscriptionTests(t, client, azureOpenAI.Whisper.Model, true)
|
||||
}
|
||||
func TestClient_GetAudioTranscription(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
model := epm.Model
|
||||
|
||||
func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
|
||||
client := newOpenAIClientForTest(t)
|
||||
runTranscriptionTests(t, client, openAI.Whisper.Model, false)
|
||||
}
|
||||
|
||||
func TestClient_GetAudioTranslation_AzureOpenAI(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
|
||||
runTranslationTests(t, client, azureOpenAI.Whisper.Model, true)
|
||||
}
|
||||
|
||||
func TestClient_GetAudioTranslation_OpenAI(t *testing.T) {
|
||||
client := newOpenAIClientForTest(t)
|
||||
runTranslationTests(t, client, openAI.Whisper.Model, false)
|
||||
}
|
||||
|
||||
func TestClient_GetAudioSpeech(t *testing.T) {
|
||||
client := newOpenAIClientForTest(t)
|
||||
|
||||
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
|
||||
Input: to.Ptr("i am a computer"),
|
||||
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
|
||||
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
|
||||
DeploymentName: to.Ptr("tts-1"),
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
audioBytes, err := io.ReadAll(audioResp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, audioBytes)
|
||||
require.Equal(t, "fLaC", string(audioBytes[0:4]))
|
||||
|
||||
// now send _it_ back through the transcription API and see if we can get something useful.
|
||||
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
|
||||
Filename: to.Ptr("test.flac"),
|
||||
File: audioBytes,
|
||||
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
|
||||
DeploymentName: &openAI.Whisper.Model,
|
||||
Temperature: to.Ptr[float32](0.0),
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotZero(t, *transcriptionResp.Duration)
|
||||
|
||||
// it occasionally comes back with different punctuation or makes a complete sentence but
|
||||
// the major words always come through.
|
||||
require.Contains(t, *transcriptionResp.Text, "computer")
|
||||
}
|
||||
|
||||
func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
|
||||
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
||||
// The bulk of the logic will test against OpenAI anyways.
|
||||
if isAzure {
|
||||
if epm.Endpoint.Azure {
|
||||
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
||||
args := newTranscriptionOptions(azopenai.AudioTranscriptionFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
||||
transcriptResp, err := client.GetAudioTranscription(context.Background(), args, nil)
|
||||
|
@ -167,10 +117,23 @@ func runTranscriptionTests(t *testing.T, client *azopenai.Client, model string,
|
|||
}
|
||||
}
|
||||
|
||||
func runTranslationTests(t *testing.T, client *azopenai.Client, model string, isAzure bool) {
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.Whisper)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.Whisper)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetAudioTranslation(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
model := epm.Model
|
||||
|
||||
// We're experiencing load issues on some of our shared test resources so we'll just spot check.
|
||||
// The bulk of the logic will test against OpenAI anyways.
|
||||
if isAzure {
|
||||
if epm.Endpoint.Azure {
|
||||
t.Run(fmt.Sprintf("%s (%s)", azopenai.AudioTranscriptionFormatText, "m4a"), func(t *testing.T) {
|
||||
args := newTranslationOptions(azopenai.AudioTranslationFormatText, model, "testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.m4a")
|
||||
transcriptResp, err := client.GetAudioTranslation(context.Background(), args, nil)
|
||||
|
@ -262,6 +225,49 @@ func runTranslationTests(t *testing.T, client *azopenai.Client, model string, is
|
|||
}
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.Whisper)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.Whisper)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetAudioSpeech(t *testing.T) {
|
||||
client := newTestClient(t, openAI.Speech.Endpoint)
|
||||
|
||||
audioResp, err := client.GenerateSpeechFromText(context.Background(), azopenai.SpeechGenerationOptions{
|
||||
Input: to.Ptr("i am a computer"),
|
||||
Voice: to.Ptr(azopenai.SpeechVoiceAlloy),
|
||||
ResponseFormat: to.Ptr(azopenai.SpeechGenerationResponseFormatFlac),
|
||||
DeploymentName: to.Ptr("tts-1"),
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
audioBytes, err := io.ReadAll(audioResp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, audioBytes)
|
||||
require.Equal(t, "fLaC", string(audioBytes[0:4]))
|
||||
|
||||
// now send _it_ back through the transcription API and see if we can get something useful.
|
||||
transcriptionResp, err := client.GetAudioTranscription(context.Background(), azopenai.AudioTranscriptionOptions{
|
||||
Filename: to.Ptr("test.flac"),
|
||||
File: audioBytes,
|
||||
ResponseFormat: to.Ptr(azopenai.AudioTranscriptionFormatVerboseJSON),
|
||||
DeploymentName: &openAI.Whisper.Model,
|
||||
Temperature: to.Ptr[float32](0.0),
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotZero(t, *transcriptionResp.Duration)
|
||||
|
||||
// it occasionally comes back with different punctuation or makes a complete sentence but
|
||||
// the major words always come through.
|
||||
require.Contains(t, *transcriptionResp.Text, "computer")
|
||||
}
|
||||
|
||||
func newTranscriptionOptions(format azopenai.AudioTranscriptionFormat, model string, path string) azopenai.AudioTranscriptionOptions {
|
||||
audioBytes, err := os.ReadFile(path)
|
||||
|
||||
|
|
|
@ -42,34 +42,7 @@ var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
|
|||
var expectedRole = azopenai.ChatRoleAssistant
|
||||
|
||||
func TestClient_GetChatCompletions(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
testGetChatCompletions(t, client, azureOpenAI.ChatCompletionsRAI.Model, true)
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletionsStream(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
||||
testGetChatCompletionsStream(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model)
|
||||
}
|
||||
|
||||
func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testGetChatCompletions(t, chatClient, openAI.ChatCompletions, false)
|
||||
}
|
||||
|
||||
func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testGetChatCompletionsStream(t, chatClient, openAI.ChatCompletions)
|
||||
}
|
||||
|
||||
func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment string, checkRAI bool) {
|
||||
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedModel string, checkRAI bool) {
|
||||
expected := azopenai.ChatCompletions{
|
||||
Choices: []azopenai.ChatChoice{
|
||||
{
|
||||
|
@ -88,9 +61,7 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
|
|||
PromptTokens: to.Ptr(int32(42)),
|
||||
TotalTokens: to.Ptr(int32(71)),
|
||||
},
|
||||
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
|
||||
// by convention) but if this fails because they _don't_ match we can just adjust the test.
|
||||
Model: &deployment,
|
||||
Model: &returnedModel,
|
||||
}
|
||||
|
||||
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
||||
|
@ -112,10 +83,136 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, deployment st
|
|||
expected.ID = resp.ID
|
||||
expected.Created = resp.Created
|
||||
|
||||
t.Logf("isAzure: %t, deployment: %s, returnedModel: %s", checkRAI, deployment, *resp.ChatCompletions.Model)
|
||||
require.Equal(t, expected, resp.ChatCompletions)
|
||||
}
|
||||
|
||||
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deployment string) {
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
||||
testFn(t, client, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4", true)
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI.DefaultAzureCredential", func(t *testing.T) {
|
||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||
t.Skipf("Not running this test in playback (for now)")
|
||||
}
|
||||
|
||||
if os.Getenv("USE_TOKEN_CREDS") != "true" {
|
||||
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
|
||||
}
|
||||
|
||||
recordingTransporter := newRecordingTransporter(t)
|
||||
|
||||
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
|
||||
ClientOptions: policy.ClientOptions{
|
||||
Transport: recordingTransporter,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chatClient, err := azopenai.NewClient(azureOpenAI.ChatCompletions.Endpoint.URL, dac, &azopenai.ClientOptions{
|
||||
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testFn(t, chatClient, azureOpenAI.ChatCompletions.Model, "gpt-4", true)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||
testFn(t, chatClient, openAI.ChatCompletions.Model, "gpt-4-0613", false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletions_LogProbs(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
opts := azopenai.ChatCompletionsOptions{
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
&azopenai.ChatRequestUserMessage{
|
||||
Content: azopenai.NewChatRequestUserMessageContent("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
|
||||
},
|
||||
},
|
||||
MaxTokens: to.Ptr(int32(1024)),
|
||||
Temperature: to.Ptr(float32(0.0)),
|
||||
DeploymentName: &epm.Model,
|
||||
LogProbs: to.Ptr(true),
|
||||
TopLogProbs: to.Ptr(int32(5)),
|
||||
}
|
||||
|
||||
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
require.NotEmpty(t, choice.LogProbs)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.ChatCompletions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.ChatCompletions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletions_LogitBias(t *testing.T) {
|
||||
// you can use LogitBias to constrain the answer to NOT contain
|
||||
// certain tokens. More or less following the technique in this OpenAI article:
|
||||
// https://help.openai.com/en/articles/5247780-using-logit-bias-to-alter-token-probability-with-the-openai-api
|
||||
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
opts := azopenai.ChatCompletionsOptions{
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
&azopenai.ChatRequestUserMessage{
|
||||
Content: azopenai.NewChatRequestUserMessageContent("Briefly, what are some common roles for people at a circus, names only, one per line?"),
|
||||
},
|
||||
},
|
||||
MaxTokens: to.Ptr(int32(200)),
|
||||
Temperature: to.Ptr(float32(0.0)),
|
||||
DeploymentName: &epm.Model,
|
||||
LogitBias: map[string]*int32{
|
||||
// you can calculate these tokens using OpenAI's online tool:
|
||||
// https://platform.openai.com/tokenizer?view=bpe
|
||||
// These token IDs are all variations of "Clown", which I want to exclude from the response.
|
||||
"25": to.Ptr(int32(-100)),
|
||||
"220": to.Ptr(int32(-100)),
|
||||
"1206": to.Ptr(int32(-100)),
|
||||
"2493": to.Ptr(int32(-100)),
|
||||
"5176": to.Ptr(int32(-100)),
|
||||
"43456": to.Ptr(int32(-100)),
|
||||
"99423": to.Ptr(int32(-100)),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.GetChatCompletions(context.Background(), opts, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
if choice.Message == nil || choice.Message.Content == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
require.NotContains(t, *choice.Message.Content, "clown")
|
||||
require.NotContains(t, *choice.Message.Content, "Clown")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.ChatCompletions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.ChatCompletions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletionsStream(t *testing.T) {
|
||||
testFn := func(t *testing.T, client *azopenai.Client, deployment string, returnedDeployment string) {
|
||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(deployment), nil)
|
||||
|
||||
if respErr := (*azcore.ResponseError)(nil); errors.As(err, &respErr) && respErr.StatusCode == http.StatusTooManyRequests {
|
||||
|
@ -141,7 +238,7 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
|
|||
|
||||
// NOTE: this is actually the name of the _model_, not the deployment. They usually match (just
|
||||
// by convention) but if this fails because they _don't_ match we can just adjust the test.
|
||||
if deployment == *completion.Model {
|
||||
if returnedDeployment == *completion.Model {
|
||||
modelWasReturned = true
|
||||
}
|
||||
|
||||
|
@ -178,34 +275,19 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, deploym
|
|||
require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
|
||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||
t.Skipf("Not running this test in playback (for now)")
|
||||
}
|
||||
|
||||
if os.Getenv("USE_TOKEN_CREDS") != "true" {
|
||||
t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests")
|
||||
}
|
||||
|
||||
recordingTransporter := newRecordingTransporter(t)
|
||||
|
||||
dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
|
||||
ClientOptions: policy.ClientOptions{
|
||||
Transport: recordingTransporter,
|
||||
},
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsRAI.Endpoint)
|
||||
testFn(t, chatClient, azureOpenAI.ChatCompletionsRAI.Model, "gpt-4")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint.URL, dac, &azopenai.ClientOptions{
|
||||
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||
testFn(t, chatClient, openAI.ChatCompletions.Model, openAI.ChatCompletions.Model)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testGetChatCompletions(t, chatClient, azureOpenAI.ChatCompletions, true)
|
||||
}
|
||||
|
||||
func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
client := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
|
||||
|
||||
_, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
|
@ -230,14 +312,14 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
|
|||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
client := newBogusAzureOpenAIClient(t)
|
||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions), nil)
|
||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI.ChatCompletions.Model), nil)
|
||||
require.Empty(t, streamResp)
|
||||
assertResponseIsError(t, err)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
client := newBogusOpenAIClient(t)
|
||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions), nil)
|
||||
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI.ChatCompletions.Model), nil)
|
||||
require.Empty(t, streamResp)
|
||||
assertResponseIsError(t, err)
|
||||
})
|
||||
|
@ -276,15 +358,15 @@ func TestClient_GetChatCompletions_Vision(t *testing.T) {
|
|||
t.Logf(*resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testFn(t, chatClient, openAI.Vision.Model)
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
|
||||
testFn(t, chatClient, azureOpenAI.Vision.Model)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, openAI.Vision.Endpoint)
|
||||
testFn(t, chatClient, openAI.Vision.Model)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
|
||||
|
@ -313,13 +395,13 @@ func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
|
|||
require.NotEmpty(t, v)
|
||||
}
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testFn(t, chatClient, "gpt-3.5-turbo-1106")
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
|
||||
testFn(t, chatClient, azureOpenAI.ChatCompletionsWithJSONResponseFormat.Model)
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
|
||||
testFn(t, chatClient, "gpt-4-1106-preview")
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, openAI.ChatCompletionsWithJSONResponseFormat.Endpoint)
|
||||
testFn(t, chatClient, openAI.ChatCompletionsWithJSONResponseFormat.Model)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,32 +15,15 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
testGetCompletions(t, client, true)
|
||||
}
|
||||
|
||||
func TestClient_GetCompletions_OpenAI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
client := newOpenAIClientForTest(t)
|
||||
testGetCompletions(t, client, false)
|
||||
}
|
||||
|
||||
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
||||
deploymentID := openAI.Completions
|
||||
|
||||
if isAzure {
|
||||
deploymentID = azureOpenAI.Completions
|
||||
}
|
||||
func TestClient_GetCompletions(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
||||
Prompt: []string{"What is Azure OpenAI?"},
|
||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||
Temperature: to.Ptr(float32(0.0)),
|
||||
DeploymentName: &deploymentID,
|
||||
DeploymentName: &epm.Model,
|
||||
}, nil)
|
||||
skipNowIfThrottled(t, err)
|
||||
require.NoError(t, err)
|
||||
|
@ -55,7 +38,7 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
|||
|
||||
require.NotEmpty(t, *resp.Completions.Choices[0].Text)
|
||||
|
||||
if isAzure {
|
||||
if epm.Endpoint.Azure {
|
||||
require.Equal(t, safeContentFilter, resp.Completions.Choices[0].ContentFilterResults)
|
||||
require.Equal(t, []azopenai.ContentFilterResultsForPrompt{
|
||||
{
|
||||
|
@ -65,3 +48,12 @@ func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.Completions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.Completions)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
client := newTestClient(t, azureOpenAI.Embeddings.Endpoint)
|
||||
|
||||
_, err := client.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
|
||||
DeploymentName: to.Ptr("thisdoesntexist"),
|
||||
|
@ -29,80 +29,10 @@ func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
|
|||
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
|
||||
}
|
||||
|
||||
func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
client := newOpenAIClientForTest(t)
|
||||
testGetEmbeddings(t, client, openAI.Embeddings)
|
||||
}
|
||||
|
||||
func TestClient_GetEmbeddings(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
|
||||
}
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
|
||||
testFn := func(t *testing.T, tv testVars, dimension int32) {
|
||||
client := newTestClient(t, tv.Endpoint)
|
||||
|
||||
arg := azopenai.EmbeddingsOptions{
|
||||
Input: []string{"hello"},
|
||||
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
|
||||
DeploymentName: &tv.TextEmbedding3Small,
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
arg.Dimensions = &dimension
|
||||
}
|
||||
|
||||
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, base64Resp.Data)
|
||||
require.Empty(t, base64Resp.Data[0].Embedding)
|
||||
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
|
||||
|
||||
// sanity checks - we deserialized everything and didn't create anything impossible.
|
||||
for _, v := range embeddings {
|
||||
require.True(t, v <= 1.0 && v >= -1.0)
|
||||
}
|
||||
|
||||
arg2 := azopenai.EmbeddingsOptions{
|
||||
Input: []string{"hello"},
|
||||
DeploymentName: &tv.TextEmbedding3Small,
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
arg2.Dimensions = &dimension
|
||||
}
|
||||
|
||||
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, floatResp.Data)
|
||||
require.NotEmpty(t, floatResp.Data[0].Embedding)
|
||||
|
||||
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
|
||||
|
||||
// This works "most of the time" but it's non-deterministic since two separate calls don't always
|
||||
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
|
||||
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
|
||||
}
|
||||
|
||||
for _, dim := range []int32{0, 1, 10, 100} {
|
||||
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||
testFn(t, azureOpenAI, dim)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||
testFn(t, openAI, dim)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
deploymentID string
|
||||
|
@ -122,10 +52,10 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
|
|||
client: client,
|
||||
args: args{
|
||||
ctx: context.TODO(),
|
||||
deploymentID: modelOrDeploymentID,
|
||||
deploymentID: epm.Model,
|
||||
body: azopenai.EmbeddingsOptions{
|
||||
Input: []string{"\"Your text string goes here\""},
|
||||
DeploymentName: &modelOrDeploymentID,
|
||||
DeploymentName: &epm.Model,
|
||||
},
|
||||
options: nil,
|
||||
},
|
||||
|
@ -151,6 +81,74 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
|
|||
}
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.Embeddings)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.Embeddings)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetEmbeddings_embeddingsFormat(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel, dimension int32) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
arg := azopenai.EmbeddingsOptions{
|
||||
Input: []string{"hello"},
|
||||
EncodingFormat: to.Ptr(azopenai.EmbeddingEncodingFormatBase64),
|
||||
DeploymentName: &epm.Model,
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
arg.Dimensions = &dimension
|
||||
}
|
||||
|
||||
base64Resp, err := client.GetEmbeddings(context.Background(), arg, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, base64Resp.Data)
|
||||
require.Empty(t, base64Resp.Data[0].Embedding)
|
||||
embeddings := deserializeBase64Embeddings(t, base64Resp.Data[0])
|
||||
|
||||
// sanity checks - we deserialized everything and didn't create anything impossible.
|
||||
for _, v := range embeddings {
|
||||
require.True(t, v <= 1.0 && v >= -1.0)
|
||||
}
|
||||
|
||||
arg2 := azopenai.EmbeddingsOptions{
|
||||
Input: []string{"hello"},
|
||||
DeploymentName: &epm.Model,
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
arg2.Dimensions = &dimension
|
||||
}
|
||||
|
||||
floatResp, err := client.GetEmbeddings(context.Background(), arg2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, floatResp.Data)
|
||||
require.NotEmpty(t, floatResp.Data[0].Embedding)
|
||||
|
||||
require.Equal(t, len(floatResp.Data[0].Embedding), len(embeddings))
|
||||
|
||||
// This works "most of the time" but it's non-deterministic since two separate calls don't always
|
||||
// produce the exact same data. Leaving it here in case you want to do some rough checks later.
|
||||
// require.Equal(t, floatResp.Data[0].Embedding[0:dimension], base64Resp.Data[0].Embedding[0:dimension])
|
||||
}
|
||||
|
||||
for _, dim := range []int32{0, 1, 10, 100} {
|
||||
t.Run(fmt.Sprintf("AzureOpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.TextEmbedding3Small, dim)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("OpenAI(dimensions=%d)", dim), func(t *testing.T) {
|
||||
testFn(t, openAI.TextEmbedding3Small, dim)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func deserializeBase64Embeddings(t *testing.T, ei azopenai.EmbeddingItem) []float32 {
|
||||
destBytes, err := base64.StdEncoding.DecodeString(ei.EmbeddingBase64)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -6,6 +6,7 @@ package azopenai_test
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
|
||||
|
@ -28,130 +29,7 @@ type ParamProperty struct {
|
|||
func TestGetChatCompletions_usingFunctions(t *testing.T) {
|
||||
// https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
|
||||
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
|
||||
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
|
||||
)
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
|
||||
testData := []struct {
|
||||
Model string
|
||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||
}{
|
||||
// all of these variants use the tool provided - auto just also works since we did provide
|
||||
// a tool reference and ask a question to use it.
|
||||
{Model: openAI.ChatCompletions, ToolChoice: nil},
|
||||
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
|
||||
}
|
||||
|
||||
for _, td := range testData {
|
||||
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
||||
|
||||
testData := []struct {
|
||||
Model string
|
||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||
}{
|
||||
// all of these variants use the tool provided - auto just also works since we did provide
|
||||
// a tool reference and ask a question to use it.
|
||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
|
||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
|
||||
}
|
||||
|
||||
for _, td := range testData {
|
||||
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletionsLegacyFunctions)
|
||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, openAI.ChatCompletions)
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
||||
testChatCompletionsFunctionsOlderStyle(t, chatClient, azureOpenAI.ChatCompletionsLegacyFunctions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
|
||||
// https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
testChatCompletionsFunctionsStreaming(t, chatClient, openAI)
|
||||
})
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
|
||||
testChatCompletionsFunctionsStreaming(t, chatClient, azureOpenAI)
|
||||
})
|
||||
}
|
||||
|
||||
func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Client, deploymentName string) {
|
||||
body := azopenai.ChatCompletionsOptions{
|
||||
DeploymentName: &deploymentName,
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
&azopenai.ChatRequestAssistantMessage{
|
||||
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
||||
},
|
||||
},
|
||||
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
|
||||
Value: to.Ptr("auto"),
|
||||
},
|
||||
Functions: []azopenai.FunctionDefinition{
|
||||
{
|
||||
Name: to.Ptr("get_current_weather"),
|
||||
Description: to.Ptr("Get the current weather in a given location"),
|
||||
Parameters: Params{
|
||||
Required: []string{"location"},
|
||||
Type: "object",
|
||||
Properties: map[string]ParamProperty{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Temperature: to.Ptr[float32](0.0),
|
||||
}
|
||||
|
||||
resp, err := client.GetChatCompletions(context.Background(), body, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
|
||||
|
||||
require.Equal(t, "get_current_weather", *funcCall.Name)
|
||||
|
||||
type location struct {
|
||||
Location string `json:"location"`
|
||||
Unit string `json:"unit"`
|
||||
}
|
||||
|
||||
var funcParams *location
|
||||
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||
}
|
||||
|
||||
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
|
||||
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
|
||||
body := azopenai.ChatCompletionsOptions{
|
||||
DeploymentName: &deploymentName,
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
|
@ -204,9 +82,124 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep
|
|||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||
}
|
||||
|
||||
func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Client, tv testVars) {
|
||||
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
|
||||
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
|
||||
)
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
chatClient := newTestClient(t, azureOpenAI.ChatCompletions.Endpoint)
|
||||
|
||||
testData := []struct {
|
||||
Model string
|
||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||
}{
|
||||
// all of these variants use the tool provided - auto just also works since we did provide
|
||||
// a tool reference and ask a question to use it.
|
||||
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: nil},
|
||||
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||
{Model: azureOpenAI.ChatCompletions.Model, ToolChoice: useSpecificTool},
|
||||
}
|
||||
|
||||
for _, td := range testData {
|
||||
testFn(t, chatClient, td.Model, td.ToolChoice)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testData := []struct {
|
||||
EPM endpointWithModel
|
||||
ToolChoice *azopenai.ChatCompletionsToolChoice
|
||||
}{
|
||||
// all of these variants use the tool provided - auto just also works since we did provide
|
||||
// a tool reference and ask a question to use it.
|
||||
{EPM: openAI.ChatCompletions, ToolChoice: nil},
|
||||
{EPM: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
|
||||
{EPM: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
|
||||
}
|
||||
|
||||
for i, td := range testData {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
chatClient := newTestClient(t, td.EPM.Endpoint)
|
||||
testFn(t, chatClient, td.EPM.Model, td.ToolChoice)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatCompletions_usingFunctions_legacy(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
body := azopenai.ChatCompletionsOptions{
|
||||
DeploymentName: &tv.ChatCompletions,
|
||||
DeploymentName: &epm.Model,
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
&azopenai.ChatRequestAssistantMessage{
|
||||
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
||||
},
|
||||
},
|
||||
FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{
|
||||
Value: to.Ptr("auto"),
|
||||
},
|
||||
Functions: []azopenai.FunctionDefinition{
|
||||
{
|
||||
Name: to.Ptr("get_current_weather"),
|
||||
Description: to.Ptr("Get the current weather in a given location"),
|
||||
Parameters: Params{
|
||||
Required: []string{"location"},
|
||||
Type: "object",
|
||||
Properties: map[string]ParamProperty{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Temperature: to.Ptr[float32](0.0),
|
||||
}
|
||||
|
||||
resp, err := client.GetChatCompletions(context.Background(), body, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
|
||||
|
||||
require.Equal(t, "get_current_weather", *funcCall.Name)
|
||||
|
||||
type location struct {
|
||||
Location string `json:"location"`
|
||||
Unit string `json:"unit"`
|
||||
}
|
||||
|
||||
var funcParams *location
|
||||
err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.ChatCompletionsLegacyFunctions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.ChatCompletions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI.LegacyFunctions", func(t *testing.T) {
|
||||
testFn(t, openAI.ChatCompletionsLegacyFunctions)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatCompletions_usingFunctions_streaming(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
body := azopenai.ChatCompletionsOptions{
|
||||
DeploymentName: &epm.Model,
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
&azopenai.ChatRequestAssistantMessage{
|
||||
Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"),
|
||||
|
@ -237,6 +230,8 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
|
|||
Temperature: to.Ptr[float32](0.0),
|
||||
}
|
||||
|
||||
chatClient := newTestClient(t, epm.Endpoint)
|
||||
|
||||
resp, err := chatClient.GetChatCompletionsStream(context.Background(), body, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, resp)
|
||||
|
@ -293,3 +288,14 @@ func testChatCompletionsFunctionsStreaming(t *testing.T, chatClient *azopenai.Cl
|
|||
|
||||
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.ChatCompletions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.ChatCompletions)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -22,13 +22,13 @@ import (
|
|||
func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) {
|
||||
// Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered
|
||||
// https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered
|
||||
client := newAzureOpenAIClientForTest(t, azureOpenAI)
|
||||
client := newTestClient(t, azureOpenAI.Completions.Endpoint)
|
||||
|
||||
resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
|
||||
Prompt: []string{"How do I rob a bank with violence?"},
|
||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||
Temperature: to.Ptr(float32(0.0)),
|
||||
DeploymentName: &azureOpenAI.Completions,
|
||||
DeploymentName: &azureOpenAI.Completions.Model,
|
||||
}, nil)
|
||||
|
||||
require.Empty(t, resp)
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -26,11 +27,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
azureOpenAI testVars
|
||||
openAI testVars
|
||||
)
|
||||
|
||||
type endpoint struct {
|
||||
URL string
|
||||
APIKey string
|
||||
|
@ -38,22 +34,19 @@ type endpoint struct {
|
|||
}
|
||||
|
||||
type testVars struct {
|
||||
Endpoint endpoint
|
||||
Completions string
|
||||
ChatCompletions string
|
||||
ChatCompletionsLegacyFunctions string
|
||||
Embeddings string
|
||||
TextEmbedding3Small string
|
||||
ChatCompletions endpointWithModel
|
||||
ChatCompletionsLegacyFunctions endpointWithModel
|
||||
ChatCompletionsOYD endpointWithModel // azure only
|
||||
ChatCompletionsRAI endpointWithModel // azure only
|
||||
ChatCompletionsWithJSONResponseFormat endpointWithModel
|
||||
Cognitive azopenai.AzureSearchChatExtensionConfiguration
|
||||
Whisper endpointWithModel
|
||||
Completions endpointWithModel
|
||||
DallE endpointWithModel
|
||||
Embeddings endpointWithModel
|
||||
Speech endpointWithModel
|
||||
TextEmbedding3Small endpointWithModel
|
||||
Vision endpointWithModel
|
||||
|
||||
ChatCompletionsRAI endpointWithModel // at the moment this is Azure only
|
||||
|
||||
// "own your data" - bringing in Azure resources as part of a chat completions
|
||||
// request.
|
||||
ChatCompletionsOYD endpointWithModel
|
||||
Whisper endpointWithModel
|
||||
}
|
||||
|
||||
type endpointWithModel struct {
|
||||
|
@ -61,16 +54,130 @@ type endpointWithModel struct {
|
|||
Model string
|
||||
}
|
||||
|
||||
type testClientOption func(opt *azopenai.ClientOptions)
|
||||
func ifAzure[T string | endpoint](azure bool, forAzure T, forOpenAI T) T {
|
||||
if azure {
|
||||
return forAzure
|
||||
}
|
||||
return forOpenAI
|
||||
}
|
||||
|
||||
func withForgivingRetryOption() testClientOption {
|
||||
return func(opt *azopenai.ClientOptions) {
|
||||
opt.Retry = policy.RetryOptions{
|
||||
MaxRetries: 10,
|
||||
var azureOpenAI, openAI, servers = func() (testVars, testVars, []string) {
|
||||
if recording.GetRecordMode() != recording.PlaybackMode {
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Fatalf("Failed to load .env file: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
servers := struct {
|
||||
USEast endpoint
|
||||
USNorthCentral endpoint
|
||||
USEast2 endpoint
|
||||
SWECentral endpoint
|
||||
OpenAI endpoint
|
||||
}{
|
||||
OpenAI: endpoint{
|
||||
URL: getEndpoint("OPENAI_ENDPOINT"), // ex: https://api.openai.com/v1/
|
||||
APIKey: recording.GetEnvVariable("OPENAI_API_KEY", fakeAPIKey),
|
||||
Azure: false,
|
||||
},
|
||||
USEast: endpoint{
|
||||
URL: getEndpoint("ENDPOINT_USEAST"),
|
||||
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST_API_KEY", fakeAPIKey),
|
||||
Azure: true,
|
||||
},
|
||||
USEast2: endpoint{
|
||||
URL: getEndpoint("ENDPOINT_USEAST2"),
|
||||
APIKey: recording.GetEnvVariable("ENDPOINT_USEAST2_API_KEY", fakeAPIKey),
|
||||
Azure: true,
|
||||
},
|
||||
USNorthCentral: endpoint{
|
||||
URL: getEndpoint("ENDPOINT_USNORTHCENTRAL"),
|
||||
APIKey: recording.GetEnvVariable("ENDPOINT_USNORTHCENTRAL_API_KEY", fakeAPIKey),
|
||||
Azure: true,
|
||||
},
|
||||
SWECentral: endpoint{
|
||||
URL: getEndpoint("ENDPOINT_SWECENTRAL"),
|
||||
APIKey: recording.GetEnvVariable("ENDPOINT_SWECENTRAL_API_KEY", fakeAPIKey),
|
||||
Azure: true,
|
||||
},
|
||||
}
|
||||
|
||||
// used when we setup the recording policy
|
||||
endpoints := []string{
|
||||
servers.OpenAI.URL,
|
||||
servers.USEast.URL,
|
||||
servers.USEast2.URL,
|
||||
servers.USNorthCentral.URL,
|
||||
servers.SWECentral.URL,
|
||||
}
|
||||
|
||||
newTestVarsFn := func(azure bool) testVars {
|
||||
return testVars{
|
||||
ChatCompletions: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
|
||||
},
|
||||
ChatCompletionsLegacyFunctions: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-0613", "gpt-4-0613"),
|
||||
},
|
||||
ChatCompletionsOYD: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
|
||||
},
|
||||
ChatCompletionsRAI: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-0613", ""), // azure only
|
||||
},
|
||||
ChatCompletionsWithJSONResponseFormat: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-1106-preview", "gpt-3.5-turbo-1106"),
|
||||
},
|
||||
Completions: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"),
|
||||
},
|
||||
DallE: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||
Model: ifAzure(azure, "dall-e-3", "dall-e-3"),
|
||||
},
|
||||
Embeddings: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "text-embedding-ada-002", "text-embedding-ada-002"),
|
||||
},
|
||||
Speech: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "tts-1", "tts-1"),
|
||||
},
|
||||
TextEmbedding3Small: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast, servers.OpenAI),
|
||||
Model: ifAzure(azure, "text-embedding-3-small", "text-embedding-3-small"),
|
||||
},
|
||||
Vision: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.SWECentral, servers.OpenAI),
|
||||
Model: ifAzure(azure, "gpt-4-vision-preview", "gpt-4-vision-preview"),
|
||||
},
|
||||
Whisper: endpointWithModel{
|
||||
Endpoint: ifAzure(azure, servers.USEast2, servers.OpenAI),
|
||||
Model: ifAzure(azure, "whisper-deployment", "whisper-1"),
|
||||
},
|
||||
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
|
||||
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
||||
Endpoint: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_ENDPOINT", fakeCognitiveEndpoint)),
|
||||
IndexName: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_INDEX", fakeCognitiveIndexName)),
|
||||
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
||||
Key: to.Ptr(recording.GetEnvVariable("COGNITIVE_SEARCH_API_KEY", fakeAPIKey)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return newTestVarsFn(true), newTestVarsFn(false), endpoints
|
||||
}()
|
||||
|
||||
type testClientOption func(opt *azopenai.ClientOptions)
|
||||
|
||||
// newTestClient creates a client enabled for HTTP recording, if needed.
|
||||
// See [newRecordingTransporter] for sanitization code.
|
||||
func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azopenai.Client {
|
||||
|
@ -101,197 +208,11 @@ func newTestClient(t *testing.T, ep endpoint, options ...testClientOption) *azop
|
|||
}
|
||||
}
|
||||
|
||||
// getEndpoint retrieves details for an endpoint and a model.
|
||||
// - res - the resource type for a particular endpoint. Ex: "DALLE".
|
||||
//
|
||||
// For example, if azure is true we'll load these environment values based on res:
|
||||
// - AOAI_DALLE_ENDPOINT
|
||||
// - AOAI_DALLE_API_KEY
|
||||
//
|
||||
// if azure is false we'll load these environment values based on res:
|
||||
// - OPENAI_ENDPOINT
|
||||
// - OPENAI_API_KEY
|
||||
func getEndpoint(res string, isAzure bool) endpointWithModel {
|
||||
var ep endpointWithModel
|
||||
if isAzure {
|
||||
// during development resources are often shifted between different
|
||||
// internal Azure OpenAI resources.
|
||||
ep = endpointWithModel{
|
||||
Endpoint: endpoint{
|
||||
URL: getRequired("AOAI_" + res + "_ENDPOINT"),
|
||||
APIKey: getRequired("AOAI_" + res + "_API_KEY"),
|
||||
Azure: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ep = endpointWithModel{
|
||||
Endpoint: endpoint{
|
||||
URL: getRequired("OPENAI_ENDPOINT"),
|
||||
APIKey: getRequired("OPENAI_API_KEY"),
|
||||
Azure: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(ep.Endpoint.URL, "/") {
|
||||
// (this just makes recording replacement easier)
|
||||
ep.Endpoint.URL += "/"
|
||||
}
|
||||
|
||||
return ep
|
||||
}
|
||||
|
||||
func model(azure bool, azureModel, openAIModel string) string {
|
||||
if azure {
|
||||
return azureModel
|
||||
}
|
||||
|
||||
return openAIModel
|
||||
}
|
||||
|
||||
func updateModels(azure bool, tv *testVars) {
|
||||
// the models we use are basically their own API surface so it's good to know which
|
||||
// specific models our tests were written against.
|
||||
tv.Completions = model(azure, "gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct")
|
||||
tv.ChatCompletions = model(azure, "gpt-35-turbo-0613", "gpt-4-0613")
|
||||
tv.ChatCompletionsLegacyFunctions = model(azure, "gpt-4-0613", "gpt-4-0613")
|
||||
tv.Embeddings = model(azure, "text-embedding-ada-002", "text-embedding-ada-002")
|
||||
tv.TextEmbedding3Small = model(azure, "text-embedding-3-small", "text-embedding-3-small")
|
||||
|
||||
tv.DallE.Model = model(azure, "dall-e-3", "dall-e-3")
|
||||
tv.Whisper.Model = model(azure, "whisper-deployment", "whisper-1")
|
||||
tv.Vision.Model = model(azure, "gpt-4-vision-preview", "gpt-4-vision-preview")
|
||||
|
||||
// these are Azure-only features
|
||||
tv.ChatCompletionsOYD.Model = model(azure, "gpt-4", "")
|
||||
tv.ChatCompletionsRAI.Model = model(azure, "gpt-4", "")
|
||||
}
|
||||
|
||||
func newTestVars(prefix string) testVars {
|
||||
azure := prefix == "AOAI"
|
||||
|
||||
tv := testVars{
|
||||
Endpoint: endpoint{
|
||||
URL: getRequired(prefix + "_ENDPOINT"),
|
||||
APIKey: getRequired(prefix + "_API_KEY"),
|
||||
Azure: azure,
|
||||
},
|
||||
Cognitive: azopenai.AzureSearchChatExtensionConfiguration{
|
||||
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
||||
Endpoint: to.Ptr(getRequired("COGNITIVE_SEARCH_API_ENDPOINT")),
|
||||
IndexName: to.Ptr(getRequired("COGNITIVE_SEARCH_API_INDEX")),
|
||||
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
||||
Key: to.Ptr(getRequired("COGNITIVE_SEARCH_API_KEY")),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
DallE: getEndpoint("DALLE", azure),
|
||||
Whisper: getEndpoint("WHISPER", azure),
|
||||
Vision: getEndpoint("VISION", azure),
|
||||
}
|
||||
|
||||
if azure {
|
||||
tv.ChatCompletionsRAI = getEndpoint("CHAT_COMPLETIONS_RAI", azure)
|
||||
tv.ChatCompletionsOYD = getEndpoint("OYD", azure)
|
||||
}
|
||||
|
||||
updateModels(azure, &tv)
|
||||
|
||||
if tv.Endpoint.URL != "" && !strings.HasSuffix(tv.Endpoint.URL, "/") {
|
||||
// (this just makes recording replacement easier)
|
||||
tv.Endpoint.URL += "/"
|
||||
}
|
||||
|
||||
return tv
|
||||
}
|
||||
|
||||
const fakeEndpoint = "https://fake-recorded-host.microsoft.com/"
|
||||
const fakeAPIKey = "redacted"
|
||||
const fakeCognitiveEndpoint = "https://fake-cognitive-endpoint.microsoft.com"
|
||||
const fakeCognitiveIndexName = "index"
|
||||
|
||||
func initEnvVars() {
|
||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||
// Setup our variables so our requests are consistent with what we recorded.
|
||||
// Endpoints are sanitized using the recording policy
|
||||
azureOpenAI.Endpoint = endpoint{
|
||||
URL: fakeEndpoint,
|
||||
APIKey: fakeAPIKey,
|
||||
Azure: true,
|
||||
}
|
||||
|
||||
azureOpenAI.Whisper = endpointWithModel{
|
||||
Endpoint: azureOpenAI.Endpoint,
|
||||
}
|
||||
|
||||
azureOpenAI.ChatCompletionsRAI = endpointWithModel{
|
||||
Endpoint: azureOpenAI.Endpoint,
|
||||
}
|
||||
|
||||
azureOpenAI.ChatCompletionsOYD = endpointWithModel{
|
||||
Endpoint: azureOpenAI.Endpoint,
|
||||
}
|
||||
|
||||
azureOpenAI.DallE = endpointWithModel{
|
||||
Endpoint: azureOpenAI.Endpoint,
|
||||
}
|
||||
|
||||
azureOpenAI.Vision = endpointWithModel{
|
||||
Endpoint: azureOpenAI.Endpoint,
|
||||
}
|
||||
|
||||
openAI.Endpoint = endpoint{
|
||||
APIKey: fakeAPIKey,
|
||||
URL: fakeEndpoint,
|
||||
}
|
||||
|
||||
openAI.Whisper = endpointWithModel{
|
||||
Endpoint: endpoint{
|
||||
APIKey: fakeAPIKey,
|
||||
URL: fakeEndpoint,
|
||||
},
|
||||
}
|
||||
|
||||
openAI.DallE = endpointWithModel{
|
||||
Endpoint: openAI.Endpoint,
|
||||
}
|
||||
|
||||
updateModels(true, &azureOpenAI)
|
||||
updateModels(false, &openAI)
|
||||
|
||||
openAI.Vision = azureOpenAI.Vision
|
||||
|
||||
azureOpenAI.Completions = "gpt-35-turbo-instruct"
|
||||
openAI.Completions = "gpt-3.5-turbo-instruct"
|
||||
|
||||
azureOpenAI.ChatCompletions = "gpt-35-turbo-0613"
|
||||
azureOpenAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
|
||||
openAI.ChatCompletions = "gpt-4-0613"
|
||||
openAI.ChatCompletionsLegacyFunctions = "gpt-4-0613"
|
||||
|
||||
openAI.Embeddings = "text-embedding-ada-002"
|
||||
azureOpenAI.Embeddings = "text-embedding-ada-002"
|
||||
|
||||
azureOpenAI.Cognitive = azopenai.AzureSearchChatExtensionConfiguration{
|
||||
Parameters: &azopenai.AzureSearchChatExtensionParameters{
|
||||
Endpoint: to.Ptr(fakeCognitiveEndpoint),
|
||||
IndexName: to.Ptr(fakeCognitiveIndexName),
|
||||
Authentication: &azopenai.OnYourDataAPIKeyAuthenticationOptions{
|
||||
Key: to.Ptr(fakeAPIKey),
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
if err := godotenv.Load(); err != nil {
|
||||
fmt.Printf("Failed to load .env file: %s\n", err)
|
||||
}
|
||||
|
||||
azureOpenAI = newTestVars("AOAI")
|
||||
openAI = newTestVars("OPENAI")
|
||||
}
|
||||
}
|
||||
|
||||
type MultipartRecordingPolicy struct {
|
||||
}
|
||||
|
||||
|
@ -315,17 +236,7 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
|
|||
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
endpoints := []string{
|
||||
azureOpenAI.Endpoint.URL,
|
||||
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
|
||||
azureOpenAI.Whisper.Endpoint.URL,
|
||||
azureOpenAI.DallE.Endpoint.URL,
|
||||
azureOpenAI.Vision.Endpoint.URL,
|
||||
azureOpenAI.ChatCompletionsOYD.Endpoint.URL,
|
||||
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
for _, ep := range servers {
|
||||
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(ep), nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -333,8 +244,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
|
|||
err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
if openAI.Endpoint.URL != "" {
|
||||
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.Endpoint.URL), nil)
|
||||
// there's only one OpenAI endpoint
|
||||
if openAI.ChatCompletions.Endpoint.URL != "" {
|
||||
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAI.ChatCompletions.Endpoint.URL), nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -401,20 +313,12 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {
|
|||
return co
|
||||
}
|
||||
|
||||
func newAzureOpenAIClientForTest(t *testing.T, tv testVars, options ...testClientOption) *azopenai.Client {
|
||||
return newTestClient(t, tv.Endpoint, options...)
|
||||
}
|
||||
|
||||
func newOpenAIClientForTest(t *testing.T, options ...testClientOption) *azopenai.Client {
|
||||
return newTestClient(t, openAI.Endpoint, options...)
|
||||
}
|
||||
|
||||
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
|
||||
// a failure.
|
||||
func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
|
||||
cred := azcore.NewKeyCredential("bogus-api-key")
|
||||
|
||||
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
return client
|
||||
|
@ -425,7 +329,7 @@ func newBogusAzureOpenAIClient(t *testing.T) *azopenai.Client {
|
|||
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
|
||||
cred := azcore.NewKeyCredential("bogus-api-key")
|
||||
|
||||
client, err := azopenai.NewClientForOpenAI(openAI.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||
client, err := azopenai.NewClientForOpenAI(openAI.Completions.Endpoint.URL, cred, newClientOptionsForTest(t))
|
||||
require.NoError(t, err)
|
||||
return client
|
||||
}
|
||||
|
@ -440,11 +344,12 @@ func assertResponseIsError(t *testing.T, err error) {
|
|||
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
|
||||
}
|
||||
|
||||
func getRequired(name string) string {
|
||||
v := os.Getenv(name)
|
||||
func getEndpoint(ev string) string {
|
||||
v := recording.GetEnvVariable(ev, fakeEndpoint)
|
||||
|
||||
if v == "" {
|
||||
panic(fmt.Sprintf("Env variable %s is missing", name))
|
||||
if !strings.HasSuffix(v, "/") {
|
||||
// (this just makes recording replacement easier)
|
||||
v += "/"
|
||||
}
|
||||
|
||||
return v
|
||||
|
|
|
@ -23,7 +23,7 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) {
|
|||
t.Skip()
|
||||
}
|
||||
|
||||
chatClient := newOpenAIClientForTest(t)
|
||||
chatClient := newTestClient(t, openAI.ChatCompletions.Endpoint)
|
||||
|
||||
_, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
|
||||
Messages: []azopenai.ChatRequestMessageClassification{
|
||||
|
|
|
@ -29,11 +29,7 @@ func TestImageGeneration_AzureOpenAI(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImageGeneration_OpenAI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
client := newOpenAIClientForTest(t)
|
||||
client := newTestClient(t, openAI.DallE.Endpoint)
|
||||
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatURL)
|
||||
}
|
||||
|
||||
|
@ -60,7 +56,7 @@ func TestImageGeneration_OpenAI_Base64(t *testing.T) {
|
|||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
client := newOpenAIClientForTest(t)
|
||||
client := newTestClient(t, openAI.DallE.Endpoint)
|
||||
testImageGeneration(t, client, openAI.DallE.Model, azopenai.ImageGenerationResponseFormatBase64)
|
||||
}
|
||||
|
||||
|
|
|
@ -76,28 +76,17 @@ func TestNewClientWithKeyCredential(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
|
||||
client := newTestClient(t, azureOpenAI.Endpoint)
|
||||
testGetCompletionsStream(t, client, azureOpenAI)
|
||||
}
|
||||
|
||||
func TestGetCompletionsStream_OpenAI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
|
||||
}
|
||||
|
||||
client := newOpenAIClientForTest(t)
|
||||
testGetCompletionsStream(t, client, openAI)
|
||||
}
|
||||
|
||||
func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
|
||||
func TestGetCompletionsStream(t *testing.T) {
|
||||
testFn := func(t *testing.T, epm endpointWithModel) {
|
||||
body := azopenai.CompletionsOptions{
|
||||
Prompt: []string{"What is Azure OpenAI?"},
|
||||
MaxTokens: to.Ptr(int32(2048)),
|
||||
Temperature: to.Ptr(float32(0.0)),
|
||||
DeploymentName: &tv.Completions,
|
||||
DeploymentName: &epm.Model,
|
||||
}
|
||||
|
||||
client := newTestClient(t, epm.Endpoint)
|
||||
|
||||
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
|
||||
skipNowIfThrottled(t, err)
|
||||
require.NoError(t, err)
|
||||
|
@ -106,6 +95,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
|
|||
t.Errorf("Client.GetCompletionsStream() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
reader := response.CompletionsStream
|
||||
defer reader.Close()
|
||||
|
||||
|
@ -146,12 +136,23 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars
|
|||
require.GreaterOrEqual(t, eventCount, 50)
|
||||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
testFn(t, azureOpenAI.Completions)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
testFn(t, openAI.Completions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetCompletions_Error(t *testing.T) {
|
||||
if recording.GetRecordMode() == recording.PlaybackMode {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
doTest := func(t *testing.T, client *azopenai.Client, model string) {
|
||||
doTest := func(t *testing.T, model string) {
|
||||
client := newBogusAzureOpenAIClient(t)
|
||||
|
||||
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
|
||||
Prompt: []string{"What is Azure OpenAI?"},
|
||||
MaxTokens: to.Ptr(int32(2048 - 127)),
|
||||
|
@ -163,12 +164,10 @@ func TestClient_GetCompletions_Error(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("AzureOpenAI", func(t *testing.T) {
|
||||
client := newBogusAzureOpenAIClient(t)
|
||||
doTest(t, client, azureOpenAI.Completions)
|
||||
doTest(t, azureOpenAI.Completions.Model)
|
||||
})
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
client := newBogusOpenAIClient(t)
|
||||
doTest(t, client, openAI.Completions)
|
||||
doTest(t, openAI.Completions.Model)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
|
|||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0
|
||||
github.com/joho/godotenv v1.3.0
|
||||
|
@ -19,9 +19,9 @@ require (
|
|||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/net v0.22.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.22.0 // indirect
|
||||
golang.org/x/net v0.24.0 // indirect
|
||||
golang.org/x/sys v0.19.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 h1:c4k2FIYIh4xtwqrQwV0Ct1v5+ehlNXj5NI/MWVsiTkQ=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2/go.mod h1:5FDJtLEO/GxwNgUxbwrY3LP0pEoThTQJtk2oysdXHxM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.6.0 h1:sUFnFjzDUie80h24I7mrKtwCKgLY9L8h5Tp2x9+TWqk=
|
||||
|
@ -25,13 +25,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
const RecordingDirectory = "sdk/ai/azopenai/testdata"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
initEnvVars()
|
||||
code := run(m)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче