From e1f2f794d81b9746d688a4b625b65e8669d58bc8 Mon Sep 17 00:00:00 2001 From: Shalom Yiblet Date: Sat, 16 Mar 2024 11:13:16 -0400 Subject: [PATCH] feat: add tests for chatting --- chat.go | 10 ++++-- chat_test.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 chat_test.go diff --git a/chat.go b/chat.go index 14fc0e1..008a7ea 100644 --- a/chat.go +++ b/chat.go @@ -15,9 +15,15 @@ type aiStreamInput struct { Timeout time.Duration } +type chatCompletionStreamer interface { + // ChatCompletion creates a completion with the Chat completion endpoint which + // is what powers the ChatGPT experience. + ChatCompletionStream(ctx context.Context, request gpt3.ChatCompletionRequest, onData func(*gpt3.ChatCompletionStreamResponse) error) error +} + func aiStream( ctx context.Context, - client gpt3.Client, + streamer chatCompletionStreamer, input aiStreamInput, handler func(message string) error, ) error { @@ -29,7 +35,7 @@ func aiStream( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - err := client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ + err := streamer.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ Messages: input.Messages, MaxTokens: input.MaxTokens, Temperature: input.Temperature, diff --git a/chat_test.go b/chat_test.go new file mode 100644 index 0000000..5c5fa36 --- /dev/null +++ b/chat_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "fmt" + "strings" + "testing" + + gpt3 "github.com/PullRequestInc/go-gpt3" +) + +type testStream struct { + t *testing.T + response string + request gpt3.ChatCompletionRequest +} + +// ChatCompletionStream implements streamClient. +func (t *testStream) ChatCompletionStream(ctx context.Context, request gpt3.ChatCompletionRequest, onData func(*gpt3.ChatCompletionStreamResponse) error) error { + convert := func(s string) *gpt3.ChatCompletionStreamResponse { + return &gpt3.ChatCompletionStreamResponse{ + Choices: []gpt3.ChatCompletionStreamResponseChoice{ + { + Delta: gpt3.ChatCompletionResponseMessage{ + Content: s, + }, + }, + }, + } + } + + t.request = request + for i, elem := range strings.Split(t.response, " ") { + t.t.Logf("inserting: %#v", elem) + elem := elem + if i == 0 { + if err := onData(convert(elem)); err != nil { + return err + } + } else { + if err := onData(convert(fmt.Sprintf(" %s", elem))); err != nil { + return err + } + } + } + + return nil +} + +func TestAIStream(t *testing.T) { + t.Parallel() + + t.Run("valid stream", func(t *testing.T) { + stream := &testStream{response: "valid stream", t: t} + var sb strings.Builder + err := aiStream(context.Background(), stream, aiStreamInput{}, func(message string) error { + t.Logf("retrieved: %#v", message) + _, err := sb.WriteString(message) + return err + }) + + if err != nil { + t.Errorf("stream should not error: %v", err) + } + + expected := stream.response + test := sb.String() + if expected != test { + t.Errorf("invalid stream response. Expected %#v got %#v", expected, test) + } + }) + + t.Run("valid stream", func(t *testing.T) { + stream := &testStream{response: "valid stream", t: t} + var sb strings.Builder + aiStream(context.Background(), stream, aiStreamInput{}, func(message string) error { + t.Logf("retrieved: %#v", message) + _, err := sb.WriteString(message) + return err + }) + + expected := stream.response + test := sb.String() + if expected != test { + t.Errorf("invalid stream response. Expected %#v got %#v", expected, test) + } + }) + +}