From 5efdcf34672cfef0b44c02f30403da8d45546b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Serta=C3=A7=20=C3=96zercan?= <852750+sozercan@users.noreply.github.com> Date: Tue, 21 Mar 2023 22:40:45 -0700 Subject: [PATCH] add chatgpt model support (#13) --- README.md | 11 +++- cmd/cli/completion.go | 124 ++++++++++++++++++++++++++++++++++++++++++ cmd/cli/openai.go | 116 +++++++++++++++++---------------------- cmd/cli/root.go | 5 -- pkg/gpt3/gpt3.go | 26 ++++++++- pkg/gpt3/models.go | 78 ++++++++++++++++++++++++++ 6 files changed, 286 insertions(+), 74 deletions(-) create mode 100644 cmd/cli/completion.go diff --git a/README.md b/README.md index f5f32cc..7dc27c4 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,20 @@ For both OpenAI and Azure OpenAI, you can use the following environment variable ```shell export OPENAI_API_KEY= +export OPENAI_DEPLOYMENT_NAME= ``` +> Following models are supported: +> - `code-davinci-002` +> - `text-davinci-003` +> - `gpt-3.5-turbo-0301` (deployment must be named `gpt-35-turbo-0301` for Azure ) +> - `gpt-3.5-turbo` +> - `gpt-35-turbo-0301` + For Azure OpenAI Service, you can use the following environment variables: ```shell -export AZURE_OPENAI_ENDPOINT= -export OPENAI_DEPLOYMENT_NAME= +export AZURE_OPENAI_ENDPOINT= ``` If `AZURE_OPENAI_ENDPOINT` variable is set, then it will use the Azure OpenAI Service. Otherwise, it will use OpenAI API. diff --git a/cmd/cli/completion.go b/cmd/cli/completion.go new file mode 100644 index 0000000..3c511f9 --- /dev/null +++ b/cmd/cli/completion.go @@ -0,0 +1,124 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + openai "github.com/PullRequestInc/go-gpt3" + gptEncoder "github.com/samber/go-gpt-3-encoder" + azureopenai "github.com/sozercan/kubectl-ai/pkg/gpt3" +) + +const userRole = "user" + +var maxTokensMap = map[string]int{ + "code-davinci-002": 8001, + "text-davinci-003": 4097, + "gpt-3.5-turbo-0301": 4096, + "gpt-3.5-turbo": 4096, + "gpt-35-turbo-0301": 4096, // for azure +} + +type oaiClients struct { + azureClient azureopenai.Client + openAIClient openai.Client +} + +func newOAIClients() (oaiClients, error) { + var oaiClient openai.Client + var azureClient azureopenai.Client + var err error + + if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { + oaiClient = openai.NewClient(*openAIAPIKey) + } else { + re := regexp.MustCompile(`^[a-zA-Z0-9]+([_-]?[a-zA-Z0-9]+)*$`) + if !re.MatchString(*openAIDeploymentName) { + err := errors.New("azure openai deployment can only include alphanumeric characters, '_,-', and can't end with '_' or '-'") + return oaiClients{}, err + } + + azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName) + if err != nil { + return oaiClients{}, err + } + } + + clients := oaiClients{ + azureClient: azureClient, + openAIClient: oaiClient, + } + return clients, nil +} + +func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) { + temp := float32(*temperature) + maxTokens, err := calculateMaxTokens(prompts, deploymentName) + if err != nil { + return "", err + } + + var prompt strings.Builder + fmt.Fprintf(&prompt, "You are a Kubernetes YAML generator, only generate valid Kubernetes YAML manifests.") + for _, p := range prompts { + fmt.Fprintf(&prompt, "%s\n", p) + } + + if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { + if *openAIDeploymentName == "gpt-3.5-turbo-0301" || *openAIDeploymentName == "gpt-3.5-turbo" { + resp, err := client.openaiGptChatCompletion(ctx, prompt, maxTokens, temp) + if err != nil { + return "", err + } + return resp, nil + } + + resp, err := client.openaiGptCompletion(ctx, prompt, maxTokens, temp) + if err != nil { + return "", err + } + return resp, nil + } + + if *openAIDeploymentName == "gpt-35-turbo-0301" || *openAIDeploymentName == "gpt-35-turbo" { + resp, err := client.azureGptChatCompletion(ctx, prompt, maxTokens, temp) + if err != nil { + return "", err + } + return resp, nil + } + + resp, err := client.azureGptCompletion(ctx, prompt, maxTokens, temp) + if err != nil { + return "", err + } + return resp, nil +} + +func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) { + maxTokens, ok := maxTokensMap[deploymentName] + if !ok { + return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName) + } + + encoder, err := gptEncoder.NewEncoder() + if err != nil { + return nil, err + } + + // start at 100 since the encoder at times doesn't get it exactly correct + totalTokens := 100 + for _, prompt := range prompts { + tokens, err := encoder.Encode(prompt) + if err != nil { + return nil, err + } + totalTokens += len(tokens) + } + + remainingTokens := maxTokens - totalTokens + return &remainingTokens, nil +} diff --git a/cmd/cli/openai.go b/cmd/cli/openai.go index 9c73ea9..d331d52 100644 --- a/cmd/cli/openai.go +++ b/cmd/cli/openai.go @@ -6,70 +6,55 @@ import ( "strings" openai "github.com/PullRequestInc/go-gpt3" - gptEncoder "github.com/samber/go-gpt-3-encoder" azureopenai "github.com/sozercan/kubectl-ai/pkg/gpt3" "github.com/sozercan/kubectl-ai/pkg/utils" ) -type oaiClients struct { - azureClient azureopenai.Client - openAIClient openai.Client -} - -func newOAIClients() (oaiClients, error) { - var oaiClient openai.Client - var azureClient azureopenai.Client - var err error - - if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { - oaiClient = openai.NewClient(*openAIAPIKey) - } else { - azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName) - if err != nil { - return oaiClients{}, err - } +func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) { + resp, err := c.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{ + Prompt: []string{prompt.String()}, + MaxTokens: maxTokens, + Echo: false, + N: utils.ToPtr(1), + Temperature: &temp, + }) + if err != nil { + return "", err } - clients := oaiClients{ - azureClient: azureClient, - openAIClient: oaiClient, + if len(resp.Choices) != 1 { + return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) } - return clients, nil + + return resp.Choices[0].Text, nil } -func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) { - temp := float32(*temperature) - maxTokens, err := calculateMaxTokens(prompts, deploymentName) +func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) { + resp, err := c.openAIClient.ChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: *openAIDeploymentName, + Messages: []openai.ChatCompletionRequestMessage{ + { + Role: userRole, + Content: prompt.String(), + }, + }, + MaxTokens: *maxTokens, + N: 1, + Temperature: &temp, + }) if err != nil { return "", err } - var prompt strings.Builder - fmt.Fprintf(&prompt, "You are a Kubernetes YAML generator, only generate valid Kubernetes YAML manifests.") - for _, p := range prompts { - fmt.Fprintf(&prompt, "%s\n", p) + if len(resp.Choices) != 1 { + return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) } - if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { - resp, err := client.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{ - Prompt: []string{prompt.String()}, - MaxTokens: maxTokens, - Echo: false, - N: utils.ToPtr(1), - Temperature: &temp, - }) - if err != nil { - return "", err - } - - if len(resp.Choices) != 1 { - return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) - } - - return resp.Choices[0].Text, nil - } + return resp.Choices[0].Message.Content, nil +} - resp, err := client.azureClient.Completion(ctx, azureopenai.CompletionRequest{ +func (c *oaiClients) azureGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) { + resp, err := c.azureClient.Completion(ctx, azureopenai.CompletionRequest{ Prompt: []string{prompt.String()}, MaxTokens: maxTokens, Echo: false, @@ -87,27 +72,26 @@ func gptCompletion(ctx context.Context, client oaiClients, prompts []string, dep return resp.Choices[0].Text, nil } -func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) { - maxTokens, ok := maxTokensMap[deploymentName] - if !ok { - return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName) - } - - encoder, err := gptEncoder.NewEncoder() +func (c *oaiClients) azureGptChatCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) { + resp, err := c.azureClient.ChatCompletion(ctx, azureopenai.ChatCompletionRequest{ + Model: *openAIDeploymentName, + Messages: []azureopenai.ChatCompletionRequestMessage{ + { + Role: userRole, + Content: prompt.String(), + }, + }, + MaxTokens: *maxTokens, + N: 1, + Temperature: &temp, + }) if err != nil { - return nil, err + return "", err } - // start at 100 since the encoder at times doesn't get it exactly correct - totalTokens := 100 - for _, prompt := range prompts { - tokens, err := encoder.Encode(prompt) - if err != nil { - return nil, err - } - totalTokens += len(tokens) + if len(resp.Choices) != 1 { + return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) } - remainingTokens := maxTokens - totalTokens - return &remainingTokens, nil + return resp.Choices[0].Message.Content, nil } diff --git a/cmd/cli/root.go b/cmd/cli/root.go index 579b4b3..d8b2e9c 100644 --- a/cmd/cli/root.go +++ b/cmd/cli/root.go @@ -26,11 +26,6 @@ var ( temperature = flag.Float64("temperature", env.GetOr("TEMPERATURE", env.WithBitSize(strconv.ParseFloat, 64), 0.0), "The temperature to use for the model. Range is between 0 and 1. Set closer to 0 if your want output to be more deterministic but less creative. Defaults to 0.0.") ) -var maxTokensMap = map[string]int{ - "text-davinci-003": 4097, - "code-davinci-002": 8001, -} - func InitAndExecute() { flag.Parse() diff --git a/pkg/gpt3/gpt3.go b/pkg/gpt3/gpt3.go index f44becb..37e6319 100644 --- a/pkg/gpt3/gpt3.go +++ b/pkg/gpt3/gpt3.go @@ -12,13 +12,17 @@ import ( ) const ( - defaultAPIVersion = "2022-12-01" + defaultAPIVersion = "2023-03-15-preview" defaultUserAgent = "kubectl-openai" defaultTimeoutSeconds = 30 ) // A Client is an API client to communicate with the OpenAI gpt-3 APIs. type Client interface { + // ChatCompletion creates a completion with the Chat completion endpoint which + // is what powers the ChatGPT experience. + ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) + // Completion creates a completion with the default engine. This is the main endpoint of the API // which auto-completes based on the given prompt. Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) @@ -86,6 +90,26 @@ func (c *client) Completion(ctx context.Context, request CompletionRequest) (*Co return output, nil } +func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) { + request.Stream = false + + req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/openai/deployments/%s/chat/completions", c.deploymentName), request) + if err != nil { + return nil, err + } + + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := new(ChatCompletionResponse) + if err := getResponseObject(resp, output); err != nil { + return nil, err + } + return output, nil +} + var ( dataPrefix = []byte("data: ") doneSequence = []byte("[DONE]") diff --git a/pkg/gpt3/models.go b/pkg/gpt3/models.go index 6c48712..ee0a802 100644 --- a/pkg/gpt3/models.go +++ b/pkg/gpt3/models.go @@ -32,6 +32,54 @@ type EnginesResponse struct { Object string `json:"object"` } +// ChatCompletionRequestMessage is a message to use as the context for the chat completion API. +type ChatCompletionRequestMessage struct { + // Role is the role is the role of the the message. Can be "system", "user", or "assistant" + Role string `json:"role"` + + // Content is the content of the message + Content string `json:"content"` +} + +// ChatCompletionRequest is a request for the chat completion API. +type ChatCompletionRequest struct { + // Model is the name of the model to use. If not specified, will default to gpt-3.5-turbo. + Model string `json:"model"` + + // Messages is a list of messages to use as the context for the chat completion. + Messages []ChatCompletionRequestMessage `json:"messages"` + + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic + Temperature *float32 `json:"temperature,omitempty"` + + // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + TopP float32 `json:"top_p,omitempty"` + + // Number of responses to generate + N int `json:"n,omitempty"` + + // Whether or not to stream responses back as they are generated + Stream bool `json:"stream,omitempty"` + + // Up to 4 sequences where the API will stop generating further tokens. + Stop []string `json:"stop,omitempty"` + + // MaxTokens is the maximum number of tokens to return. + MaxTokens int `json:"max_tokens,omitempty"` + + // (-2, 2) Penalize tokens that haven't appeared yet in the history. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + + // (-2, 2) Penalize tokens that appear too frequently in the history. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + + // Modify the probability of specific tokens appearing in the completion. + LogitBias map[string]float32 `json:"logit_bias,omitempty"` + + // Can be used to identify an end-user + User string `json:"user,omitempty"` +} + // CompletionRequest is a request for the completions API. type CompletionRequest struct { // A list of string prompts to use. @@ -94,6 +142,36 @@ type EmbeddingsRequest struct { User string `json:"user,omitempty"` } +// ChatCompletionResponseMessage is a message returned in the response to the Chat Completions API. +type ChatCompletionResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionResponseChoice is one of the choices returned in the response to the Chat Completions API. +type ChatCompletionResponseChoice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Message ChatCompletionResponseMessage `json:"message"` +} + +// ChatCompletionsResponseUsage is the object that returns how many tokens the completion's request used. +type ChatCompletionsResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatCompletionResponse is the full response from a request to the Chat Completions API. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionResponseChoice `json:"choices"` + Usage ChatCompletionsResponseUsage `json:"usage"` +} + // LogprobResult represents logprob result of Choice. type LogprobResult struct { Tokens []string `json:"tokens"`