diff --git a/go.mod b/go.mod index ed3a8501..6ade4286 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/disintegration/imaging v1.6.2 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/gage-technologies/mistral-go v0.1.1 // indirect github.com/gdamore/encoding v1.0.0 // indirect github.com/go-fed/httpsig v1.1.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect diff --git a/go.sum b/go.sum index cd89bb33..d486c240 100644 --- a/go.sum +++ b/go.sum @@ -108,6 +108,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gage-technologies/mistral-go v0.1.1 h1:BFSXVJoyPEr/niKbVyWl/vMBvSIEGiKp3GHqhppojcc= +github.com/gage-technologies/mistral-go v0.1.1/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= github.com/gdamore/tcell/v2 v2.7.0 h1:I5LiGTQuwrysAt1KS9wg1yFfOI3arI3ucFrxtd/xqaA= diff --git a/pkg/mutators/single/mistral.go b/pkg/mutators/single/mistral.go new file mode 100644 index 00000000..e3c186a6 --- /dev/null +++ b/pkg/mutators/single/mistral.go @@ -0,0 +1,92 @@ +package mutators + +import ( + "io" + "strconv" + "strings" + + "github.com/batmac/ccat/pkg/log" + "github.com/batmac/ccat/pkg/secretprovider" + + "github.com/gage-technologies/mistral-go" +) + +// https://platform.openai.com/docs/guides/chat + +func init() { + singleRegister("mistralai", mistralai, + withDescription("ask MistralAI, X: max replied tokens, the optional second arg is the model (Requires a valid key in $MISTRAL_API_KEY)"), + withConfigBuilder(stdConfigStrings(0, 2)), + withAliases("mistral"), + withHintSlow(), // output asap (when no other mutator is used) + withCategory("external APIs"), + ) +} + +func mistralai(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { + args := conf.([]string) + model := "mistral-tiny" + maxTokens := 4000 + var err error + if len(args) > 0 && args[0] != "" { + maxTokens, err = strconv.Atoi(args[0]) + if err != nil { + log.Println("first arg: ", err) + } + } + if len(args) >= 2 && args[1] != "" { + model = args[1] + } + + key, _ := secretprovider.GetSecret("mistralai", "MISTRAL_API_KEY") + if key == "" { + log.Fatal("MISTRAL_API_KEY environment variable is not set") + } + + log.Debugln("model: ", model) + log.Debugln("maxTokens: ", maxTokens) + + client := mistral.NewMistralClientDefault(key) + // log.Debugf("models: %+v", listModels(client)) + + prompt, err := io.ReadAll(r) + if err != nil { + return 0, err + } + + req := []mistral.ChatMessage{{Content: string(prompt), Role: mistral.RoleUser}} + log.Debugf("request: %#v", req) + if key == "CI" { + log.Println("MISTRAL_API_KEY is set to CI, using fake response") + return io.Copy(w, strings.NewReader("CI")) + } + params := mistral.DefaultChatRequestParams + params.MaxTokens = maxTokens + stream, err := client.ChatStream(model, req, ¶ms) + if err != nil { + return 0, err + } + + defer func() { + if _, err = w.Write([]byte("\n")); err != nil { + log.Println(err) + } + }() + + var totalWritten int64 + var steps int + for chunk := range stream { + if chunk.Error != nil { + return 0, chunk.Error + } + log.Debugf("chunk: %#v", chunk) + n, err := w.Write([]byte(chunk.Choices[0].Delta.Content)) + if err != nil { + return 0, err + } + totalWritten += int64(n) + steps++ + } + log.Debugf("finished after %d steps.", steps) + return totalWritten, nil +} diff --git a/pkg/mutators/single/mistral_test.go b/pkg/mutators/single/mistral_test.go new file mode 100644 index 00000000..707670ae --- /dev/null +++ b/pkg/mutators/single/mistral_test.go @@ -0,0 +1,19 @@ +package mutators_test + +import ( + "testing" + + "github.com/batmac/ccat/pkg/mutators" +) + +func Test_mistral(t *testing.T) { + // only test that we do not panic + t.Setenv("MISTRAL_API_KEY", "CI") + + f := "mistral:100:fakemodel" + t.Run("donotpanicplease", func(t *testing.T) { + if got := mutators.Run(f, "hi"); got != "CI" { + t.Errorf("%s = %v, want %v", f, got, "CI") + } + }) +}