Skip to content

Commit

Permalink
add huggingface mutator (#308)
Browse files Browse the repository at this point in the history
* add hf mutator supporting some tasks

* add minimal test

* dont insert newline

* fix echo tests

* add bloomz shortcut

* support inference endpoints

* add an alias for bigcode/starcoder

---------

Co-authored-by: batmac <[email protected]>
  • Loading branch information
batmac and batmac committed May 13, 2023
1 parent 26be2bc commit 6f56e4a
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 0 deletions.
173 changes: 173 additions & 0 deletions pkg/mutators/single/huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package mutators

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"

"github.com/batmac/ccat/pkg/log"
"github.com/batmac/ccat/pkg/utils"
)

// https://huggingface.co/docs/api-inference/detailed_parameters

// Supported tasks:

// Fill Mask task
// Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
// Recommended model: bert-base-uncased (it’s a simple model, but fun to play with).

// Summarization task
// This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input.
// That means that the summary cannot handle full books for instance. Be careful when choosing your model.
// If you want to discuss your summarization needs, please get in touch with us: <[email protected]>
// Recommended model: facebook/bart-large-cnn.

// Text Classification task
// Usually used for sentiment-analysis this will output the likelihood of classes of an input.
// Recommended model: distilbert-base-uncased-finetuned-sst-2-english

// Text Generation task
// Use to continue text from a prompt. This is a very generic task.
// Recommended model: gpt2 (it’s a simple model, but fun to play with).

// Token Classification task
// Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
// Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english

// Translation task
// This task is well known to translate text from one language to another
// Recommended model: Helsinki-NLP/opus-mt-ru-en. Helsinki-NLP uploaded many models with many language pairs.
// Recommended model: t5-base.

var HuggingFaceCommonTasks = map[string]string{
// keep the keys in lowercase
"fillmask": "bert-base-uncased",
"summarization": "facebook/bart-large-cnn",
"classification": "distilbert-base-uncased-finetuned-sst-2-english",
"text-generation": "gpt2",
"ner": "dbmdz/bert-large-cased-finetuned-conll03-english",
"translation": "t5-base",
"bloom": "bigscience/bloom",
"bloomz": "bigscience/bloomz",
"chat": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"starcoder": "bigcode/starcoder",
}

type HuggingFaceRequest struct {
Inputs string `json:"inputs"`
Options map[string]any `json:"options"`
}

func init() {
singleRegister("huggingface", huggingface,
withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN, set HUGGING_FACE_ENDPOINT to use an Inference API endpoint)"),
withConfigBuilder(stdConfigStrings(0, 1)),
withAliases("hf"),
withCategory("external APIs"),
)
}

func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) {
arg := conf.([]string)

baseURL := "https://api-inference.huggingface.co/models/"

token, source, err := getHuggingFaceToken()
if err != nil && os.Getenv("CI") != "CI" {
return 0, err
}
model := "bigscience/bloom"
if len(arg) >= 1 && arg[0] != "" {
model = arg[0]
}

log.Debugf("task aliases: %v\n", HuggingFaceCommonTasks)

if m, ok := HuggingFaceCommonTasks[strings.ToLower(model)]; ok {
model = m
}

url := baseURL + model
if os.Getenv("HUGGING_FACE_ENDPOINT") != "" {
url = os.Getenv("HUGGING_FACE_ENDPOINT")
if len(arg) >= 1 && arg[0] != "" {
log.Println("warning: HUGGING_FACE_ENDPOINT is set, ignoring model argument")
}
}

log.Debugln("token: from ", source)
log.Debugln("model: ", model)
log.Debugln("url: ", url)

input, err := io.ReadAll(r)
if err != nil {
return 0, err
}
request, err := json.Marshal(HuggingFaceRequest{Inputs: string(input), Options: map[string]any{"wait_for_model": true}})
if err != nil {
return 0, err
}

log.Debugf("request: %s\n", request)

req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(request))
if err != nil {
return 0, err
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "ccat")

if os.Getenv("CI") == "CI" {
_, _ = w.Write([]byte("fake"))
return 0, nil
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return 0, fmt.Errorf("error: %s", resp.Status)
}

got, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}

n, err := w.Write(got)
return int64(n), err
}

//nolint:gosec
func getHuggingFaceToken() (string, string, error) {
// HUGGING_FACE_HUB_TOKEN,
// then HF_API_KEY,
// then the content of the file $HF_HOME/token
// then the content of the file ~/.huggingface/token
// and finally the content of the file ~/.cache/huggingface/token
token, source := os.Getenv("HUGGING_FACE_HUB_TOKEN"), "HUGGING_FACE_HUB_TOKEN"
if token == "" {
token, source = os.Getenv("HF_API_KEY"), "HF_API_KEY"
}
for _, path := range []string{"$HF_HOME/token", "~/.huggingface/token", "~/.cache/huggingface/token"} {
if token != "" {
break
}
content, _ := os.ReadFile(utils.ExpandPath(path))
token, source = string(content), path
}

if token == "" || os.Getenv("CI") == "CI" {
return "", "", fmt.Errorf("no HuggingFace token found")
}
return token, source, nil
}
51 changes: 51 additions & 0 deletions pkg/mutators/single/huggingface_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package mutators

import (
"testing"

"github.com/batmac/ccat/pkg/mutators"
)

func Test_huggingface(t *testing.T) {
// only test that we do not panic
t.Setenv("CI", "CI")

f := "huggingface"
t.Run("donotpanicplease", func(t *testing.T) {
if got := mutators.Run(f, "hi"); got != "fake" {
t.Errorf("%s = %v, want %v", f, got, "fake")
}
})
}

func Test_getHuggingFaceToken(t *testing.T) {
tests := []struct {
name string
want string
source string
wantErr bool
}{
{
name: "donotpanic",
want: "",
source: "",
wantErr: true,
},
}
t.Setenv("CI", "CI")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := getHuggingFaceToken()
if (err != nil) != tt.wantErr {
t.Errorf("getHuggingFaceToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getHuggingFaceToken() got = %v, want %v", got, tt.want)
}
if got1 != tt.source {
t.Errorf("getHuggingFaceToken() got1 = %v, want %v", got1, tt.source)
}
})
}
}

0 comments on commit 6f56e4a

Please sign in to comment.