diff --git a/types/invoker.go b/types/invoker.go index 1170972..5bcd233 100644 --- a/types/invoker.go +++ b/types/invoker.go @@ -15,45 +15,67 @@ type Invoker struct { GatewayURL string } -func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) { - if len(*message) > 0 { +type InvocationResult struct { + StatusCode int + Body *[]byte + Error error +} - matchedFunctions := topicMap.Match(topic) - for _, matchedFunction := range matchedFunctions { +func NewInvocationResult(statusCode int, body *[]byte, error error) *InvocationResult { + return &InvocationResult{StatusCode: statusCode, Body: body, Error: error} +} - log.Printf("Invoke function: %s", matchedFunction) +type InvocationResponse struct { + StatusCode int + Body *[]byte + Headers http.Header +} - gwURL := fmt.Sprintf("%s/function/%s", i.GatewayURL, matchedFunction) - reader := bytes.NewReader(*message) +func NewInvocationResponse(statusCode int, body *[]byte, headers http.Header) *InvocationResponse { + return &InvocationResponse{StatusCode: statusCode, Body: body, Headers: headers} +} - body, statusCode, doErr := invokefunction(i.Client, gwURL, reader) +func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) (result map[string]*InvocationResult) { + result = make(map[string]*InvocationResult) - if doErr != nil { - log.Printf("Unable to invoke from %s, error: %s\n", matchedFunction, doErr) - return - } + if message != nil && len(*message) > 0 { - printBody := false - stringOutput := "" + matchedFunctions := topicMap.Match(topic) + for _, matchedFunction := range matchedFunctions { - if body != nil && i.PrintResponse { - stringOutput = string(*body) - printBody = true - } + log.Printf("Invoke function: %s", matchedFunction) + functionURL := fmt.Sprintf("%s/function/%s", i.GatewayURL, matchedFunction) + reader := bytes.NewReader(*message) - if printBody { - log.Printf("Response [%d] from %s %s", statusCode, matchedFunction, stringOutput) + response, err := i.performInvocation(functionURL, reader) + if err != nil { + if response != nil { + result[matchedFunction] = NewInvocationResult(response.StatusCode, nil, err) + } else { + result[matchedFunction] = NewInvocationResult(-1, nil, err) + } } else { - log.Printf("Response [%d] from %s", statusCode, matchedFunction) + result[matchedFunction] = NewInvocationResult(response.StatusCode, response.Body, err) + } + + if response != nil && response.Body != nil && i.PrintResponse { + stringOutput := string(*response.Body) + log.Printf("Headers: %s", response.Headers) + log.Printf("Response: [%d] from %s %s", response.StatusCode, matchedFunction, stringOutput) } } } + return result } -func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, int, error) { +func (i *Invoker) performInvocation(functionURL string, bodyReader io.Reader) (*InvocationResponse, error) { - httpReq, _ := http.NewRequest(http.MethodPost, gwURL, reader) + httpReq, requestErr := http.NewRequest(http.MethodPost, functionURL, bodyReader) + + if requestErr != nil { + return nil, requestErr + } if httpReq.Body != nil { defer httpReq.Body.Close() @@ -61,9 +83,9 @@ func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, in var body *[]byte - res, doErr := c.Do(httpReq) + res, doErr := i.Client.Do(httpReq) if doErr != nil { - return nil, http.StatusServiceUnavailable, doErr + return nil, doErr } if res.Body != nil { @@ -72,11 +94,11 @@ func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, in bytesOut, readErr := ioutil.ReadAll(res.Body) if readErr != nil { log.Printf("Error reading body") - return nil, http.StatusServiceUnavailable, doErr + return NewInvocationResponse(res.StatusCode, nil, res.Header), readErr } body = &bytesOut } - return body, res.StatusCode, doErr + return NewInvocationResponse(res.StatusCode, body, res.Header), nil } diff --git a/types/invoker_test.go b/types/invoker_test.go new file mode 100644 index 0000000..8862852 --- /dev/null +++ b/types/invoker_test.go @@ -0,0 +1,147 @@ +// Copyright (c) OpenFaaS Project 2018. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +package types + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestInvoker_Invoke(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + function := strings.Split(r.URL.Path, "/")[2] + + switch function { + case "success": + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(200) + w.Write([]byte("Hello World")) + break + case "headers": + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Charset", "utf-8") + w.WriteHeader(200) + w.Write([]byte("

Hello World

")) + break + case "wrong_payload": + w.WriteHeader(400) + w.Write([]byte("")) + break + case "aborts": + if wr, ok := w.(http.Hijacker); ok { + conn, _, err := wr.Hijack() + if err != nil{ + fmt.Printf("Recieved %s",err) + }else{ + conn.Close() + } + + } + break + case "server_error": + w.WriteHeader(500) + w.Write([]byte("")) + break + } + })) + + client := srv.Client() + topicMap := NewTopicMap() + + sampleFunc := map[string][]string{ + "All": []string{"success", "headers", "wrong_payload"}, + "Contains_Fail": []string{"success", "server_error", "aborts", "headers"}, + "NOP": []string{}, + } + + topicMap.Sync(&sampleFunc) + + t.Run("Should invoke no function when body is nil", func(t *testing.T) { + target := &Invoker{ + PrintResponse:false, + Client:client, + GatewayURL: srv.URL, + } + + results := target.Invoke(&topicMap, "NOP", nil) + + if len(results) != 0 { + t.Errorf("When body is nil it should perform a request") + } + }) + + t.Run("Should invoke no function when body is empty", func(t *testing.T) { + target := &Invoker{ + true, + client, + srv.URL, + } + + body := []byte("") + results := target.Invoke(&topicMap, "NOP", &body) + + if len(results) != 0 { + t.Errorf("When body is empty it should perform a request") + } + }) + + t.Run("Should invoke all functions", func(t *testing.T) { + target := &Invoker{ + true, + client, + srv.URL, + } + + body := []byte("Some Input") + results := target.Invoke(&topicMap, "All", &body) + + const ExpectedResults = 3 + if len(results) != ExpectedResults { + t.Errorf("Expected %d results recieved %d", ExpectedResults, len(results)) + } + + for name, result := range results { + if result.Error != nil { + t.Errorf("Received unexpected error %s for %s", result.Error, name) + } + + if result.StatusCode != 200 && result.StatusCode != 400 { + t.Errorf("Received unexpected status code %d for %s", result.StatusCode, name) + } + } + + }) + + t.Run("Should invoke all functions even if one request fails", func(t *testing.T) { + target := &Invoker{ + true, + client, + srv.URL, + } + + body := []byte("Hello World") + results := target.Invoke(&topicMap, "Contains_Fail", &body) + + const ExpectedResults = 4 + if len(results) != ExpectedResults { + t.Errorf("Expected %d results recieved %d", ExpectedResults, len(results)) + } + + for name, result := range results { + if name == "aborts" { + if result.Error == nil { + t.Errorf("Expected call for %s to fail", name) + } + } else { + if result.Error != nil { + t.Errorf("Received unexpected error %s for %s", result.Error, name) + } + } + } + }) +}