Skip to content

Commit

Permalink
add chatgpt model support (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
sozercan authored Mar 22, 2023
1 parent 2de181b commit 5efdcf3
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 74 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ For both OpenAI and Azure OpenAI, you can use the following environment variable

```shell
export OPENAI_API_KEY=<your OpenAI key>
export OPENAI_DEPLOYMENT_NAME=<your OpenAI deployment/model name. defaults to "gpt-3.5-turbo">
```

> Following models are supported:
> - `code-davinci-002`
> - `text-davinci-003`
> - `gpt-3.5-turbo-0301` (deployment must be named `gpt-35-turbo-0301` for Azure )
> - `gpt-3.5-turbo`
> - `gpt-35-turbo-0301`
For Azure OpenAI Service, you can use the following environment variables:

```shell
export AZURE_OPENAI_ENDPOINT=<your Azure OpenAI endpoint, like https://my-aoi-endpoint.openai.azure.com>
export OPENAI_DEPLOYMENT_NAME=<your OpenAI deployment/model name. defaults to "text-davinci-003">
export AZURE_OPENAI_ENDPOINT=<your Azure OpenAI endpoint, like "https://my-aoi-endpoint.openai.azure.com">
```

If `AZURE_OPENAI_ENDPOINT` variable is set, then it will use the Azure OpenAI Service. Otherwise, it will use OpenAI API.
Expand Down
124 changes: 124 additions & 0 deletions cmd/cli/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package cli

import (
"context"
"errors"
"fmt"
"regexp"
"strings"

openai "github.com/PullRequestInc/go-gpt3"
gptEncoder "github.com/samber/go-gpt-3-encoder"
azureopenai "github.com/sozercan/kubectl-ai/pkg/gpt3"
)

const userRole = "user"

var maxTokensMap = map[string]int{
"code-davinci-002": 8001,
"text-davinci-003": 4097,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo": 4096,
"gpt-35-turbo-0301": 4096, // for azure
}

type oaiClients struct {
azureClient azureopenai.Client
openAIClient openai.Client
}

func newOAIClients() (oaiClients, error) {
var oaiClient openai.Client
var azureClient azureopenai.Client
var err error

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
oaiClient = openai.NewClient(*openAIAPIKey)
} else {
re := regexp.MustCompile(`^[a-zA-Z0-9]+([_-]?[a-zA-Z0-9]+)*$`)
if !re.MatchString(*openAIDeploymentName) {
err := errors.New("azure openai deployment can only include alphanumeric characters, '_,-', and can't end with '_' or '-'")
return oaiClients{}, err
}

azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName)
if err != nil {
return oaiClients{}, err
}
}

clients := oaiClients{
azureClient: azureClient,
openAIClient: oaiClient,
}
return clients, nil
}

func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) {
temp := float32(*temperature)
maxTokens, err := calculateMaxTokens(prompts, deploymentName)
if err != nil {
return "", err
}

var prompt strings.Builder
fmt.Fprintf(&prompt, "You are a Kubernetes YAML generator, only generate valid Kubernetes YAML manifests.")
for _, p := range prompts {
fmt.Fprintf(&prompt, "%s\n", p)
}

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
if *openAIDeploymentName == "gpt-3.5-turbo-0301" || *openAIDeploymentName == "gpt-3.5-turbo" {
resp, err := client.openaiGptChatCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
}
return resp, nil
}

resp, err := client.openaiGptCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
}
return resp, nil
}

if *openAIDeploymentName == "gpt-35-turbo-0301" || *openAIDeploymentName == "gpt-35-turbo" {
resp, err := client.azureGptChatCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
}
return resp, nil
}

resp, err := client.azureGptCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
}
return resp, nil
}

func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) {
maxTokens, ok := maxTokensMap[deploymentName]
if !ok {
return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName)
}

encoder, err := gptEncoder.NewEncoder()
if err != nil {
return nil, err
}

// start at 100 since the encoder at times doesn't get it exactly correct
totalTokens := 100
for _, prompt := range prompts {
tokens, err := encoder.Encode(prompt)
if err != nil {
return nil, err
}
totalTokens += len(tokens)
}

remainingTokens := maxTokens - totalTokens
return &remainingTokens, nil
}
116 changes: 50 additions & 66 deletions cmd/cli/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,70 +6,55 @@ import (
"strings"

openai "github.com/PullRequestInc/go-gpt3"
gptEncoder "github.com/samber/go-gpt-3-encoder"
azureopenai "github.com/sozercan/kubectl-ai/pkg/gpt3"
"github.com/sozercan/kubectl-ai/pkg/utils"
)

type oaiClients struct {
azureClient azureopenai.Client
openAIClient openai.Client
}

