Skip to content

Commit

Permalink
feat: add -a command on the ask subcommand to pass in a suffix
Browse files Browse the repository at this point in the history
  • Loading branch information
yiblet committed Feb 2, 2024
1 parent d003be5 commit 225f435
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions ask_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package main
import (
"context"
"fmt"
"io"
"os"
"strings"

"github.com/PullRequestInc/go-gpt3"
Expand Down Expand Up @@ -30,17 +32,48 @@ type askCmd struct {
Temperature float32 `default:"0.7"`
Bash bool `arg:"--bash" help:"output only valid bash"`
Model string `arg:"--model,-m" help:"set openai model"`
Attach []string `arg:"--attach,-a" help:"attach additional files at the end of the message"`
}

func (args *askCmd) messages() []gpt3.ChatCompletionRequestMessage {
func (args *askCmd) buildContent(ctx context.Context) (string, error) {
var sb strings.Builder
for idx, q := range args.Question {
if idx != 0 {
sb.WriteRune(' ')
}
sb.WriteString(q)
}

if len(args.Question) > 0 &&
!strings.HasSuffix(args.Question[len(args.Question)-1], "\n") {
sb.WriteRune('\n')
}

for _, a := range args.Attach {
sb.WriteRune('\n')
file, err := os.Open(a)
if err != nil {
return "", err
}
defer file.Close()
_, err = io.Copy(&sb, file)
if err != nil {
return "", err
}
}

return sb.String(), nil
}

func (args *askCmd) messages(content string) []gpt3.ChatCompletionRequestMessage {
if args.Bash {
return []gpt3.ChatCompletionRequestMessage{
{Role: "system", Content: systemMessage},
{Role: "user", Content: strings.Join(args.Question, " ")},
{Role: "user", Content: content},
}
} else {
return []gpt3.ChatCompletionRequestMessage{
{Role: "system", Content: strings.Join(args.Question, " ")},
{Role: "system", Content: content},
}
}

Expand All @@ -54,8 +87,12 @@ func (args *askCmd) Execute(ctx context.Context, config *config) error {
client := config.Client()

lastMessage := ""
err := client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{
Messages: args.messages(),
content, err := args.buildContent(ctx)
if err != nil {
return fmt.Errorf("cannot build message: %w", err)
}
err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{
Messages: args.messages(content),
MaxTokens: args.MaxTokens,
Temperature: &args.Temperature,
Stream: true,
Expand Down

0 comments on commit 225f435

Please sign in to comment.