diff --git a/types/controller.go b/types/controller.go index 7efa63d..8d405d8 100644 --- a/types/controller.go +++ b/types/controller.go @@ -1,6 +1,7 @@ package types import ( + "context" "fmt" "log" "sync" @@ -111,7 +112,13 @@ func (c *Controller) Subscribe(subscriber ResponseSubscriber) { // Invoke attempts to invoke any functions which match the // topic the incoming message was published on. func (c *Controller) Invoke(topic string, message *[]byte) { - c.Invoker.Invoke(c.TopicMap, topic, message) + c.InvokeWithContext(context.Background(), topic, message) +} + +// InvokeWithContext attempts to invoke any functions which match the topic +// the incoming message was published on while propagating context. +func (c *Controller) InvokeWithContext(ctx context.Context, topic string, message *[]byte) { + c.Invoker.InvokeWithContext(ctx, c.TopicMap, topic, message) } // BeginMapBuilder begins to build a map of function->topic by diff --git a/types/invoker.go b/types/invoker.go index 7946876..1101ca7 100644 --- a/types/invoker.go +++ b/types/invoker.go @@ -2,6 +2,7 @@ package types import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -19,6 +20,7 @@ type Invoker struct { } type InvokerResponse struct { + Context context.Context Body *[]byte Header *http.Header Status int @@ -38,9 +40,15 @@ func NewInvoker(gatewayURL string, client *http.Client, printResponse bool) *Inv // Invoke triggers a function by accessing the API Gateway func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) { + i.InvokeWithContext(context.Background(), topicMap, topic, message) +} + +// Invoke triggers a function by accessing the API Gateway while propagating context +func (i *Invoker) InvokeWithContext(ctx context.Context, topicMap *TopicMap, topic string, message *[]byte) { if len(*message) == 0 { i.Responses <- InvokerResponse{ - Error: fmt.Errorf("no message to send"), + Context: ctx, + Error: fmt.Errorf("no message to send"), } } @@ -51,16 +59,18 @@ func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) { gwURL := fmt.Sprintf("%s/%s", i.GatewayURL, matchedFunction) reader := bytes.NewReader(*message) - body, statusCode, header, doErr := invokefunction(i.Client, gwURL, reader) + body, statusCode, header, doErr := invokefunction(ctx, i.Client, gwURL, reader) if doErr != nil { i.Responses <- InvokerResponse{ - Error: errors.Wrap(doErr, fmt.Sprintf("unable to invoke %s", matchedFunction)), + Context: ctx, + Error: errors.Wrap(doErr, fmt.Sprintf("unable to invoke %s", matchedFunction)), } continue } i.Responses <- InvokerResponse{ + Context: ctx, Body: body, Status: statusCode, Header: header, @@ -70,9 +80,10 @@ func (i *Invoker) Invoke(topicMap *TopicMap, topic string, message *[]byte) { } } -func invokefunction(c *http.Client, gwURL string, reader io.Reader) (*[]byte, int, *http.Header, error) { +func invokefunction(ctx context.Context, c *http.Client, gwURL string, reader io.Reader) (*[]byte, int, *http.Header, error) { httpReq, _ := http.NewRequest(http.MethodPost, gwURL, reader) + httpReq.WithContext(ctx) if httpReq.Body != nil { defer httpReq.Body.Close()