From 18dc85b59358f0c45cef34d9f85001d783072b46 Mon Sep 17 00:00:00 2001 From: batmac Date: Sat, 29 Apr 2023 03:24:49 +0200 Subject: [PATCH] support inference endpoints --- pkg/mutators/single/huggingface.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pkg/mutators/single/huggingface.go b/pkg/mutators/single/huggingface.go index bb56e2a6..c7f1b9e1 100644 --- a/pkg/mutators/single/huggingface.go +++ b/pkg/mutators/single/huggingface.go @@ -61,7 +61,7 @@ type HuggingFaceRequest struct { func init() { singleRegister("huggingface", huggingface, - withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN)"), + withDescription("ask HuggingFace for simple tasks, optional arg is the model (needs a valid key in $HUGGING_FACE_HUB_TOKEN, set HUGGING_FACE_ENDPOINT to use an Inference API endpoint)"), withConfigBuilder(stdConfigStrings(0, 1)), withAliases("hf"), withCategory("external APIs"), @@ -70,6 +70,7 @@ func init() { func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { arg := conf.([]string) + baseURL := "https://api-inference.huggingface.co/models/" token, source, err := getHuggingFaceToken() @@ -86,8 +87,17 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { model = m } + url := baseURL + model + if os.Getenv("HUGGING_FACE_ENDPOINT") != "" { + url = os.Getenv("HUGGING_FACE_ENDPOINT") + if len(arg) >= 1 && arg[0] != "" { + log.Println("warning: HUGGING_FACE_ENDPOINT is set, ignoring model argument") + } + } + log.Debugln("token: from ", source) log.Debugln("model: ", model) + log.Debugln("url: ", url) input, err := io.ReadAll(r) if err != nil { @@ -100,7 +110,7 @@ func huggingface(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) { log.Debugf("request: %s\n", request) - req, err := http.NewRequest(http.MethodPost, baseURL+model, bytes.NewReader(request)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(request)) if err != nil { return 0, err }