Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion components/public-api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.24.0
require (
github.com/gin-contrib/cors v1.7.6
github.com/gin-gonic/gin v1.11.0
github.com/google/uuid v1.6.0
github.com/prometheus/client_golang v1.23.2
github.com/rs/zerolog v1.34.0
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.65.0
Expand All @@ -31,7 +32,7 @@ require (
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/google/uuid v1.6.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
Expand Down
47 changes: 45 additions & 2 deletions components/public-api/handlers/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func setupTestRouter() *gin.Engine {
v1.POST("/sessions", CreateSession)
v1.GET("/sessions/:id", GetSession)
v1.DELETE("/sessions/:id", DeleteSession)

v1.POST("/sessions/:id/runs", CreateRun)
v1.GET("/sessions/:id/runs", GetSessionRuns)
v1.POST("/sessions/:id/message", SendMessage)
v1.GET("/sessions/:id/output", GetSessionOutput)
v1.POST("/sessions/:id/start", StartSession)
v1.POST("/sessions/:id/stop", StopSession)
v1.POST("/sessions/:id/interrupt", InterruptSession)
}

return r
Expand Down Expand Up @@ -109,8 +117,43 @@ func TestE2E_CreateSession(t *testing.T) {
}

// Verify request body was transformed correctly
if !strings.Contains(requestBody, "prompt") {
t.Errorf("Expected request body to contain 'prompt', got %s", requestBody)
if !strings.Contains(requestBody, "initialPrompt") {
t.Errorf("Expected request body to contain 'initialPrompt', got %s", requestBody)
}
}

func TestE2E_CreateSession_WithDisplayName(t *testing.T) {
var receivedBody map[string]interface{}
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.Decode(&receivedBody)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"name": "session-123"})
}))
defer backend.Close()

originalURL := BackendURL
BackendURL = backend.URL
defer func() { BackendURL = originalURL }()

router := setupTestRouter()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/sessions",
strings.NewReader(`{"task": "Fix the bug", "display_name": "Bug Fix Session"}`))
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("X-Ambient-Project", "test-project")
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)

if w.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String())
}

// Verify display_name was forwarded as displayName (camelCase) to backend
if receivedBody["displayName"] != "Bug Fix Session" {
t.Errorf("Expected displayName 'Bug Fix Session' in backend request, got %v", receivedBody["displayName"])
}
}

Expand Down
223 changes: 223 additions & 0 deletions components/public-api/handlers/lifecycle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package handlers

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"

"ambient-code-public-api/types"

"github.com/gin-gonic/gin"
)

// StartSession handles POST /v1/sessions/:id/start
//
// Defense-in-depth: The gateway fetches the session phase before forwarding.
// The backend also validates phase transitions, so this is a redundant guard
// that provides faster feedback and reduces unnecessary backend writes.
func StartSession(c *gin.Context) {
project := GetProject(c)
if !ValidateProjectName(project) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"})
return
}
sessionID := c.Param("id")
if !ValidateSessionID(sessionID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session ID"})
return
}

phase, err := getSessionPhase(c, project, sessionID)
if err != nil {
return // getSessionPhase already wrote the error response
}

if phase == "" {
log.Printf("Session %s has no phase, treating as unknown", sessionID)
c.JSON(http.StatusConflict, gin.H{"error": "Session state is unknown"})
return
}

if phase == "running" || phase == "pending" {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "Session is already running or pending"})
return
}

path := fmt.Sprintf("/api/projects/%s/agentic-sessions/%s/start", project, sessionID)
resp, cancel, err := ProxyRequest(c, http.MethodPost, path, nil)
if err != nil {
log.Printf("Backend request failed for start session %s: %v", sessionID, err)
c.JSON(http.StatusBadGateway, gin.H{"error": "Backend unavailable"})
return
}
defer cancel()
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
forwardErrorResponse(c, resp.StatusCode, body)
return
}

