-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add huggingface mutator #308
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
79f49e1
add hf mutator supporting some tasks
batmac 6144aef
add minimal test
batmac a4f6e6d
dont insert newline
batmac ecbdc0a
fix echo tests
batmac bf0f776
add bloomz shortcut
batmac 18dc85b
support inference endpoints
880d83a
add an alias for bigcode/starcoder
900eb36
Merge branch 'main' into huggingface
batmac File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
package mutators | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"os" | ||
"strings" | ||
|
||
"github.com/batmac/ccat/pkg/log" | ||
"github.com/batmac/ccat/pkg/utils" | ||
) | ||
|
||
// https://huggingface.co/docs/api-inference/detailed_parameters | ||
|
||
// Supported tasks: | ||
|
||
// Fill Mask task | ||
// Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models. | ||
// Recommended model: bert-base-uncased (it’s a simple model, but fun to play with). | ||
|
||
// Summarization task | ||
// This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. | ||
// That means that the summary cannot handle full books for instance. Be careful when choosing your model. | ||
// If you want to discuss your summarization needs, please get in touch with us: <[email protected]> | ||
// Recommended model: facebook/bart-large-cnn. | ||
|
||
// Text Classification task | ||
// Usually used for sentiment-analysis this will output the likelihood of classes of an input. | ||
// Recommended model: distilbert-base-uncased-finetuned-sst-2-english | ||
|
||
// Text Generation task | ||
// Use to continue text from a prompt. This is a very generic task. | ||
// Recommended model: gpt2 (it’s a simple model, but fun to play with). | ||
|
||
// Token Classification task | ||
// Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. | ||
// Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english | ||
|
||
// Translation task | ||
// This task is well known to translate text from one language to another | ||
// Recommended model: Helsinki-NLP/opus-mt-ru-en. Helsinki-NLP uploaded many models with many language pairs. | ||
// Recommended model: t5-base. | ||
|
||
var HuggingFaceCommonTasks = map[string]string{ | ||
// keep the keys in lowercase | ||
"fillmask": "bert-base-uncased", | ||
"summarization": "facebook/bart-large-cnn", | ||
"classification": "distilbert-base-uncased-finetuned-sst-2-english", | ||
"text-generation": "gpt2", | ||
"ner": "dbmdz/bert-large-cased-finetuned-conll03-english", | ||
"translation": "t5-base", | ||
"bloom": "bigscience/bloom", | ||
"bloomz": "bigscience/bloomz", | ||
"chat": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", | ||
"starcoder": "bigcode/starcoder", | ||
} | ||
|
||
type HuggingFaceRequest struct { | ||
Inputs string `json:"inputs"` | ||
Options map[string]any `json:"options"` | ||
} | ||
|
||
func init() { | ||
singleRegister("huggingface", huggingface, | ||
withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN, set HUGGING_FACE_ENDPOINT to use an Inference API endpoint)"), | ||
withConfigBuilder(stdConfigStrings(0, 1)), | ||
withAliases("hf"), | ||
withCategory("external APIs"), | ||
) | ||
} | ||
|
||
func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { | ||
arg := conf.([]string) | ||
|
||
baseURL := "https://api-inference.huggingface.co/models/" | ||
|
||
token, source, err := getHuggingFaceToken() | ||
if err != nil && os.Getenv("CI") != "CI" { | ||
return 0, err | ||
} | ||
model := "bigscience/bloom" | ||
if len(arg) >= 1 && arg[0] != "" { | ||
model = arg[0] | ||
} | ||
|
||
log.Debugf("task aliases: %v\n", HuggingFaceCommonTasks) | ||
|
||
if m, ok := HuggingFaceCommonTasks[strings.ToLower(model)]; ok { | ||
model = m | ||
} | ||
|
||
url := baseURL + model | ||
if os.Getenv("HUGGING_FACE_ENDPOINT") != "" { | ||
url = os.Getenv("HUGGING_FACE_ENDPOINT") | ||
if len(arg) >= 1 && arg[0] != "" { | ||
log.Println("warning: HUGGING_FACE_ENDPOINT is set, ignoring model argument") | ||
} | ||
} | ||
|
||
log.Debugln("token: from ", source) | ||
log.Debugln("model: ", model) | ||
log.Debugln("url: ", url) | ||
|
||
input, err := io.ReadAll(r) | ||
if err != nil { | ||
return 0, err | ||
} | ||
request, err := json.Marshal(HuggingFaceRequest{Inputs: string(input), Options: map[string]any{"wait_for_model": true}}) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
log.Debugf("request: %s\n", request) | ||
|
||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(request)) | ||
if err != nil { | ||
return 0, err | ||
} | ||
req.Header.Set("Authorization", "Bearer "+token) | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", "application/json") | ||
req.Header.Set("User-Agent", "ccat") | ||
|
||
if os.Getenv("CI") == "CI" { | ||
_, _ = w.Write([]byte("fake")) | ||
return 0, nil | ||
} | ||
|
||
resp, err := http.DefaultClient.Do(req) | ||
if err != nil { | ||
return 0, err | ||
} | ||
defer resp.Body.Close() | ||
if resp.StatusCode < 200 || resp.StatusCode >= 300 { | ||
return 0, fmt.Errorf("error: %s", resp.Status) | ||
} | ||
|
||
got, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
n, err := w.Write(got) | ||
return int64(n), err | ||
} | ||
|
||
//nolint:gosec | ||
func getHuggingFaceToken() (string, string, error) { | ||
// HUGGING_FACE_HUB_TOKEN, | ||
// then HF_API_KEY, | ||
// then the content of the file $HF_HOME/token | ||
// then the content of the file ~/.huggingface/token | ||
// and finally the content of the file ~/.cache/huggingface/token | ||
token, source := os.Getenv("HUGGING_FACE_HUB_TOKEN"), "HUGGING_FACE_HUB_TOKEN" | ||
batmac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if token == "" { | ||
token, source = os.Getenv("HF_API_KEY"), "HF_API_KEY" | ||
Check failure Code scanning / gosec Potential hardcoded credentials
Potential hardcoded credentials
batmac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
for _, path := range []string{"$HF_HOME/token", "~/.huggingface/token", "~/.cache/huggingface/token"} { | ||
if token != "" { | ||
break | ||
} | ||
content, _ := os.ReadFile(utils.ExpandPath(path)) | ||
token, source = string(content), path | ||
} | ||
|
||
if token == "" || os.Getenv("CI") == "CI" { | ||
return "", "", fmt.Errorf("no HuggingFace token found") | ||
} | ||
return token, source, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package mutators | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/batmac/ccat/pkg/mutators" | ||
) | ||
|
||
func Test_huggingface(t *testing.T) { | ||
// only test that we do not panic | ||
t.Setenv("CI", "CI") | ||
|
||
f := "huggingface" | ||
t.Run("donotpanicplease", func(t *testing.T) { | ||
if got := mutators.Run(f, "hi"); got != "fake" { | ||
t.Errorf("%s = %v, want %v", f, got, "fake") | ||
} | ||
}) | ||
} | ||
|
||
func Test_getHuggingFaceToken(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
want string | ||
source string | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "donotpanic", | ||
want: "", | ||
source: "", | ||
wantErr: true, | ||
}, | ||
} | ||
t.Setenv("CI", "CI") | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, got1, err := getHuggingFaceToken() | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("getHuggingFaceToken() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if got != tt.want { | ||
t.Errorf("getHuggingFaceToken() got = %v, want %v", got, tt.want) | ||
} | ||
if got1 != tt.source { | ||
t.Errorf("getHuggingFaceToken() got1 = %v, want %v", got1, tt.source) | ||
} | ||
}) | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check failure
Code scanning / gosec
Potential hardcoded credentials