From 99b18041e45743b284389baf01087fa5e41d7f87 Mon Sep 17 00:00:00 2001 From: razorJulius Date: Tue, 14 Feb 2023 18:43:12 +0900 Subject: [PATCH] Sync completion with official documentation --- gpt3.go | 12 ++++++------ gpt3_test.go | 18 +++++++++--------- models.go | 13 ++++++++++++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/gpt3.go b/gpt3.go index 207612f..2bd9552 100644 --- a/gpt3.go +++ b/gpt3.go @@ -55,10 +55,6 @@ const ( defaultTimeoutSeconds = 30 ) -func getEngineURL(engine string) string { - return fmt.Sprintf("%s/engines/%s/completions", defaultBaseURL, engine) -} - // A Client is an API client to communicate with the OpenAI gpt-3 APIs type Client interface { // Engines lists the currently available engines, and provides basic information about each @@ -125,6 +121,8 @@ func NewClient(apiKey string, options ...ClientOption) Client { return c } +// The Engines endpoints are deprecated. +// Please use their replacement, Models, instead. func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) { req, err := c.newRequest(ctx, "GET", "/engines", nil) if err != nil { @@ -142,6 +140,8 @@ func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) { return output, nil } +// The Engines endpoints are deprecated. +// Please use their replacement, Models, instead. func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, error) { req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/engines/%s", engine), nil) if err != nil { @@ -165,7 +165,7 @@ func (c *client) Completion(ctx context.Context, request CompletionRequest) (*Co func (c *client) CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error) { request.Stream = false - req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/completions", engine), request) + req, err := c.newRequest(ctx, "POST", "/completions", request) if err != nil { return nil, err } @@ -195,7 +195,7 @@ func (c *client) CompletionStreamWithEngine( onData func(*CompletionResponse), ) error { request.Stream = true - req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/completions", engine), request) + req, err := c.newRequest(ctx, "POST", "/completions", request) if err != nil { return err } diff --git a/gpt3_test.go b/gpt3_test.go index 6b36b65..244991f 100644 --- a/gpt3_test.go +++ b/gpt3_test.go @@ -61,7 +61,7 @@ func TestRequestCreationFails(t *testing.T) { func() (interface{}, error) { return client.Completion(ctx, gpt3.CompletionRequest{}) }, - "Post \"https://api.openai.com/v1/engines/davinci/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { "CompletionStream", func() (interface{}, error) { @@ -71,13 +71,13 @@ func TestRequestCreationFails(t *testing.T) { } return rsp, client.CompletionStream(ctx, gpt3.CompletionRequest{}, onData) }, - "Post \"https://api.openai.com/v1/engines/davinci/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { "CompletionWithEngine", func() (interface{}, error) { return client.CompletionWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}) }, - "Post \"https://api.openai.com/v1/engines/ada/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { "CompletionStreamWithEngine", func() (interface{}, error) { @@ -87,7 +87,7 @@ func TestRequestCreationFails(t *testing.T) { } return rsp, client.CompletionStreamWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}, onData) }, - "Post \"https://api.openai.com/v1/engines/ada/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { "Edits", func() (interface{}, error) { @@ -149,7 +149,7 @@ func TestResponses(t *testing.T) { }, &gpt3.EnginesResponse{ Data: []gpt3.EngineObject{ - gpt3.EngineObject{ + { ID: "123", Object: "list", Owner: "owner", @@ -181,7 +181,7 @@ func TestResponses(t *testing.T) { Created: 123456789, Model: "davinci-12", Choices: []gpt3.CompletionResponseChoice{ - gpt3.CompletionResponseChoice{ + { Text: "output", FinishReason: "stop", }, @@ -208,7 +208,7 @@ func TestResponses(t *testing.T) { Created: 123456789, Model: "davinci-12", Choices: []gpt3.CompletionResponseChoice{ - gpt3.CompletionResponseChoice{ + { Text: "output", FinishReason: "stop", }, @@ -231,7 +231,7 @@ func TestResponses(t *testing.T) { }, &gpt3.SearchResponse{ Data: []gpt3.SearchData{ - gpt3.SearchData{ + { Document: 1, Object: "search_result", Score: 40.312, @@ -245,7 +245,7 @@ func TestResponses(t *testing.T) { }, &gpt3.SearchResponse{ Data: []gpt3.SearchData{ - gpt3.SearchData{ + { Document: 1, Object: "search_result", Score: 40.312, diff --git a/models.go b/models.go index 884906c..6f00ae9 100644 --- a/models.go +++ b/models.go @@ -35,9 +35,13 @@ type EnginesResponse struct { // CompletionRequest is a request for the completions API type CompletionRequest struct { // A list of string prompts to use. + // ID of the model to use. + Model string `json:"model"` // TODO there are other prompt types here for using token integers that we could add support for. Prompt []string `json:"prompt"` - // How many tokens to complete up to. Max of 512 + // The suffix that comes after a completion of inserted text. + Suffix string `json:"suffix,omitempty"` + // How many tokens to complete up to. Most models have a context length of 2048 tokens (except for the newest models, which support 4096). MaxTokens *int `json:"max_tokens,omitempty"` // Sampling temperature to use Temperature *float32 `json:"temperature,omitempty"` @@ -55,6 +59,13 @@ type CompletionRequest struct { PresencePenalty float32 `json:"presence_penalty"` // FrequencyPenalty number between 0 and 1 that penalizes tokens on existing frequency in the text so far. FrequencyPenalty float32 `json:"frequency_penalty"` + // Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). + // Results cannot be streamed. + BestOf int `json:"best_of,omitempty"` + // Modify the likelihood of specified tokens appearing in the completion. + LogitBias map[string]int `json:"logit_bias,omitempty"` + // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse + User string `json:"user,omitempty"` // Whether to stream back results or not. Don't set this value in the request yourself // as it will be overriden depending on if you use CompletionStream or Completion methods.