diff --git a/lib/http.go b/lib/http.go index 60a1520..a2629f1 100644 --- a/lib/http.go +++ b/lib/http.go @@ -276,7 +276,7 @@ import ( // // line=25&page=2" func HTTP(client *http.Client, limit *rate.Limiter, auth *BasicAuth) cel.EnvOption { - return HTTPWithContextOpts(context.Background(), client, HTTPOptions{Limiter: limit, BasicAuth: auth}) + return HTTPWithContextFnOpts(context.Background, client, HTTPOptions{Limiter: limit, BasicAuth: auth}) } // HTTPWithContext returns a cel.EnvOption to configure extended functions @@ -289,6 +289,14 @@ func HTTPWithContext(ctx context.Context, client *http.Client, limit *rate.Limit // for HTTP requests that include a context.Context in network requests and // includes extended client options. func HTTPWithContextOpts(ctx context.Context, client *http.Client, options HTTPOptions) cel.EnvOption { + return HTTPWithContextFnOpts(func() context.Context { return ctx }, client, options) +} + +// HTTPWithContextFnOps is like HTTPWithContextOpts but takes context as a +// function rather than a static value. +// Support for context propagation https://github.com/google/cel-go/pull/925 +// would make this pattern unnecessary. +func HTTPWithContextFnOpts(ctxFn func() context.Context, client *http.Client, options HTTPOptions) cel.EnvOption { if client == nil { client = http.DefaultClient } @@ -298,7 +306,7 @@ func HTTPWithContextOpts(ctx context.Context, client *http.Client, options HTTPO return cel.Lib(httpLib{ client: client, options: options, - ctx: ctx, + ctxFn: ctxFn, }) } @@ -341,7 +349,7 @@ func (o HTTPOptions) IsZero() bool { type httpLib struct { client *http.Client - ctx context.Context + ctxFn func() context.Context options HTTPOptions } @@ -529,7 +537,7 @@ func (l httpLib) doHead(arg ref.Val) ref.Val { } func (l httpLib) head(url types.String) (*http.Response, error) { - req, err := http.NewRequestWithContext(l.ctx, http.MethodHead, string(url), nil) + req, err := http.NewRequestWithContext(l.ctxFn(), http.MethodHead, string(url), nil) if err != nil { return nil, err } @@ -564,7 +572,7 @@ func (l httpLib) doGet(arg ref.Val) ref.Val { } func (l httpLib) get(url types.String) (*http.Response, error) { - req, err := http.NewRequestWithContext(l.ctx, http.MethodGet, string(url), nil) + req, err := http.NewRequestWithContext(l.ctxFn(), http.MethodGet, string(url), nil) if err != nil { return nil, err } @@ -623,7 +631,7 @@ func (l httpLib) doPost(args ...ref.Val) ref.Val { } func (l httpLib) post(url, content types.String, body io.Reader) (*http.Response, error) { - req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, string(url), body) + req, err := http.NewRequestWithContext(l.ctxFn(), http.MethodPost, string(url), body) if err != nil { return nil, err } @@ -841,8 +849,8 @@ func (l httpLib) doRequest(arg ref.Val) ref.Val { return types.NewErr("%s", err) } // Recover the context lost during serialisation to JSON. - req = req.WithContext(l.ctx) - err = l.options.Limiter.Wait(l.ctx) + req = req.WithContext(l.ctxFn()) + err = l.options.Limiter.Wait(l.ctxFn()) if err != nil { return types.NewErr("%s", err) }