From a24c01efc3e87d5ce7b1a2ce9ec6946ad6152ec8 Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Wed, 24 Sep 2025 14:27:10 -0700 Subject: [PATCH 1/6] fix test --- server/cmd/api/api/process_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/cmd/api/api/process_test.go b/server/cmd/api/api/process_test.go index 59ecb282..2204eba9 100644 --- a/server/cmd/api/api/process_test.go +++ b/server/cmd/api/api/process_test.go @@ -24,7 +24,7 @@ func TestProcessExec(t *testing.T) { svc := &ApiService{procs: make(map[string]*processHandle)} cmd := "sh" - args := []string{"-c", "echo -n out; echo -n err 1>&2; exit 3"} + args := []string{"-c", "printf out; printf err 1>&2; exit 3"} body := &oapi.ProcessExecRequest{Command: cmd, Args: &args} resp, err := svc.ProcessExec(ctx, oapi.ProcessExecRequestObject{Body: body}) require.NoError(t, err, "ProcessExec error") From f9c9ce82e7aa6e07220e07e1b0513e0a62a6da87 Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Wed, 24 Sep 2025 14:58:01 -0700 Subject: [PATCH 2/6] swap to coder/websocket lib --- server/lib/devtoolsproxy/proxy.go | 59 ++++++++++++------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index ce803ae6..dd1ef536 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -16,7 +16,7 @@ import ( "sync/atomic" "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" ) var devtoolsListeningRegexp = regexp.MustCompile(`DevTools listening on (ws://\S+)`) @@ -161,48 +161,35 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess } // Always use the full upstream path and query, ignoring the client's request path/query upstreamURL := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String() - upgrader := websocket.Upgrader{ - ReadBufferSize: 65536, - WriteBufferSize: 65536, - EnableCompression: true, - CheckOrigin: func(r *http.Request) bool { return true }, + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + CompressionMode: websocket.CompressionContextTakeover, } - logger.Info("upgrader config", slog.Any("upgrader", upgrader)) - clientConn, err := upgrader.Upgrade(w, r, nil) + logger.Info("accept options", slog.Any("options", acceptOptions)) + clientConn, err := websocket.Accept(w, r, acceptOptions) if err != nil { - logger.Error("websocket upgrade failed", slog.String("err", err.Error())) + logger.Error("websocket accept failed", slog.String("err", err.Error())) return } - clientConn.SetReadDeadline(time.Time{}) // No timeout--hold on to connections for dear life - clientConn.SetWriteDeadline(time.Time{}) // No timeout--hold on to connections for dear life clientConn.SetReadLimit(100 * 1024 * 1024) // 100 MB. Effectively no maximum size of message from client - clientConn.EnableWriteCompression(true) - clientConn.SetCompressionLevel(6) - - dialer := websocket.Dialer{ - ReadBufferSize: 65536, - WriteBufferSize: 65536, - HandshakeTimeout: 30 * time.Second, + dialOptions := &websocket.DialOptions{ + CompressionMode: websocket.CompressionContextTakeover, } - logger.Info("dialer config", slog.Any("dialer", dialer)) - upstreamConn, _, err := dialer.Dial(upstreamURL, nil) + logger.Info("dial options", slog.Any("options", dialOptions)) + upstreamConn, _, err := websocket.Dial(r.Context(), upstreamURL, dialOptions) if err != nil { logger.Error("dial upstream failed", slog.String("err", err.Error()), slog.String("url", upstreamURL)) - _ = clientConn.Close() + _ = clientConn.Close(websocket.StatusInternalError, "failed to connect to upstream") return } upstreamConn.SetReadLimit(100 * 1024 * 1024) // 100 MB. Effectively no maximum size of message from upstream - upstreamConn.EnableWriteCompression(true) - upstreamConn.SetCompressionLevel(6) - upstreamConn.SetReadDeadline(time.Time{}) // no timeout - upstreamConn.SetWriteDeadline(time.Time{}) // no timeout logger.Debug("proxying devtools websocket", slog.String("url", upstreamURL)) var once sync.Once cleanup := func() { once.Do(func() { - _ = upstreamConn.Close() - _ = clientConn.Close() + _ = upstreamConn.Close(websocket.StatusNormalClosure, "") + _ = clientConn.Close(websocket.StatusNormalClosure, "") }) } proxyWebSocket(r.Context(), clientConn, upstreamConn, cleanup, logger, logCDPMessages) @@ -210,14 +197,14 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess } type wsConn interface { - ReadMessage() (messageType int, p []byte, err error) - WriteMessage(messageType int, data []byte) error - Close() error + Read(ctx context.Context) (websocket.MessageType, []byte, error) + Write(ctx context.Context, typ websocket.MessageType, p []byte) error + Close(statusCode websocket.StatusCode, reason string) error } // logCDPMessage logs a CDP message with its direction if logging is enabled -func logCDPMessage(logger *slog.Logger, direction string, mt int, msg []byte) { - if mt != websocket.TextMessage { +func logCDPMessage(logger *slog.Logger, direction string, mt websocket.MessageType, msg []byte) { + if mt != websocket.MessageText { return // Only log text messages (CDP messages) } @@ -298,7 +285,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos go func() { for { - mt, msg, err := clientConn.ReadMessage() + mt, msg, err := clientConn.Read(ctx) if err != nil { logger.Error("client read error", slog.String("err", err.Error())) errChan <- err @@ -310,7 +297,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos logCDPMessage(logger, "->", mt, msg) } - if err := upstreamConn.WriteMessage(mt, msg); err != nil { + if err := upstreamConn.Write(ctx, mt, msg); err != nil { logger.Error("upstream write error", slog.String("err", err.Error())) errChan <- err break @@ -319,7 +306,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos }() go func() { for { - mt, msg, err := upstreamConn.ReadMessage() + mt, msg, err := upstreamConn.Read(ctx) if err != nil { logger.Error("upstream read error", slog.String("err", err.Error())) errChan <- err @@ -331,7 +318,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos logCDPMessage(logger, "<-", mt, msg) } - if err := clientConn.WriteMessage(mt, msg); err != nil { + if err := clientConn.Write(ctx, mt, msg); err != nil { logger.Error("client write error", slog.String("err", err.Error())) errChan <- err break From 578781791214c67f6e9bbfb9f73c9156daec0936 Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Wed, 24 Sep 2025 14:58:07 -0700 Subject: [PATCH 3/6] go mod tidy --- server/go.mod | 1 + server/go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/server/go.mod b/server/go.mod index 14eee91d..ab9adcb4 100644 --- a/server/go.mod +++ b/server/go.mod @@ -3,6 +3,7 @@ module github.com/onkernel/kernel-images/server go 1.25.0 require ( + github.com/coder/websocket v1.8.14 github.com/fsnotify/fsnotify v1.9.0 github.com/getkin/kin-openapi v0.132.0 github.com/ghodss/yaml v1.0.0 diff --git a/server/go.sum b/server/go.sum index 63accb8c..8be7344b 100644 --- a/server/go.sum +++ b/server/go.sum @@ -2,6 +2,8 @@ github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMz github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From c0058b208c29d7e48e941fe86067821cab8c0e00 Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Wed, 24 Sep 2025 15:02:00 -0700 Subject: [PATCH 4/6] fix test too --- server/lib/devtoolsproxy/proxy_test.go | 27 +++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/server/lib/devtoolsproxy/proxy_test.go b/server/lib/devtoolsproxy/proxy_test.go index 3a50ca64..8b0da742 100644 --- a/server/lib/devtoolsproxy/proxy_test.go +++ b/server/lib/devtoolsproxy/proxy_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" ) func silentLogger() *slog.Logger { @@ -66,21 +66,24 @@ func TestWaitForInitialTimeoutWhenLogMissing(t *testing.T) { func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { // Start an echo websocket server as upstream echoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - c, err := upgrader.Upgrade(w, r, nil) + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) if err != nil { - t.Fatalf("upgrade failed: %v", err) + t.Fatalf("accept failed: %v", err) return } - defer c.Close() + defer c.Close(websocket.StatusNormalClosure, "") + + ctx := r.Context() for { - mt, msg, err := c.ReadMessage() + mt, msg, err := c.Read(ctx) if err != nil { return } // echo back with path+query prefixed to verify preservation payload := []byte(r.URL.Path + "?" + r.URL.RawQuery + "|" + string(msg)) - if err := c.WriteMessage(mt, payload); err != nil { + if err := c.Write(ctx, mt, payload); err != nil { return } } @@ -109,16 +112,18 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { pu.Path = "/client" pu.RawQuery = "x=y" - conn, _, err := websocket.DefaultDialer.Dial(pu.String(), nil) + ctx := context.Background() + conn, _, err := websocket.Dial(ctx, pu.String(), nil) if err != nil { t.Fatalf("dial proxy failed: %v", err) } - defer conn.Close() + defer conn.Close(websocket.StatusNormalClosure, "") + msg := "hello" - if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { + if err := conn.Write(ctx, websocket.MessageText, []byte(msg)); err != nil { t.Fatalf("write failed: %v", err) } - _, resp, err := conn.ReadMessage() + _, resp, err := conn.Read(ctx) if err != nil { t.Fatalf("read failed: %v", err) } From f3c5d00ff8eb08ffb1ff30388f6cb95a0c00ac8d Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Wed, 24 Sep 2025 15:02:05 -0700 Subject: [PATCH 5/6] go mod tidy --- server/go.mod | 1 - server/go.sum | 2 -- 2 files changed, 3 deletions(-) diff --git a/server/go.mod b/server/go.mod index ab9adcb4..15c3a663 100644 --- a/server/go.mod +++ b/server/go.mod @@ -10,7 +10,6 @@ require ( github.com/glebarez/sqlite v1.11.0 github.com/go-chi/chi/v5 v5.2.1 github.com/google/uuid v1.5.0 - github.com/gorilla/websocket v1.5.3 github.com/kelseyhightower/envconfig v1.4.0 github.com/nrednav/cuid2 v1.1.0 github.com/oapi-codegen/runtime v1.1.1 diff --git a/server/go.sum b/server/go.sum index 8be7344b..0b8d7092 100644 --- a/server/go.sum +++ b/server/go.sum @@ -31,8 +31,6 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= From 195231ea1b386ab555b627f417aa41d1d0ae4f92 Mon Sep 17 00:00:00 2001 From: Sayan Samanta Date: Thu, 25 Sep 2025 17:57:17 -0700 Subject: [PATCH 6/6] stz during cdp --- server/cmd/api/main.go | 2 +- server/lib/devtoolsproxy/proxy.go | 6 +++++- server/lib/devtoolsproxy/proxy_test.go | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 83b26f0f..0c637f96 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -141,7 +141,7 @@ func main() { }) }) rDevtools.Get("/*", func(w http.ResponseWriter, r *http.Request) { - devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages).ServeHTTP(w, r) + devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz).ServeHTTP(w, r) }) srvDevtools := &http.Server{ diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index dd1ef536..f1d1c472 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -17,6 +17,7 @@ import ( "time" "github.com/coder/websocket" + "github.com/onkernel/kernel-images/server/lib/scaletozero" ) var devtoolsListeningRegexp = regexp.MustCompile(`DevTools listening on (ws://\S+)`) @@ -147,8 +148,11 @@ func (u *UpstreamManager) runTailOnce(ctx context.Context) { // WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and // proxies them to the current upstream websocket URL. It expects only websocket requests. // If logCDPMessages is true, all CDP messages will be logged with their direction. -func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool) http.Handler { +func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctrl.Disable(context.WithoutCancel(r.Context())) + defer ctrl.Enable(context.WithoutCancel(r.Context())) + upstreamCurrent := mgr.Current() if upstreamCurrent == "" { http.Error(w, "upstream not ready", http.StatusServiceUnavailable) diff --git a/server/lib/devtoolsproxy/proxy_test.go b/server/lib/devtoolsproxy/proxy_test.go index 8b0da742..2b1899aa 100644 --- a/server/lib/devtoolsproxy/proxy_test.go +++ b/server/lib/devtoolsproxy/proxy_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/coder/websocket" + "github.com/onkernel/kernel-images/server/lib/scaletozero" ) func silentLogger() *slog.Logger { @@ -101,7 +102,7 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { // seed current upstream to echo server including path/query (bypass tailing) mgr.setCurrent((&url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path, RawQuery: u.RawQuery}).String()) - proxy := WebSocketProxyHandler(mgr, logger, false) + proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController()) proxySrv := httptest.NewServer(proxy) defer proxySrv.Close()