From 13d89aec8ee40a0010b18179fb628462ea1cfed6 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Thu, 12 Dec 2024 11:16:40 -0500 Subject: [PATCH 1/4] Reorder methods in Runner --- internal/server/runner.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/server/runner.go b/internal/server/runner.go index 1980874..65bc017 100644 --- a/internal/server/runner.go +++ b/internal/server/runner.go @@ -171,23 +171,6 @@ func (r *Runner) stop() error { //////////////////// // Prediction -func (r *Runner) cancel(pid string) error { - r.mu.Lock() - defer r.mu.Unlock() - if _, ok := r.pending[pid]; !ok { - return ErrNotFound - } - if r.asyncPredict { - // Async predict, use files to cancel - p := path.Join(r.workingDir, fmt.Sprintf(CANCEL_FMT, pid)) - return os.WriteFile(p, []byte{}, 0644) - } else { - // Blocking predict, use SIGUSR1 to cancel - // FIXME: ensure only one prediction in flight? - return syscall.Kill(r.cmd.Process.Pid, syscall.SIGUSR1) - } -} - func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error) { log := logger.Sugar() if r.status == StatusSetupFailed { @@ -238,6 +221,23 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error) return pr.c, nil } +func (r *Runner) cancel(pid string) error { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.pending[pid]; !ok { + return ErrNotFound + } + if r.asyncPredict { + // Async predict, use files to cancel + p := path.Join(r.workingDir, fmt.Sprintf(CANCEL_FMT, pid)) + return os.WriteFile(p, []byte{}, 0644) + } else { + // Blocking predict, use SIGUSR1 to cancel + // FIXME: ensure only one prediction in flight? + return syscall.Kill(r.cmd.Process.Pid, syscall.SIGUSR1) + } +} + //////////////////// // Background tasks From c37f3de84565fd0003b156295be7ff029b378720 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Thu, 12 Dec 2024 11:16:17 -0500 Subject: [PATCH 2/4] Test prediction concurrency --- internal/server/http.go | 1 + internal/server/runner.go | 5 ++ internal/server/server.go | 7 ++- internal/tests/async_prediction_test.go | 62 +++++++++++++++++++++++++ internal/tests/prediction_test.go | 39 ++++++++++++++++ 5 files changed, 112 insertions(+), 2 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index c93eb94..e368a3d 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,6 +6,7 @@ import ( ) var ( + ErrConflict = errors.New("already running a prediction") ErrExists = errors.New("prediction exists") ErrNotFound = errors.New("prediction not found") ErrDefunct = errors.New("server is defunct") diff --git a/internal/server/runner.go b/internal/server/runner.go index 65bc017..d4ee6ab 100644 --- a/internal/server/runner.go +++ b/internal/server/runner.go @@ -184,6 +184,11 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error) req.CreatedAt = util.NowIso() } r.mu.Lock() + if !r.asyncPredict && req.Webhook != "" && len(r.pending) > 0 { + r.mu.Unlock() + log.Errorw("prediction rejected: Already running a prediction") + return nil, ErrConflict + } if _, ok := r.pending[req.Id]; ok { r.mu.Unlock() log.Errorw("prediction rejected: prediction exists", "id", req.Id) diff --git a/internal/server/server.go b/internal/server/server.go index c174733..82285a6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -81,11 +81,14 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) { } c, err := h.runner.predict(req) - if errors.Is(err, ErrDefunct) { + if errors.Is(err, ErrConflict) { + http.Error(w, err.Error(), http.StatusConflict) + return + } else if errors.Is(err, ErrDefunct) { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } else if errors.Is(err, ErrExists) { - http.Error(w, err.Error(), http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusConflict) return } else if errors.Is(err, ErrSetupFailed) { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/internal/tests/async_prediction_test.go b/internal/tests/async_prediction_test.go index 344d3d4..14504f3 100644 --- a/internal/tests/async_prediction_test.go +++ b/internal/tests/async_prediction_test.go @@ -1,10 +1,18 @@ package tests import ( + "bytes" + "encoding/json" + "fmt" + "net/http" "strings" "testing" "time" + "github.com/replicate/go/must" + + "github.com/replicate/cog-runtime/internal/util" + "github.com/replicate/cog-runtime/internal/server" "github.com/stretchr/testify/assert" @@ -242,3 +250,57 @@ func TestAsyncPredictionCanceled(t *testing.T) { ct.Shutdown() assert.NoError(t, ct.Cleanup()) } + +func TestAsyncPredictionConcurrency(t *testing.T) { + ct := NewCogTest(t, "sleep") + ct.StartWebhook() + assert.NoError(t, ct.Start()) + + hc := ct.WaitForSetup() + assert.Equal(t, server.StatusReady.String(), hc.Status) + assert.Equal(t, server.SetupSucceeded, hc.Setup.Status) + + ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) + + // Fail prediction requests when one is in progress + req := server.PredictionRequest{ + CreatedAt: util.NowIso(), + Input: map[string]any{"i": 1, "s": "baz"}, + Webhook: fmt.Sprintf("http://localhost:%d/webhook", ct.webhookPort), + } + data := bytes.NewReader(must.Get(json.Marshal(req))) + r := must.Get(http.NewRequest(http.MethodPost, ct.Url("/predictions"), data)) + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Prefer", "respond-async") + resp := must.Get(http.DefaultClient.Do(r)) + assert.Equal(t, http.StatusConflict, resp.StatusCode) + + wr := ct.WaitForWebhookCompletion() + if *legacyCog { + assert.Len(t, wr, 3) + logs := "" + // Compat: legacy Cog sends no "starting" event + ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) + // Compat: legacy Cog buffers logging? + logs += "starting prediction\n" + ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs) + logs += "prediction in progress 1/1\n" + logs += "completed prediction\n" + ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs) + } else { + assert.True(t, len(wr) > 0) + assert.Len(t, wr, 5) + logs := "" + ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) + logs += "starting prediction\n" + ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) + logs += "prediction in progress 1/1\n" + ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) + logs += "completed prediction\n" + ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) + ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs) + } + + ct.Shutdown() + assert.NoError(t, ct.Cleanup()) +} diff --git a/internal/tests/prediction_test.go b/internal/tests/prediction_test.go index f292e31..a81da99 100644 --- a/internal/tests/prediction_test.go +++ b/internal/tests/prediction_test.go @@ -114,3 +114,42 @@ func TestPredictionCrash(t *testing.T) { ct.Shutdown() assert.NoError(t, ct.Cleanup()) } + +func TestPredictionConcurrency(t *testing.T) { + ct := NewCogTest(t, "sleep") + assert.NoError(t, ct.Start()) + + hc := ct.WaitForSetup() + assert.Equal(t, server.StatusReady.String(), hc.Status) + assert.Equal(t, server.SetupSucceeded, hc.Setup.Status) + + var resp1 server.PredictionResponse + var resp2 server.PredictionResponse + done1 := make(chan bool, 1) + done2 := make(chan bool, 1) + + go func() { + resp1 = ct.Prediction(map[string]any{"i": 1, "s": "bar"}) + done1 <- true + }() + + time.Sleep(100 * time.Millisecond) + // Block prediction requests when one is in progress + go func() { + resp2 = ct.Prediction(map[string]any{"i": 1, "s": "baz"}) + done2 <- true + }() + + <-done1 + assert.Equal(t, server.PredictionSucceeded, resp1.Status) + assert.Equal(t, "*bar*", resp1.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", resp1.Logs) + + <-done2 + assert.Equal(t, server.PredictionSucceeded, resp2.Status) + assert.Equal(t, "*baz*", resp2.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", resp2.Logs) + + ct.Shutdown() + assert.NoError(t, ct.Cleanup()) +} From 4ab4500bd8684e9e0a3a892592c052eebc20309d Mon Sep 17 00:00:00 2001 From: Neville Li Date: Thu, 12 Dec 2024 11:55:04 -0500 Subject: [PATCH 3/4] Simplify prediction test helpers --- internal/tests/async_prediction_test.go | 10 +--- internal/tests/cog_test.go | 77 +++++++++++++++---------- internal/tests/prediction_test.go | 14 ++--- 3 files changed, 52 insertions(+), 49 deletions(-) diff --git a/internal/tests/async_prediction_test.go b/internal/tests/async_prediction_test.go index 14504f3..5c6aac9 100644 --- a/internal/tests/async_prediction_test.go +++ b/internal/tests/async_prediction_test.go @@ -1,16 +1,12 @@ package tests import ( - "bytes" - "encoding/json" "fmt" "net/http" "strings" "testing" "time" - "github.com/replicate/go/must" - "github.com/replicate/cog-runtime/internal/util" "github.com/replicate/cog-runtime/internal/server" @@ -268,11 +264,7 @@ func TestAsyncPredictionConcurrency(t *testing.T) { Input: map[string]any{"i": 1, "s": "baz"}, Webhook: fmt.Sprintf("http://localhost:%d/webhook", ct.webhookPort), } - data := bytes.NewReader(must.Get(json.Marshal(req))) - r := must.Get(http.NewRequest(http.MethodPost, ct.Url("/predictions"), data)) - r.Header.Set("Content-Type", "application/json") - r.Header.Set("Prefer", "respond-async") - resp := must.Get(http.DefaultClient.Do(r)) + resp := ct.PredictionReq(http.MethodPost, "/predictions", req) assert.Equal(t, http.StatusConflict, resp.StatusCode) wr := ct.WaitForWebhookCompletion() diff --git a/internal/tests/cog_test.go b/internal/tests/cog_test.go index 4b4a492..ad02fdc 100644 --- a/internal/tests/cog_test.go +++ b/internal/tests/cog_test.go @@ -248,6 +248,14 @@ func (ct *CogTest) Url(path string) string { return fmt.Sprintf("http://localhost:%d%s", ct.serverPort, path) } +func (ct *CogTest) WebhookUrl() string { + return fmt.Sprintf("http://localhost:%d/webhook", ct.webhookPort) +} + +func (ct *CogTest) UploadUrl() string { + return fmt.Sprintf("http://localhost:%d/upload/", ct.webhookPort) +} + func (ct *CogTest) HealthCheck() server.HealthCheck { url := fmt.Sprintf("http://localhost:%d/health-check", ct.serverPort) for { @@ -273,65 +281,74 @@ func (ct *CogTest) WaitForSetup() server.HealthCheck { func (ct *CogTest) Prediction(input map[string]any) server.PredictionResponse { req := server.PredictionRequest{Input: input} - return ct.prediction(http.MethodPost, ct.Url("/predictions"), req) + return ct.prediction(http.MethodPost, "/predictions", req) } func (ct *CogTest) PredictionWithId(pid string, input map[string]any) server.PredictionResponse { req := server.PredictionRequest{Id: pid, Input: input} - return ct.prediction(http.MethodPut, ct.Url(fmt.Sprintf("/predictions/%s", pid)), req) + return ct.prediction(http.MethodPut, fmt.Sprintf("/predictions/%s", pid), req) } func (ct *CogTest) PredictionWithUpload(input map[string]any) server.PredictionResponse { req := server.PredictionRequest{ Input: input, - OutputFilePrefix: fmt.Sprintf("http://localhost:%d/upload/", ct.webhookPort), + OutputFilePrefix: ct.UploadUrl(), } - return ct.prediction(http.MethodPost, ct.Url("/predictions"), req) -} - -func (ct *CogTest) prediction(method string, url string, req server.PredictionRequest) server.PredictionResponse { - req.CreatedAt = util.NowIso() - data := bytes.NewReader(must.Get(json.Marshal(req))) - r := must.Get(http.NewRequest(method, url, data)) - r.Header.Set("Content-Type", "application/json") - resp := must.Get(http.DefaultClient.Do(r)) - assert.Equal(ct.t, http.StatusOK, resp.StatusCode) - var pr server.PredictionResponse - must.Do(json.Unmarshal(must.Get(io.ReadAll(resp.Body)), &pr)) - return pr + return ct.prediction(http.MethodPost, "/predictions", req) } func (ct *CogTest) AsyncPrediction(input map[string]any) string { - req := server.PredictionRequest{Input: input} - return ct.asyncPrediction(http.MethodPost, ct.Url("/predictions"), req) + req := server.PredictionRequest{ + Input: input, + Webhook: ct.WebhookUrl(), + } + resp := ct.prediction(http.MethodPost, "/predictions", req) + return resp.Id } func (ct *CogTest) AsyncPredictionWithFilter(input map[string]any, filter []server.WebhookEvent) string { req := server.PredictionRequest{ Input: input, + Webhook: ct.WebhookUrl(), WebhookEventsFilter: filter, } - return ct.asyncPrediction(http.MethodPost, ct.Url("/predictions"), req) + resp := ct.prediction(http.MethodPost, "/predictions", req) + return resp.Id } func (ct *CogTest) AsyncPredictionWithId(pid string, input map[string]any) string { - req := server.PredictionRequest{Id: pid, Input: input} - return ct.asyncPrediction(http.MethodPut, ct.Url(fmt.Sprintf("/predictions/%s", pid)), req) + req := server.PredictionRequest{ + Id: pid, + Input: input, + Webhook: ct.WebhookUrl(), + } + resp := ct.prediction(http.MethodPut, fmt.Sprintf("/predictions/%s", pid), req) + return resp.Id } -func (ct *CogTest) asyncPrediction(method string, url string, req server.PredictionRequest) string { +func (ct *CogTest) prediction(method string, path string, req server.PredictionRequest) server.PredictionResponse { + resp := ct.PredictionReq(method, path, req) + if req.Webhook == "" { + + assert.Equal(ct.t, http.StatusOK, resp.StatusCode) + } else { + assert.Equal(ct.t, http.StatusAccepted, resp.StatusCode) + } ct.pending++ + var pr server.PredictionResponse + must.Do(json.Unmarshal(must.Get(io.ReadAll(resp.Body)), &pr)) + return pr +} + +func (ct *CogTest) PredictionReq(method string, path string, req server.PredictionRequest) *http.Response { req.CreatedAt = util.NowIso() - req.Webhook = fmt.Sprintf("http://localhost:%d/webhook", ct.webhookPort) data := bytes.NewReader(must.Get(json.Marshal(req))) - r := must.Get(http.NewRequest(method, url, data)) + r := must.Get(http.NewRequest(method, ct.Url(path), data)) r.Header.Set("Content-Type", "application/json") - r.Header.Set("Prefer", "respond-async") - resp := must.Get(http.DefaultClient.Do(r)) - assert.Equal(ct.t, http.StatusAccepted, resp.StatusCode) - var pr server.PredictionResponse - must.Do(json.Unmarshal(must.Get(io.ReadAll(resp.Body)), &pr)) - return pr.Id + if req.Webhook != "" { + r.Header.Set("Prefer", "respond-async") + } + return must.Get(http.DefaultClient.Do(r)) } func (ct *CogTest) Cancel(pid string) { diff --git a/internal/tests/prediction_test.go b/internal/tests/prediction_test.go index a81da99..a51f6e5 100644 --- a/internal/tests/prediction_test.go +++ b/internal/tests/prediction_test.go @@ -1,8 +1,6 @@ package tests import ( - "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -11,8 +9,6 @@ import ( "github.com/replicate/go/must" - "github.com/replicate/cog-runtime/internal/util" - "github.com/replicate/cog-runtime/internal/server" "github.com/stretchr/testify/assert" @@ -88,12 +84,10 @@ func TestPredictionCrash(t *testing.T) { assert.Equal(t, server.SetupSucceeded, hc.Setup.Status) if *legacyCog { - req := server.PredictionRequest{Input: map[string]any{"i": 1, "s": "bar"}} - req.CreatedAt = util.NowIso() - data := bytes.NewReader(must.Get(json.Marshal(req))) - r := must.Get(http.NewRequest(http.MethodPost, ct.Url("/predictions"), data)) - r.Header.Set("Content-Type", "application/json") - resp := must.Get(http.DefaultClient.Do(r)) + req := server.PredictionRequest{ + Input: map[string]any{"i": 1, "s": "bar"}, + } + resp := ct.PredictionReq(http.MethodPost, "/predictions", req) // Compat: legacy Cog returns HTTP 500 and "Internal Server Error" assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) body := string(must.Get(io.ReadAll(resp.Body))) From c65541adde45a4db3c11b193734244e819f850b8 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Thu, 12 Dec 2024 12:30:52 -0500 Subject: [PATCH 4/4] Improve validation error message --- python/coglet/runner.py | 46 +++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/python/coglet/runner.py b/python/coglet/runner.py index 1ec15c5..fc9acff 100644 --- a/python/coglet/runner.py +++ b/python/coglet/runner.py @@ -11,48 +11,54 @@ def _kwargs(adt_ins: Dict[str, adt.Input], inputs: Dict[str, Any]) -> Dict[str, Any]: kwargs: Dict[str, Any] = {} for name, value in inputs.items(): + assert name in adt_ins, f'unknown field: {name}' adt_in = adt_ins[name] cog_t = adt_in.type if adt_in.is_list: assert all( util.check_value(cog_t, v) for v in value - ), f'incompatible input for: {name}' + ), f'incompatible value for field: {name}={value}' value = [util.normalize_value(cog_t, v) for v in value] else: - assert util.check_value(cog_t, value), f'incompatible input for: {name}' + assert util.check_value( + cog_t, value + ), f'incompatible value for field: {name}={value}' value = util.normalize_value(cog_t, value) kwargs[name] = value for name, adt_in in adt_ins.items(): if name not in kwargs: - assert adt_in.default is not None, f'missing default value for: {name}' + assert ( + adt_in.default is not None + ), f'missing default value for field: {name}' kwargs[name] = adt_in.default - vals = kwargs[name] if adt_in.is_list else [kwargs[name]] + values = kwargs[name] if adt_in.is_list else [kwargs[name]] + v = kwargs[name] if adt_in.ge is not None: assert ( - x >= adt_in.ge for x in vals - ), f'not all values >= {adt_in.ge} for: {name}' + x >= adt_in.ge for x in values + ), f'validation failure: >= {adt_in.ge} for field: {name}={v}' if adt_in.le is not None: assert ( - x <= adt_in.le for x in vals - ), f'not all values <= {adt_in.le} for: {name}' + x <= adt_in.le for x in values + ), f'validation failure: <= {adt_in.le} for field: {name}={v}' if adt_in.min_length is not None: assert ( - len(x) >= adt_in.min_length for x in vals - ), f'not all values have len(x) >= {adt_in.min_length} for: {name}' + len(x) >= adt_in.min_length for x in values + ), f'validation failure: len(x) >= {adt_in.min_length} for field: {name}={v}' if adt_in.max_length is not None: assert ( - len(x) <= adt_in.max_length for x in vals - ), f'not all values have len(x) <= {adt_in.max_length} for: {name}' + len(x) <= adt_in.max_length for x in values + ), f'validation failure: len(x) <= {adt_in.max_length} for field: {name}={v}' if adt_in.regex is not None: p = re.compile(adt_in.regex) assert all( - p.match(x) is not None for x in vals - ), f'not all inputs match regex for: {name}' + p.match(x) is not None for x in values + ), f'validation failure: regex match for field: {name}={v}' if adt_in.choices is not None: assert all( - x in adt_in.choices for x in vals - ), f'not all inputs in choices for: {name}' + x in adt_in.choices for x in values + ), f'validation failure: choices for field: {name}={v}' return kwargs @@ -65,7 +71,9 @@ def _check_output(adt_out: adt.Output, output: Any) -> Any: assert adt_out.type is not None, 'missing output type' assert type(output) is list, 'output is not list' for i, x in enumerate(output): - assert util.check_value(adt_out.type, x), f'incompatible output: {x}' + assert util.check_value( + adt_out.type, x + ), f'incompatible output element: {x}' output[i] = util.normalize_value(adt_out.type, x) return output elif adt_out.kind == adt.Kind.OBJECT: @@ -73,7 +81,9 @@ def _check_output(adt_out: adt.Output, output: Any) -> Any: for name, tpe in adt_out.fields.items(): assert hasattr(output, name), f'missing output field: {name}' value = getattr(output, name) - assert util.check_value(tpe, value), f'incompatible output: {name}={value}' + assert util.check_value( + tpe, value + ), f'incompatible output for field: {name}={value}' setattr(output, name, util.normalize_value(tpe, value)) return output