From 79f49e18de74841213b7c9ea28b80da2fdc135f3 Mon Sep 17 00:00:00 2001 From: Baptiste Canton Date: Sat, 15 Apr 2023 13:20:38 +0200 Subject: [PATCH 1/7] add hf mutator supporting some tasks --- pkg/mutators/single/huggingface.go | 151 +++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 pkg/mutators/single/huggingface.go diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go new file mode 100644 index 00000000..ffc7037f --- /dev/null +++ b/pkg/mutators/single/huggingface.go @@ -0,0 +1,151 @@ +package mutators + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + + "github.com/batmac/ccat/pkg/log" + "github.com/batmac/ccat/pkg/utils" +) + +// https://huggingface.co/docs/api-inference/detailed_parameters + +// Supported tasks: + +// Fill Mask task +// Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models. +// Recommended model: bert-base-uncased (it’s a simple model, but fun to play with). + +// Summarization task +// This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. +// That means that the summary cannot handle full books for instance. Be careful when choosing your model. +// If you want to discuss your summarization needs, please get in touch with us: +// Recommended model: facebook/bart-large-cnn. + +// Text Classification task +// Usually used for sentiment-analysis this will output the likelihood of classes of an input. +// Recommended model: distilbert-base-uncased-finetuned-sst-2-english + +// Text Generation task +// Use to continue text from a prompt. This is a very generic task. +// Recommended model: gpt2 (it’s a simple model, but fun to play with). + +// Token Classification task +// Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. +// Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english + +// Translation task +// This task is well known to translate text from one language to another +// Recommended model: Helsinki-NLP/opus-mt-ru-en. Helsinki-NLP uploaded many models with many language pairs. +// Recommended model: t5-base. + +var HuggingFaceCommonTasks = map[string]string{ + "fillmask": "bert-base-uncased", + "summarization": "facebook/bart-large-cnn", + "classification": "distilbert-base-uncased-finetuned-sst-2-english", + "text-generation": "gpt2", + "ner": "dbmdz/bert-large-cased-finetuned-conll03-english", + "translation": "t5-base", + "bloom": "bigscience/bloom", +} + +type HuggingFaceRequest struct { + Inputs string `json:"inputs"` + Options map[string]any `json:"options"` +} + +func init() { + singleRegister("huggingface", huggingface, + withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN)"), + withConfigBuilder(stdConfigStrings(0, 1)), + withAliases("hf"), + withCategory("external APIs"), + ) +} + +func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { + arg := conf.([]string) + baseURL := "https://api-inference.huggingface.co/models/" + + token, source, err := getHuggingFaceToken() + if err != nil { + return 0, err + } + model := "bigscience/bloom" + if len(arg) >= 1 && arg[0] != "" { + model = arg[0] + } + + log.Debugf("task aliases: %v\n", HuggingFaceCommonTasks) + if m, ok := HuggingFaceCommonTasks[model]; ok { + model = m + } + + log.Debugln("token: from ", source) + log.Debugln("model: ", model) + + input, err := io.ReadAll(r) + if err != nil { + return 0, err + } + request, err := json.Marshal(HuggingFaceRequest{Inputs: string(input), Options: map[string]any{"wait_for_model": true}}) + if err != nil { + return 0, err + } + + log.Debugf("request: %s\n", request) + + req, err := http.NewRequest(http.MethodPost, baseURL+model, bytes.NewReader(request)) + if err != nil { + return 0, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "ccat") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return 0, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return 0, fmt.Errorf("error: %s", resp.Status) + } + + got, err := io.ReadAll(resp.Body) + if err != nil { + return 0, err + } + + n, err := w.Write(got) + return int64(n), err +} + +//nolint:gosec +func getHuggingFaceToken() (string, string, error) { + // HUGGING_FACE_HUB_TOKEN, + // then HF_API_KEY, + // then the content of the file $HF_HOME/token + // then the content of the file ~/.huggingface/token + // and finally the content of the file ~/.cache/huggingface/token + token, source := os.Getenv("HUGGING_FACE_HUB_TOKEN"), "HUGGING_FACE_HUB_TOKEN" + if token == "" { + token, source = os.Getenv("HF_API_KEY"), "HF_API_KEY" + } + for _, path := range []string{"$HF_HOME/token", "~/.huggingface/token", "~/.cache/huggingface/token"} { + if token != "" { + break + } + content, _ := os.ReadFile(utils.ExpandPath(path)) + token, source = string(content), path + } + + if token == "" { + return "", "", fmt.Errorf("no HuggingFace token found") + } + return token, source, nil +} From 6144aef305e477c41e7e36fdcee0cbb832038f79 Mon Sep 17 00:00:00 2001 From: Baptiste Canton Date: Sat, 15 Apr 2023 17:58:43 +0200 Subject: [PATCH 2/7] add minimal test --- pkg/mutators/single/huggingface.go | 10 +++- .../single/huggingface_internal_test.go | 51 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 pkg/mutators/single/huggingface_internal_test.go diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go index ffc7037f..bf084e3f 100644 --- a/pkg/mutators/single/huggingface.go +++ b/pkg/mutators/single/huggingface.go @@ -72,7 +72,7 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { baseURL := "https://api-inference.huggingface.co/models/" token, source, err := getHuggingFaceToken() - if err != nil { + if err != nil && os.Getenv("CI") != "CI" { return 0, err } model := "bigscience/bloom" @@ -107,6 +107,12 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", "ccat") + + if os.Getenv("CI") == "CI" { + _, _ = w.Write([]byte("fake")) + return 0, nil + } + resp, err := http.DefaultClient.Do(req) if err != nil { return 0, err @@ -144,7 +150,7 @@ func getHuggingFaceToken() (string, string, error) { token, source = string(content), path } - if token == "" { + if token == "" || os.Getenv("CI") == "CI" { return "", "", fmt.Errorf("no HuggingFace token found") } return token, source, nil diff --git a/pkg/mutators/single/huggingface_internal_test.go b/pkg/mutators/single/huggingface_internal_test.go new file mode 100644 index 00000000..aea1d83f --- /dev/null +++ b/pkg/mutators/single/huggingface_internal_test.go @@ -0,0 +1,51 @@ +package mutators + +import ( + "testing" + + "github.com/batmac/ccat/pkg/mutators" +) + +func Test_huggingface(t *testing.T) { + // only test that we do not panic + t.Setenv("CI", "CI") + + f := "huggingface" + t.Run("donotpanicplease", func(t *testing.T) { + if got := mutators.Run(f, "hi"); got != "fake" { + t.Errorf("%s = %v, want %v", f, got, "fake") + } + }) +} + +func Test_getHuggingFaceToken(t *testing.T) { + tests := []struct { + name string + want string + source string + wantErr bool + }{ + { + name: "donotpanic", + want: "", + source: "", + wantErr: true, + }, + } + t.Setenv("CI", "CI") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := getHuggingFaceToken() + if (err != nil) != tt.wantErr { + t.Errorf("getHuggingFaceToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getHuggingFaceToken() got = %v, want %v", got, tt.want) + } + if got1 != tt.source { + t.Errorf("getHuggingFaceToken() got1 = %v, want %v", got1, tt.source) + } + }) + } +} From a4f6e6d45651405c40a2d83e337e8beb338d0eac Mon Sep 17 00:00:00 2001 From: Baptiste Canton Date: Mon, 17 Apr 2023 13:10:48 +0200 Subject: [PATCH 3/7] dont insert newline --- pkg/openers/echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/openers/echo.go b/pkg/openers/echo.go index 611d18d2..13b37e87 100644 --- a/pkg/openers/echo.go +++ b/pkg/openers/echo.go @@ -46,7 +46,7 @@ func (f echoOpener) Open(s string, _ bool) (io.ReadCloser, error) { return nil, fmt.Errorf("no data given") } - datareader := io.NopCloser(io.MultiReader(strings.NewReader(s), strings.NewReader("\n"))) + datareader := io.NopCloser(strings.NewReader(s)) return datareader, nil } From ecbdc0ac231dfeb5cd1ee902adf9777b7ba8c4c4 Mon Sep 17 00:00:00 2001 From: Baptiste Canton Date: Mon, 17 Apr 2023 19:30:22 +0200 Subject: [PATCH 4/7] fix echo tests --- pkg/openers/echo_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/openers/echo_test.go b/pkg/openers/echo_test.go index b88cefe7..5e12bf6f 100644 --- a/pkg/openers/echo_test.go +++ b/pkg/openers/echo_test.go @@ -17,14 +17,14 @@ func Test_echoOpener_Open(t *testing.T) { wantErr bool wanted string }{ - {name: "ok", args: "echo://salut", wantErr: false, wanted: "salut\n"}, + {name: "ok", args: "echo://salut", wantErr: false, wanted: "salut"}, {name: "ko", args: "echo://", wantErr: true, wanted: ""}, - {name: "ok double-quoted", args: "echo://\"salut yo\"", wantErr: false, wanted: "salut yo\n"}, - {name: "ok single-quoted", args: "echo://'salut again'", wantErr: false, wanted: "salut again\n"}, + {name: "ok double-quoted", args: "echo://\"salut yo\"", wantErr: false, wanted: "salut yo"}, + {name: "ok single-quoted", args: "echo://'salut again'", wantErr: false, wanted: "salut again"}, {name: "ko double-quoted", args: "echo://\"\"", wantErr: true, wanted: ""}, {name: "ko single-quoted", args: "echo://''", wantErr: true, wanted: ""}, - {name: "emoji", args: "echo://👍 🥳", wantErr: false, wanted: "👍 🥳\n"}, - {name: "with 'zero' byte in it", args: "echo://salut\x00and again", wantErr: false, wanted: "salut\x00and again\n"}, + {name: "emoji", args: "echo://👍 🥳", wantErr: false, wanted: "👍 🥳"}, + {name: "with 'zero' byte in it", args: "echo://salut\x00and again", wantErr: false, wanted: "salut\x00and again"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From bf0f776fce99c40689239eefa53558f5b8f85a94 Mon Sep 17 00:00:00 2001 From: Baptiste Canton Date: Fri, 21 Apr 2023 19:52:33 +0200 Subject: [PATCH 5/7] add bloomz shortcut --- pkg/mutators/single/huggingface.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go index bf084e3f..bb56e2a6 100644 --- a/pkg/mutators/single/huggingface.go +++ b/pkg/mutators/single/huggingface.go @@ -51,6 +51,7 @@ var HuggingFaceCommonTasks = map[string]string{ "ner": "dbmdz/bert-large-cased-finetuned-conll03-english", "translation": "t5-base", "bloom": "bigscience/bloom", + "bloomz": "bigscience/bloomz", } type HuggingFaceRequest struct { From 18dc85b59358f0c45cef34d9f85001d783072b46 Mon Sep 17 00:00:00 2001 From: batmac Date: Sat, 29 Apr 2023 03:24:49 +0200 Subject: [PATCH 6/7] support inference endpoints --- pkg/mutators/single/huggingface.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go index bb56e2a6..c7f1b9e1 100644 --- a/pkg/mutators/single/huggingface.go +++ b/pkg/mutators/single/huggingface.go @@ -61,7 +61,7 @@ type HuggingFaceRequest struct { func init() { singleRegister("huggingface", huggingface, - withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN)"), + withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN, set HUGGING_FACE_ENDPOINT to use an Inference API endpoint)"), withConfigBuilder(stdConfigStrings(0, 1)), withAliases("hf"), withCategory("external APIs"), @@ -70,6 +70,7 @@ func init() { func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { arg := conf.([]string) + baseURL := "https://api-inference.huggingface.co/models/" token, source, err := getHuggingFaceToken() @@ -86,8 +87,17 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { model = m } + url := baseURL + model + if os.Getenv("HUGGING_FACE_ENDPOINT") != "" { + url = os.Getenv("HUGGING_FACE_ENDPOINT") + if len(arg) >= 1 && arg[0] != "" { + log.Println("warning: HUGGING_FACE_ENDPOINT is set, ignoring model argument") + } + } + log.Debugln("token: from ", source) log.Debugln("model: ", model) + log.Debugln("url: ", url) input, err := io.ReadAll(r) if err != nil { @@ -100,7 +110,7 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { log.Debugf("request: %s\n", request) - req, err := http.NewRequest(http.MethodPost, baseURL+model, bytes.NewReader(request)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(request)) if err != nil { return 0, err } From 880d83aa541d7ba6f7aa2b9f2b56f548c4d2c9f4 Mon Sep 17 00:00:00 2001 From: batmac Date: Sat, 6 May 2023 02:20:22 +0200 Subject: [PATCH 7/7] add an alias for bigcode/starcoder --- pkg/mutators/single/huggingface.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go index c7f1b9e1..8505ab7d 100644 --- a/pkg/mutators/single/huggingface.go +++ b/pkg/mutators/single/huggingface.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "strings" "github.com/batmac/ccat/pkg/log" "github.com/batmac/ccat/pkg/utils" @@ -44,6 +45,7 @@ import ( // Recommended model: t5-base. var HuggingFaceCommonTasks = map[string]string{ + // keep the keys in lowercase "fillmask": "bert-base-uncased", "summarization": "facebook/bart-large-cnn", "classification": "distilbert-base-uncased-finetuned-sst-2-english", @@ -52,6 +54,8 @@ var HuggingFaceCommonTasks = map[string]string{ "translation": "t5-base", "bloom": "bigscience/bloom", "bloomz": "bigscience/bloomz", + "chat": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", + "starcoder": "bigcode/starcoder", } type HuggingFaceRequest struct { @@ -83,7 +87,8 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { } log.Debugf("task aliases: %v\n", HuggingFaceCommonTasks) - if m, ok := HuggingFaceCommonTasks[model]; ok { + + if m, ok := HuggingFaceCommonTasks[strings.ToLower(model)]; ok { model = m }