diff --git a/components/execd/pkg/runtime/bash_session.go b/components/execd/pkg/runtime/bash_session.go new file mode 100644 index 00000000..5e5d01d3 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session.go @@ -0,0 +1,463 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "os/exec" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/google/uuid" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" + "github.com/alibaba/opensandbox/execd/pkg/log" +) + +const ( + envDumpStartMarker = "__ENV_DUMP_START__" + envDumpEndMarker = "__ENV_DUMP_END__" + exitMarkerPrefix = "__EXIT_CODE__:" + pwdMarkerPrefix = "__PWD__:" +) + +func (c *Controller) createBashSession(req *CreateContextRequest) (string, error) { + session := newBashSession(req.Cwd) + if err := session.start(); err != nil { + return "", fmt.Errorf("failed to start bash session: %w", err) + } + + c.bashSessionClientMap.Store(session.config.Session, session) + log.Info("created bash session %s", session.config.Session) + return session.config.Session, nil +} + +func (c *Controller) runBashSession(ctx context.Context, request *ExecuteCodeRequest) error { + session := c.getBashSession(request.Context) + if session == nil { + return ErrContextNotFound + } + + return session.run(ctx, request) +} + +func (c *Controller) getBashSession(sessionId string) *bashSession { + if v, ok := c.bashSessionClientMap.Load(sessionId); ok { + if s, ok := v.(*bashSession); ok { + return s + } + } + return nil +} + +func (c *Controller) closeBashSession(sessionId string) error { + session := c.getBashSession(sessionId) + if session == nil { + return ErrContextNotFound + } + + err := session.close() + if err != nil { + return err + } + + c.bashSessionClientMap.Delete(sessionId) + return nil +} + +func (c *Controller) CreateBashSession(req *CreateContextRequest) (string, error) { + return c.createBashSession(req) +} + +func (c *Controller) RunInBashSession(ctx context.Context, req *ExecuteCodeRequest) error { + return c.runBashSession(ctx, req) +} + +func (c *Controller) DeleteBashSession(sessionID string) error { + return c.closeBashSession(sessionID) +} + +// Session implementation (pipe-based, no PTY) +func newBashSession(cwd string) *bashSession { + config := &bashSessionConfig{ + Session: uuidString(), + StartupTimeout: 5 * time.Second, + } + + env := make(map[string]string) + for _, kv := range os.Environ() { + if k, v, ok := splitEnvPair(kv); ok { + env[k] = v + } + } + + return &bashSession{ + config: config, + env: env, + cwd: cwd, + } +} + +func (s *bashSession) start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.started { + return errors.New("session already started") + } + + s.started = true + return nil +} + +func (s *bashSession) trackCurrentProcess(pid int) { + s.mu.Lock() + defer s.mu.Unlock() + s.currentProcessPid = pid +} + +func (s *bashSession) untrackCurrentProcess() { + s.mu.Lock() + defer s.mu.Unlock() + s.currentProcessPid = 0 +} + +//nolint:gocognit +func (s *bashSession) run(ctx context.Context, request *ExecuteCodeRequest) error { + s.mu.Lock() + if !s.started { + s.mu.Unlock() + return errors.New("session not started") + } + + envSnapshot := copyEnvMap(s.env) + + cwd := s.cwd + // override original cwd if specified + if request.Cwd != "" { + cwd = request.Cwd + } + sessionID := s.config.Session + s.mu.Unlock() + + startAt := time.Now() + if request.Hooks.OnExecuteInit != nil { + request.Hooks.OnExecuteInit(sessionID) + } + + wait := request.Timeout + if wait <= 0 { + wait = 24 * 3600 * time.Second // max to 24 hours + } + + ctx, cancel := context.WithTimeout(ctx, wait) + defer cancel() + + script := buildWrappedScript(request.Code, envSnapshot, cwd) + scriptFile, err := os.CreateTemp("", "execd_bash_*.sh") + if err != nil { + return fmt.Errorf("create script file: %w", err) + } + scriptPath := scriptFile.Name() + if _, err := scriptFile.WriteString(script); err != nil { + _ = scriptFile.Close() + return fmt.Errorf("write script file: %w", err) + } + if err := scriptFile.Close(); err != nil { + return fmt.Errorf("close script file: %w", err) + } + + cmd := exec.CommandContext(ctx, "bash", "--noprofile", "--norc", scriptPath) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // Do not pass envSnapshot via cmd.Env to avoid "argument list too long" when session env is large. + // Child inherits parent env (nil => default in Go). The script file already has "export K=V" for + // all session vars at the top, so the session environment is applied when the script runs. + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("stdout pipe: %w", err) + } + cmd.Stderr = cmd.Stdout + + if err := cmd.Start(); err != nil { + log.Error("start bash session failed: %v (command: %q)", err, request.Code) + return fmt.Errorf("start bash: %w", err) + } + defer s.untrackCurrentProcess() + s.trackCurrentProcess(cmd.Process.Pid) + + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024) + + var ( + envLines []string + pwdLine string + exitCode *int + inEnv bool + ) + + for scanner.Scan() { + line := scanner.Text() + switch { + case line == envDumpStartMarker: + inEnv = true + case line == envDumpEndMarker: + inEnv = false + case strings.HasPrefix(line, exitMarkerPrefix): + if code, err := strconv.Atoi(strings.TrimPrefix(line, exitMarkerPrefix)); err == nil { + exitCode = &code //nolint:ineffassign + } + case strings.HasPrefix(line, pwdMarkerPrefix): + pwdLine = strings.TrimPrefix(line, pwdMarkerPrefix) + default: + if inEnv { + envLines = append(envLines, line) + continue + } + if request.Hooks.OnExecuteStdout != nil { + request.Hooks.OnExecuteStdout(line) + } + } + } + + scanErr := scanner.Err() + waitErr := cmd.Wait() + + if scanErr != nil { + log.Error("read stdout failed: %v (command: %q)", scanErr, request.Code) + return fmt.Errorf("read stdout: %w", scanErr) + } + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + log.Error("timeout after %s while running command: %q", wait, request.Code) + return fmt.Errorf("timeout after %s while running command %q", wait, request.Code) + } + + if exitCode == nil && cmd.ProcessState != nil { + code := cmd.ProcessState.ExitCode() //nolint:staticcheck + exitCode = &code //nolint:ineffassign + } + + updatedEnv := parseExportDump(envLines) + s.mu.Lock() + if len(updatedEnv) > 0 { + s.env = updatedEnv + } + if pwdLine != "" { + s.cwd = pwdLine + } + s.mu.Unlock() + + var exitErr *exec.ExitError + if waitErr != nil && !errors.As(waitErr, &exitErr) { + log.Error("command wait failed: %v (command: %q)", waitErr, request.Code) + return waitErr + } + + userExitCode := 0 + if exitCode != nil { + userExitCode = *exitCode + } + + if userExitCode != 0 { + errMsg := fmt.Sprintf("command exited with code %d", userExitCode) + if waitErr != nil { + errMsg = waitErr.Error() + } + if request.Hooks.OnExecuteError != nil { + request.Hooks.OnExecuteError(&execute.ErrorOutput{ + EName: "CommandExecError", + EValue: strconv.Itoa(userExitCode), + Traceback: []string{errMsg}, + }) + } + log.Error("CommandExecError: %s (command: %q)", errMsg, request.Code) + return nil + } + + if request.Hooks.OnExecuteComplete != nil { + request.Hooks.OnExecuteComplete(time.Since(startAt)) + } + + return nil +} + +func buildWrappedScript(command string, env map[string]string, cwd string) string { + var b strings.Builder + + keys := make([]string, 0, len(env)) + for k := range env { + v := env[k] + if isValidEnvKey(k) && !envKeysNotPersisted[k] && len(v) <= maxPersistedEnvValueSize { + keys = append(keys, k) + } + } + sort.Strings(keys) + for _, k := range keys { + b.WriteString("export ") + b.WriteString(k) + b.WriteString("=") + b.WriteString(shellEscape(env[k])) + b.WriteString("\n") + } + + if cwd != "" { + b.WriteString("cd ") + b.WriteString(shellEscape(cwd)) + b.WriteString("\n") + } + + b.WriteString(command) + if !strings.HasSuffix(command, "\n") { + b.WriteString("\n") + } + + b.WriteString("__USER_EXIT_CODE__=$?\n") + b.WriteString("printf \"\\n%s\\n\" \"" + envDumpStartMarker + "\"\n") + b.WriteString("export -p\n") + b.WriteString("printf \"%s\\n\" \"" + envDumpEndMarker + "\"\n") + b.WriteString("printf \"" + pwdMarkerPrefix + "%s\\n\" \"$(pwd)\"\n") + b.WriteString("printf \"" + exitMarkerPrefix + "%s\\n\" \"$__USER_EXIT_CODE__\"\n") + b.WriteString("exit \"$__USER_EXIT_CODE__\"\n") + + return b.String() +} + +// envKeysNotPersisted are not carried across runs (prompt/display vars). +var envKeysNotPersisted = map[string]bool{ + "PS1": true, "PS2": true, "PS3": true, "PS4": true, + "PROMPT_COMMAND": true, +} + +// maxPersistedEnvValueSize caps single env value length as a safeguard. +const maxPersistedEnvValueSize = 8 * 1024 + +func parseExportDump(lines []string) map[string]string { + if len(lines) == 0 { + return nil + } + env := make(map[string]string, len(lines)) + for _, line := range lines { + k, v, ok := parseExportLine(line) + if !ok || envKeysNotPersisted[k] || len(v) > maxPersistedEnvValueSize { + continue + } + env[k] = v + } + return env +} + +func parseExportLine(line string) (string, string, bool) { + const prefix = "declare -x " + if !strings.HasPrefix(line, prefix) { + return "", "", false + } + rest := strings.TrimSpace(strings.TrimPrefix(line, prefix)) + if rest == "" { + return "", "", false + } + name, value := rest, "" + if eq := strings.Index(rest, "="); eq >= 0 { + name = rest[:eq] + raw := rest[eq+1:] + if unquoted, err := strconv.Unquote(raw); err == nil { + value = unquoted + } else { + value = strings.Trim(raw, `"`) + } + } + if !isValidEnvKey(name) { + return "", "", false + } + return name, value, true +} + +func shellEscape(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + +func isValidEnvKey(key string) bool { + if key == "" { + return false + } + + for i, r := range key { + if i == 0 { + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && r != '_' { + return false + } + continue + } + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && (r < '0' || r > '9') && r != '_' { + return false + } + } + + return true +} + +func copyEnvMap(src map[string]string) map[string]string { + if src == nil { + return map[string]string{} + } + + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func splitEnvPair(kv string) (string, string, bool) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return "", "", false + } + if !isValidEnvKey(parts[0]) { + return "", "", false + } + return parts[0], parts[1], true +} + +func (s *bashSession) close() error { + s.mu.Lock() + defer s.mu.Unlock() + + pid := s.currentProcessPid + s.currentProcessPid = 0 + s.started = false + s.env = nil + s.cwd = "" + + if pid != 0 { + if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil { + log.Warning("kill session process group %d: %v (process may have already exited)", pid, err) + } + } + return nil +} + +func uuidString() string { + return uuid.New().String() +} diff --git a/components/execd/pkg/runtime/bash_session_test.go b/components/execd/pkg/runtime/bash_session_test.go new file mode 100644 index 00000000..b18af3de --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_test.go @@ -0,0 +1,599 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "context" + "fmt" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" +) + +func TestBashSession_NonZeroExitEmitsError(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + c := NewController("", "") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var ( + sessionID string + stdoutLine string + errCh = make(chan *execute.ErrorOutput, 1) + completeCh = make(chan struct{}, 1) + ) + + req := &ExecuteCodeRequest{ + Language: Bash, + Code: `echo "before"; exit 7`, + Cwd: t.TempDir(), + Timeout: 5 * time.Second, + Hooks: ExecuteResultHook{ + OnExecuteInit: func(s string) { sessionID = s }, + OnExecuteStdout: func(s string) { stdoutLine = s }, + OnExecuteError: func(err *execute.ErrorOutput) { errCh <- err }, + OnExecuteComplete: func(_ time.Duration) { + completeCh <- struct{}{} + }, + }, + } + + session, err := c.createBashSession(&CreateContextRequest{}) + assert.NoError(t, err) + req.Context = session + require.NoError(t, c.runBashSession(ctx, req)) + + var gotErr *execute.ErrorOutput + select { + case gotErr = <-errCh: + case <-time.After(2 * time.Second): + require.Fail(t, "expected error hook to be called") + } + require.NotNil(t, gotErr, "expected non-nil error output") + require.Equal(t, "CommandExecError", gotErr.EName) + require.Equal(t, "7", gotErr.EValue) + require.NotEmpty(t, sessionID, "expected session id to be set") + require.Equal(t, "before", stdoutLine) + + select { + case <-completeCh: + require.Fail(t, "did not expect completion hook on non-zero exit") + default: + } +} + +func TestBashSession_envAndExitCode(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + require.Equal(t, session.config.Session, ctx, "unexpected session in OnExecuteInit") + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + // 1) export an env var + request := &ExecuteCodeRequest{ + Code: "export FOO=hello", + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + exportStdoutCount := len(stdoutLines) + + // 2) verify env is persisted + request = &ExecuteCodeRequest{ + Code: "echo $FOO", + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + echoLines := stdoutLines[exportStdoutCount:] + foundHello := false + for _, line := range echoLines { + if strings.TrimSpace(line) == "hello" { + foundHello = true + break + } + } + require.True(t, foundHello, "expected echo $FOO to output 'hello', got %v", echoLines) + + // 3) ensure exit code of previous command is reflected in shell state + request = &ExecuteCodeRequest{ + Code: "false; echo EXIT:$?", + Hooks: hooks, + Timeout: 3 * time.Second, + } + prevCount := len(stdoutLines) + require.NoError(t, session.run(context.Background(), request)) + exitLines := stdoutLines[prevCount:] + foundExit := false + for _, line := range exitLines { + if strings.Contains(line, "EXIT:1") { + foundExit = true + break + } + } + require.True(t, foundExit, "expected exit code output 'EXIT:1', got %v", exitLines) + require.Equal(t, 3, initCalls, "OnExecuteInit expected 3 calls") + require.Equal(t, 3, completeCalls, "OnExecuteComplete expected 3 calls") +} + +func TestBashSession_envLargeOutputChained(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + require.Equal(t, session.config.Session, ctx, "unexpected session in OnExecuteInit") + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + runAndCollect := func(cmd string) []string { + start := len(stdoutLines) + request := &ExecuteCodeRequest{ + Code: cmd, + Hooks: hooks, + Timeout: 10 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + return append([]string(nil), stdoutLines[start:]...) + } + + lines1 := runAndCollect("export FOO=hello1; for i in $(seq 1 60); do echo A${i}:$FOO; done") + require.GreaterOrEqual(t, len(lines1), 60, "expected >=60 lines for cmd1") + require.True(t, containsLine(lines1, "A1:hello1") && containsLine(lines1, "A60:hello1"), "env not reflected in cmd1 output, got %v", lines1[:3]) + + lines2 := runAndCollect("export FOO=${FOO}_next; export BAR=bar1; for i in $(seq 1 60); do echo B${i}:$FOO:$BAR; done") + require.GreaterOrEqual(t, len(lines2), 60, "expected >=60 lines for cmd2") + require.True(t, containsLine(lines2, "B1:hello1_next:bar1") && containsLine(lines2, "B60:hello1_next:bar1"), "env not propagated to cmd2 output, sample %v", lines2[:3]) + + lines3 := runAndCollect("export BAR=${BAR}_last; for i in $(seq 1 60); do echo C${i}:$FOO:$BAR; done; echo FINAL_FOO=$FOO; echo FINAL_BAR=$BAR") + require.GreaterOrEqual(t, len(lines3), 62, "expected >=62 lines for cmd3") // 60 lines + 2 finals + require.True(t, containsLine(lines3, "C1:hello1_next:bar1_last") && containsLine(lines3, "C60:hello1_next:bar1_last"), "env not propagated to cmd3 output, sample %v", lines3[:3]) + require.True(t, containsLine(lines3, "FINAL_FOO=hello1_next") && containsLine(lines3, "FINAL_BAR=bar1_last"), "final env lines missing, got %v", lines3[len(lines3)-5:]) + require.Equal(t, 3, initCalls, "OnExecuteInit expected 3 calls") + require.Equal(t, 3, completeCalls, "OnExecuteComplete expected 3 calls") +} + +func TestBashSession_cwdPersistsWithoutOverride(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + targetDir := t.TempDir() + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + require.NoError(t, session.run(context.Background(), req)) + return append([]string(nil), stdoutLines[start:]...) + } + + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", targetDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(firstRunLines, targetDir), "expected cd to update cwd to %q, got %v", targetDir, firstRunLines) + + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(secondRunLines, targetDir), "expected subsequent run to inherit cwd %q, got %v", targetDir, secondRunLines) + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + require.Equal(t, targetDir, finalCwd, "expected session cwd to stay at %q", targetDir) +} + +func TestBashSession_requestCwdOverridesAfterCd(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + initialDir := t.TempDir() + overrideDir := t.TempDir() + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + require.NoError(t, session.run(context.Background(), req)) + return append([]string(nil), stdoutLines[start:]...) + } + + // First request: change session cwd via script. + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", initialDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(firstRunLines, initialDir), "expected cd to update cwd to %q, got %v", initialDir, firstRunLines) + + // Second request: explicit Cwd overrides session cwd. + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Cwd: overrideDir, + Hooks: hooks, + Timeout: 3 * time.Second, + }) + require.True(t, containsLine(secondRunLines, overrideDir), "expected command to run in override cwd %q, got %v", overrideDir, secondRunLines) + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + require.Equal(t, overrideDir, finalCwd, "expected session cwd updated to override dir %q", overrideDir) +} + +func TestBashSession_envDumpNotLeakedWhenNoTrailingNewline(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; printf '{"foo":1}'`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + + require.Len(t, stdoutLines, 1, "expected exactly one stdout line") + require.Equal(t, `{"foo":1}`, strings.TrimSpace(stdoutLines[0])) + for _, line := range stdoutLines { + require.NotContains(t, line, envDumpStartMarker, "env dump leaked into stdout: %v", stdoutLines) + require.NotContains(t, line, "declare -x", "env dump leaked into stdout: %v", stdoutLines) + } +} + +func TestBashSession_envDumpNotLeakedWhenNoOutput(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; true`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + require.NoError(t, session.run(context.Background(), request)) + + require.LessOrEqual(t, len(stdoutLines), 1, "expected at most one stdout line, got %v", stdoutLines) + if len(stdoutLines) == 1 { + require.Empty(t, strings.TrimSpace(stdoutLines[0]), "expected empty stdout") + } + for _, line := range stdoutLines { + require.NotContains(t, line, envDumpStartMarker, "env dump leaked into stdout: %v", stdoutLines) + require.NotContains(t, line, "declare -x", "env dump leaked into stdout: %v", stdoutLines) + } +} + +func TestBashSession_heredoc(t *testing.T) { + rewardDir := t.TempDir() + controller := NewController("", "") + + sessionID, err := controller.CreateBashSession(&CreateContextRequest{}) + require.NoError(t, err) + t.Cleanup(func() { _ = controller.DeleteBashSession(sessionID) }) + + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + fmt.Printf("[stdout] %s\n", line) + }, + OnExecuteComplete: func(d time.Duration) { + fmt.Printf("[complete] %s\n", d) + }, + } + + // First run: heredoc + reward file write. + script := fmt.Sprintf(` +set -x +reward_dir=%q +mkdir -p "$reward_dir" + +cat > /tmp/repro_script.sh <<'SHEOF' +#!/usr/bin/env sh +echo "hello heredoc" +SHEOF + +chmod +x /tmp/repro_script.sh +/tmp/repro_script.sh +echo "after heredoc" +echo 1 > "$reward_dir/reward.txt" +cat "$reward_dir/reward.txt" +`, rewardDir) + + ctx := context.Background() + require.NoError(t, controller.RunInBashSession(ctx, &ExecuteCodeRequest{ + Context: sessionID, + Language: Bash, + Timeout: 10 * time.Second, + Code: script, + Hooks: hooks, + })) + + // Second run: ensure the session keeps working. + require.NoError(t, controller.RunInBashSession(ctx, &ExecuteCodeRequest{ + Context: sessionID, + Language: Bash, + Timeout: 5 * time.Second, + Code: "echo 'second command works'", + Hooks: hooks, + })) +} + +func TestBashSession_execReplacesShell(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +cat > /tmp/exec_child.sh <<'EOF' +echo "child says hi" +EOF +chmod +x /tmp/exec_child.sh +exec /tmp/exec_child.sh +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + require.NoError(t, session.run(context.Background(), request), "expected exec to complete without killing the session") + require.True(t, containsLine(stdoutLines, "child says hi"), "expected child output, got %v", stdoutLines) + + // Subsequent run should still work because we restart bash per run. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + require.NoError(t, session.run(context.Background(), request), "expected run to succeed after exec replaced the shell") + require.True(t, containsLine(stdoutLines, "still-alive"), "expected follow-up output, got %v", stdoutLines) +} + +func TestBashSession_complexExec(t *testing.T) { + session := newBashSession("") + t.Cleanup(func() { _ = session.close() }) + + require.NoError(t, session.start()) + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +LOG_FILE=$(mktemp) +export LOG_FILE +exec 3>&1 4>&2 +exec > >(tee "$LOG_FILE") 2>&1 + +set -x +echo "from-complex-exec" +exec 1>&3 2>&4 # step record +echo "after-restore" +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + require.NoError(t, session.run(context.Background(), request), "expected complex exec to finish") + require.True(t, containsLine(stdoutLines, "from-complex-exec") && containsLine(stdoutLines, "after-restore"), "expected exec outputs, got %v", stdoutLines) + + // Session should still be usable. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + require.NoError(t, session.run(context.Background(), request), "expected run to succeed after complex exec") + require.True(t, containsLine(stdoutLines, "still-alive"), "expected follow-up output, got %v", stdoutLines) +} + +func containsLine(lines []string, target string) bool { + for _, l := range lines { + if strings.TrimSpace(l) == target { + return true + } + } + return false +} + +// TestBashSession_CloseKillsRunningProcess verifies that session.close() kills the active +// process group so that a long-running command (e.g. sleep) does not keep running after close. +func TestBashSession_CloseKillsRunningProcess(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + session := newBashSession("") + require.NoError(t, session.start()) + + runDone := make(chan error, 1) + req := &ExecuteCodeRequest{ + Code: "sleep 30", + Timeout: 60 * time.Second, + Hooks: ExecuteResultHook{}, + } + go func() { + runDone <- session.run(context.Background(), req) + }() + + // Give the child process time to start. + time.Sleep(200 * time.Millisecond) + + // Close should kill the process group; run() should return soon (it may return nil + // because the code path treats non-zero exit as success after calling OnExecuteError). + require.NoError(t, session.close()) + + select { + case <-runDone: + // run() returned; process was killed so we did not wait 30s + case <-time.After(3 * time.Second): + require.Fail(t, "run did not return within 3s after close (process was not killed)") + } +} + +// TestBashSession_DeleteBashSessionKillsRunningProcess verifies that DeleteBashSession +// (close path) kills the active run and removes the session from the controller. +func TestBashSession_DeleteBashSessionKillsRunningProcess(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + c := NewController("", "") + sessionID, err := c.CreateBashSession(&CreateContextRequest{}) + require.NoError(t, err) + + runDone := make(chan error, 1) + req := &ExecuteCodeRequest{ + Language: Bash, + Context: sessionID, + Code: "sleep 30", + Timeout: 60 * time.Second, + Hooks: ExecuteResultHook{}, + } + go func() { + runDone <- c.RunInBashSession(context.Background(), req) + }() + + time.Sleep(200 * time.Millisecond) + + require.NoError(t, c.DeleteBashSession(sessionID)) + + select { + case <-runDone: + // RunInBashSession returned; process was killed + case <-time.After(3 * time.Second): + require.Fail(t, "RunInBashSession did not return within 3s after DeleteBashSession") + } + + // Session should be gone; deleting again should return ErrContextNotFound. + err = c.DeleteBashSession(sessionID) + require.Error(t, err) + require.ErrorIs(t, err, ErrContextNotFound) +} + +// TestBashSession_CloseWithNoActiveRun verifies that close() with no running command +// completes without error and does not hang. +func TestBashSession_CloseWithNoActiveRun(t *testing.T) { + session := newBashSession("") + require.NoError(t, session.start()) + + done := make(chan struct{}, 1) + go func() { + _ = session.close() + done <- struct{}{} + }() + + select { + case <-done: + // close() returned + case <-time.After(2 * time.Second): + require.Fail(t, "close() did not return within 2s when no run was active") + } +} diff --git a/components/execd/pkg/runtime/bash_session_windows.go b/components/execd/pkg/runtime/bash_session_windows.go new file mode 100644 index 00000000..d0cd008b --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_windows.go @@ -0,0 +1,40 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows +// +build windows + +package runtime + +import ( + "context" + "errors" +) + +var errBashSessionNotSupported = errors.New("bash session is not supported on windows") + +// CreateBashSession is not supported on Windows. +func (c *Controller) CreateBashSession(_ *CreateContextRequest) (string, error) { //nolint:revive + return "", errBashSessionNotSupported +} + +// RunInBashSession is not supported on Windows. +func (c *Controller) RunInBashSession(_ context.Context, _ *ExecuteCodeRequest) error { //nolint:revive + return errBashSessionNotSupported +} + +// DeleteBashSession is not supported on Windows. +func (c *Controller) DeleteBashSession(_ string) error { //nolint:revive + return errBashSessionNotSupported +} diff --git a/components/execd/pkg/runtime/command_common.go b/components/execd/pkg/runtime/command_common.go index 960ff273..03205e20 100644 --- a/components/execd/pkg/runtime/command_common.go +++ b/components/execd/pkg/runtime/command_common.go @@ -46,18 +46,17 @@ func (c *Controller) tailStdPipe(file string, onExecute func(text string), done // getCommandKernel retrieves a command execution context. func (c *Controller) getCommandKernel(sessionID string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.commandClientMap[sessionID] + if v, ok := c.commandClientMap.Load(sessionID); ok { + if kernel, ok := v.(*commandKernel); ok { + return kernel + } + } + return nil } // storeCommandKernel registers a command execution context. func (c *Controller) storeCommandKernel(sessionID string, kernel *commandKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.commandClientMap[sessionID] = kernel + c.commandClientMap.Store(sessionID, kernel) } // stdLogDescriptor creates temporary files for capturing command output. diff --git a/components/execd/pkg/runtime/command_status.go b/components/execd/pkg/runtime/command_status.go index 97f112b1..6dbc6d4f 100644 --- a/components/execd/pkg/runtime/command_status.go +++ b/components/execd/pkg/runtime/command_status.go @@ -40,11 +40,11 @@ type CommandOutput struct { } func (c *Controller) commandSnapshot(session string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return nil } @@ -116,8 +116,11 @@ func (c *Controller) markCommandFinished(session string, exitCode int, errMsg st c.mu.Lock() defer c.mu.Unlock() - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return } diff --git a/components/execd/pkg/runtime/context.go b/components/execd/pkg/runtime/context.go index c2f9052a..42f7fe50 100644 --- a/components/execd/pkg/runtime/context.go +++ b/components/execd/pkg/runtime/context.go @@ -31,7 +31,9 @@ import ( ) // CreateContext provisions a kernel-backed session and returns its ID. +// Bash language uses Jupyter kernel like other languages; for pipe-based bash sessions use CreateBashSession (session API). func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { + // Create a new Jupyter session. var ( client *jupyter.Client session *jupytersession.Session @@ -42,7 +44,7 @@ func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { log.Error("failed to create session, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(*req) + client, session, err = c.createJupyterContext(*req) return err }) if err != nil { @@ -114,20 +116,11 @@ func (c *Controller) deleteSessionAndCleanup(session string) error { if c.getJupyterKernel(session) == nil { return ErrContextNotFound } - if err := c.jupyterClient().DeleteSession(session); err != nil { return err } - - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.jupyterClientMap, session) - for lang, id := range c.defaultLanguageJupyterSessions { - if id == session { - delete(c.defaultLanguageJupyterSessions, lang) - } - } + c.jupyterClientMap.Delete(session) + c.deleteDefaultSessionByID(session) return nil } @@ -146,8 +139,12 @@ func (c *Controller) newIpynbPath(sessionID, cwd string) (string, error) { return filepath.Join(cwd, fmt.Sprintf("%s.ipynb", sessionID)), nil } -// createDefaultLanguageContext prewarms a session for stateless execution. -func (c *Controller) createDefaultLanguageContext(language Language) error { +// createDefaultLanguageJupyterContext prewarms a session for stateless execution. +func (c *Controller) createDefaultLanguageJupyterContext(language Language) error { + if c.getDefaultLanguageSession(language) != "" { + return nil + } + var ( client *jupyter.Client session *jupytersession.Session @@ -157,7 +154,7 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { log.Error("failed to create context, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(CreateContextRequest{ + client, session, err = c.createJupyterContext(CreateContextRequest{ Language: language, Cwd: "", }) @@ -167,20 +164,17 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { return err } - c.mu.Lock() - defer c.mu.Unlock() - - c.defaultLanguageJupyterSessions[language] = session.ID - c.jupyterClientMap[session.ID] = &jupyterKernel{ + c.setDefaultLanguageSession(language, session.ID) + c.jupyterClientMap.Store(session.ID, &jupyterKernel{ kernelID: session.Kernel.ID, client: client, language: language, - } + }) return nil } -// createContext performs the actual context creation workflow. -func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { +// createJupyterContext performs the actual context creation workflow. +func (c *Controller) createJupyterContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { client := c.jupyterClient() kernel, err := c.searchKernel(client, request.Language) @@ -220,10 +214,7 @@ func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Clien // storeJupyterKernel caches a session -> kernel mapping. func (c *Controller) storeJupyterKernel(sessionID string, kernel *jupyterKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.jupyterClientMap[sessionID] = kernel + c.jupyterClientMap.Store(sessionID, kernel) } func (c *Controller) jupyterClient() *jupyter.Client { @@ -239,49 +230,63 @@ func (c *Controller) jupyterClient() *jupyter.Client { jupyter.WithHTTPClient(httpClient)) } -func (c *Controller) listAllContexts() ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() +func (c *Controller) getDefaultLanguageSession(language Language) string { + if v, ok := c.defaultLanguageSessions.Load(language); ok { + if session, ok := v.(string); ok { + return session + } + } + return "" +} + +func (c *Controller) setDefaultLanguageSession(language Language, sessionID string) { + c.defaultLanguageSessions.Store(language, sessionID) +} +func (c *Controller) deleteDefaultSessionByID(sessionID string) { + c.defaultLanguageSessions.Range(func(key, value any) bool { + if s, ok := value.(string); ok && s == sessionID { + c.defaultLanguageSessions.Delete(key) + } + return true + }) +} + +func (c *Controller) listAllContexts() ([]CodeContext, error) { contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil { - contexts = append(contexts, CodeContext{ - ID: session, - Language: kernel.language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil { + contexts = append(contexts, CodeContext{ID: session, Language: kernel.language}) } - } + return true + }) - for language, defaultContext := range c.defaultLanguageJupyterSessions { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) - } + c.defaultLanguageSessions.Range(func(key, value any) bool { + lang, _ := key.(Language) + session, _ := value.(string) + if session == "" { + return true + } + contexts = append(contexts, CodeContext{ID: session, Language: lang}) + return true + }) return contexts, nil } func (c *Controller) listLanguageContexts(language Language) ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() - contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil && kernel.language == language { - contexts = append(contexts, CodeContext{ - ID: session, - Language: language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil && kernel.language == language { + contexts = append(contexts, CodeContext{ID: session, Language: language}) } - } + return true + }) - if defaultContext := c.defaultLanguageJupyterSessions[language]; defaultContext != "" { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) + if defaultContext := c.getDefaultLanguageSession(language); defaultContext != "" { + contexts = append(contexts, CodeContext{ID: defaultContext, Language: language}) } return contexts, nil diff --git a/components/execd/pkg/runtime/context_test.go b/components/execd/pkg/runtime/context_test.go index 9a0376d0..e0cdf3e6 100644 --- a/components/execd/pkg/runtime/context_test.go +++ b/components/execd/pkg/runtime/context_test.go @@ -27,8 +27,8 @@ import ( func TestListContextsAndNewIpynbPath(t *testing.T) { c := NewController("http://example", "token") - c.jupyterClientMap["session-python"] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Go] = "session-go-default" + c.jupyterClientMap.Store("session-python", &jupyterKernel{language: Python}) + c.defaultLanguageSessions.Store(Go, "session-go-default") pyContexts, err := c.listLanguageContexts(Python) require.NoError(t, err) @@ -107,13 +107,13 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[sessionID] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Python] = sessionID + c.jupyterClientMap.Store(sessionID, &jupyterKernel{language: Python}) + c.defaultLanguageSessions.Store(Python, sessionID) require.NoError(t, c.DeleteContext(sessionID)) require.Nil(t, c.getJupyterKernel(sessionID), "expected cache to be cleared") - _, ok := c.defaultLanguageJupyterSessions[Python] + _, ok := c.defaultLanguageSessions.Load(Python) require.False(t, ok, "expected default session entry to be removed") } @@ -138,17 +138,17 @@ func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[session1] = &jupyterKernel{language: lang} - c.jupyterClientMap[session2] = &jupyterKernel{language: lang} - c.defaultLanguageJupyterSessions[lang] = session2 + c.jupyterClientMap.Store(session1, &jupyterKernel{language: lang}) + c.jupyterClientMap.Store(session2, &jupyterKernel{language: lang}) + c.defaultLanguageSessions.Store(lang, session2) require.NoError(t, c.DeleteLanguageContext(lang)) - _, ok := c.jupyterClientMap[session1] + _, ok := c.jupyterClientMap.Load(session1) require.False(t, ok, "expected session1 removed from cache") - _, ok = c.jupyterClientMap[session2] + _, ok = c.jupyterClientMap.Load(session2) require.False(t, ok, "expected session2 removed from cache") - _, ok = c.defaultLanguageJupyterSessions[lang] + _, ok = c.defaultLanguageSessions.Load(lang) require.False(t, ok, "expected default entry removed") require.Equal(t, 1, deleteCalls[session1]) require.Equal(t, 1, deleteCalls[session2]) diff --git a/components/execd/pkg/runtime/ctrl.go b/components/execd/pkg/runtime/ctrl.go index 36c325b4..3a835b31 100644 --- a/components/execd/pkg/runtime/ctrl.go +++ b/components/execd/pkg/runtime/ctrl.go @@ -35,14 +35,15 @@ var kernelWaitingBackoff = wait.Backoff{ // Controller manages code execution across runtimes. type Controller struct { - baseURL string - token string - mu sync.RWMutex - jupyterClientMap map[string]*jupyterKernel - defaultLanguageJupyterSessions map[Language]string - commandClientMap map[string]*commandKernel - db *sql.DB - dbOnce sync.Once + baseURL string + token string + mu sync.RWMutex + jupyterClientMap sync.Map // map[sessionID]*jupyterKernel + defaultLanguageSessions sync.Map // map[Language]string + commandClientMap sync.Map // map[sessionID]*commandKernel + bashSessionClientMap sync.Map // map[sessionID]*bashSession + db *sql.DB + dbOnce sync.Once } type jupyterKernel struct { @@ -70,10 +71,6 @@ func NewController(baseURL, token string) *Controller { return &Controller{ baseURL: baseURL, token: token, - - jupyterClientMap: make(map[string]*jupyterKernel), - defaultLanguageJupyterSessions: make(map[Language]string), - commandClientMap: make(map[string]*commandKernel), } } diff --git a/components/execd/pkg/runtime/interrupt.go b/components/execd/pkg/runtime/interrupt.go index 1a9515fa..67902a3d 100644 --- a/components/execd/pkg/runtime/interrupt.go +++ b/components/execd/pkg/runtime/interrupt.go @@ -38,6 +38,8 @@ func (c *Controller) Interrupt(sessionID string) error { case c.getCommandKernel(sessionID) != nil: kernel := c.getCommandKernel(sessionID) return c.killPid(kernel.pid) + case c.getBashSession(sessionID) != nil: + return c.closeBashSession(sessionID) default: return errors.New("no such session") } diff --git a/components/execd/pkg/runtime/jupyter.go b/components/execd/pkg/runtime/jupyter.go index cdc0a6cc..9ea33b13 100644 --- a/components/execd/pkg/runtime/jupyter.go +++ b/components/execd/pkg/runtime/jupyter.go @@ -29,9 +29,8 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest return errors.New("language runtime server not configured, please check your image runtime") } if request.Context == "" { - if _, exists := c.defaultLanguageJupyterSessions[request.Language]; !exists { - err := c.createDefaultLanguageContext(request.Language) - if err != nil { + if c.getDefaultLanguageSession(request.Language) == "" { + if err := c.createDefaultLanguageJupyterContext(request.Language); err != nil { return err } } @@ -39,7 +38,7 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest var targetSessionID string if request.Context == "" { - targetSessionID = c.defaultLanguageJupyterSessions[request.Language] + targetSessionID = c.getDefaultLanguageSession(request.Language) } else { targetSessionID = request.Context } @@ -135,10 +134,12 @@ func (c *Controller) setWorkingDir(_ *jupyterKernel, _ *CreateContextRequest) er // getJupyterKernel retrieves a kernel connection from the session map. func (c *Controller) getJupyterKernel(sessionID string) *jupyterKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.jupyterClientMap[sessionID] + if v, ok := c.jupyterClientMap.Load(sessionID); ok { + if kernel, ok := v.(*jupyterKernel); ok { + return kernel + } + } + return nil } // searchKernel finds a kernel spec name for the given language. diff --git a/components/execd/pkg/runtime/types.go b/components/execd/pkg/runtime/types.go index 4dc459b3..cd0615c6 100644 --- a/components/execd/pkg/runtime/types.go +++ b/components/execd/pkg/runtime/types.go @@ -16,6 +16,7 @@ package runtime import ( "fmt" + "sync" "time" "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" @@ -82,3 +83,28 @@ type CodeContext struct { ID string `json:"id,omitempty"` Language Language `json:"language"` } + +// bashSessionConfig holds bash session configuration. +type bashSessionConfig struct { + // StartupSource is a list of scripts sourced on startup. + StartupSource []string + // Session is the session identifier. + Session string + // StartupTimeout is the startup timeout. + StartupTimeout time.Duration + // Cwd is the working directory. + Cwd string +} + +// bashSession represents a bash session. +type bashSession struct { + config *bashSessionConfig + mu sync.Mutex + started bool + env map[string]string + cwd string + + // currentProcessPid is the pid of the active run's process group leader (bash). + // Set after cmd.Start(), cleared when run() returns. Used by close() to kill the process group. + currentProcessPid int +} diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index df4a28db..c95d83f1 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "net/http" "sync" "time" @@ -236,6 +237,123 @@ func (c *CodeInterpretingController) DeleteContext() { c.RespondSuccess(nil) } +// CreateSession creates a new bash session (create_session API). +// An empty body is allowed and is treated as default options (no cwd override). +func (c *CodeInterpretingController) CreateSession() { + var request model.CreateSessionRequest + if err := c.bindJSON(&request); err != nil && !errors.Is(err, io.EOF) { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("error parsing request. %v", err), + ) + return + } + + sessionID, err := codeRunner.CreateBashSession(&runtime.CreateContextRequest{ + Cwd: request.Cwd, + }) + if err != nil { + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error creating session. %v", err), + ) + return + } + + c.RespondSuccess(model.CreateSessionResponse{SessionID: sessionID}) +} + +// RunInSession runs code in an existing bash session and streams output via SSE (run_in_session API). +func (c *CodeInterpretingController) RunInSession() { + sessionID := c.ctx.Param("sessionId") + if sessionID == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + var request model.RunInSessionRequest + if err := c.bindJSON(&request); err != nil { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("error parsing request. %v", err), + ) + return + } + if err := request.Validate(); err != nil { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("invalid request. %v", err), + ) + return + } + + timeout := time.Duration(request.TimeoutMs) * time.Millisecond + runReq := &runtime.ExecuteCodeRequest{ + Language: runtime.Bash, + Context: sessionID, + Code: request.Code, + Cwd: request.Cwd, + Timeout: timeout, + } + ctx, cancel := context.WithCancel(c.ctx.Request.Context()) + defer cancel() + runReq.Hooks = c.setServerEventsHandler(ctx) + + c.setupSSEResponse() + err := codeRunner.RunInBashSession(ctx, runReq) + if err != nil { + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error running in session. %v", err), + ) + return + } + + time.Sleep(flag.ApiGracefulShutdownTimeout) +} + +// DeleteSession deletes a bash session (delete_session API). +func (c *CodeInterpretingController) DeleteSession() { + sessionID := c.ctx.Param("sessionId") + if sessionID == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + err := codeRunner.DeleteBashSession(sessionID) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("session %s not found", sessionID), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error deleting session %s. %v", sessionID, err), + ) + return + } + + c.RespondSuccess(nil) +} + // buildExecuteCodeRequest converts a RunCodeRequest to runtime format. func (c *CodeInterpretingController) buildExecuteCodeRequest(request model.RunCodeRequest) *runtime.ExecuteCodeRequest { req := &runtime.ExecuteCodeRequest{ diff --git a/components/execd/pkg/web/model/session.go b/components/execd/pkg/web/model/session.go new file mode 100644 index 00000000..0b4f598b --- /dev/null +++ b/components/execd/pkg/web/model/session.go @@ -0,0 +1,42 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "github.com/go-playground/validator/v10" +) + +// CreateSessionRequest is the request body for creating a bash session. +type CreateSessionRequest struct { + Cwd string `json:"cwd,omitempty"` +} + +// CreateSessionResponse is the response for create_session. +type CreateSessionResponse struct { + SessionID string `json:"session_id"` +} + +// RunInSessionRequest is the request body for running code in an existing session. +type RunInSessionRequest struct { + Code string `json:"code" validate:"required"` + Cwd string `json:"cwd,omitempty"` + TimeoutMs int64 `json:"timeout_ms,omitempty" validate:"omitempty,gte=0"` +} + +// Validate validates RunInSessionRequest. +func (r *RunInSessionRequest) Validate() error { + validate := validator.New() + return validate.Struct(r) +} diff --git a/components/execd/pkg/web/router.go b/components/execd/pkg/web/router.go index dfca0821..8894257d 100644 --- a/components/execd/pkg/web/router.go +++ b/components/execd/pkg/web/router.go @@ -62,6 +62,13 @@ func NewRouter(accessToken string) *gin.Engine { code.GET("/contexts/:contextId", withCode(func(c *controller.CodeInterpretingController) { c.GetContext() })) } + session := r.Group("/session") + { + session.POST("", withCode(func(c *controller.CodeInterpretingController) { c.CreateSession() })) + session.POST("/:sessionId/run", withCode(func(c *controller.CodeInterpretingController) { c.RunInSession() })) + session.DELETE("/:sessionId", withCode(func(c *controller.CodeInterpretingController) { c.DeleteSession() })) + } + command := r.Group("/command") { command.POST("", withCode(func(c *controller.CodeInterpretingController) { c.RunCommand() }))