var backendResp map[string]interface{}
if err := json.Unmarshal(body, &backendResp); err != nil {
log.Printf("Failed to parse backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

c.JSON(http.StatusAccepted, transformSession(backendResp))
}

// StopSession handles POST /v1/sessions/:id/stop
//
// Defense-in-depth: The gateway fetches the session phase before forwarding.
// The backend also validates phase transitions, so this is a redundant guard
// that provides faster feedback and reduces unnecessary backend writes.
func StopSession(c *gin.Context) {
project := GetProject(c)
if !ValidateProjectName(project) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"})
return
}
sessionID := c.Param("id")
if !ValidateSessionID(sessionID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session ID"})
return
}

phase, err := getSessionPhase(c, project, sessionID)
if err != nil {
return
}

if phase == "" {
log.Printf("Session %s has no phase, treating as unknown", sessionID)
c.JSON(http.StatusConflict, gin.H{"error": "Session state is unknown"})
return
}

if phase == "completed" || phase == "failed" {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "Session is not in a running state"})
return
}

path := fmt.Sprintf("/api/projects/%s/agentic-sessions/%s/stop", project, sessionID)
resp, cancel, err := ProxyRequest(c, http.MethodPost, path, nil)
if err != nil {
log.Printf("Backend request failed for stop session %s: %v", sessionID, err)
c.JSON(http.StatusBadGateway, gin.H{"error": "Backend unavailable"})
return
}
defer cancel()
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
forwardErrorResponse(c, resp.StatusCode, body)
return
}

var backendResp map[string]interface{}
if err := json.Unmarshal(body, &backendResp); err != nil {
log.Printf("Failed to parse backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

c.JSON(http.StatusAccepted, transformSession(backendResp))
}

// InterruptSession handles POST /v1/sessions/:id/interrupt
func InterruptSession(c *gin.Context) {
project := GetProject(c)
if !ValidateProjectName(project) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"})
return
}
sessionID := c.Param("id")
if !ValidateSessionID(sessionID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session ID"})
return
}

path := fmt.Sprintf("/api/projects/%s/agentic-sessions/%s/agui/interrupt", project, sessionID)
resp, cancel, err := ProxyRequest(c, http.MethodPost, path, []byte("{}"))
if err != nil {
log.Printf("Backend request failed for interrupt session %s: %v", sessionID, err)
c.JSON(http.StatusBadGateway, gin.H{"error": "Backend unavailable"})
return
}
defer cancel()
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

if resp.StatusCode != http.StatusOK {
forwardErrorResponse(c, resp.StatusCode, body)
return
}

c.JSON(http.StatusOK, types.MessageResponse{Message: "Interrupt signal sent"})
}

// getSessionPhase fetches the session from the backend and returns its normalized phase.
// On error, it writes the appropriate error response to the gin context.
func getSessionPhase(c *gin.Context, project, sessionID string) (string, error) {
path := fmt.Sprintf("/api/projects/%s/agentic-sessions/%s", project, sessionID)
resp, cancel, err := ProxyRequest(c, http.MethodGet, path, nil)
if err != nil {
log.Printf("Backend request failed for get session phase %s: %v", sessionID, err)
c.JSON(http.StatusBadGateway, gin.H{"error": "Backend unavailable"})
return "", fmt.Errorf("backend unavailable")
}
defer cancel()
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return "", fmt.Errorf("internal server error")
}

if resp.StatusCode != http.StatusOK {
forwardErrorResponse(c, resp.StatusCode, body)
return "", fmt.Errorf("backend returned %d", resp.StatusCode)
}

var backendResp map[string]interface{}
if err := json.Unmarshal(body, &backendResp); err != nil {
log.Printf("Failed to parse backend response: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return "", fmt.Errorf("internal server error")
}

phase := ""
if status, ok := backendResp["status"].(map[string]interface{}); ok {
if p, ok := status["phase"].(string); ok {
phase = normalizePhase(p)
}
}

return phase, nil
}
Loading