Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/cmd/api/api/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestProcessExec(t *testing.T) {
svc := &ApiService{procs: make(map[string]*processHandle)}

cmd := "sh"
args := []string{"-c", "echo -n out; echo -n err 1>&2; exit 3"}
args := []string{"-c", "printf out; printf err 1>&2; exit 3"}
body := &oapi.ProcessExecRequest{Command: cmd, Args: &args}
resp, err := svc.ProcessExec(ctx, oapi.ProcessExecRequestObject{Body: body})
require.NoError(t, err, "ProcessExec error")
Expand Down
2 changes: 1 addition & 1 deletion server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func main() {
})
})
rDevtools.Get("/*", func(w http.ResponseWriter, r *http.Request) {
devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages).ServeHTTP(w, r)
devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz).ServeHTTP(w, r)
})

srvDevtools := &http.Server{
Expand Down
2 changes: 1 addition & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module github.com/onkernel/kernel-images/server
go 1.25.0

require (
github.com/coder/websocket v1.8.14
github.com/fsnotify/fsnotify v1.9.0
github.com/getkin/kin-openapi v0.132.0
github.com/ghodss/yaml v1.0.0
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.1
github.com/google/uuid v1.5.0
github.com/gorilla/websocket v1.5.3
github.com/kelseyhightower/envconfig v1.4.0
github.com/nrednav/cuid2 v1.1.0
github.com/oapi-codegen/runtime v1.1.1
Expand Down
4 changes: 2 additions & 2 deletions server/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMz
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -29,8 +31,6 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
Expand Down
65 changes: 28 additions & 37 deletions server/lib/devtoolsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
"sync/atomic"
"time"

"github.com/gorilla/websocket"
"github.com/coder/websocket"
"github.com/onkernel/kernel-images/server/lib/scaletozero"
)

var devtoolsListeningRegexp = regexp.MustCompile(`DevTools listening on (ws://\S+)`)
Expand Down Expand Up @@ -147,8 +148,11 @@ func (u *UpstreamManager) runTailOnce(ctx context.Context) {
// WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and
// proxies them to the current upstream websocket URL. It expects only websocket requests.
// If logCDPMessages is true, all CDP messages will be logged with their direction.
func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool) http.Handler {
func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctrl.Disable(context.WithoutCancel(r.Context()))
defer ctrl.Enable(context.WithoutCancel(r.Context()))

upstreamCurrent := mgr.Current()
if upstreamCurrent == "" {
http.Error(w, "upstream not ready", http.StatusServiceUnavailable)
Expand All @@ -161,63 +165,50 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess
}
// Always use the full upstream path and query, ignoring the client's request path/query
upstreamURL := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String()
upgrader := websocket.Upgrader{
ReadBufferSize: 65536,
WriteBufferSize: 65536,
EnableCompression: true,
CheckOrigin: func(r *http.Request) bool { return true },
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
CompressionMode: websocket.CompressionContextTakeover,
}
logger.Info("upgrader config", slog.Any("upgrader", upgrader))
clientConn, err := upgrader.Upgrade(w, r, nil)
logger.Info("accept options", slog.Any("options", acceptOptions))
clientConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
logger.Error("websocket upgrade failed", slog.String("err", err.Error()))
logger.Error("websocket accept failed", slog.String("err", err.Error()))
return
}
clientConn.SetReadDeadline(time.Time{}) // No timeout--hold on to connections for dear life
clientConn.SetWriteDeadline(time.Time{}) // No timeout--hold on to connections for dear life
clientConn.SetReadLimit(100 * 1024 * 1024) // 100 MB. Effectively no maximum size of message from client
clientConn.EnableWriteCompression(true)
clientConn.SetCompressionLevel(6)

dialer := websocket.Dialer{
ReadBufferSize: 65536,
WriteBufferSize: 65536,
HandshakeTimeout: 30 * time.Second,
dialOptions := &websocket.DialOptions{
CompressionMode: websocket.CompressionContextTakeover,
}
logger.Info("dialer config", slog.Any("dialer", dialer))
upstreamConn, _, err := dialer.Dial(upstreamURL, nil)
logger.Info("dial options", slog.Any("options", dialOptions))
upstreamConn, _, err := websocket.Dial(r.Context(), upstreamURL, dialOptions)
if err != nil {
logger.Error("dial upstream failed", slog.String("err", err.Error()), slog.String("url", upstreamURL))
_ = clientConn.Close()
_ = clientConn.Close(websocket.StatusInternalError, "failed to connect to upstream")
return
}
upstreamConn.SetReadLimit(100 * 1024 * 1024) // 100 MB. Effectively no maximum size of message from upstream
upstreamConn.EnableWriteCompression(true)
upstreamConn.SetCompressionLevel(6)
upstreamConn.SetReadDeadline(time.Time{}) // no timeout
upstreamConn.SetWriteDeadline(time.Time{}) // no timeout
logger.Debug("proxying devtools websocket", slog.String("url", upstreamURL))

var once sync.Once
cleanup := func() {
once.Do(func() {
_ = upstreamConn.Close()
_ = clientConn.Close()
_ = upstreamConn.Close(websocket.StatusNormalClosure, "")
_ = clientConn.Close(websocket.StatusNormalClosure, "")
})
}
proxyWebSocket(r.Context(), clientConn, upstreamConn, cleanup, logger, logCDPMessages)
})
}

