From 6cdac1221297489a5de2232c386704c4098533db Mon Sep 17 00:00:00 2001 From: Neville Li Date: Wed, 11 Dec 2024 18:48:49 -0500 Subject: [PATCH] Improve HTTP error handling --- internal/server/http.go | 5 ++++- internal/server/runner.go | 8 ++++---- internal/server/server.go | 13 ++++++++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index f5d91c3..c93eb94 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,7 +6,10 @@ import ( ) var ( - ErrNotFound = errors.New("prediction ID not found") + ErrExists = errors.New("prediction exists") + ErrNotFound = errors.New("prediction not found") + ErrDefunct = errors.New("server is defunct") + ErrSetupFailed = errors.New("setup failed") ) func NewServer(addr string, runner *Runner) *http.Server { diff --git a/internal/server/runner.go b/internal/server/runner.go index 9765e8c..1980874 100644 --- a/internal/server/runner.go +++ b/internal/server/runner.go @@ -192,10 +192,10 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error) log := logger.Sugar() if r.status == StatusSetupFailed { log.Errorw("prediction rejected: setup failed") - return nil, fmt.Errorf("setup failed") + return nil, ErrSetupFailed } else if r.status == StatusDefunct { log.Errorw("prediction rejected: server is defunct") - return nil, fmt.Errorf("server is defunct") + return nil, ErrDefunct } if req.CreatedAt == "" { req.CreatedAt = util.NowIso() @@ -203,8 +203,8 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error) r.mu.Lock() if _, ok := r.pending[req.Id]; ok { r.mu.Unlock() - log.Errorw("prediction rejected: prediction ID exists", "id", req.Id) - return nil, fmt.Errorf("prediction ID exists") + log.Errorw("prediction rejected: prediction exists", "id", req.Id) + return nil, ErrExists } r.mu.Unlock() diff --git a/internal/server/server.go b/internal/server/server.go index 1f092f7..c174733 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -81,9 +81,20 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) { } c, err := h.runner.predict(req) - if err != nil { + if errors.Is(err, ErrDefunct) { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } else if errors.Is(err, ErrExists) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } else if errors.Is(err, ErrSetupFailed) { http.Error(w, err.Error(), http.StatusInternalServerError) + return + } else if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + if c == nil { w.WriteHeader(http.StatusAccepted) resp := PredictionResponse{Id: req.Id, Status: "starting"}