func newOAIClients() (oaiClients, error) {
var oaiClient openai.Client
var azureClient azureopenai.Client
var err error

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
oaiClient = openai.NewClient(*openAIAPIKey)
} else {
azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName)
if err != nil {
return oaiClients{}, err
}
func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) {
resp, err := c.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{
Prompt: []string{prompt.String()},
MaxTokens: maxTokens,
Echo: false,
N: utils.ToPtr(1),
Temperature: &temp,
})
if err != nil {
return "", err
}

clients := oaiClients{
azureClient: azureClient,
openAIClient: oaiClient,
if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
}
return clients, nil

return resp.Choices[0].Text, nil
}

func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) {
temp := float32(*temperature)
maxTokens, err := calculateMaxTokens(prompts, deploymentName)
func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) {
resp, err := c.openAIClient.ChatCompletion(ctx, openai.ChatCompletionRequest{
Model: *openAIDeploymentName,
Messages: []openai.ChatCompletionRequestMessage{
{
Role: userRole,
Content: prompt.String(),
},
},
MaxTokens: *maxTokens,
N: 1,
Temperature: &temp,
})
if err != nil {
return "", err
}

var prompt strings.Builder
fmt.Fprintf(&prompt, "You are a Kubernetes YAML generator, only generate valid Kubernetes YAML manifests.")
for _, p := range prompts {
fmt.Fprintf(&prompt, "%s\n", p)
if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
}

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
resp, err := client.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{
Prompt: []string{prompt.String()},
MaxTokens: maxTokens,
Echo: false,
N: utils.ToPtr(1),
Temperature: &temp,
})
if err != nil {
return "", err
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
}

return resp.Choices[0].Text, nil
}
return resp.Choices[0].Message.Content, nil
}

resp, err := client.azureClient.Completion(ctx, azureopenai.CompletionRequest{
func (c *oaiClients) azureGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) {
resp, err := c.azureClient.Completion(ctx, azureopenai.CompletionRequest{
Prompt: []string{prompt.String()},
MaxTokens: maxTokens,
Echo: false,
Expand All @@ -87,27 +72,26 @@ func gptCompletion(ctx context.Context, client oaiClients, prompts []string, dep
return resp.Choices[0].Text, nil
}

func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) {
maxTokens, ok := maxTokensMap[deploymentName]
if !ok {
return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName)
}

encoder, err := gptEncoder.NewEncoder()
func (c *oaiClients) azureGptChatCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) {
resp, err := c.azureClient.ChatCompletion(ctx, azureopenai.ChatCompletionRequest{
Model: *openAIDeploymentName,
Messages: []azureopenai.ChatCompletionRequestMessage{
{
Role: userRole,
Content: prompt.String(),
},
},
MaxTokens: *maxTokens,
N: 1,
Temperature: &temp,
})
if err != nil {
return nil, err
return "", err
}

// start at 100 since the encoder at times doesn't get it exactly correct
totalTokens := 100
for _, prompt := range prompts {
tokens, err := encoder.Encode(prompt)
if err != nil {
return nil, err
}
totalTokens += len(tokens)
if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
}

remainingTokens := maxTokens - totalTokens
return &remainingTokens, nil
return resp.Choices[0].Message.Content, nil
}
5 changes: 0 additions & 5 deletions cmd/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ var (
temperature = flag.Float64("temperature", env.GetOr("TEMPERATURE", env.WithBitSize(strconv.ParseFloat, 64), 0.0), "The temperature to use for the model. Range is between 0 and 1. Set closer to 0 if your want output to be more deterministic but less creative. Defaults to 0.0.")
)

var maxTokensMap = map[string]int{
"text-davinci-003": 4097,
"code-davinci-002": 8001,
}

func InitAndExecute() {
flag.Parse()

Expand Down
26 changes: 25 additions & 1 deletion pkg/gpt3/gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ import (
)

const (
defaultAPIVersion = "2022-12-01"
defaultAPIVersion = "2023-03-15-preview"
defaultUserAgent = "kubectl-openai"
defaultTimeoutSeconds = 30
)

// A Client is an API client to communicate with the OpenAI gpt-3 APIs.
type Client interface {
// ChatCompletion creates a completion with the Chat completion endpoint which
// is what powers the ChatGPT experience.
ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error)

// Completion creates a completion with the default engine. This is the main endpoint of the API
// which auto-completes based on the given prompt.
Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error)
Expand Down Expand Up @@ -86,6 +90,26 @@ func (c *client) Completion(ctx context.Context, request CompletionRequest) (*Co
return output, nil
}

func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) {
request.Stream = false

req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/openai/deployments/%s/chat/completions", c.deploymentName), request)
if err != nil {
return nil, err
}

resp, err := c.performRequest(req)
if err != nil {
return nil, err
}

output := new(ChatCompletionResponse)
if err := getResponseObject(resp, output); err != nil {
return nil, err
}
return output, nil
}

var (
dataPrefix = []byte("data: ")
doneSequence = []byte("[DONE]")
Expand Down
Loading

0 comments on commit 5efdcf3

Please sign in to comment.