type wsConn interface {
ReadMessage() (messageType int, p []byte, err error)
WriteMessage(messageType int, data []byte) error
Close() error
Read(ctx context.Context) (websocket.MessageType, []byte, error)
Write(ctx context.Context, typ websocket.MessageType, p []byte) error
Close(statusCode websocket.StatusCode, reason string) error
}

// logCDPMessage logs a CDP message with its direction if logging is enabled
func logCDPMessage(logger *slog.Logger, direction string, mt int, msg []byte) {
if mt != websocket.TextMessage {
func logCDPMessage(logger *slog.Logger, direction string, mt websocket.MessageType, msg []byte) {
if mt != websocket.MessageText {
return // Only log text messages (CDP messages)
}

Expand Down Expand Up @@ -298,7 +289,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos

go func() {
for {
mt, msg, err := clientConn.ReadMessage()
mt, msg, err := clientConn.Read(ctx)
if err != nil {
logger.Error("client read error", slog.String("err", err.Error()))
errChan <- err
Expand All @@ -310,7 +301,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos
logCDPMessage(logger, "->", mt, msg)
}

if err := upstreamConn.WriteMessage(mt, msg); err != nil {
if err := upstreamConn.Write(ctx, mt, msg); err != nil {
logger.Error("upstream write error", slog.String("err", err.Error()))
errChan <- err
break
Expand All @@ -319,7 +310,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos
}()
go func() {
for {
mt, msg, err := upstreamConn.ReadMessage()
mt, msg, err := upstreamConn.Read(ctx)
if err != nil {
logger.Error("upstream read error", slog.String("err", err.Error()))
errChan <- err
Expand All @@ -331,7 +322,7 @@ func proxyWebSocket(ctx context.Context, clientConn, upstreamConn wsConn, onClos
logCDPMessage(logger, "<-", mt, msg)
}

if err := clientConn.WriteMessage(mt, msg); err != nil {
if err := clientConn.Write(ctx, mt, msg); err != nil {
logger.Error("client write error", slog.String("err", err.Error()))
errChan <- err
break
Expand Down
30 changes: 18 additions & 12 deletions server/lib/devtoolsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/coder/websocket"
"github.com/onkernel/kernel-images/server/lib/scaletozero"
)

func silentLogger() *slog.Logger {
Expand Down Expand Up @@ -66,21 +67,24 @@ func TestWaitForInitialTimeoutWhenLogMissing(t *testing.T) {
func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) {
// Start an echo websocket server as upstream
echoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
c, err := upgrader.Upgrade(w, r, nil)
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
t.Fatalf("upgrade failed: %v", err)
t.Fatalf("accept failed: %v", err)
return
}
defer c.Close()
defer c.Close(websocket.StatusNormalClosure, "")

ctx := r.Context()
for {
mt, msg, err := c.ReadMessage()
mt, msg, err := c.Read(ctx)
if err != nil {
return
}
// echo back with path+query prefixed to verify preservation
payload := []byte(r.URL.Path + "?" + r.URL.RawQuery + "|" + string(msg))
if err := c.WriteMessage(mt, payload); err != nil {
if err := c.Write(ctx, mt, payload); err != nil {
return
}
}
Expand All @@ -98,7 +102,7 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) {
// seed current upstream to echo server including path/query (bypass tailing)
mgr.setCurrent((&url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path, RawQuery: u.RawQuery}).String())

proxy := WebSocketProxyHandler(mgr, logger, false)
proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController())
proxySrv := httptest.NewServer(proxy)
defer proxySrv.Close()

Expand All @@ -109,16 +113,18 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) {
pu.Path = "/client"
pu.RawQuery = "x=y"

conn, _, err := websocket.DefaultDialer.Dial(pu.String(), nil)
ctx := context.Background()
conn, _, err := websocket.Dial(ctx, pu.String(), nil)
if err != nil {
t.Fatalf("dial proxy failed: %v", err)
}
defer conn.Close()
defer conn.Close(websocket.StatusNormalClosure, "")

msg := "hello"
if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
if err := conn.Write(ctx, websocket.MessageText, []byte(msg)); err != nil {
t.Fatalf("write failed: %v", err)
}
_, resp, err := conn.ReadMessage()
_, resp, err := conn.Read(ctx)
if err != nil {
t.Fatalf("read failed: %v", err)
}
Expand Down
Loading