diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..0fabde53 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,28 @@ +name: CI +on: + push: + branches: [feat/paop-steering, main] + pull_request: + branches: [feat/paop-steering, main] + +jobs: + build-and-test: + runs-on: ubuntu-latest + strategy: + matrix: + go: ["1.21", "1.22"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + cache: true + - name: Build + working-directory: components/execd + run: go build ./... + - name: Vet + working-directory: components/execd + run: go vet ./... + - name: Test + working-directory: components/execd + run: go test ./... -race -timeout 60s diff --git a/FORK.md b/FORK.md new file mode 100644 index 00000000..e65e8ebb --- /dev/null +++ b/FORK.md @@ -0,0 +1,57 @@ +# OpenSandbox Fork — PAOP WebSocket Steering Integration + +This is `danieliser/OpenSandbox`, a fork of [`alibaba/OpenSandbox`](https://github.com/alibaba/OpenSandbox). + +## Purpose + +This fork adds WebSocket-based steering support to `execd` so that PAOP (Persistent Agent +Orchestration Platform) can replace its tmux/poll executor with a push-based, in-container +execution model. + +Key goals: + +- Add `GET /ws/session/:sessionId` WebSocket endpoint to `components/execd` (Phase 1) +- Add PTY opt-in via `?pty=1` query parameter for interactive programs (Phase 2) +- Fix residual bugs from upstream PR #104 (`feat/bash-session`) (Phase 0) + +The PAOP-side counterpart lives in the `persistence` repo under `paop/executor/`. + +## Working Branch + +All active development happens on `feat/paop-steering`. + +## Upstream Sync + +To pull in upstream changes from `alibaba/OpenSandbox`: + +```bash +git fetch upstream +git checkout feat/paop-steering +git merge upstream/main +# Resolve conflicts, then push +git push origin feat/paop-steering +``` + +If the `upstream` remote is not yet configured: + +```bash +git remote add upstream https://github.com/alibaba/OpenSandbox.git +``` + +## What's PAOP-Only vs. Upstream Candidates + +| Phase | Changes | Upstream candidate? | +|-------|---------|---------------------| +| Phase 0 | Bug fixes for PR #104 (TOCTOU race, stderr routing, sentinel collision, context leak, shutdown race) | **Yes** — these are correctness fixes valuable to all users | +| Phase 1 | `GET /ws/session/:sessionId` WebSocket endpoint | **Possibly** — generic enough; needs upstream discussion | +| Phase 2 | PTY opt-in (`?pty=1`) | **Possibly** — generic; needs upstream discussion | +| Phase 3 | PAOP `WSExecutor` integration (lives in `persistence` repo) | **No** — PAOP-specific, stays here / in persistence repo | + +Phase 0 bug fixes are the strongest upstream PR candidates. They fix real correctness issues +independent of any PAOP integration and should be submitted back once validated. + +## CI + +GitHub Actions runs on every push and pull request targeting `feat/paop-steering` or `main`. +Matrix: Go 1.21 and 1.22. Steps: build, vet, race-detector test suite (60s timeout). +See `.github/workflows/ci.yml`. diff --git a/components/execd/go.mod b/components/execd/go.mod index 0b3f39cb..7cad2fb0 100644 --- a/components/execd/go.mod +++ b/components/execd/go.mod @@ -23,6 +23,7 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/creack/pty v1.1.24 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect diff --git a/components/execd/go.sum b/components/execd/go.sum index 0ebe846c..88c4b7fe 100644 --- a/components/execd/go.sum +++ b/components/execd/go.sum @@ -11,6 +11,8 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/execd/pkg/runtime/bash_session.go b/components/execd/pkg/runtime/bash_session.go new file mode 100644 index 00000000..deacb077 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session.go @@ -0,0 +1,965 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/creack/pty" + "github.com/google/uuid" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" + "github.com/alibaba/opensandbox/execd/pkg/log" +) + +const ( + envDumpStartMarker = "__EXECD_ENV_DUMP_START_8a3f__" + envDumpEndMarker = "__EXECD_ENV_DUMP_END_8a3f__" + exitMarkerPrefix = "__EXECD_EXIT_v1__:" + pwdMarkerPrefix = "__EXECD_PWD_v1__:" +) + +func (c *Controller) createBashSession(req *CreateContextRequest) (string, error) { + session := newBashSession(req.Cwd) + if err := session.start(); err != nil { + return "", fmt.Errorf("failed to start bash session: %w", err) + } + + c.bashSessionClientMap.Store(session.config.Session, session) + log.Info("created bash session %s", session.config.Session) + return session.config.Session, nil +} + +func (c *Controller) runBashSession(ctx context.Context, request *ExecuteCodeRequest) error { + session := c.getBashSession(request.Context) + if session == nil { + return ErrContextNotFound + } + + return session.run(ctx, request) +} + +func (c *Controller) getBashSession(sessionId string) *bashSession { + if v, ok := c.bashSessionClientMap.Load(sessionId); ok { + if s, ok := v.(*bashSession); ok { + return s + } + } + return nil +} + +func (c *Controller) closeBashSession(sessionId string) error { + session := c.getBashSession(sessionId) + if session == nil { + return ErrContextNotFound + } + + err := session.close() + if err != nil { + return err + } + + c.bashSessionClientMap.Delete(sessionId) + return nil +} + +func (c *Controller) CreateBashSession(req *CreateContextRequest) (string, error) { + return c.createBashSession(req) +} + +func (c *Controller) RunInBashSession(ctx context.Context, req *ExecuteCodeRequest) error { + return c.runBashSession(ctx, req) +} + +func (c *Controller) DeleteBashSession(sessionID string) error { + return c.closeBashSession(sessionID) +} + +// BashSessionStatus holds observable state for a bash session. +type BashSessionStatus struct { + SessionID string + Running bool + OutputOffset int64 +} + +// WriteSessionOutput appends data to the replay buffer for the named session. +// Used by the WebSocket handler to persist live output for reconnect replay. +func (c *Controller) WriteSessionOutput(sessionID string, data []byte) { + s := c.getBashSession(sessionID) + if s == nil { + return + } + s.replay.write(data) +} + +// ReplaySessionOutput returns buffered output bytes starting from offset. +// Returns (data, nextOffset). See replayBuffer.readFrom for semantics. +// GetBashSession retrieves a bash session by ID. Returns nil if not found. +func (c *Controller) GetBashSession(sessionID string) BashSession { + s := c.getBashSession(sessionID) + if s == nil { + return nil + } + return s +} + +func (c *Controller) ReplaySessionOutput(sessionID string, offset int64) ([]byte, int64, error) { + session := c.getBashSession(sessionID) + if session == nil { + return nil, 0, ErrContextNotFound + } + data, next := session.replay.readFrom(offset) + return data, next, nil +} + +// GetBashSessionStatus returns status info for a bash session, including replay buffer offset. +func (c *Controller) GetBashSessionStatus(sessionID string) (*BashSessionStatus, error) { + session := c.getBashSession(sessionID) + if session == nil { + return nil, ErrContextNotFound + } + session.mu.Lock() + running := session.wsPid != 0 + session.mu.Unlock() + return &BashSessionStatus{ + SessionID: sessionID, + Running: running, + OutputOffset: session.replay.Total(), + }, nil +} + +// Session implementation (pipe-based, no PTY) +func newBashSession(cwd string) *bashSession { + config := &bashSessionConfig{ + Session: uuidString(), + StartupTimeout: 5 * time.Second, + } + + env := make(map[string]string) + for _, kv := range os.Environ() { + if k, v, ok := splitEnvPair(kv); ok { + env[k] = v + } + } + + return &bashSession{ + config: config, + env: env, + cwd: cwd, + replay: newReplayBuffer(defaultReplayBufSize), + lastExitCode: -1, + } +} + +func (s *bashSession) start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.started { + return errors.New("session already started") + } + + s.started = true + return nil +} + +// Start launches an interactive bash process for WebSocket stdin/stdout mode. +// It is idempotent: if the process is already running, it returns nil. +// Unlike run(), this bash process stays alive reading from stdin until closed. +func (s *bashSession) Start() error { + s.mu.Lock() + if s.wsPid != 0 { + s.mu.Unlock() + return nil // already running + } + if s.closing { + s.mu.Unlock() + return errors.New("session is closing") + } + s.mu.Unlock() + + cmd := exec.Command("bash", "--noprofile", "--norc") + if s.cwd != "" { + cmd.Dir = s.cwd + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + stdinR, stdinW, err := os.Pipe() + if err != nil { + return fmt.Errorf("create stdin pipe: %w", err) + } + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + return fmt.Errorf("create stdout pipe: %w", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + _ = stdoutR.Close() + _ = stdoutW.Close() + return fmt.Errorf("create stderr pipe: %w", err) + } + + cmd.Stdin = stdinR + cmd.Stdout = stdoutW + cmd.Stderr = stderrW + + if err := cmd.Start(); err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + return fmt.Errorf("start bash: %w", err) + } + + // Close child-side ends in the parent process. + _ = stdinR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + + doneCh := make(chan struct{}) + + s.mu.Lock() + // Reset stale PTY state so WriteStdin targets the correct pipe on mode switch. + s.isPTY = false + s.ptmx = nil + s.stdin = stdinW + s.doneCh = doneCh + s.wsPid = cmd.Process.Pid + s.started = true + s.mu.Unlock() + + // Broadcast goroutine: reads real stdout, always writes to replay buffer, and + // fans out to the current per-connection sink when one is attached. + // Output produced during client downtime is preserved in the replay buffer so + // reconnecting clients can catch up via ?since=. + go func() { + defer stdoutR.Close() // release the OS fd when the shell's stdout closes + buf := make([]byte, 32*1024) + for { + n, err := stdoutR.Read(buf) + if n > 0 { + chunk := buf[:n] + s.replay.write(chunk) + s.outMu.Lock() + w := s.stdoutW + s.outMu.Unlock() + if w != nil { + _, _ = w.Write(chunk) + } + } + if err != nil { + s.outMu.Lock() + if s.stdoutW != nil { + _ = s.stdoutW.Close() + s.stdoutW = nil + } + s.outMu.Unlock() + return + } + } + }() + + // Broadcast goroutine: reads real stderr, always writes to replay buffer, and + // fans out to the current per-connection sink when one is attached. + go func() { + defer stderrR.Close() // release the OS fd when the shell's stderr closes + buf := make([]byte, 32*1024) + for { + n, err := stderrR.Read(buf) + if n > 0 { + chunk := buf[:n] + s.replay.write(chunk) + s.outMu.Lock() + w := s.stderrW + s.outMu.Unlock() + if w != nil { + _, _ = w.Write(chunk) + } + } + if err != nil { + s.outMu.Lock() + if s.stderrW != nil { + _ = s.stderrW.Close() + s.stderrW = nil + } + s.outMu.Unlock() + return + } + } + }() + + go func() { + _ = cmd.Wait() + code := -1 + if cmd.ProcessState != nil { + code = cmd.ProcessState.ExitCode() + } + _ = stdinW.Close() + s.mu.Lock() + s.lastExitCode = code + s.wsPid = 0 + s.mu.Unlock() + close(doneCh) + }() + + return nil +} + +// StartPTY launches an interactive bash process using a PTY instead of pipes. +// stdout and stderr arrive merged on the PTY master fd. +// It is idempotent: if the process is already running, it returns nil. +func (s *bashSession) StartPTY() error { + s.mu.Lock() + if s.wsPid != 0 { + s.mu.Unlock() + return nil // already running + } + if s.closing { + s.mu.Unlock() + return errors.New("session is closing") + } + s.mu.Unlock() + + cmd := exec.Command("bash", "--noprofile", "--norc") + if s.cwd != "" { + cmd.Dir = s.cwd + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Env = append(os.Environ(), "TERM=xterm-256color", "COLUMNS=80", "LINES=24") + + ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Rows: 24, Cols: 80}) + if err != nil { + return fmt.Errorf("pty start: %w", err) + } + + doneCh := make(chan struct{}) + + s.mu.Lock() + s.ptmx = ptmx + s.isPTY = true + s.doneCh = doneCh + s.wsPid = cmd.Process.Pid + s.started = true + s.mu.Unlock() + + // Broadcast goroutine: reads PTY master (stdout+stderr merged), always writes to + // replay buffer, and fans out to the current per-connection sink when one is attached. + go func() { + buf := make([]byte, 32*1024) + for { + n, err := ptmx.Read(buf) + if n > 0 { + chunk := buf[:n] + s.replay.write(chunk) + s.outMu.Lock() + w := s.stdoutW + s.outMu.Unlock() + if w != nil { + _, _ = w.Write(chunk) + } + } + if err != nil { + s.outMu.Lock() + if s.stdoutW != nil { + _ = s.stdoutW.Close() + s.stdoutW = nil + } + s.outMu.Unlock() + return + } + } + }() + + go func() { + _ = cmd.Wait() + code := -1 + if cmd.ProcessState != nil { + code = cmd.ProcessState.ExitCode() + } + _ = ptmx.Close() + s.mu.Lock() + s.lastExitCode = code + s.wsPid = 0 + // Clear PTY descriptors so a subsequent Start() in pipe mode is clean. + s.isPTY = false + s.ptmx = nil + s.mu.Unlock() + close(doneCh) + }() + + return nil +} + +// ResizePTY sends a TIOCSWINSZ ioctl to the PTY master. +// No-op if not in PTY mode. +func (s *bashSession) ResizePTY(cols, rows uint16) error { + s.mu.Lock() + ptmx := s.ptmx + s.mu.Unlock() + if ptmx == nil { + return nil + } + return pty.Setsize(ptmx, &pty.Winsize{Rows: rows, Cols: cols}) +} + +// SendSignal sends a named OS signal (e.g. "SIGINT") to the session's process group. +// No-op if the session is not running or the signal name is unknown. +func (s *bashSession) SendSignal(name string) { + s.mu.Lock() + pid := s.wsPid + s.mu.Unlock() + if pid == 0 { + return + } + sig := signalByName(name) + if sig == 0 { + return + } + _ = syscall.Kill(-pid, sig) +} + +// signalByName maps a POSIX signal name to its syscall.Signal number. +// Returns 0 for unknown names. +func signalByName(name string) syscall.Signal { + switch name { + case "SIGINT": + return syscall.SIGINT + case "SIGTERM": + return syscall.SIGTERM + case "SIGKILL": + return syscall.SIGKILL + case "SIGQUIT": + return syscall.SIGQUIT + case "SIGHUP": + return syscall.SIGHUP + default: + return 0 + } +} + +// WriteStdin writes p to the session's stdin. +// In PTY mode it writes to the PTY master fd; in pipe mode it writes to the stdin pipe. +// Returns error if the session has not started or the pipe is closed. +func (s *bashSession) WriteStdin(p []byte) (int, error) { + s.mu.Lock() + isPTY := s.isPTY + ptmx := s.ptmx + stdin := s.stdin + s.mu.Unlock() + + if isPTY { + if ptmx == nil { + return 0, errors.New("PTY not started") + } + return ptmx.Write(p) + } + if stdin == nil { + return 0, errors.New("session not started") + } + return stdin.Write(p) +} + +// LockWS atomically acquires exclusive WebSocket access. +// Returns false if already locked. +func (s *bashSession) LockWS() bool { + return s.wsConnected.CompareAndSwap(false, true) +} + +// UnlockWS releases the WebSocket connection lock. +func (s *bashSession) UnlockWS() { + s.wsConnected.Store(false) +} + +// IsRunning reports whether the long-lived WS bash process is currently alive. +func (s *bashSession) IsRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.wsPid != 0 +} + +// ExitCode returns the exit code of the most recently completed process. +// Returns -1 if the process has not yet exited. +func (s *bashSession) ExitCode() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastExitCode +} + +// AttachOutput installs a fresh per-connection pipe pair and returns readers plus a detach func. +// The broadcast goroutine (started by Start/StartPTY) copies from the real OS pipe into the +// current PipeWriter. Calling detach() closes the PipeWriters so the returned readers return +// EOF, unblocking any scanner goroutines on this connection without affecting the underlying pipe. +func (s *bashSession) AttachOutput() (stdout io.Reader, stderr io.Reader, detach func()) { + stdoutR, stdoutW := io.Pipe() + + s.outMu.Lock() + // Close any previous writer (e.g. from a stale prior connection) before swapping. + if s.stdoutW != nil { + _ = s.stdoutW.Close() + } + s.stdoutW = stdoutW + s.outMu.Unlock() + + var stderrR *io.PipeReader + var stderrPW *io.PipeWriter + + s.mu.Lock() + isPTY := s.isPTY + s.mu.Unlock() + + if !isPTY { + stderrR, stderrPW = io.Pipe() + s.outMu.Lock() + if s.stderrW != nil { + _ = s.stderrW.Close() + } + s.stderrW = stderrPW + s.outMu.Unlock() + } + + detach = func() { + s.outMu.Lock() + // Only close if we're still the active writer (guards against double-detach). + if s.stdoutW == stdoutW { + _ = stdoutW.Close() + s.stdoutW = nil + } + if stderrPW != nil && s.stderrW == stderrPW { + _ = stderrPW.Close() + s.stderrW = nil + } + s.outMu.Unlock() + } + + return stdoutR, stderrR, detach +} + +// Done returns a channel that is closed when the WS-mode bash process exits. +func (s *bashSession) Done() <-chan struct{} { return s.doneCh } + +// IsPTY reports whether the session is running in PTY mode. +func (s *bashSession) IsPTY() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.isPTY +} + +func (s *bashSession) trackCurrentProcess(pid int) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closing { + // close() already ran while we were in cmd.Start(); kill immediately + _ = syscall.Kill(-pid, syscall.SIGKILL) + return + } + s.currentProcessPid = pid +} + +func (s *bashSession) untrackCurrentProcess() { + s.mu.Lock() + defer s.mu.Unlock() + s.currentProcessPid = 0 +} + +//nolint:gocognit +func (s *bashSession) run(ctx context.Context, request *ExecuteCodeRequest) error { + s.mu.Lock() + if !s.started { + s.mu.Unlock() + return errors.New("session not started") + } + + envSnapshot := copyEnvMap(s.env) + + cwd := s.cwd + // override original cwd if specified + if request.Cwd != "" { + cwd = request.Cwd + } + sessionID := s.config.Session + s.mu.Unlock() + + startAt := time.Now() + if request.Hooks.OnExecuteInit != nil { + request.Hooks.OnExecuteInit(sessionID) + } + + wait := request.Timeout + if wait <= 0 { + wait = 24 * 3600 * time.Second // max to 24 hours + } + + ctx, cancel := context.WithTimeout(ctx, wait) + defer cancel() + + script := buildWrappedScript(request.Code, envSnapshot, cwd) + scriptFile, err := os.CreateTemp("", "execd_bash_*.sh") + if err != nil { + return fmt.Errorf("create script file: %w", err) + } + scriptPath := scriptFile.Name() + defer os.Remove(scriptPath) // clean up temp script regardless of outcome + if _, err := scriptFile.WriteString(script); err != nil { + _ = scriptFile.Close() + return fmt.Errorf("write script file: %w", err) + } + if err := scriptFile.Close(); err != nil { + return fmt.Errorf("close script file: %w", err) + } + + cmd := exec.CommandContext(ctx, "bash", "--noprofile", "--norc", scriptPath) + cmd.Dir = cwd // set OS-level CWD; harmless if cwd == "" (inherits daemon CWD) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // Do not pass envSnapshot via cmd.Env to avoid "argument list too long" when session env is large. + // Child inherits parent env (nil => default in Go). The script file already has "export K=V" for + // all session vars at the top, so the session environment is applied when the script runs. + // Use OS pipes (not io.Pipe) so we can close the parent-side write ends immediately + // after cmd.Start() without breaking in-flight writes. The kernel buffers data + // independently; closing the write end in the parent just signals EOF to the reader + // once the child has exited and flushed, without any "write on closed pipe" errors. + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + return fmt.Errorf("create stdout pipe: %w", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + _ = stdoutR.Close() + _ = stdoutW.Close() + return fmt.Errorf("create stderr pipe: %w", err) + } + cmd.Stdout = stdoutW + cmd.Stderr = stderrW + + if err := cmd.Start(); err != nil { + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + log.Error("start bash session failed: %v (command: %q)", err, request.Code) + return fmt.Errorf("start bash: %w", err) + } + defer s.untrackCurrentProcess() + s.trackCurrentProcess(cmd.Process.Pid) + + // Close parent-side write ends now. The child has inherited its own copies; + // closing ours here means the reader gets EOF as soon as the child exits, + // without waiting for cmd.Wait() — eliminating the scan↔Wait deadlock. + _ = stdoutW.Close() + _ = stderrW.Close() + + // Drain stderr in a separate goroutine; fire OnExecuteStderr for each line. + stderrDone := make(chan struct{}) + go func() { + defer close(stderrDone) + defer stderrR.Close() // release OS fd once stderr is fully drained + stderrScanner := bufio.NewScanner(stderrR) + stderrScanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024) + for stderrScanner.Scan() { + line := stderrScanner.Text() + "\n" + s.replay.write([]byte(line)) + if request.Hooks.OnExecuteStderr != nil { + request.Hooks.OnExecuteStderr(stderrScanner.Text()) + } + } + }() + + defer stdoutR.Close() // release OS fd once stdout is fully drained + scanner := bufio.NewScanner(stdoutR) + scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024) + + var ( + envLines []string + pwdLine string + exitCode *int + inEnv bool + ) + + for scanner.Scan() { + line := scanner.Text() + switch { + case line == envDumpStartMarker: + inEnv = true + case line == envDumpEndMarker: + inEnv = false + case strings.HasPrefix(line, exitMarkerPrefix): + if code, err := strconv.Atoi(strings.TrimPrefix(line, exitMarkerPrefix)); err == nil { + exitCode = &code //nolint:ineffassign + } + case strings.HasPrefix(line, pwdMarkerPrefix): + pwdLine = strings.TrimPrefix(line, pwdMarkerPrefix) + default: + if inEnv { + envLines = append(envLines, line) + continue + } + s.replay.write([]byte(line + "\n")) + if request.Hooks.OnExecuteStdout != nil { + request.Hooks.OnExecuteStdout(line) + } + } + } + + scanErr := scanner.Err() + waitErr := cmd.Wait() + // Wait for stderr goroutine to drain. + <-stderrDone + + if scanErr != nil { + log.Error("read stdout failed: %v (command: %q)", scanErr, request.Code) + return fmt.Errorf("read stdout: %w", scanErr) + } + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + log.Error("timeout after %s while running command: %q", wait, request.Code) + return fmt.Errorf("timeout after %s while running command %q", wait, request.Code) + } + + if exitCode == nil && cmd.ProcessState != nil { + code := cmd.ProcessState.ExitCode() //nolint:staticcheck + exitCode = &code //nolint:ineffassign + } + + updatedEnv := parseExportDump(envLines) + s.mu.Lock() + if len(updatedEnv) > 0 { + s.env = updatedEnv + } + if pwdLine != "" { + s.cwd = pwdLine + } + s.mu.Unlock() + + var exitErr *exec.ExitError + if waitErr != nil && !errors.As(waitErr, &exitErr) { + log.Error("command wait failed: %v (command: %q)", waitErr, request.Code) + return waitErr + } + + userExitCode := 0 + if exitCode != nil { + userExitCode = *exitCode + } + + if userExitCode != 0 { + errMsg := fmt.Sprintf("command exited with code %d", userExitCode) + if waitErr != nil { + errMsg = waitErr.Error() + } + if request.Hooks.OnExecuteError != nil { + request.Hooks.OnExecuteError(&execute.ErrorOutput{ + EName: "CommandExecError", + EValue: strconv.Itoa(userExitCode), + Traceback: []string{errMsg}, + }) + } + log.Error("CommandExecError: %s (command: %q)", errMsg, request.Code) + return nil + } + + if request.Hooks.OnExecuteComplete != nil { + request.Hooks.OnExecuteComplete(time.Since(startAt)) + } + + return nil +} + +func buildWrappedScript(command string, env map[string]string, cwd string) string { + var b strings.Builder + + keys := make([]string, 0, len(env)) + for k := range env { + v := env[k] + if isValidEnvKey(k) && !envKeysNotPersisted[k] && len(v) <= maxPersistedEnvValueSize { + keys = append(keys, k) + } + } + sort.Strings(keys) + for _, k := range keys { + b.WriteString("export ") + b.WriteString(k) + b.WriteString("=") + b.WriteString(shellEscape(env[k])) + b.WriteString("\n") + } + + if cwd != "" { + b.WriteString("cd ") + b.WriteString(shellEscape(cwd)) + b.WriteString("\n") + } + + b.WriteString(command) + if !strings.HasSuffix(command, "\n") { + b.WriteString("\n") + } + + b.WriteString("__USER_EXIT_CODE__=$?\n") + b.WriteString("printf \"\\n%s\\n\" \"" + envDumpStartMarker + "\"\n") + b.WriteString("export -p\n") + b.WriteString("printf \"%s\\n\" \"" + envDumpEndMarker + "\"\n") + b.WriteString("printf \"" + pwdMarkerPrefix + "%s\\n\" \"$(pwd)\"\n") + b.WriteString("printf \"" + exitMarkerPrefix + "%s\\n\" \"$__USER_EXIT_CODE__\"\n") + b.WriteString("exit \"$__USER_EXIT_CODE__\"\n") + + return b.String() +} + +// envKeysNotPersisted are not carried across runs (prompt/display vars). +var envKeysNotPersisted = map[string]bool{ + "PS1": true, "PS2": true, "PS3": true, "PS4": true, + "PROMPT_COMMAND": true, +} + +// maxPersistedEnvValueSize caps single env value length as a safeguard. +const maxPersistedEnvValueSize = 8 * 1024 + +func parseExportDump(lines []string) map[string]string { + if len(lines) == 0 { + return nil + } + env := make(map[string]string, len(lines)) + for _, line := range lines { + k, v, ok := parseExportLine(line) + if !ok || envKeysNotPersisted[k] || len(v) > maxPersistedEnvValueSize { + continue + } + env[k] = v + } + return env +} + +func parseExportLine(line string) (string, string, bool) { + const prefix = "declare -x " + if !strings.HasPrefix(line, prefix) { + return "", "", false + } + rest := strings.TrimSpace(strings.TrimPrefix(line, prefix)) + if rest == "" { + return "", "", false + } + name, value := rest, "" + if eq := strings.Index(rest, "="); eq >= 0 { + name = rest[:eq] + raw := rest[eq+1:] + if unquoted, err := strconv.Unquote(raw); err == nil { + value = unquoted + } else { + value = strings.Trim(raw, `"`) + } + } + if !isValidEnvKey(name) { + return "", "", false + } + return name, value, true +} + +func shellEscape(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + +func isValidEnvKey(key string) bool { + if key == "" { + return false + } + + for i, r := range key { + if i == 0 { + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && r != '_' { + return false + } + continue + } + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && (r < '0' || r > '9') && r != '_' { + return false + } + } + + return true +} + +func copyEnvMap(src map[string]string) map[string]string { + if src == nil { + return map[string]string{} + } + + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func splitEnvPair(kv string) (string, string, bool) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return "", "", false + } + if !isValidEnvKey(parts[0]) { + return "", "", false + } + return parts[0], parts[1], true +} + +func (s *bashSession) close() error { + s.mu.Lock() + defer s.mu.Unlock() + + s.closing = true + wsPid := s.wsPid + runPid := s.currentProcessPid + ptmx := s.ptmx + s.wsPid = 0 + s.currentProcessPid = 0 + s.started = false + s.env = nil + s.cwd = "" + + for _, pid := range []int{wsPid, runPid} { + if pid != 0 { + if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil { + log.Warning("kill session process group %d: %v (process may have already exited)", pid, err) + } + } + } + if ptmx != nil { + _ = ptmx.Close() + s.ptmx = nil + } + return nil +} + +func uuidString() string { + return uuid.New().String() +} diff --git a/components/execd/pkg/runtime/bash_session_pty_test.go b/components/execd/pkg/runtime/bash_session_pty_test.go new file mode 100644 index 00000000..a41bb402 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_pty_test.go @@ -0,0 +1,172 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/creack/pty" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// readOutputTimeout drains r for up to d, returning all collected bytes. +func readOutputTimeout(r io.Reader, d time.Duration) string { + var buf strings.Builder + deadline := time.Now().Add(d) + tmp := make([]byte, 256) + for time.Now().Before(deadline) { + n, err := r.Read(tmp) + if n > 0 { + buf.Write(tmp[:n]) + } + if err != nil { + break + } + } + return buf.String() +} + +// TestPTY_BasicExecution verifies that a PTY session can run a command and +// the output is received on the PTY master. +func TestPTY_BasicExecution(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + s := newBashSession(t.TempDir()) + t.Cleanup(func() { _ = s.close() }) + + require.NoError(t, s.StartPTY()) + require.True(t, s.IsRunning(), "expected bash process to be running after StartPTY") + + // Send a command via stdin. + _, err := s.WriteStdin([]byte("echo hi\n")) + require.NoError(t, err) + + // Read output via AttachOutput. + outR, _, detach := s.AttachOutput() + defer detach() + out := readOutputTimeout(outR, 3*time.Second) + assert.Contains(t, out, "hi", "expected 'hi' in PTY output, got: %q", out) +} + +// TestPTY_ResizeUpdatesWinsize verifies that ResizePTY changes the terminal +// dimensions reported by the PTY (no error path; structural change verified). +func TestPTY_ResizeUpdatesWinsize(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + s := newBashSession(t.TempDir()) + t.Cleanup(func() { _ = s.close() }) + + require.NoError(t, s.StartPTY()) + + // Resize to known dimensions. + require.NoError(t, s.ResizePTY(120, 40)) + + // Verify via pty.GetsizeFull that the kernel registered the new size. + s.mu.Lock() + ptmx := s.ptmx + s.mu.Unlock() + require.NotNil(t, ptmx) + + ws, err := pty.GetsizeFull(ptmx) + require.NoError(t, err) + assert.Equal(t, uint16(120), ws.Cols, "expected cols=120 after resize") + assert.Equal(t, uint16(40), ws.Rows, "expected rows=40 after resize") +} + +// TestPTY_AnsiSequencesPresent verifies that PTY output contains ANSI escape +// sequences (the prompt), which distinguishes PTY mode from plain pipe mode. +func TestPTY_AnsiSequencesPresent(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + s := newBashSession(t.TempDir()) + t.Cleanup(func() { _ = s.close() }) + + require.NoError(t, s.StartPTY()) + + // Send a command that forces a prompt re-emission. + _, err := s.WriteStdin([]byte("PS1='\\e[1;32m>>\\e[0m '; echo marker\n")) + require.NoError(t, err) + + outR, _, detach := s.AttachOutput() + defer detach() + out := readOutputTimeout(outR, 3*time.Second) + // ANSI escape sequences start with ESC (\x1b) followed by [ + assert.Contains(t, out, "\x1b[", "expected ANSI escape sequence in PTY output, got: %q", out) + assert.Contains(t, out, "marker", "expected 'marker' in PTY output, got: %q", out) +} + +// TestPTY_PipeModeUnchanged verifies that a session created without ?pty=1 +// still uses plain pipes and has no PTY fd open — regression guard. +func TestPTY_PipeModeUnchanged(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + s := newBashSession(t.TempDir()) + t.Cleanup(func() { _ = s.close() }) + + require.NoError(t, s.Start()) + require.True(t, s.IsRunning(), "expected bash process to be running") + + // PTY fields must be unset in pipe mode. + s.mu.Lock() + isPTY := s.isPTY + ptmx := s.ptmx + s.mu.Unlock() + + assert.False(t, isPTY, "isPTY must be false in pipe mode") + assert.Nil(t, ptmx, "ptmx must be nil in pipe mode") + + // ResizePTY must be a no-op (no error) when not in PTY mode. + require.NoError(t, s.ResizePTY(100, 30)) + + // Attach output first so the broadcast goroutine has a PipeWriter in place, + // then write stdin to avoid a race where output lands only in the replay buffer. + outR, _, detach := s.AttachOutput() + defer detach() + + // Stdin must still work via pipe. + _, err := s.WriteStdin([]byte("echo pipe-ok\n")) + require.NoError(t, err) + + // Poll the replay buffer until output appears — this is reliable regardless of + // whether the output arrived before or after AttachOutput installed the PipeWriter. + deadline := time.Now().Add(5 * time.Second) + var got string + for time.Now().Before(deadline) { + data, _ := s.replay.readFrom(0) + got = string(data) + if strings.Contains(got, "pipe-ok") { + break + } + time.Sleep(50 * time.Millisecond) + } + _ = outR // attached above; detach() will clean up + assert.Contains(t, got, "pipe-ok", "expected pipe-mode echo output in replay buffer, got: %q", got) +} diff --git a/components/execd/pkg/runtime/bash_session_test.go b/components/execd/pkg/runtime/bash_session_test.go new file mode 100644 index 00000000..b18af3de --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_test.go @@ -0,0 +1,599 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "context" + "fmt" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" +) + +func TestBashSession_NonZeroExitEmitsError(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + c := NewController("", "") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var ( + sessionID string + stdoutLine string + errCh = make(chan *execute.ErrorOutput, 1) + completeCh = make(chan struct{}, 1) + ) + + req := &ExecuteCodeRequest{ + Language: Bash, + Code: `echo "before"; exit 7`, + Cwd: t.TempDir(), + Timeout: 5 * time.Second, + Hooks: ExecuteResultHook{ + OnExecuteInit: func(s string) { sessionID = s }, + OnExecuteStdout: func(s string) { stdoutLine = s }, + OnExecuteError: func(err *execute.ErrorOutput) { errCh <- err }, + OnExecuteComplete: func(_ time.Duration) { + completeCh <- struct{}{} + }, + }, + } + + session, err := c.createBashSession(&CreateContextRequest{}) + assert.NoError(t, err) + req.Context = session + require.NoError(t, c.runBashSession(ctx, req)) + + var gotErr *execute.ErrorOutput + select { + case gotErr = <-errCh: + case <-time.After(2 * time.Second): + require.Fail(t, "expected error hook to be called") + } + require.NotNil(t, gotErr, "expected non-nil error output") + require.Equal(t, "CommandExecError", gotErr.EName) + require.Equal(t, "7", gotErr.EValue) + require.NotEmpty(t, sessionID, "expected session id to be set") + require.Equal(t, "before", stdoutLine) + + select { + case <-completeCh: + require.Fail(t, "did not expect completion hook on non-zero exit") + default: + } +} + +func TestBashSession_envAndExitCode(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + require.Equal(t, session.config.Session, ctx, "unexpected session in OnExecuteInit") + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + // 1) export an env var + request := &ExecuteCodeRequest{ + Code: "export FOO=hello", + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + exportStdoutCount := len(stdoutLines) + + // 2) verify env is persisted + request = &ExecuteCodeRequest{ + Code: "echo $FOO", + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + echoLines := stdoutLines[exportStdoutCount:] + foundHello := false + for _, line := range echoLines { + if strings.TrimSpace(line) == "hello" { + foundHello = true + break + } + } + require.True(t, foundHello, "expected echo $FOO to output 'hello', got %v", echoLines) + + // 3) ensure exit code of previous command is reflected in shell state + request = &ExecuteCodeRequest{ + Code: "false; echo EXIT:$?", + Hooks: hooks, + Timeout: 3 * time.Second, + } + prevCount := len(stdoutLines) + require.NoError(t, session.run(context.Background(), request)) + exitLines := stdoutLines[prevCount:] + foundExit := false + for _, line := range exitLines { + if strings.Contains(line, "EXIT:1") { + foundExit = true + break + } + } + require.True(t, foundExit, "expected exit code output 'EXIT:1', got %v", exitLines) + require.Equal(t, 3, initCalls, "OnExecuteInit expected 3 calls") + require.Equal(t, 3, completeCalls, "OnExecuteComplete expected 3 calls") +} + +func TestBashSession_envLargeOutputChained(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + require.Equal(t, session.config.Session, ctx, "unexpected session in OnExecuteInit") + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + runAndCollect := func(cmd string) []string { + start := len(stdoutLines) + request := &ExecuteCodeRequest{ + Code: cmd, + Hooks: hooks, + Timeout: 10 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + return append([]string(nil), stdoutLines[start:]...) + } + + lines1 := runAndCollect("export FOO=hello1; for i in $(seq 1 60); do echo A${i}:$FOO; done") + require.GreaterOrEqual(t, len(lines1), 60, "expected >=60 lines for cmd1") + require.True(t, containsLine(lines1, "A1:hello1") && containsLine(lines1, "A60:hello1"), "env not reflected in cmd1 output, got %v", lines1[:3]) + + lines2 := runAndCollect("export FOO=${FOO}_next; export BAR=bar1; for i in $(seq 1 60); do echo B${i}:$FOO:$BAR; done") + require.GreaterOrEqual(t, len(lines2), 60, "expected >=60 lines for cmd2") + require.True(t, containsLine(lines2, "B1:hello1_next:bar1") && containsLine(lines2, "B60:hello1_next:bar1"), "env not propagated to cmd2 output, sample %v", lines2[:3]) + + lines3 := runAndCollect("export BAR=${BAR}_last; for i in $(seq 1 60); do echo C${i}:$FOO:$BAR; done; echo FINAL_FOO=$FOO; echo FINAL_BAR=$BAR") + require.GreaterOrEqual(t, len(lines3), 62, "expected >=62 lines for cmd3") // 60 lines + 2 finals + require.True(t, containsLine(lines3, "C1:hello1_next:bar1_last") && containsLine(lines3, "C60:hello1_next:bar1_last"), "env not propagated to cmd3 output, sample %v", lines3[:3]) + require.True(t, containsLine(lines3, "FINAL_FOO=hello1_next") && containsLine(lines3, "FINAL_BAR=bar1_last"), "final env lines missing, got %v", lines3[len(lines3)-5:]) + require.Equal(t, 3, initCalls, "OnExecuteInit expected 3 calls") + require.Equal(t, 3, completeCalls, "OnExecuteComplete expected 3 calls") +} + +func TestBashSession_cwdPersistsWithoutOverride(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + targetDir := t.TempDir() + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + require.NoError(t, session.run(context.Background(), req)) + return append([]string(nil), stdoutLines[start:]...) + } + + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", targetDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(firstRunLines, targetDir), "expected cd to update cwd to %q, got %v", targetDir, firstRunLines) + + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(secondRunLines, targetDir), "expected subsequent run to inherit cwd %q, got %v", targetDir, secondRunLines) + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + require.Equal(t, targetDir, finalCwd, "expected session cwd to stay at %q", targetDir) +} + +func TestBashSession_requestCwdOverridesAfterCd(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + initialDir := t.TempDir() + overrideDir := t.TempDir() + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + require.NoError(t, session.run(context.Background(), req)) + return append([]string(nil), stdoutLines[start:]...) + } + + // First request: change session cwd via script. + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", initialDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(firstRunLines, initialDir), "expected cd to update cwd to %q, got %v", initialDir, firstRunLines) + + // Second request: explicit Cwd overrides session cwd. + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Cwd: overrideDir, + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(secondRunLines, overrideDir), "expected command to run in override cwd %q, got %v", overrideDir, secondRunLines) + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + require.Equal(t, overrideDir, finalCwd, "expected session cwd updated to override dir %q", overrideDir) +} + +func TestBashSession_envDumpNotLeakedWhenNoTrailingNewline(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; printf '{"foo":1}'`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + + require.Len(t, stdoutLines, 1, "expected exactly one stdout line") + require.Equal(t, `{"foo":1}`, strings.TrimSpace(stdoutLines[0])) + for _, line := range stdoutLines { + require.NotContains(t, line, envDumpStartMarker, "env dump leaked into stdout: %v", stdoutLines) + require.NotContains(t, line, "declare -x", "env dump leaked into stdout: %v", stdoutLines) + } +} + +func TestBashSession_envDumpNotLeakedWhenNoOutput(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; true`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + + require.LessOrEqual(t, len(stdoutLines), 1, "expected at most one stdout line, got %v", stdoutLines) + if len(stdoutLines) == 1 { + require.Empty(t, strings.TrimSpace(stdoutLines[0]), "expected empty stdout") + } + for _, line := range stdoutLines { + require.NotContains(t, line, envDumpStartMarker, "env dump leaked into stdout: %v", stdoutLines) + require.NotContains(t, line, "declare -x", "env dump leaked into stdout: %v", stdoutLines) + } +} + +func TestBashSession_heredoc(t *testing.T) { + rewardDir := t.TempDir() + controller := NewController("", "") + + sessionID, err := controller.CreateBashSession(&CreateContextRequest{}) + require.NoError(t, err) + t.Cleanup(func() { _ = controller.DeleteBashSession(sessionID) }) + + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + fmt.Printf("[stdout] %s\n", line) + }, + OnExecuteComplete: func(d time.Duration) { + fmt.Printf("[complete] %s\n", d) + }, + } + + // First run: heredoc + reward file write. + script := fmt.Sprintf(` +set -x +reward_dir=%q +mkdir -p "$reward_dir" + +cat > /tmp/repro_script.sh <<'SHEOF' +#!/usr/bin/env sh +echo "hello heredoc" +SHEOF + +chmod +x /tmp/repro_script.sh +/tmp/repro_script.sh +echo "after heredoc" +echo 1 > "$reward_dir/reward.txt" +cat "$reward_dir/reward.txt" +`, rewardDir) + + ctx := context.Background() + require.NoError(t, controller.RunInBashSession(ctx, &ExecuteCodeRequest{ + Context: sessionID, + Language: Bash, + Timeout: 10 * time.Second, + Code: script, + Hooks: hooks, + })) + + // Second run: ensure the session keeps working. + require.NoError(t, controller.RunInBashSession(ctx, &ExecuteCodeRequest{ + Context: sessionID, + Language: Bash, + Timeout: 5 * time.Second, + Code: "echo 'second command works'", + Hooks: hooks, + })) +} + +func TestBashSession_execReplacesShell(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +cat > /tmp/exec_child.sh <<'EOF' +echo "child says hi" +EOF +chmod +x /tmp/exec_child.sh +exec /tmp/exec_child.sh +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + require.NoError(t, session.run(context.Background(), request), "expected exec to complete without killing the session") + require.True(t, containsLine(stdoutLines, "child says hi"), "expected child output, got %v", stdoutLines) + + // Subsequent run should still work because we restart bash per run. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + require.NoError(t, session.run(context.Background(), request), "expected run to succeed after exec replaced the shell") + require.True(t, containsLine(stdoutLines, "still-alive"), "expected follow-up output, got %v", stdoutLines) +} + +func TestBashSession_complexExec(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +LOG_FILE=$(mktemp) +export LOG_FILE +exec 3>&1 4>&2 +exec > >(tee "$LOG_FILE") 2>&1 + +set -x +echo "from-complex-exec" +exec 1>&3 2>&4 # step record +echo "after-restore" +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + require.NoError(t, session.run(context.Background(), request), "expected complex exec to finish") + require.True(t, containsLine(stdoutLines, "from-complex-exec") && containsLine(stdoutLines, "after-restore"), "expected exec outputs, got %v", stdoutLines) + + // Session should still be usable. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + require.NoError(t, session.run(context.Background(), request), "expected run to succeed after complex exec") + require.True(t, containsLine(stdoutLines, "still-alive"), "expected follow-up output, got %v", stdoutLines) +} + +func containsLine(lines []string, target string) bool { + for _, l := range lines { + if strings.TrimSpace(l) == target { + return true + } + } + return false +} + +// TestBashSession_CloseKillsRunningProcess verifies that session.close() kills the active +// process group so that a long-running command (e.g. sleep) does not keep running after close. +func TestBashSession_CloseKillsRunningProcess(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + session := newBashSession("") + require.NoError(t, session.start()) + + runDone := make(chan error, 1) + req := &ExecuteCodeRequest{ + Code: "sleep 30", + Timeout: 60 * time.Second, + Hooks: ExecuteResultHook{}, + } + go func() { + runDone <- session.run(context.Background(), req) + }() + + // Give the child process time to start. + time.Sleep(200 * time.Millisecond) + + // Close should kill the process group; run() should return soon (it may return nil + // because the code path treats non-zero exit as success after calling OnExecuteError). + require.NoError(t, session.close()) + + select { + case <-runDone: + // run() returned; process was killed so we did not wait 30s + case <-time.After(3 * time.Second): + require.Fail(t, "run did not return within 3s after close (process was not killed)") + } +} + +// TestBashSession_DeleteBashSessionKillsRunningProcess verifies that DeleteBashSession +// (close path) kills the active run and removes the session from the controller. +func TestBashSession_DeleteBashSessionKillsRunningProcess(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + c := NewController("", "") + sessionID, err := c.CreateBashSession(&CreateContextRequest{}) + require.NoError(t, err) + + runDone := make(chan error, 1) + req := &ExecuteCodeRequest{ + Language: Bash, + Context: sessionID, + Code: "sleep 30", + Timeout: 60 * time.Second, + Hooks: ExecuteResultHook{}, + } + go func() { + runDone <- c.RunInBashSession(context.Background(), req) + }() + + time.Sleep(200 * time.Millisecond) + + require.NoError(t, c.DeleteBashSession(sessionID)) + + select { + case <-runDone: + // RunInBashSession returned; process was killed + case <-time.After(3 * time.Second): + require.Fail(t, "RunInBashSession did not return within 3s after DeleteBashSession") + } + + // Session should be gone; deleting again should return ErrContextNotFound. + err = c.DeleteBashSession(sessionID) + require.Error(t, err) + require.ErrorIs(t, err, ErrContextNotFound) +} + +// TestBashSession_CloseWithNoActiveRun verifies that close() with no running command +// completes without error and does not hang. +func TestBashSession_CloseWithNoActiveRun(t *testing.T) { + session := newBashSession("") + require.NoError(t, session.start()) + + done := make(chan struct{}, 1) + go func() { + _ = session.close() + done <- struct{}{} + }() + + select { + case <-done: + // close() returned + case <-time.After(2 * time.Second): + require.Fail(t, "close() did not return within 2s when no run was active") + } +} diff --git a/components/execd/pkg/runtime/bash_session_windows.go b/components/execd/pkg/runtime/bash_session_windows.go new file mode 100644 index 00000000..0891128a --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_windows.go @@ -0,0 +1,65 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows +// +build windows + +package runtime + +import ( + "context" + "errors" +) + +var errBashSessionNotSupported = errors.New("bash session is not supported on windows") + +// BashSessionStatus holds observable state for a bash session. +type BashSessionStatus struct { + SessionID string + Running bool + OutputOffset int64 +} + +// CreateBashSession is not supported on Windows. +func (c *Controller) CreateBashSession(_ *CreateContextRequest) (string, error) { //nolint:revive + return "", errBashSessionNotSupported +} + +// RunInBashSession is not supported on Windows. +func (c *Controller) RunInBashSession(_ context.Context, _ *ExecuteCodeRequest) error { //nolint:revive + return errBashSessionNotSupported +} + +// DeleteBashSession is not supported on Windows. +func (c *Controller) DeleteBashSession(_ string) error { //nolint:revive + return errBashSessionNotSupported +} + +// GetBashSession is not supported on Windows. +func (c *Controller) GetBashSession(_ string) BashSession { //nolint:revive + return nil +} + +// GetBashSessionStatus is not supported on Windows. +func (c *Controller) GetBashSessionStatus(_ string) (*BashSessionStatus, error) { //nolint:revive + return nil, errBashSessionNotSupported +} + +// ReplaySessionOutput is not supported on Windows. +func (c *Controller) ReplaySessionOutput(_ string, _ int64) ([]byte, int64, error) { //nolint:revive + return nil, 0, errBashSessionNotSupported +} + +// WriteSessionOutput is not supported on Windows. +func (c *Controller) WriteSessionOutput(_ string, _ []byte) {} //nolint:revive diff --git a/components/execd/pkg/runtime/command_common.go b/components/execd/pkg/runtime/command_common.go index 960ff273..03205e20 100644 --- a/components/execd/pkg/runtime/command_common.go +++ b/components/execd/pkg/runtime/command_common.go @@ -46,18 +46,17 @@ func (c *Controller) tailStdPipe(file string, onExecute func(text string), done // getCommandKernel retrieves a command execution context. func (c *Controller) getCommandKernel(sessionID string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.commandClientMap[sessionID] + if v, ok := c.commandClientMap.Load(sessionID); ok { + if kernel, ok := v.(*commandKernel); ok { + return kernel + } + } + return nil } // storeCommandKernel registers a command execution context. func (c *Controller) storeCommandKernel(sessionID string, kernel *commandKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.commandClientMap[sessionID] = kernel + c.commandClientMap.Store(sessionID, kernel) } // stdLogDescriptor creates temporary files for capturing command output. diff --git a/components/execd/pkg/runtime/command_status.go b/components/execd/pkg/runtime/command_status.go index 97f112b1..7bf6f58d 100644 --- a/components/execd/pkg/runtime/command_status.go +++ b/components/execd/pkg/runtime/command_status.go @@ -40,15 +40,19 @@ type CommandOutput struct { } func (c *Controller) commandSnapshot(session string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return nil } + // Hold the read lock while copying so the snapshot is consistent with + // concurrent markCommandFinished writes (which take the write lock). + c.mu.RLock() cp := *kernel + c.mu.RUnlock() return &cp } @@ -116,8 +120,11 @@ func (c *Controller) markCommandFinished(session string, exitCode int, errMsg st c.mu.Lock() defer c.mu.Unlock() - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return } diff --git a/components/execd/pkg/runtime/context.go b/components/execd/pkg/runtime/context.go index c2f9052a..42f7fe50 100644 --- a/components/execd/pkg/runtime/context.go +++ b/components/execd/pkg/runtime/context.go @@ -31,7 +31,9 @@ import ( ) // CreateContext provisions a kernel-backed session and returns its ID. +// Bash language uses Jupyter kernel like other languages; for pipe-based bash sessions use CreateBashSession (session API). func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { + // Create a new Jupyter session. var ( client *jupyter.Client session *jupytersession.Session @@ -42,7 +44,7 @@ func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { log.Error("failed to create session, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(*req) + client, session, err = c.createJupyterContext(*req) return err }) if err != nil { @@ -114,20 +116,11 @@ func (c *Controller) deleteSessionAndCleanup(session string) error { if c.getJupyterKernel(session) == nil { return ErrContextNotFound } - if err := c.jupyterClient().DeleteSession(session); err != nil { return err } - - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.jupyterClientMap, session) - for lang, id := range c.defaultLanguageJupyterSessions { - if id == session { - delete(c.defaultLanguageJupyterSessions, lang) - } - } + c.jupyterClientMap.Delete(session) + c.deleteDefaultSessionByID(session) return nil } @@ -146,8 +139,12 @@ func (c *Controller) newIpynbPath(sessionID, cwd string) (string, error) { return filepath.Join(cwd, fmt.Sprintf("%s.ipynb", sessionID)), nil } -// createDefaultLanguageContext prewarms a session for stateless execution. -func (c *Controller) createDefaultLanguageContext(language Language) error { +// createDefaultLanguageJupyterContext prewarms a session for stateless execution. +func (c *Controller) createDefaultLanguageJupyterContext(language Language) error { + if c.getDefaultLanguageSession(language) != "" { + return nil + } + var ( client *jupyter.Client session *jupytersession.Session @@ -157,7 +154,7 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { log.Error("failed to create context, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(CreateContextRequest{ + client, session, err = c.createJupyterContext(CreateContextRequest{ Language: language, Cwd: "", }) @@ -167,20 +164,17 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { return err } - c.mu.Lock() - defer c.mu.Unlock() - - c.defaultLanguageJupyterSessions[language] = session.ID - c.jupyterClientMap[session.ID] = &jupyterKernel{ + c.setDefaultLanguageSession(language, session.ID) + c.jupyterClientMap.Store(session.ID, &jupyterKernel{ kernelID: session.Kernel.ID, client: client, language: language, - } + }) return nil } -// createContext performs the actual context creation workflow. -func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { +// createJupyterContext performs the actual context creation workflow. +func (c *Controller) createJupyterContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { client := c.jupyterClient() kernel, err := c.searchKernel(client, request.Language) @@ -220,10 +214,7 @@ func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Clien // storeJupyterKernel caches a session -> kernel mapping. func (c *Controller) storeJupyterKernel(sessionID string, kernel *jupyterKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.jupyterClientMap[sessionID] = kernel + c.jupyterClientMap.Store(sessionID, kernel) } func (c *Controller) jupyterClient() *jupyter.Client { @@ -239,49 +230,63 @@ func (c *Controller) jupyterClient() *jupyter.Client { jupyter.WithHTTPClient(httpClient)) } -func (c *Controller) listAllContexts() ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() +func (c *Controller) getDefaultLanguageSession(language Language) string { + if v, ok := c.defaultLanguageSessions.Load(language); ok { + if session, ok := v.(string); ok { + return session + } + } + return "" +} + +func (c *Controller) setDefaultLanguageSession(language Language, sessionID string) { + c.defaultLanguageSessions.Store(language, sessionID) +} +func (c *Controller) deleteDefaultSessionByID(sessionID string) { + c.defaultLanguageSessions.Range(func(key, value any) bool { + if s, ok := value.(string); ok && s == sessionID { + c.defaultLanguageSessions.Delete(key) + } + return true + }) +} + +func (c *Controller) listAllContexts() ([]CodeContext, error) { contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil { - contexts = append(contexts, CodeContext{ - ID: session, - Language: kernel.language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil { + contexts = append(contexts, CodeContext{ID: session, Language: kernel.language}) } - } + return true + }) - for language, defaultContext := range c.defaultLanguageJupyterSessions { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) - } + c.defaultLanguageSessions.Range(func(key, value any) bool { + lang, _ := key.(Language) + session, _ := value.(string) + if session == "" { + return true + } + contexts = append(contexts, CodeContext{ID: session, Language: lang}) + return true + }) return contexts, nil } func (c *Controller) listLanguageContexts(language Language) ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() - contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil && kernel.language == language { - contexts = append(contexts, CodeContext{ - ID: session, - Language: language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil && kernel.language == language { + contexts = append(contexts, CodeContext{ID: session, Language: language}) } - } + return true + }) - if defaultContext := c.defaultLanguageJupyterSessions[language]; defaultContext != "" { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) + if defaultContext := c.getDefaultLanguageSession(language); defaultContext != "" { + contexts = append(contexts, CodeContext{ID: defaultContext, Language: language}) } return contexts, nil diff --git a/components/execd/pkg/runtime/context_test.go b/components/execd/pkg/runtime/context_test.go index 9a0376d0..e0cdf3e6 100644 --- a/components/execd/pkg/runtime/context_test.go +++ b/components/execd/pkg/runtime/context_test.go @@ -27,8 +27,8 @@ import ( func TestListContextsAndNewIpynbPath(t *testing.T) { c := NewController("http://example", "token") - c.jupyterClientMap["session-python"] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Go] = "session-go-default" + c.jupyterClientMap.Store("session-python", &jupyterKernel{language: Python}) + c.defaultLanguageSessions.Store(Go, "session-go-default") pyContexts, err := c.listLanguageContexts(Python) require.NoError(t, err) @@ -107,13 +107,13 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[sessionID] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Python] = sessionID + c.jupyterClientMap.Store(sessionID, &jupyterKernel{language: Python}) + c.defaultLanguageSessions.Store(Python, sessionID) require.NoError(t, c.DeleteContext(sessionID)) require.Nil(t, c.getJupyterKernel(sessionID), "expected cache to be cleared") - _, ok := c.defaultLanguageJupyterSessions[Python] + _, ok := c.defaultLanguageSessions.Load(Python) require.False(t, ok, "expected default session entry to be removed") } @@ -138,17 +138,17 @@ func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[session1] = &jupyterKernel{language: lang} - c.jupyterClientMap[session2] = &jupyterKernel{language: lang} - c.defaultLanguageJupyterSessions[lang] = session2 + c.jupyterClientMap.Store(session1, &jupyterKernel{language: lang}) + c.jupyterClientMap.Store(session2, &jupyterKernel{language: lang}) + c.defaultLanguageSessions.Store(lang, session2) require.NoError(t, c.DeleteLanguageContext(lang)) - _, ok := c.jupyterClientMap[session1] + _, ok := c.jupyterClientMap.Load(session1) require.False(t, ok, "expected session1 removed from cache") - _, ok = c.jupyterClientMap[session2] + _, ok = c.jupyterClientMap.Load(session2) require.False(t, ok, "expected session2 removed from cache") - _, ok = c.defaultLanguageJupyterSessions[lang] + _, ok = c.defaultLanguageSessions.Load(lang) require.False(t, ok, "expected default entry removed") require.Equal(t, 1, deleteCalls[session1]) require.Equal(t, 1, deleteCalls[session2]) diff --git a/components/execd/pkg/runtime/ctrl.go b/components/execd/pkg/runtime/ctrl.go index 36c325b4..2946fd81 100644 --- a/components/execd/pkg/runtime/ctrl.go +++ b/components/execd/pkg/runtime/ctrl.go @@ -35,14 +35,15 @@ var kernelWaitingBackoff = wait.Backoff{ // Controller manages code execution across runtimes. type Controller struct { - baseURL string - token string - mu sync.RWMutex - jupyterClientMap map[string]*jupyterKernel - defaultLanguageJupyterSessions map[Language]string - commandClientMap map[string]*commandKernel - db *sql.DB - dbOnce sync.Once + baseURL string + token string + mu sync.RWMutex + jupyterClientMap sync.Map // map[sessionID]*jupyterKernel + defaultLanguageSessions sync.Map // map[Language]string + commandClientMap sync.Map // map[sessionID]*commandKernel + bashSessionClientMap sync.Map // map[sessionID]*bashSession + db *sql.DB + dbOnce sync.Once } type jupyterKernel struct { @@ -65,15 +66,12 @@ type commandKernel struct { content string } + // NewController creates a runtime controller. func NewController(baseURL, token string) *Controller { return &Controller{ baseURL: baseURL, token: token, - - jupyterClientMap: make(map[string]*jupyterKernel), - defaultLanguageJupyterSessions: make(map[Language]string), - commandClientMap: make(map[string]*commandKernel), } } diff --git a/components/execd/pkg/runtime/interrupt.go b/components/execd/pkg/runtime/interrupt.go index 1a9515fa..67902a3d 100644 --- a/components/execd/pkg/runtime/interrupt.go +++ b/components/execd/pkg/runtime/interrupt.go @@ -38,6 +38,8 @@ func (c *Controller) Interrupt(sessionID string) error { case c.getCommandKernel(sessionID) != nil: kernel := c.getCommandKernel(sessionID) return c.killPid(kernel.pid) + case c.getBashSession(sessionID) != nil: + return c.closeBashSession(sessionID) default: return errors.New("no such session") } diff --git a/components/execd/pkg/runtime/jupyter.go b/components/execd/pkg/runtime/jupyter.go index cdc0a6cc..9ea33b13 100644 --- a/components/execd/pkg/runtime/jupyter.go +++ b/components/execd/pkg/runtime/jupyter.go @@ -29,9 +29,8 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest return errors.New("language runtime server not configured, please check your image runtime") } if request.Context == "" { - if _, exists := c.defaultLanguageJupyterSessions[request.Language]; !exists { - err := c.createDefaultLanguageContext(request.Language) - if err != nil { + if c.getDefaultLanguageSession(request.Language) == "" { + if err := c.createDefaultLanguageJupyterContext(request.Language); err != nil { return err } } @@ -39,7 +38,7 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest var targetSessionID string if request.Context == "" { - targetSessionID = c.defaultLanguageJupyterSessions[request.Language] + targetSessionID = c.getDefaultLanguageSession(request.Language) } else { targetSessionID = request.Context } @@ -135,10 +134,12 @@ func (c *Controller) setWorkingDir(_ *jupyterKernel, _ *CreateContextRequest) er // getJupyterKernel retrieves a kernel connection from the session map. func (c *Controller) getJupyterKernel(sessionID string) *jupyterKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.jupyterClientMap[sessionID] + if v, ok := c.jupyterClientMap.Load(sessionID); ok { + if kernel, ok := v.(*jupyterKernel); ok { + return kernel + } + } + return nil } // searchKernel finds a kernel spec name for the given language. diff --git a/components/execd/pkg/runtime/replay_buffer.go b/components/execd/pkg/runtime/replay_buffer.go new file mode 100644 index 00000000..9eaa5594 --- /dev/null +++ b/components/execd/pkg/runtime/replay_buffer.go @@ -0,0 +1,80 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtime + +import "sync" + +const defaultReplayBufSize = 1 << 20 // 1 MiB + +// replayBuffer is a bounded circular output buffer that allows reconnecting +// clients to replay missed output from a given byte offset. +type replayBuffer struct { + mu sync.Mutex + buf []byte // circular storage + size int // capacity + head int // next write position (wraps mod size) + total int64 // total bytes ever written (monotonic offset) +} + +func newReplayBuffer(size int) *replayBuffer { + return &replayBuffer{buf: make([]byte, size), size: size} +} + +// write appends p to the ring buffer, overwriting oldest bytes if full. +func (r *replayBuffer) write(p []byte) { + r.mu.Lock() + defer r.mu.Unlock() + for _, b := range p { + r.buf[r.head] = b + r.head = (r.head + 1) % r.size + r.total++ + } +} + +// Total returns the total number of bytes ever written to the buffer. +// Safe to call concurrently. +func (r *replayBuffer) Total() int64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.total +} + +// readFrom returns all bytes from offset onward (up to buffer capacity). +// Returns (data, nextOffset). +// - If offset >= total, returns (nil, total) — client is caught up. +// - If offset is too old (evicted), reads from the oldest available byte. +func (r *replayBuffer) readFrom(offset int64) ([]byte, int64) { + r.mu.Lock() + defer r.mu.Unlock() + + oldest := r.total - int64(r.size) + if oldest < 0 { + oldest = 0 + } + if offset >= r.total { + return nil, r.total // nothing new + } + if offset < oldest { + offset = oldest // truncated — client missed some output + } + + n := int(r.total - offset) + out := make([]byte, n) + start := int(offset % int64(r.size)) + for i := 0; i < n; i++ { + out[i] = r.buf[(start+i)%r.size] + } + return out, r.total +} diff --git a/components/execd/pkg/runtime/replay_buffer_test.go b/components/execd/pkg/runtime/replay_buffer_test.go new file mode 100644 index 00000000..a5142e0f --- /dev/null +++ b/components/execd/pkg/runtime/replay_buffer_test.go @@ -0,0 +1,123 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtime + +import ( + "bytes" + "sync" + "testing" +) + +func TestReplayBuffer_WriteAndRead(t *testing.T) { + rb := newReplayBuffer(64) + data := []byte("hello world\n") + rb.write(data) + + got, next := rb.readFrom(0) + if !bytes.Equal(got, data) { + t.Fatalf("expected %q, got %q", data, got) + } + if next != int64(len(data)) { + t.Fatalf("expected next=%d, got %d", len(data), next) + } +} + +func TestReplayBuffer_CircularEviction(t *testing.T) { + size := 16 + rb := newReplayBuffer(size) + + // Write 20 bytes — 4 bytes will be evicted. + first := []byte("AAAA") // will be evicted + second := []byte("BBBBBBBBBBBBBBBB") // 16 bytes fills the buffer + rb.write(first) + rb.write(second) + + // total == 20, oldest == 4 + got, next := rb.readFrom(0) // offset 0 is too old, should be clamped to oldest + if next != 20 { + t.Fatalf("expected next=20, got %d", next) + } + // Should get exactly 16 bytes (the second write, which overwrote first) + if len(got) != size { + t.Fatalf("expected %d bytes, got %d", size, len(got)) + } + if !bytes.Equal(got, second) { + t.Fatalf("expected %q, got %q", second, got) + } +} + +func TestReplayBuffer_OffsetCaughtUp(t *testing.T) { + rb := newReplayBuffer(64) + rb.write([]byte("some output\n")) + + total := rb.total + got, next := rb.readFrom(total) + if got != nil { + t.Fatalf("expected nil for caught-up offset, got %q", got) + } + if next != total { + t.Fatalf("expected next=%d, got %d", total, next) + } +} + +func TestReplayBuffer_LargeGap(t *testing.T) { + size := 8 + rb := newReplayBuffer(size) + + // Write 16 bytes total — first 8 are evicted. + rb.write([]byte("12345678")) // bytes 0-7 + rb.write([]byte("ABCDEFGH")) // bytes 8-15 + + // Request from offset 0 (evicted) — should return from oldest available (offset 8). + got, next := rb.readFrom(0) + if next != 16 { + t.Fatalf("expected next=16, got %d", next) + } + if !bytes.Equal(got, []byte("ABCDEFGH")) { + t.Fatalf("expected oldest available data, got %q", got) + } +} + +func TestReplayBuffer_Concurrent(t *testing.T) { + rb := newReplayBuffer(1024) + chunk := []byte("x") + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + rb.write(chunk) + } + }() + } + + // Concurrent reader + wg.Add(1) + go func() { + defer wg.Done() + var off int64 + for i := 0; i < 50; i++ { + _, off = rb.readFrom(off) + } + }() + + wg.Wait() + + if rb.total != 1000 { + t.Fatalf("expected total=1000, got %d", rb.total) + } +} diff --git a/components/execd/pkg/runtime/types.go b/components/execd/pkg/runtime/types.go index 4dc459b3..ae81d1ec 100644 --- a/components/execd/pkg/runtime/types.go +++ b/components/execd/pkg/runtime/types.go @@ -16,6 +16,10 @@ package runtime import ( "fmt" + "io" + "os" + "sync" + "sync/atomic" "time" "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" @@ -82,3 +86,84 @@ type CodeContext struct { ID string `json:"id,omitempty"` Language Language `json:"language"` } + +// BashSession is the interface exposed to callers outside the runtime package. +type BashSession interface { + // LockWS atomically acquires exclusive WebSocket access. Returns false if already locked. + LockWS() bool + // UnlockWS releases the WebSocket connection lock. + UnlockWS() + // Start launches the underlying bash process (idempotent: no-op if already running). + Start() error + // StartPTY launches the bash process with a PTY instead of pipes (idempotent). + StartPTY() error + // IsRunning reports whether the bash process is currently alive. + IsRunning() bool + // ExitCode returns the exit code of the most recently completed process (-1 if not exited). + ExitCode() int + // WriteStdin writes p to the session's stdin pipe. + WriteStdin(p []byte) (int, error) + // AttachOutput returns per-connection pipe readers for stdout (and stderr in pipe mode) + // plus a detach func. The broadcast goroutine copies from the real OS pipe into these + // readers. Calling detach closes the write ends, causing the readers to return EOF and + // unblocking any scanner goroutines without touching the underlying OS pipe. + AttachOutput() (stdout io.Reader, stderr io.Reader, detach func()) + // Done returns a channel closed when the bash process exits. + Done() <-chan struct{} + // SendSignal sends a named signal (e.g. "SIGINT") to the process group. + SendSignal(name string) + // ResizePTY sends a TIOCSWINSZ ioctl to the PTY master. No-op if not in PTY mode. + ResizePTY(cols, rows uint16) error + // IsPTY reports whether the session is currently running in PTY mode. + IsPTY() bool +} + +// bashSessionConfig holds bash session configuration. +type bashSessionConfig struct { + // StartupSource is a list of scripts sourced on startup. + StartupSource []string + // Session is the session identifier. + Session string + // StartupTimeout is the startup timeout. + StartupTimeout time.Duration + // Cwd is the working directory. + Cwd string +} + +// bashSession represents a bash session. +type bashSession struct { + config *bashSessionConfig + mu sync.Mutex + started bool + closing bool + env map[string]string + cwd string + + // currentProcessPid tracks the PID of a short-lived run() command (fire-and-forget execution). + // Set after cmd.Start(), cleared when run() returns. Used by close() for cleanup. + currentProcessPid int + + // wsPid tracks the PID of the long-lived interactive shell started by Start/StartPTY. + // Kept separate so run() cannot clobber it, ensuring IsRunning/close/interrupt remain correct. + wsPid int + + // replay buffers all output so reconnecting clients can catch up on missed bytes. + replay *replayBuffer + + // WS mode fields — set by Start/StartPTY when the interactive shell is launched. + wsConnected atomic.Bool // true while a WS connection holds the session + lastExitCode int // stored on process exit; -1 if not yet exited + stdin io.WriteCloser // write end of bash's stdin pipe (WS mode) + doneCh chan struct{} // closed when WS-mode bash process exits + + // Output broadcast: a goroutine reads the real OS pipe and writes to the current + // per-connection PipeWriter. On WS disconnect, detach() closes the PipeWriter so + // the handler's scanner gets EOF. On reconnect a new PipeWriter is swapped in. + outMu sync.Mutex // guards stdoutW / stderrW + stdoutW *io.PipeWriter // current broadcast sink for stdout; nil before first attach + stderrW *io.PipeWriter // current broadcast sink for stderr; nil in PTY mode or before attach + + // PTY mode fields — non-nil only when started via StartPTY(). + isPTY bool // true when session uses a PTY instead of pipes + ptmx *os.File // PTY master fd (read=stdout+stderr merged, write=stdin) +} diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index df4a28db..0043458a 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -18,7 +18,9 @@ import ( "context" "errors" "fmt" + "io" "net/http" + "strconv" "sync" "time" @@ -236,6 +238,192 @@ func (c *CodeInterpretingController) DeleteContext() { c.RespondSuccess(nil) } +// CreateSession creates a new bash session (create_session API). +// An empty body is allowed and is treated as default options (no cwd override). +func (c *CodeInterpretingController) CreateSession() { + var request model.CreateSessionRequest + if err := c.bindJSON(&request); err != nil && !errors.Is(err, io.EOF) { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("error parsing request. %v", err), + ) + return + } + + sessionID, err := codeRunner.CreateBashSession(&runtime.CreateContextRequest{ + Cwd: request.Cwd, + }) + if err != nil { + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error creating session. %v", err), + ) + return + } + + c.RespondSuccess(model.CreateSessionResponse{SessionID: sessionID}) +} + +// RunInSession runs code in an existing bash session and streams output via SSE (run_in_session API). +func (c *CodeInterpretingController) RunInSession() { + sessionID := c.ctx.Param("sessionId") + if sessionID == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + var request model.RunInSessionRequest + if err := c.bindJSON(&request); err != nil && !errors.Is(err, io.EOF) { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("error parsing request. %v", err), + ) + return + } + if err := request.Validate(); err != nil { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("invalid request. %v", err), + ) + return + } + + timeout := time.Duration(request.TimeoutMs) * time.Millisecond + runReq := &runtime.ExecuteCodeRequest{ + Language: runtime.Bash, + Context: sessionID, + Code: request.Code, + Cwd: request.Cwd, + Timeout: timeout, + } + // Verify the session exists BEFORE committing the SSE response (200 + headers). + // Once setupSSEResponse() flushes, we can no longer send HTTP error codes. + if _, _, err := codeRunner.ReplaySessionOutput(sessionID, 0); err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError(http.StatusNotFound, model.ErrorCodeContextNotFound, "session not found") + return + } + } + + ctx, cancel := context.WithCancel(c.ctx.Request.Context()) + defer cancel() + runReq.Hooks = c.setServerEventsHandler(ctx) + + c.setupSSEResponse() + + // If ?since= is provided, replay buffered output before live stream. + if sinceStr := c.ctx.Query("since"); sinceStr != "" { + if since, err := strconv.ParseInt(sinceStr, 10, 64); err == nil { + if replayData, _, replayErr := codeRunner.ReplaySessionOutput(sessionID, since); replayErr == nil && len(replayData) > 0 { + event := model.ServerStreamEvent{ + Type: model.StreamEventTypeReplay, + Text: string(replayData), + Timestamp: time.Now().UnixMilli(), + } + c.writeSingleEvent("Replay", event.ToJSON(), true, event.Summary()) + } + } + } + + err := codeRunner.RunInBashSession(ctx, runReq) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("session not found. %v", err), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error running in session. %v", err), + ) + return + } + + time.Sleep(flag.ApiGracefulShutdownTimeout) +} + +// GetSessionStatus returns status and replay buffer offset for a bash session. +func (c *CodeInterpretingController) GetSessionStatus() { + sessionID := c.ctx.Param("sessionId") + if sessionID == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + status, err := codeRunner.GetBashSessionStatus(sessionID) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("session %s not found", sessionID), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error getting session status. %v", err), + ) + return + } + + c.RespondSuccess(model.SessionStatusResponse{ + SessionID: status.SessionID, + Running: status.Running, + OutputOffset: status.OutputOffset, + }) +} + +// DeleteSession deletes a bash session (delete_session API). +func (c *CodeInterpretingController) DeleteSession() { + sessionID := c.ctx.Param("sessionId") + if sessionID == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + err := codeRunner.DeleteBashSession(sessionID) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("session %s not found", sessionID), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error deleting session %s. %v", sessionID, err), + ) + return + } + + c.RespondSuccess(nil) +} + // buildExecuteCodeRequest converts a RunCodeRequest to runtime format. func (c *CodeInterpretingController) buildExecuteCodeRequest(request model.RunCodeRequest) *runtime.ExecuteCodeRequest { req := &runtime.ExecuteCodeRequest{ diff --git a/components/execd/pkg/web/controller/session_ws.go b/components/execd/pkg/web/controller/session_ws.go new file mode 100644 index 00000000..bcc26ea6 --- /dev/null +++ b/components/execd/pkg/web/controller/session_ws.go @@ -0,0 +1,287 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controller + +import ( + "context" + "io" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/alibaba/opensandbox/execd/pkg/web/model" +) + +var wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + // execd runs inside a container; auth-header check is the access gate. + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// SessionWebSocket handles GET /ws/session/:sessionId — bidirectional stdin/stdout steering. +func (c *CodeInterpretingController) SessionWebSocket() { + sessionID := c.ctx.Param("sessionId") + + // 1. Look up session BEFORE upgrade so we can still return HTTP errors. + session := codeRunner.GetBashSession(sessionID) + if session == nil { + c.RespondError(http.StatusNotFound, model.ErrorCodeContextNotFound, "session not found") + return + } + + // 2. Acquire exclusive WS lock (prevents concurrent connections). + if !session.LockWS() { + c.RespondError(http.StatusConflict, model.ErrorCodeRuntimeError, "session already connected") + return + } + // Do NOT defer UnlockWS here — we release it manually after pump goroutines + // finish, so a reconnecting client cannot start new scanners on the shared pipe + // while stale scanners from the previous connection are still blocked in Scan(). + + // 3. Upgrade HTTP → WebSocket. + conn, err := wsUpgrader.Upgrade(c.ctx.Writer, c.ctx.Request, nil) + if err != nil { + // gorilla writes the HTTP error response automatically. + session.UnlockWS() + return + } + defer conn.Close() + + // writeMu serializes all writes to conn — gorilla/websocket requires this. + var writeMu sync.Mutex + writeJSON := func(v any) error { + writeMu.Lock() + defer writeMu.Unlock() + conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) //nolint:errcheck + return conn.WriteJSON(v) + } + + usePTY := c.ctx.Query("pty") == "1" + + // 4. Start bash if not already running. + if !session.IsRunning() { + var startErr error + if usePTY { + startErr = session.StartPTY() + } else { + startErr = session.Start() + } + if startErr != nil { + _ = writeJSON(model.ServerFrame{ + Type: "error", + Error: "failed to start bash", + Code: model.WSErrCodeStartFailed, + }) + session.UnlockWS() + return + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // 5. Snapshot the replay buffer offset THEN attach the live pipe — in that order. + // + // Why this order matters: + // - Snapshotting first captures a definite "replay up to here" watermark. + // - AttachOutput installs the PipeWriter so the broadcast goroutine begins + // queuing bytes into the pipe immediately. + // - Any byte produced between snapshot and attach lands in the pipe only + // (not in the replay frame), so each byte is delivered exactly once. + // - If we attached first and snapshotted second, bytes produced in that + // window would appear in both the replay frame and the live pipe (duplicate). + var replayData []byte + var replayNextOffset int64 + if sinceStr := c.ctx.Query("since"); sinceStr != "" { + if since, parseErr := strconv.ParseInt(sinceStr, 10, 64); parseErr == nil { + replayData, replayNextOffset, _ = codeRunner.ReplaySessionOutput(sessionID, since) + } + } + + stdout, stderr, detach := session.AttachOutput() + var pumpWg sync.WaitGroup + defer func() { + cancel() + detach() + pumpWg.Wait() + session.UnlockWS() + }() + + // 6. Send replay frame now that the live sink is attached — no gap, no duplicates. + if len(replayData) > 0 { + _ = writeJSON(model.ServerFrame{ + Type: "replay", + Data: string(replayData), + Offset: replayNextOffset, + }) + } + + // 7. Send connected frame — mode derived from actual session state, not the request parameter, + // so reconnecting clients always receive the correct terminal assumptions. + mode := "pipe" + if session.IsPTY() { + mode = "pty" + } + _ = writeJSON(model.ServerFrame{ + Type: "connected", + SessionID: sessionID, + Mode: mode, + }) + + // 8. Ping/pong keepalive — RFC 6455 control-level pings every 30s. + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) //nolint:errcheck + return nil + }) + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + writeMu.Lock() + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck + writeErr := conn.WriteMessage(websocket.PingMessage, nil) + writeMu.Unlock() + if writeErr != nil { + cancel() + return + } + } + } + }() + + // streamPump reads raw byte chunks from r and forwards them as WS frames of the given type. + // Raw reads (rather than line scanning) are required so that PTY prompts, progress spinners, + // and cursor-control sequences — which are often written without a trailing newline — are + // delivered immediately without buffering. + streamPump := func(r io.Reader, frameType string) { + defer pumpWg.Done() + buf := make([]byte, 32*1024) + for { + n, readErr := r.Read(buf) + if n > 0 { + if ctx.Err() != nil { + return + } + if writeErr := writeJSON(model.ServerFrame{ + Type: frameType, + Data: string(buf[:n]), + Timestamp: time.Now().UnixMilli(), + }); writeErr != nil { + return + } + } + if readErr != nil { + if readErr != io.EOF && ctx.Err() == nil { + _ = writeJSON(model.ServerFrame{ + Type: "error", + Error: frameType + " read error: " + readErr.Error(), + Code: model.WSErrCodeRuntimeError, + }) + cancel() + } + return + } + } + } + + // 8. Write pump — stdout (raw byte chunks). + pumpWg.Add(1) + go streamPump(stdout, "stdout") + + // 9. Write pump — stderr (pipe mode only; PTY merges stderr into ptmx; nil in PTY mode). + if stderr != nil { + pumpWg.Add(1) + go streamPump(stderr, "stderr") + } + + // 10. Exit watcher — sends exit frame when bash process ends, then closes the + // connection so the read loop's ReadJSON unblocks immediately rather than waiting + // up to 60s for the deadline. Without this, reconnect attempts during that window + // hit "session already connected" even though the process is already gone. + go func() { + defer cancel() + doneCh := session.Done() + if doneCh == nil { + return + } + select { + case <-ctx.Done(): + return + case <-doneCh: + } + exitCode := session.ExitCode() + _ = writeJSON(model.ServerFrame{Type: "exit", ExitCode: &exitCode}) + // Close with a normal closure code so the read loop gets an error immediately. + writeMu.Lock() + _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "process exited")) + writeMu.Unlock() + conn.Close() + }() + + // 11. Read pump — client → bash stdin. + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) //nolint:errcheck + for { + var frame model.ClientFrame + if readErr := conn.ReadJSON(&frame); readErr != nil { + if ctx.Err() == nil { + cancel() + } + break + } + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) //nolint:errcheck + + switch frame.Type { + case "stdin": + if _, writeErr := session.WriteStdin([]byte(frame.Data)); writeErr != nil { + _ = writeJSON(model.ServerFrame{ + Type: "error", + Error: writeErr.Error(), + Code: model.WSErrCodeStdinWriteFailed, + }) + cancel() + return + } + case "signal": + session.SendSignal(frame.Signal) + case "resize": + if session.IsPTY() { + if resizeErr := session.ResizePTY(uint16(frame.Cols), uint16(frame.Rows)); resizeErr != nil { + _ = writeJSON(model.ServerFrame{ + Type: "error", + Error: "resize failed: " + resizeErr.Error(), + Code: model.WSErrCodeRuntimeError, + }) + } + } + // Silently ignored in pipe mode; accepted to avoid client errors. + case "ping": + _ = writeJSON(model.ServerFrame{Type: "pong"}) + default: + _ = writeJSON(model.ServerFrame{ + Type: "error", + Error: "unknown frame type", + Code: model.WSErrCodeInvalidFrame, + }) + } + } +} diff --git a/components/execd/pkg/web/controller/session_ws_test.go b/components/execd/pkg/web/controller/session_ws_test.go new file mode 100644 index 00000000..d3ce0bd4 --- /dev/null +++ b/components/execd/pkg/web/controller/session_ws_test.go @@ -0,0 +1,285 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package controller + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + + "github.com/alibaba/opensandbox/execd/pkg/runtime" + "github.com/alibaba/opensandbox/execd/pkg/web/model" +) + +// wsTestServer spins up a real httptest.Server with the WS route wired in. +func wsTestServer(t *testing.T) *httptest.Server { + t.Helper() + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/ws/session/:sessionId", func(ctx *gin.Context) { + NewCodeInterpretingController(ctx).SessionWebSocket() + }) + return httptest.NewServer(r) +} + +// wsURL converts an http:// test-server URL to a ws:// URL. +func wsURL(srv *httptest.Server, sessionID string) string { + return "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/session/" + sessionID +} + +// wsURLWithSince appends ?since= to the WS URL. +func wsURLWithSince(srv *httptest.Server, sessionID string, since int64) string { + return wsURL(srv, sessionID) + "?since=" + strconv.FormatInt(since, 10) +} + +// dialWS opens a WebSocket connection to the test server. +func dialWS(t *testing.T, url string) *websocket.Conn { + t.Helper() + dialer := websocket.DefaultDialer + conn, resp, err := dialer.Dial(url, nil) + if err != nil { + if resp != nil { + t.Fatalf("WS dial failed: %v (HTTP %d)", err, resp.StatusCode) + } + t.Fatalf("WS dial failed: %v", err) + } + t.Cleanup(func() { conn.Close() }) + return conn +} + +// readFrame reads one JSON ServerFrame from the WebSocket connection. +func readFrame(t *testing.T, conn *websocket.Conn, timeout time.Duration) model.ServerFrame { + t.Helper() + conn.SetReadDeadline(time.Now().Add(timeout)) + var frame model.ServerFrame + _, msg, err := conn.ReadMessage() + require.NoError(t, err, "reading WS frame") + require.NoError(t, json.Unmarshal(msg, &frame), "unmarshalling WS frame") + return frame +} + +// withFreshRunner swaps codeRunner for a clean controller and restores on cleanup. +func withFreshRunner(t *testing.T) { + t.Helper() + prev := codeRunner + codeRunner = runtime.NewController("", "") + t.Cleanup(func() { codeRunner = prev }) +} + +// createTestSession creates a bash session and returns its ID. +func createTestSession(t *testing.T) string { + t.Helper() + id, err := codeRunner.CreateBashSession(&runtime.CreateContextRequest{}) + require.NoError(t, err) + t.Cleanup(func() { _ = codeRunner.DeleteBashSession(id) }) + return id +} + +// TestSessionWS_ConnectUnknownSession verifies that connecting to a non-existent +// session returns HTTP 404 before the WebSocket upgrade. +func TestSessionWS_ConnectUnknownSession(t *testing.T) { + withFreshRunner(t) + srv := wsTestServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/ws/session/does-not-exist") + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +// TestSessionWS_PingPong sends an application-level ping frame and expects a pong. +func TestSessionWS_PingPong(t *testing.T) { + withFreshRunner(t) + srv := wsTestServer(t) + defer srv.Close() + + id := createTestSession(t) + conn := dialWS(t, wsURL(srv, id)) + + // Drain the "connected" frame. + connected := readFrame(t, conn, 5*time.Second) + require.Equal(t, "connected", connected.Type) + + // Send application ping. + require.NoError(t, conn.WriteJSON(model.ClientFrame{Type: "ping"})) + + // Expect pong. + frame := readFrame(t, conn, 5*time.Second) + require.Equal(t, "pong", frame.Type) +} + +// TestSessionWS_StdinForwarding connects to a session, sends a stdin frame +// with an echo command, and verifies that a stdout frame arrives. +func TestSessionWS_StdinForwarding(t *testing.T) { + withFreshRunner(t) + srv := wsTestServer(t) + defer srv.Close() + + id := createTestSession(t) + conn := dialWS(t, wsURL(srv, id)) + + // Drain "connected". + connected := readFrame(t, conn, 5*time.Second) + require.Equal(t, "connected", connected.Type) + + // Send a command via stdin. + require.NoError(t, conn.WriteJSON(model.ClientFrame{ + Type: "stdin", + Data: "echo hello_ws\n", + })) + + // Collect frames until we see the expected stdout or timeout. + deadline := time.Now().Add(10 * time.Second) + found := false + for time.Now().Before(deadline) { + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var f model.ServerFrame + _, msg, err := conn.ReadMessage() + if err != nil { + break + } + if jsonErr := json.Unmarshal(msg, &f); jsonErr != nil { + continue + } + if f.Type == "stdout" && strings.Contains(f.Data, "hello_ws") { + found = true + break + } + } + require.True(t, found, "expected stdout frame with 'hello_ws'") +} + +// TestSessionWS_ReplayOnConnect connects with ?since=0 and verifies that +// a replay frame arrives before live output when there is buffered data. +func TestSessionWS_ReplayOnConnect(t *testing.T) { + withFreshRunner(t) + srv := wsTestServer(t) + defer srv.Close() + + id := createTestSession(t) + + // Prime the replay buffer by running a command through the HTTP API + // (RunInSession SSE endpoint), then reconnect via WS with ?since=0. + // Simpler: connect once, write stdin, disconnect, reconnect with since=0. + conn1 := dialWS(t, wsURL(srv, id)) + connected := readFrame(t, conn1, 5*time.Second) + require.Equal(t, "connected", connected.Type) + + // Write to stdin and wait briefly for stdout to land in the replay buffer. + require.NoError(t, conn1.WriteJSON(model.ClientFrame{ + Type: "stdin", + Data: "echo replay_test\n", + })) + // Wait for stdout to arrive so the replay buffer is populated. + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + conn1.SetReadDeadline(time.Now().Add(5 * time.Second)) + var f model.ServerFrame + _, msg, err := conn1.ReadMessage() + if err != nil { + break + } + if jsonErr := json.Unmarshal(msg, &f); jsonErr != nil { + continue + } + if f.Type == "stdout" && strings.Contains(f.Data, "replay_test") { + break + } + } + + // Close first connection to release the WS lock. + conn1.Close() + // Give the server a moment to release the lock. + time.Sleep(100 * time.Millisecond) + + // Reconnect with ?since=0 — should receive a replay frame. + conn2 := dialWS(t, wsURLWithSince(srv, id, 0)) + defer conn2.Close() + + // We expect a replay frame before the connected frame (replay is sent first). + deadline = time.Now().Add(10 * time.Second) + foundReplay := false + for time.Now().Before(deadline) { + conn2.SetReadDeadline(time.Now().Add(5 * time.Second)) + var f model.ServerFrame + _, msg, err := conn2.ReadMessage() + if err != nil { + break + } + if jsonErr := json.Unmarshal(msg, &f); jsonErr != nil { + continue + } + if f.Type == "replay" { + require.Contains(t, f.Data, "replay_test", "replay frame should contain buffered output") + foundReplay = true + break + } + } + require.True(t, foundReplay, "expected replay frame with buffered output") +} + +// TestSessionWS_ExitFrame runs a short-lived command and verifies that +// an exit frame is received with code 0 after bash exits. +func TestSessionWS_ExitFrame(t *testing.T) { + withFreshRunner(t) + srv := wsTestServer(t) + defer srv.Close() + + id := createTestSession(t) + conn := dialWS(t, wsURL(srv, id)) + + connected := readFrame(t, conn, 5*time.Second) + require.Equal(t, "connected", connected.Type) + + // Ask bash to exit cleanly. + require.NoError(t, conn.WriteJSON(model.ClientFrame{ + Type: "stdin", + Data: "exit 0\n", + })) + + // Collect frames looking for the exit frame. + deadline := time.Now().Add(10 * time.Second) + foundExit := false + for time.Now().Before(deadline) { + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var f model.ServerFrame + _, msg, err := conn.ReadMessage() + if err != nil { + break + } + if jsonErr := json.Unmarshal(msg, &f); jsonErr != nil { + continue + } + if f.Type == "exit" { + require.NotNil(t, f.ExitCode, "exit frame must include exit_code") + require.Equal(t, 0, *f.ExitCode) + foundExit = true + break + } + } + require.True(t, foundExit, "expected exit frame with code 0") +} diff --git a/components/execd/pkg/web/model/codeinterpreting.go b/components/execd/pkg/web/model/codeinterpreting.go index 771b6d75..5f8bafdc 100644 --- a/components/execd/pkg/web/model/codeinterpreting.go +++ b/components/execd/pkg/web/model/codeinterpreting.go @@ -83,6 +83,7 @@ const ( StreamEventTypeComplete ServerStreamEventType = "execution_complete" StreamEventTypeCount ServerStreamEventType = "execution_count" StreamEventTypePing ServerStreamEventType = "ping" + StreamEventTypeReplay ServerStreamEventType = "replay" ) // ServerStreamEvent is emitted to clients over SSE. diff --git a/components/execd/pkg/web/model/session.go b/components/execd/pkg/web/model/session.go new file mode 100644 index 00000000..0acba83e --- /dev/null +++ b/components/execd/pkg/web/model/session.go @@ -0,0 +1,49 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "github.com/go-playground/validator/v10" +) + +// CreateSessionRequest is the request body for creating a bash session. +type CreateSessionRequest struct { + Cwd string `json:"cwd,omitempty"` +} + +// CreateSessionResponse is the response for create_session. +type CreateSessionResponse struct { + SessionID string `json:"session_id"` +} + +// RunInSessionRequest is the request body for running code in an existing session. +type RunInSessionRequest struct { + Code string `json:"code" validate:"required"` + Cwd string `json:"cwd,omitempty"` + TimeoutMs int64 `json:"timeout_ms,omitempty" validate:"omitempty,gte=0"` +} + +// Validate validates RunInSessionRequest. +func (r *RunInSessionRequest) Validate() error { + validate := validator.New() + return validate.Struct(r) +} + +// SessionStatusResponse is the response for GET /session/:id. +type SessionStatusResponse struct { + SessionID string `json:"session_id"` + Running bool `json:"running"` + OutputOffset int64 `json:"output_offset"` +} diff --git a/components/execd/pkg/web/model/session_ws.go b/components/execd/pkg/web/model/session_ws.go new file mode 100644 index 00000000..d8f77f5f --- /dev/null +++ b/components/execd/pkg/web/model/session_ws.go @@ -0,0 +1,47 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +// ClientFrame is a JSON frame sent from the WebSocket client to the server. +type ClientFrame struct { + Type string `json:"type"` + Data string `json:"data,omitempty"` // stdin payload (plain text) + Cols int `json:"cols,omitempty"` // resize — PTY mode only + Rows int `json:"rows,omitempty"` // resize — PTY mode only + Signal string `json:"signal,omitempty"` // signal name, e.g. "SIGINT" +} + +// ServerFrame is a JSON frame sent from the server to the WebSocket client. +type ServerFrame struct { + Type string `json:"type"` + SessionID string `json:"session_id,omitempty"` // connected + Mode string `json:"mode,omitempty"` // connected: "pipe" | "pty" + Data string `json:"data,omitempty"` // stdout/stderr/replay payload + Offset int64 `json:"offset,omitempty"` // replay: next byte offset + ExitCode *int `json:"exit_code,omitempty"` // exit — pointer so 0 is marshalled + Error string `json:"error,omitempty"` // error description + Code string `json:"code,omitempty"` // machine-readable error code + Timestamp int64 `json:"timestamp,omitempty"` +} + +// WebSocket error code constants. +const ( + WSErrCodeSessionGone = "SESSION_GONE" + WSErrCodeStartFailed = "START_FAILED" + WSErrCodeStdinWriteFailed = "STDIN_WRITE_FAILED" + WSErrCodeInvalidFrame = "INVALID_FRAME" + WSErrCodeAlreadyConnected = "ALREADY_CONNECTED" + WSErrCodeRuntimeError = "RUNTIME_ERROR" +) diff --git a/components/execd/pkg/web/router.go b/components/execd/pkg/web/router.go index dfca0821..34114c3b 100644 --- a/components/execd/pkg/web/router.go +++ b/components/execd/pkg/web/router.go @@ -62,6 +62,19 @@ func NewRouter(accessToken string) *gin.Engine { code.GET("/contexts/:contextId", withCode(func(c *controller.CodeInterpretingController) { c.GetContext() })) } + session := r.Group("/session") + { + session.POST("", withCode(func(c *controller.CodeInterpretingController) { c.CreateSession() })) + session.GET("/:sessionId", withCode(func(c *controller.CodeInterpretingController) { c.GetSessionStatus() })) + session.POST("/:sessionId/run", withCode(func(c *controller.CodeInterpretingController) { c.RunInSession() })) + session.DELETE("/:sessionId", withCode(func(c *controller.CodeInterpretingController) { c.DeleteSession() })) + } + + ws := r.Group("/ws") + { + ws.GET("/session/:sessionId", withCode(func(c *controller.CodeInterpretingController) { c.SessionWebSocket() })) + } + command := r.Group("/command") { command.POST("", withCode(func(c *controller.CodeInterpretingController) { c.RunCommand() }))