diff --git a/types/invoker.go b/types/invoker.go index 1170972..5bcd233 100644 --- a/types/invoker.go +++ b/types/invoker.go @@ -15,45 +15,67 @@ type Invoker struct { GatewayURL string } -func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) { - if len(*message) > 0 { +type InvocationResult struct { + StatusCode int + Body *[]byte + Error error +} - matchedFunctions := topicMap.Match(topic) - for _, matchedFunction := range matchedFunctions { +func NewInvocationResult(statusCode int, body *[]byte, error error) *InvocationResult { + return &InvocationResult{StatusCode: statusCode, Body: body, Error: error} +} - log.Printf("Invoke function: %s", matchedFunction) +type InvocationResponse struct { + StatusCode int + Body *[]byte + Headers http.Header +} - gwURL := fmt.Sprintf("%s/function/%s", i.GatewayURL, matchedFunction) - reader := bytes.NewReader(*message) +func NewInvocationResponse(statusCode int, body *[]byte, headers http.Header) *InvocationResponse { + return &InvocationResponse{StatusCode: statusCode, Body: body, Headers: headers} +} - body, statusCode, doErr := invokefunction(i.Client, gwURL, reader) +func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) (result map[string]*InvocationResult) { + result = make(map[string]*InvocationResult) - if doErr != nil { - log.Printf("Unable to invoke from %s, error: %s\n", matchedFunction, doErr) - return - } + if message != nil && len(*message) > 0 { - printBody := false - stringOutput := "" + matchedFunctions := topicMap.Match(topic) + for _, matchedFunction := range matchedFunctions { - if body != nil && i.PrintResponse { - stringOutput = string(*body) - printBody = true - } + log.Printf("Invoke function: %s", matchedFunction) + functionURL := fmt.Sprintf("%s/function/%s", i.GatewayURL, matchedFunction) + reader := bytes.NewReader(*message) - if printBody { - log.Printf("Response [%d] from %s %s", statusCode, matchedFunction, stringOutput) + response, err := i.performInvocation(functionURL, reader) + if err != nil { + if response != nil { + result[matchedFunction] = NewInvocationResult(response.StatusCode, nil, err) + } else { + result[matchedFunction] = NewInvocationResult(-1, nil, err) + } } else { - log.Printf("Response [%d] from %s", statusCode, matchedFunction) + result[matchedFunction] = NewInvocationResult(response.StatusCode, response.Body, err) + } + + if response != nil && response.Body != nil && i.PrintResponse { + stringOutput := string(*response.Body) + log.Printf("Headers: %s", response.Headers) + log.Printf("Response: [%d] from %s %s", response.StatusCode, matchedFunction, stringOutput) } } } + return result } -func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, int, error) { +func (i *Invoker) performInvocation(functionURL string, bodyReader io.Reader) (*InvocationResponse, error) { - httpReq, _ := http.NewRequest(http.MethodPost, gwURL, reader) + httpReq, requestErr := http.NewRequest(http.MethodPost, functionURL, bodyReader) + + if requestErr != nil { + return nil, requestErr + } if httpReq.Body != nil { defer httpReq.Body.Close() @@ -61,9 +83,9 @@ func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, in var body *[]byte - res, doErr := c.Do(httpReq) + res, doErr := i.Client.Do(httpReq) if doErr != nil { - return nil, http.StatusServiceUnavailable, doErr + return nil, doErr } if res.Body != nil { @@ -72,11 +94,11 @@ func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, in bytesOut, readErr := ioutil.ReadAll(res.Body) if readErr != nil { log.Printf("Error reading body") - return nil, http.StatusServiceUnavailable, doErr + return NewInvocationResponse(res.StatusCode, nil, res.Header), readErr } body = &bytesOut } - return body, res.StatusCode, doErr + return NewInvocationResponse(res.StatusCode, body, res.Header), nil } diff --git a/types/invoker_test.go b/types/invoker_test.go new file mode 100644 index 0000000..8862852 --- /dev/null +++ b/types/invoker_test.go @@ -0,0 +1,147 @@ +// Copyright (c) OpenFaaS Project 2018. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +package types + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestInvoker_Invoke(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + function := strings.Split(r.URL.Path, "/")[2] + + switch function { + case "success": + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(200) + w.Write([]byte("Hello World")) + break + case "headers": + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Charset", "utf-8") + w.WriteHeader(200) + w.Write([]byte("