Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions chat.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -86,12 +87,35 @@ type ChatMessagePartType string
const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
ChatMessagePartTypeAudio ChatMessagePartType = "input_audio"
ChatMessagePartTypeVideo ChatMessagePartType = "video"
ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url"
)

/* reference:
* https://bailian.console.aliyun.com/
* ?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576
* https://help.aliyun.com/zh/model-studio/qwen-omni#423736d367a7x
*/
type InputAudio struct {
Data string `json:"data"`
Format string `json:"format"`
}

type CacheControl struct {
Type string `json:"type"` // must be "ephemeral"
}

type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
Audio *InputAudio `json:"input_audio,omitempty"` // required when Type is "input_audio"
VideoURL *ChatMessageImageURL `json:"video_url,omitempty"` // required when Type is "video_url"
Video []string `json:"video,omitempty"` // required when Type is "video", array of image URLs
MinPixels int `json:"min_pixels,omitempty"`
MaxPixels int `json:"max_pixels,omitempty"`
*CacheControl `json:"cache_control,omitempty"`
}

type ChatCompletionMessage struct {
Expand Down Expand Up @@ -333,6 +357,34 @@ type ChatCompletionRequest struct {
SafetyIdentifier string `json:"safety_identifier,omitempty"`
// Embedded struct for non-OpenAI extensions
ChatCompletionRequestExtensions
// non-OpenAI extensions
Extensions map[string]interface{} `json:"-"`
}

type customChatCompletionRequest ChatCompletionRequest

const TrailingLen = 2 // length of "}\n"
func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) {
if len(r.Extensions) == 0 {
return json.Marshal((*customChatCompletionRequest)(r))
}
buf := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buf)
if err := encoder.Encode((*customChatCompletionRequest)(r)); err != nil {
return nil, err
}
// remove the trailing "}\n"
buf.Truncate(buf.Len() - TrailingLen)
// record the current position
pos := buf.Len()
// append extensions
if err := encoder.Encode(r.Extensions); err != nil {
return nil, err
}
data := buf.Bytes()
// change the leading '{' of extensions to ','
data[pos] = ','
return data, nil
}

type StreamOptions struct {
Expand Down
10 changes: 9 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import (
"net/http"
)

type OutputAudio struct {
Transcript string `json:"transcript"` // streamed text content
Data string `json:"data"` // base64-encoded audio data
ExpiresAt int `json:"expires_at"` // the timestamp when the request was created
}

type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
Expand All @@ -17,6 +23,8 @@ type ChatCompletionStreamChoiceDelta struct {
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
// Audio is used for audio responses, if supported by the model, such as "qwen-omni".
Audio *OutputAudio `json:"audio,omitempty"`
}

type ChatCompletionStreamChoiceLogprobs struct {
Expand Down Expand Up @@ -95,7 +103,7 @@ func (c *Client) CreateChatCompletionStream(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
withBody(&request),
)
if err != nil {
return nil, err
Expand Down
180 changes: 180 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,183 @@
}
return true
}

func TestOutputAudio(t *testing.T) {
audio := openai.OutputAudio{
Transcript: "Hello, world!",
Data: "base64encodedaudiodata",
ExpiresAt: 1234567890,
}

data, err := json.Marshal(audio)
if err != nil {
t.Errorf("Failed to marshal OutputAudio: %v", err)
return
}

var result openai.OutputAudio
if err = json.Unmarshal(data, &result); err != nil {
t.Errorf("Failed to unmarshal OutputAudio: %v", err)
return
}

if result.Transcript != audio.Transcript {
t.Errorf("Expected transcript %s, got %s", audio.Transcript, result.Transcript)
}
if result.Data != audio.Data {
t.Errorf("Expected data %s, got %s", audio.Data, result.Data)
}
if result.ExpiresAt != audio.ExpiresAt {
t.Errorf("Expected expires_at %d, got %d", audio.ExpiresAt, result.ExpiresAt)
}
}

// verifyAudioContent checks if the audio content matches between expected and actual

Check failure on line 1055 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) {
if actual.Transcript != expected.Transcript {
t.Errorf("Expected audio transcript %s, got %s", expected.Transcript, actual.Transcript)
}
if actual.Data != expected.Data {
t.Errorf("Expected audio data %s, got %s", expected.Data, actual.Data)
}
if actual.ExpiresAt != expected.ExpiresAt {
t.Errorf("Expected audio expires_at %d, got %d", expected.ExpiresAt, actual.ExpiresAt)
}
}

// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta

Check failure on line 1068 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStreamChoiceDelta) {
if expected.Audio != nil {
if actual.Audio == nil {
t.Error("Expected audio to be present, but it's nil")
return
}
verifyAudioContent(t, expected.Audio, actual.Audio)
} else if actual.Audio != nil {
t.Error("Expected audio to be nil, but it's present")
}
}

// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta

Check failure on line 1081 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func testDeltaSerialization(t *testing.T, delta openai.ChatCompletionStreamChoiceDelta) openai.ChatCompletionStreamChoiceDelta {

Check failure on line 1082 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

The line is 128 characters long, which exceeds the maximum of 120 characters. (lll)
// Test JSON marshaling
data, err := json.Marshal(delta)
if err != nil {
t.Errorf("Failed to marshal ChatCompletionStreamChoiceDelta: %v", err)
return openai.ChatCompletionStreamChoiceDelta{}
}

// Test JSON unmarshaling
var result openai.ChatCompletionStreamChoiceDelta
if err = json.Unmarshal(data, &result); err != nil {
t.Errorf("Failed to unmarshal ChatCompletionStreamChoiceDelta: %v", err)
return openai.ChatCompletionStreamChoiceDelta{}
}

return result
}

func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) {
tests := []struct {
name string
delta openai.ChatCompletionStreamChoiceDelta
}{
{
name: "with audio",
delta: openai.ChatCompletionStreamChoiceDelta{
Content: "Hello",
Audio: &openai.OutputAudio{
Transcript: "Hello, world!",
Data: "base64encodedaudiodata",
ExpiresAt: 1234567890,
},
},
},
{
name: "without audio",
delta: openai.ChatCompletionStreamChoiceDelta{
Content: "Hello",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := testDeltaSerialization(t, tt.delta)

// Verify the content is preserved
if result.Content != tt.delta.Content {
t.Errorf("Expected content %s, got %s", tt.delta.Content, result.Content)
}

// Verify audio is preserved when present
verifyAudioInDelta(t, tt.delta, result)
})
}
}

func TestCreateChatCompletionStreamWithAudio(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()

server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses with audio
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`

Check failure on line 1149 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

The line is 169 characters long, which exceeds the maximum of 120 characters. (lll)
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}`

Check failure on line 1153 in chat_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

The line is 245 characters long, which exceeds the maximum of 120 characters. (lll)
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, _ = w.Write(dataBytes)
})

ctx := context.Background()
req := openai.ChatCompletionRequest{
Model: "qwen-omni",
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}

stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletionStream error: %v", err)
}
defer stream.Close()

hasAudio := false
for {
var resp openai.ChatCompletionStreamResponse
resp, err = stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
t.Fatalf("Stream error: %v", err)
}

if len(resp.Choices) > 0 && resp.Choices[0].Delta.Audio != nil {
hasAudio = true
if resp.Choices[0].Delta.Audio.Transcript != "Hello, world!" {
t.Errorf("Expected transcript 'Hello, world!', got %s", resp.Choices[0].Delta.Audio.Transcript)
}
if resp.Choices[0].Delta.Audio.Data != "base64encodedaudiodata" {
t.Errorf("Expected audio data 'base64encodedaudiodata', got %s", resp.Choices[0].Delta.Audio.Data)
}
}
}

if !hasAudio {
t.Error("Expected to receive audio in stream response")
}
}
Loading
Loading