diff --git a/cmd/alpamon/command/root.go b/cmd/alpamon/command/root.go index 081a221..a441592 100644 --- a/cmd/alpamon/command/root.go +++ b/cmd/alpamon/command/root.go @@ -1,18 +1,21 @@ package command import ( - "context" "fmt" "os" "os/signal" "syscall" + "time" "github.com/alpacax/alpamon/cmd/alpamon/command/ftp" "github.com/alpacax/alpamon/cmd/alpamon/command/setup" "github.com/alpacax/alpamon/cmd/alpamon/command/tunnel" + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/collector" "github.com/alpacax/alpamon/pkg/config" "github.com/alpacax/alpamon/pkg/db" + "github.com/alpacax/alpamon/pkg/executor" "github.com/alpacax/alpamon/pkg/logger" "github.com/alpacax/alpamon/pkg/pidfile" "github.com/alpacax/alpamon/pkg/runner" @@ -42,15 +45,16 @@ func init() { } func runAgent() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // Create global context manager for the entire application + ctxManager := agent.NewContextManager() + ctx := ctxManager.Root() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigChan - cancel() + ctxManager.Shutdown() }() // Logger @@ -72,41 +76,66 @@ func runAgent() { settings := config.LoadConfig(config.Files(name), wsPath) config.InitSettings(settings) + // Create global worker pool for the entire application using config settings + workerPool := pool.NewPool(settings.PoolMaxWorkers, settings.PoolQueueSize) + log.Info().Msgf("Initialized global worker pool with %d workers and queue capacity %d", + workerPool.MaxWorkers(), workerPool.QueueCapacity()) + // Session session := scheduler.InitSession() commissioned := session.CheckSession(ctx) - // Reporter - scheduler.StartReporters(session) + // Reporter - pass context manager for centralized context management + reporters := scheduler.StartReporters(session, ctxManager) - // Log server - logServer := logger.NewLogServer() + // Log server - pass worker pool and context manager for connection handling + logServer := logger.NewLogServer(workerPool, ctxManager) if logServer != nil { go logServer.StartLogServer() } log.Info().Msgf("%s initialized and running.", name) - // Commit - runner.CommitAsync(session, commissioned) + // Commit - pass context manager for coordinated lifecycle management + runner.CommitAsync(session, commissioned, ctxManager) // DB client := db.InitDB() - // Collector - metricCollector := collector.InitCollector(session, client) + // Collector - pass context manager for centralized context management + metricCollector := collector.InitCollector(session, client, ctxManager) if metricCollector != nil { metricCollector.Start() } - // Websocket Client (Backhaul - commands, sessions) - wsClient := runner.NewWebsocketClient(session) + // Websocket Client - pass context manager and worker pool for centralized management + wsClient := runner.NewWebsocketClient(session, ctxManager, workerPool) + + // Initialize dispatcher system with callbacks + dispatcher, err := executor.InitDispatcher( + workerPool, + ctxManager, + session, + wsClient, + executor.SystemInfoCallbacks{ + CommitFunc: runner.CommitSystemInfo, + SyncFunc: runner.SyncSystemInfo, + }, + ) + if err != nil { + log.Fatal().Err(err).Msg("Failed to initialize dispatcher system") + } + + wsClient.SetDispatcher(dispatcher) + log.Info().Msg("Dispatcher system initialized successfully") + go wsClient.RunForever(ctx) // Control Client (Control - sudo approval) controlClient := runner.NewControlClient() go controlClient.RunForever(ctx) + // Auth Manager for sudo approval workflow authManager := runner.GetAuthManager(controlClient) go authManager.Start(ctx) @@ -114,23 +143,23 @@ func runAgent() { select { case <-ctx.Done(): log.Info().Msg("Received termination signal. Shutting down...") - gracefulShutdown(metricCollector, wsClient, controlClient, authManager, logServer, pidFilePath) + gracefulShutdown(metricCollector, wsClient, controlClient, authManager, workerPool, logServer, reporters, pidFilePath) return case <-wsClient.ShutDownChan: log.Info().Msg("Shutdown command received. Shutting down...") - cancel() - gracefulShutdown(metricCollector, wsClient, controlClient, authManager, logServer, pidFilePath) + ctxManager.Shutdown() + gracefulShutdown(metricCollector, wsClient, controlClient, authManager, workerPool, logServer, reporters, pidFilePath) return case <-wsClient.RestartChan: log.Info().Msg("Restart command received. Restarting...") - cancel() - gracefulShutdown(metricCollector, wsClient, controlClient, authManager, logServer, pidFilePath) + ctxManager.Shutdown() + gracefulShutdown(metricCollector, wsClient, controlClient, authManager, workerPool, logServer, reporters, pidFilePath) restartAgent() return case <-wsClient.CollectorRestartChan: log.Info().Msg("Collector restart command received. Restarting Collector...") metricCollector.Stop() - metricCollector = collector.InitCollector(session, client) + metricCollector = collector.InitCollector(session, client, ctxManager) metricCollector.Start() } } @@ -149,7 +178,7 @@ func restartAgent() { } } -func gracefulShutdown(collector *collector.Collector, wsClient *runner.WebsocketClient, controlClient *runner.ControlClient, authManager *runner.AuthManager, logServer *logger.LogServer, pidPath string) { +func gracefulShutdown(collector *collector.Collector, wsClient *runner.WebsocketClient, controlClient *runner.ControlClient, authManager *runner.AuthManager, workerPool *pool.Pool, logServer *logger.LogServer, reporters *scheduler.ReporterManager, pidPath string) { if collector != nil { collector.Stop() } @@ -162,6 +191,19 @@ func gracefulShutdown(collector *collector.Collector, wsClient *runner.Websocket if authManager != nil { authManager.Stop() } + // Shutdown reporters before worker pool + if reporters != nil { + if err := reporters.Shutdown(1 * time.Second); err != nil { + log.Error().Err(err).Msg("Failed to shutdown reporters gracefully") + } + } + // Shutdown the global worker pool + if workerPool != nil { + log.Info().Msg("Shutting down global worker pool...") + if err := workerPool.Shutdown(1 * time.Second); err != nil { + log.Error().Err(err).Msg("Failed to shutdown worker pool gracefully") + } + } if logServer != nil { logServer.Stop() } diff --git a/configs/alpamon.conf b/configs/alpamon.conf index b8d0748..2f731f9 100644 --- a/configs/alpamon.conf +++ b/configs/alpamon.conf @@ -8,4 +8,19 @@ verify = {{.Verify}} ca_cert = {{.CACert}} [logging] -debug = {{.Debug}} \ No newline at end of file +debug = {{.Debug}} + +[pool] +# Maximum number of concurrent workers in the global worker pool +# Default: 20 +# max_workers = 20 + +# Size of the job queue for the global worker pool +# Default: 200 +# queue_size = 200 + +# Default timeout in seconds for pool tasks +# If not set or commented out, the default value of 30 seconds will be used +# Set to a positive number to specify timeout in seconds +# Default: 30 +# default_timeout = 30 \ No newline at end of file diff --git a/internal/pool/pool.go b/internal/pool/pool.go new file mode 100644 index 0000000..9460aa2 --- /dev/null +++ b/internal/pool/pool.go @@ -0,0 +1,176 @@ +// Package pool provides a true worker pool implementation with job queue and context support. +package pool + +import ( + "context" + "fmt" + "log" + "runtime/debug" + "sync" + "time" +) + +// Job represents a unit of work to be executed by the pool +type Job struct { + fn func() error +} + +// Pool represents a pool of workers with a job queue +type Pool struct { + maxWorkers int + jobQueue chan *Job + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewPool creates a new worker pool with job queue +func NewPool(maxWorkers int, queueSize int) *Pool { + if maxWorkers <= 0 { + maxWorkers = 10 + } + if queueSize <= 0 { + queueSize = maxWorkers * 10 // Default queue size + } + + ctx, cancel := context.WithCancel(context.Background()) + + p := &Pool{ + maxWorkers: maxWorkers, + jobQueue: make(chan *Job, queueSize), + ctx: ctx, + cancel: cancel, + } + + // Start worker goroutines + p.startWorkers() + + return p +} + +// startWorkers launches the worker goroutines +func (p *Pool) startWorkers() { + for i := 0; i < p.maxWorkers; i++ { + p.wg.Add(1) + go p.worker() + } +} + +// worker is the main loop for each worker goroutine +func (p *Pool) worker() { + defer p.wg.Done() + + for { + select { + case job, ok := <-p.jobQueue: + if !ok { + // Job queue is closed + return + } + p.executeJob(job) + case <-p.ctx.Done(): + // Pool is shutting down - drain remaining jobs first + for { + select { + case job, ok := <-p.jobQueue: + if !ok { + return + } + p.executeJob(job) + default: + // No more jobs in queue + return + } + } + } + } +} + +// executeJob runs a job with panic recovery +func (p *Pool) executeJob(job *Job) { + // Panic recovery + defer func() { + if r := recover(); r != nil { + log.Printf("recovered from panic in pool worker: %v\nstack: %s", r, debug.Stack()) + } + }() + + // Execute the function + err := job.fn() + if err != nil { + log.Printf("pool worker error: %v", err) + } +} + +// Submit adds a job to the queue (non-blocking) +// Returns error if the queue is full or pool is shutting down +func (p *Pool) Submit(ctx context.Context, fn func() error) error { + // Check if pool is shutting down first + select { + case <-p.ctx.Done(): + return fmt.Errorf("pool is shutting down") + default: + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-p.ctx.Done(): + return fmt.Errorf("pool is shutting down") + case p.jobQueue <- &Job{fn: fn}: + return nil + default: + // Queue is full + return fmt.Errorf("job queue is full") + } +} + +// QueueSize returns the current number of jobs in the queue +func (p *Pool) QueueSize() int { + return len(p.jobQueue) +} + +// QueueCapacity returns the maximum queue capacity +func (p *Pool) QueueCapacity() int { + return cap(p.jobQueue) +} + +// MaxWorkers returns the number of workers +func (p *Pool) MaxWorkers() int { + return p.maxWorkers +} + +// Shutdown gracefully shuts down the pool +func (p *Pool) Shutdown(timeout time.Duration) error { + // Signal shutdown to prevent new submissions + p.cancel() + + // Don't close the job queue immediately - let workers drain it + // Workers will exit when context is cancelled + + // Wait for all workers to complete + done := make(chan struct{}) + go func() { + p.wg.Wait() + // Close the job queue after all workers have exited + close(p.jobQueue) + close(done) + }() + + select { + case <-done: + return nil + case <-time.After(timeout): + return fmt.Errorf("shutdown timeout after %v", timeout) + } +} + +// IsShuttingDown returns true if the pool is shutting down +func (p *Pool) IsShuttingDown() bool { + select { + case <-p.ctx.Done(): + return true + default: + return false + } +} diff --git a/internal/pool/pool_benchmark_test.go b/internal/pool/pool_benchmark_test.go new file mode 100644 index 0000000..656061e --- /dev/null +++ b/internal/pool/pool_benchmark_test.go @@ -0,0 +1,102 @@ +package pool + +import ( + "context" + "sync" + "testing" + "time" +) + +// BenchmarkPool_WorkerScaling measures performance with different worker counts +func BenchmarkPool_WorkerScaling(b *testing.B) { + workerCounts := []int{1, 5, 10, 20, 50} + + for _, workers := range workerCounts { + b.Run(string(rune('0'+workers/10)+rune('0'+workers%10))+"workers", func(b *testing.B) { + pool := NewPool(workers, 1000) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = pool.Submit(ctx, func() error { + return nil + }) + } + b.StopTimer() + + _ = pool.Shutdown(5 * time.Second) + }) + } +} + +// BenchmarkPool_QueueThroughput measures queue throughput under load +func BenchmarkPool_QueueThroughput(b *testing.B) { + pool := NewPool(10, 50000) + ctx := context.Background() + + var wg sync.WaitGroup + + b.ResetTimer() + for i := 0; i < b.N; i++ { + wg.Add(1) + err := pool.Submit(ctx, func() error { + wg.Done() + return nil + }) + if err != nil { + wg.Done() // Decrement if submission failed + } + } + wg.Wait() + b.StopTimer() + + _ = pool.Shutdown(5 * time.Second) +} + +// BenchmarkPool_ConcurrentSubmit measures concurrent submission performance +func BenchmarkPool_ConcurrentSubmit(b *testing.B) { + pool := NewPool(20, 5000) + ctx := context.Background() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = pool.Submit(ctx, func() error { + return nil + }) + } + }) + b.StopTimer() + + _ = pool.Shutdown(10 * time.Second) +} + +// BenchmarkPool_WithWork measures performance with actual work +func BenchmarkPool_WithWork(b *testing.B) { + pool := NewPool(10, 50000) + ctx := context.Background() + + var wg sync.WaitGroup + + b.ResetTimer() + for i := 0; i < b.N; i++ { + wg.Add(1) + err := pool.Submit(ctx, func() error { + // Simulate light work + sum := 0 + for j := 0; j < 100; j++ { + sum += j + } + _ = sum + wg.Done() + return nil + }) + if err != nil { + wg.Done() // Decrement if submission failed + } + } + wg.Wait() + b.StopTimer() + + _ = pool.Shutdown(5 * time.Second) +} diff --git a/internal/pool/pool_leak_test.go b/internal/pool/pool_leak_test.go new file mode 100644 index 0000000..7887353 --- /dev/null +++ b/internal/pool/pool_leak_test.go @@ -0,0 +1,286 @@ +package pool + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestPool_NoGoroutineLeak verifies that goroutines are properly cleaned up after normal operations +func TestPool_NoGoroutineLeak(t *testing.T) { + // Allow garbage collection and goroutine cleanup + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + // Use larger queue size to avoid queue full errors + pool := NewPool(5, 200) + ctx := context.Background() + + var wg sync.WaitGroup + + // Submit 100 jobs + for i := 0; i < 100; i++ { + wg.Add(1) + err := pool.Submit(ctx, func() error { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + return nil + }) + if err != nil { + wg.Done() + // Don't fail on queue full - just log it + t.Logf("job submit failed: %v", err) + } + } + + // Wait for all jobs to complete + wg.Wait() + + // Shutdown pool + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + + // Allow goroutines to exit + runtime.GC() + time.Sleep(100 * time.Millisecond) + + final := runtime.NumGoroutine() + + // Allow some margin for test runtime and background goroutines + if final > initial+3 { + t.Errorf("goroutine leak detected: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestPool_NoLeakAfterPanic verifies goroutines are cleaned up after panic recovery +func TestPool_NoLeakAfterPanic(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + pool := NewPool(3, 20) + ctx := context.Background() + + var completed int32 + + // Submit jobs that panic + for i := 0; i < 10; i++ { + _ = pool.Submit(ctx, func() error { + panic("test panic") + }) + } + + // Submit normal jobs after panics + for i := 0; i < 10; i++ { + _ = pool.Submit(ctx, func() error { + atomic.AddInt32(&completed, 1) + return nil + }) + } + + // Wait for jobs to process + time.Sleep(200 * time.Millisecond) + + // Shutdown pool + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after panic: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } + + if atomic.LoadInt32(&completed) == 0 { + t.Error("no normal jobs completed after panics - workers may have died") + } +} + +// TestPool_NoLeakAfterContextCancel verifies goroutines are cleaned up after context cancellation +func TestPool_NoLeakAfterContextCancel(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + pool := NewPool(5, 100) + + // Submit some jobs with cancelable context + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + err := pool.Submit(ctx, func() error { + defer wg.Done() + time.Sleep(50 * time.Millisecond) + return nil + }) + if err != nil { + wg.Done() + } + } + + // Cancel context while jobs are running + time.Sleep(10 * time.Millisecond) + cancel() + + // Wait for running jobs to complete + wg.Wait() + + // Try to submit more jobs with cancelled context + // Note: Pool may or may not check context before submission depending on implementation + err := pool.Submit(ctx, func() error { + return nil + }) + // Just log the result - this tests if the pool handles cancelled context + if err != nil { + t.Logf("submit to cancelled context returned: %v", err) + } + + // Shutdown pool + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after context cancel: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestPool_NoLeakQueueFull verifies goroutines are cleaned up when queue is full +func TestPool_NoLeakQueueFull(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + // Small pool and queue to trigger queue full + pool := NewPool(1, 2) + ctx := context.Background() + + // Block the worker + blocker := make(chan struct{}) + _ = pool.Submit(ctx, func() error { + <-blocker + return nil + }) + + // Wait for worker to pick up the blocking job + time.Sleep(10 * time.Millisecond) + + // Fill the queue + _ = pool.Submit(ctx, func() error { return nil }) + _ = pool.Submit(ctx, func() error { return nil }) + + // This should fail - queue full + err := pool.Submit(ctx, func() error { return nil }) + if err == nil { + t.Error("expected queue full error") + } + + // Unblock and shutdown + close(blocker) + + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after queue full: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestPool_NoLeakRapidShutdown verifies goroutines are cleaned up on rapid shutdown +func TestPool_NoLeakRapidShutdown(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + // Create and shutdown multiple pools rapidly + for i := 0; i < 10; i++ { + pool := NewPool(5, 50) + ctx := context.Background() + + // Submit a few jobs + for j := 0; j < 10; j++ { + _ = pool.Submit(ctx, func() error { + time.Sleep(5 * time.Millisecond) + return nil + }) + } + + // Shutdown immediately + if err := pool.Shutdown(1 * time.Second); err != nil { + t.Logf("shutdown %d failed: %v", i, err) + } + } + + runtime.GC() + time.Sleep(200 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+5 { + t.Errorf("goroutine leak after rapid shutdowns: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestPool_MaxGoroutineEnforcement verifies max concurrent goroutines are enforced +func TestPool_MaxGoroutineEnforcement(t *testing.T) { + maxWorkers := 10 + pool := NewPool(maxWorkers, 100) + defer func() { _ = pool.Shutdown(5 * time.Second) }() + + var maxConcurrent int32 + var concurrent int32 + + ctx := context.Background() + + for i := 0; i < 100; i++ { + _ = pool.Submit(ctx, func() error { + current := atomic.AddInt32(&concurrent, 1) + defer atomic.AddInt32(&concurrent, -1) + + // Track max concurrent + for { + max := atomic.LoadInt32(&maxConcurrent) + if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { + break + } + } + + time.Sleep(50 * time.Millisecond) + return nil + }) + } + + // Wait for all jobs to complete + time.Sleep(600 * time.Millisecond) + + if int(maxConcurrent) > maxWorkers { + t.Errorf("max concurrent workers %d exceeded limit %d", maxConcurrent, maxWorkers) + } +} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go new file mode 100644 index 0000000..a058b0f --- /dev/null +++ b/internal/pool/pool_test.go @@ -0,0 +1,219 @@ +package pool + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestPoolBasic(t *testing.T) { + pool := NewPool(3, 10) + defer func() { + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + }() + + var counter int32 + var wg sync.WaitGroup + + // Submit 10 jobs + for i := 0; i < 10; i++ { + wg.Add(1) + err := pool.Submit(context.Background(), func() error { + atomic.AddInt32(&counter, 1) + wg.Done() + time.Sleep(10 * time.Millisecond) + return nil + }) + if err != nil { + t.Errorf("failed to submit job %d: %v", i, err) + wg.Done() + } + } + + // Wait for all jobs to complete + wg.Wait() + + if atomic.LoadInt32(&counter) != 10 { + t.Errorf("expected 10 jobs to complete, got %d", counter) + } +} + +func TestPoolQueueFull(t *testing.T) { + // Small queue size to test overflow + // Queue size = 1, Workers = 1 + pool := NewPool(1, 1) + defer func() { + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + }() + + // Use a channel to control task execution + blocker := make(chan struct{}) + + // Block the worker - this job will be picked up by the worker immediately + err := pool.Submit(context.Background(), func() error { + <-blocker // Wait until we signal + return nil + }) + if err != nil { + t.Fatalf("failed to submit blocking task: %v", err) + } + + // Give worker time to pick up the first job + time.Sleep(10 * time.Millisecond) + + // Fill the queue (capacity is 1) + err = pool.Submit(context.Background(), func() error { + return nil + }) + if err != nil { + t.Fatalf("failed to submit to queue: %v", err) + } + + // This should fail immediately (worker is blocked, queue is full) + err = pool.Submit(context.Background(), func() error { + return nil + }) + + if err == nil || err.Error() != "job queue is full" { + t.Errorf("expected queue full error, got: %v", err) + } + + // Unblock the worker to allow shutdown + close(blocker) +} + +func TestPoolConcurrency(t *testing.T) { + maxWorkers := 3 + pool := NewPool(maxWorkers, 100) + defer func() { + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + }() + + var concurrent int32 + var maxConcurrent int32 + + // Submit many jobs + for i := 0; i < 50; i++ { + if err := pool.Submit(context.Background(), func() error { + current := atomic.AddInt32(&concurrent, 1) + defer atomic.AddInt32(&concurrent, -1) + + // Track max concurrent + for { + max := atomic.LoadInt32(&maxConcurrent) + if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { + break + } + } + + time.Sleep(10 * time.Millisecond) + return nil + }); err != nil { + t.Errorf("failed to submit job: %v", err) + } + } + + // Wait a bit for jobs to run + time.Sleep(200 * time.Millisecond) + + // Max concurrent should not exceed worker count + if int(maxConcurrent) > maxWorkers { + t.Errorf("exceeded max concurrency: got %d, want <= %d", maxConcurrent, maxWorkers) + } +} + +func TestPoolPanicRecovery(t *testing.T) { + pool := NewPool(2, 10) + defer func() { + if err := pool.Shutdown(5 * time.Second); err != nil { + t.Errorf("shutdown failed: %v", err) + } + }() + + var completed int32 + + // Submit job that panics + if err := pool.Submit(context.Background(), func() error { + panic("test panic") + }); err != nil { + t.Errorf("failed to submit panic job: %v", err) + } + + // Submit normal job after panic + err := pool.Submit(context.Background(), func() error { + atomic.AddInt32(&completed, 1) + return nil + }) + + if err != nil { + t.Errorf("failed to submit job after panic: %v", err) + } + + // Wait for job completion + time.Sleep(100 * time.Millisecond) + + if atomic.LoadInt32(&completed) != 1 { + t.Error("normal job did not complete after panic") + } +} + +func TestPoolShutdown(t *testing.T) { + pool := NewPool(2, 10) + + var completed int32 + + // Submit several jobs + for i := 0; i < 5; i++ { + if err := pool.Submit(context.Background(), func() error { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + return nil + }); err != nil { + t.Errorf("failed to submit job: %v", err) + } + } + + // Shutdown with timeout + err := pool.Shutdown(1 * time.Second) + if err != nil { + t.Fatalf("shutdown failed: %v", err) + } + + // All jobs should have completed + if atomic.LoadInt32(&completed) != 5 { + t.Errorf("not all jobs completed: got %d, want 5", completed) + } + + // New submissions should fail + err = pool.Submit(context.Background(), func() error { return nil }) + if err == nil { + t.Error("expected error when submitting to shut down pool") + } +} + +// Benchmark comparison +func BenchmarkPoolSubmit(b *testing.B) { + pool := NewPool(10, 50000) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Ignore queue full errors in benchmark + _ = pool.Submit(ctx, func() error { + return nil + }) + } + b.StopTimer() + + if err := pool.Shutdown(10 * time.Second); err != nil { + b.Errorf("shutdown failed: %v", err) + } +} diff --git a/internal/protocol/command.go b/internal/protocol/command.go new file mode 100644 index 0000000..c1717cc --- /dev/null +++ b/internal/protocol/command.go @@ -0,0 +1,259 @@ +package protocol + +import ( + "encoding/json" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// Command represents a command request from the server +type Command struct { + ID string `json:"id"` + Shell string `json:"shell"` + Line string `json:"line"` + User string `json:"user"` + Group string `json:"group"` + Env map[string]string `json:"env"` + Data string `json:"data,omitempty"` +} + +// File represents a file in command data +type File struct { + Username string `json:"username"` + Groupname string `json:"groupname"` + Type string `json:"type"` + Content string `json:"content"` + Path string `json:"path"` + AllowOverwrite bool `json:"allow_overwrite"` + AllowUnzip bool `json:"allow_unzip"` + URL string `json:"url"` +} + +// CommandData holds additional command parameters +type CommandData struct { + SessionID string `json:"session_id"` + URL string `json:"url"` + Rows uint16 `json:"rows"` + Cols uint16 `json:"cols"` + Username string `json:"username"` + Groupname string `json:"groupname"` + Groupnames []string `json:"groupnames"` + HomeDirectory string `json:"home_directory"` + HomeDirectoryPermission string `json:"home_directory_permission"` + PurgeHomeDirectory bool `json:"purge_home"` + UID uint64 `json:"uid"` + GID uint64 `json:"gid"` + Comment string `json:"comment"` + Shell string `json:"shell"` + Groups []uint64 `json:"groups"` + Type string `json:"type"` + Content string `json:"content"` + Path string `json:"path"` + Paths []string `json:"paths"` + Files []File `json:"files,omitempty"` + AllowOverwrite bool `json:"allow_overwrite,omitempty"` + AllowUnzip bool `json:"allow_unzip,omitempty"` + UseBlob bool `json:"use_blob,omitempty"` + Keys []string `json:"keys"` + ChainName string `json:"chain_name"` + Method string `json:"method"` + Chain string `json:"chain"` + Protocol string `json:"protocol"` + PortStart int `json:"port_start"` + PortEnd int `json:"port_end"` + DPorts []int `json:"dports"` + ICMPType string `json:"icmp_type"` + Source string `json:"source"` + Destination string `json:"destination"` + Target string `json:"target"` + Description string `json:"description"` + Priority int `json:"priority"` + RuleType string `json:"rule_type"` + Rules []map[string]interface{} `json:"rules"` + Operation string `json:"operation"` + RuleID string `json:"rule_id"` + OldRuleID string `json:"old_rule_id"` + AssignmentID string `json:"assignment_id"` + ServerID string `json:"server_id"` + ChainNames []string `json:"chain_names"` + + // Backend information + Backend string `json:"backend"` + Table string `json:"table"` + Family string `json:"family"` + + // Firewalld specific fields + Zone string `json:"zone"` + Service string `json:"service"` + FirewalldRuleType string `json:"firewalld_rule_type"` + + // UFW specific fields + Direction string `json:"direction"` + Interface string `json:"interface"` + + // Tunnel specific fields + TargetPort int `json:"target_port"` +} + +// ParseCommandData parses the Data field of a Command into CommandData +func (c *Command) ParseCommandData() (*CommandData, error) { + if c.Data == "" { + return &CommandData{}, nil + } + var data CommandData + if err := json.Unmarshal([]byte(c.Data), &data); err != nil { + return nil, err + } + return &data, nil +} + +// ToArgs converts CommandData to CommandArgs for executor compatibility +func (c *CommandData) ToArgs() *common.CommandArgs { + args := &common.CommandArgs{ + // Common fields + SessionID: c.SessionID, + URL: c.URL, + + // User management + Username: c.Username, + Groupname: c.Groupname, + Groupnames: c.Groupnames, + HomeDirectory: c.HomeDirectory, + HomeDirectoryPermission: c.HomeDirectoryPermission, + PurgeHomeDirectory: c.PurgeHomeDirectory, + UID: c.UID, + GID: c.GID, + Comment: c.Comment, + Shell: c.Shell, + Groups: c.Groups, + + // File operations + Type: c.Type, + Content: c.Content, + Path: c.Path, + Paths: c.Paths, + AllowOverwrite: c.AllowOverwrite, + AllowUnzip: c.AllowUnzip, + UseBlob: c.UseBlob, + + // Terminal operations + Rows: c.Rows, + Cols: c.Cols, + + // Tunnel operations + TargetPort: c.TargetPort, + + // Firewall operations + Keys: c.Keys, + ChainName: c.ChainName, + Method: c.Method, + Chain: c.Chain, + Protocol: c.Protocol, + PortStart: c.PortStart, + PortEnd: c.PortEnd, + DPorts: c.DPorts, + ICMPType: c.ICMPType, + Source: c.Source, + Destination: c.Destination, + Target: c.Target, + Description: c.Description, + Priority: c.Priority, + RuleType: c.RuleType, + Operation: c.Operation, + RuleID: c.RuleID, + OldRuleID: c.OldRuleID, + AssignmentID: c.AssignmentID, + ServerID: c.ServerID, + ChainNames: c.ChainNames, + + // Backend information + Backend: c.Backend, + Table: c.Table, + Family: c.Family, + + // Firewalld specific + Zone: c.Zone, + Service: c.Service, + FirewalldRuleType: c.FirewalldRuleType, + + // UFW specific + Direction: c.Direction, + Interface: c.Interface, + } + + // Convert Files if present + if len(c.Files) > 0 { + args.Files = make([]common.File, len(c.Files)) + for i, f := range c.Files { + args.Files[i] = common.File{ + Username: f.Username, + Groupname: f.Groupname, + Type: f.Type, + Content: f.Content, + Path: f.Path, + AllowOverwrite: f.AllowOverwrite, + AllowUnzip: f.AllowUnzip, + URL: f.URL, + } + } + } + + // Convert Rules if present + if len(c.Rules) > 0 { + args.Rules = make([]common.FirewallRule, len(c.Rules)) + for i, ruleMap := range c.Rules { + rule := common.FirewallRule{} + if v, ok := ruleMap["chain_name"].(string); ok { + rule.ChainName = v + } + if v, ok := ruleMap["method"].(string); ok { + rule.Method = v + } + if v, ok := ruleMap["chain"].(string); ok { + rule.Chain = v + } + if v, ok := ruleMap["protocol"].(string); ok { + rule.Protocol = v + } + if v, ok := ruleMap["port_start"].(float64); ok { + rule.PortStart = int(v) + } + if v, ok := ruleMap["port_end"].(float64); ok { + rule.PortEnd = int(v) + } + if v, ok := ruleMap["icmp_type"].(string); ok { + rule.ICMPType = v + } + if v, ok := ruleMap["source"].(string); ok { + rule.Source = v + } + if v, ok := ruleMap["destination"].(string); ok { + rule.Destination = v + } + if v, ok := ruleMap["target"].(string); ok { + rule.Target = v + } + if v, ok := ruleMap["description"].(string); ok { + rule.Description = v + } + if v, ok := ruleMap["priority"].(float64); ok { + rule.Priority = int(v) + } + if v, ok := ruleMap["rule_type"].(string); ok { + rule.RuleType = v + } + if v, ok := ruleMap["rule_id"].(string); ok { + rule.RuleID = v + } + if v, ok := ruleMap["old_rule_id"].(string); ok { + rule.OldRuleID = v + } + if v, ok := ruleMap["operation"].(string); ok { + rule.Operation = v + } + args.Rules[i] = rule + } + } + + return args +} diff --git a/internal/protocol/message.go b/internal/protocol/message.go new file mode 100644 index 0000000..c8c5b2e --- /dev/null +++ b/internal/protocol/message.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "encoding/json" + "time" +) + +// MessageType defines the type of protocol message +type MessageType string + +const ( + MessageTypeCommand MessageType = "command" + MessageTypeQuit MessageType = "quit" + MessageTypeReconnect MessageType = "reconnect" + MessageTypePing MessageType = "ping" +) + +// Message is the base protocol message envelope received from the server +type Message struct { + Query MessageType `json:"query"` + Command *Command `json:"command,omitempty"` + Reason string `json:"reason,omitempty"` + Raw json.RawMessage `json:"-"` // Original raw message for debugging +} + +// Response represents a response message to send back to the server +type Response struct { + Query string `json:"query"` +} + +// CommandResponse represents a command execution result +type CommandResponse struct { + Success bool `json:"success"` + Result string `json:"result"` + ElapsedTime float64 `json:"elapsed_time"` +} + +// PingResponse represents a ping response +type PingResponse struct { + Query string `json:"query"` + Timestamp time.Time `json:"timestamp,omitempty"` +} + +// NewPingResponse creates a new ping response +func NewPingResponse() *PingResponse { + return &PingResponse{ + Query: "ping", + Timestamp: time.Now(), + } +} + +// NewCommandResponse creates a new command response +func NewCommandResponse(success bool, result string, elapsed float64) *CommandResponse { + return &CommandResponse{ + Success: success, + Result: result, + ElapsedTime: elapsed, + } +} + +// ParseMessage parses a raw JSON message into a Message struct +func ParseMessage(data []byte) (*Message, error) { + var msg Message + if err := json.Unmarshal(data, &msg); err != nil { + return nil, err + } + msg.Raw = data + return &msg, nil +} diff --git a/internal/protocol/message_test.go b/internal/protocol/message_test.go new file mode 100644 index 0000000..86060d4 --- /dev/null +++ b/internal/protocol/message_test.go @@ -0,0 +1,128 @@ +package protocol + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseMessage_Command(t *testing.T) { + raw := `{ + "query": "command", + "command": { + "id": "test-123", + "shell": "internal", + "line": "ping", + "user": "root", + "group": "root", + "env": {"FOO": "bar"}, + "data": "{\"session_id\": \"sess-456\"}" + } + }` + + msg, err := ParseMessage([]byte(raw)) + require.NoError(t, err) + assert.Equal(t, MessageTypeCommand, msg.Query) + require.NotNil(t, msg.Command) + assert.Equal(t, "test-123", msg.Command.ID) + assert.Equal(t, "internal", msg.Command.Shell) + assert.Equal(t, "ping", msg.Command.Line) + assert.Equal(t, "root", msg.Command.User) + assert.Equal(t, "bar", msg.Command.Env["FOO"]) +} + +func TestParseMessage_Quit(t *testing.T) { + raw := `{"query": "quit", "reason": "server shutdown"}` + + msg, err := ParseMessage([]byte(raw)) + require.NoError(t, err) + assert.Equal(t, MessageTypeQuit, msg.Query) + assert.Equal(t, "server shutdown", msg.Reason) +} + +func TestParseMessage_Reconnect(t *testing.T) { + raw := `{"query": "reconnect", "reason": "server restart"}` + + msg, err := ParseMessage([]byte(raw)) + require.NoError(t, err) + assert.Equal(t, MessageTypeReconnect, msg.Query) + assert.Equal(t, "server restart", msg.Reason) +} + +func TestParseMessage_Invalid(t *testing.T) { + raw := `invalid json` + + msg, err := ParseMessage([]byte(raw)) + assert.Error(t, err) + assert.Nil(t, msg) +} + +func TestCommand_ParseCommandData(t *testing.T) { + cmd := &Command{ + ID: "test-123", + Shell: "internal", + Line: "adduser", + Data: `{"username": "testuser", "uid": 1000, "gid": 1000}`, + } + + data, err := cmd.ParseCommandData() + require.NoError(t, err) + assert.Equal(t, "testuser", data.Username) + assert.Equal(t, uint64(1000), data.UID) + assert.Equal(t, uint64(1000), data.GID) +} + +func TestCommand_ParseCommandData_Empty(t *testing.T) { + cmd := &Command{ + ID: "test-123", + Data: "", + } + + data, err := cmd.ParseCommandData() + require.NoError(t, err) + assert.NotNil(t, data) + assert.Equal(t, "", data.Username) +} + +func TestCommand_ParseCommandData_Invalid(t *testing.T) { + cmd := &Command{ + ID: "test-123", + Data: "invalid json", + } + + data, err := cmd.ParseCommandData() + assert.Error(t, err) + assert.Nil(t, data) +} + +func TestNewCommandResponse(t *testing.T) { + resp := NewCommandResponse(true, "success output", 1.5) + + assert.True(t, resp.Success) + assert.Equal(t, "success output", resp.Result) + assert.Equal(t, 1.5, resp.ElapsedTime) +} + +func TestNewPingResponse(t *testing.T) { + resp := NewPingResponse() + + assert.Equal(t, "ping", resp.Query) + assert.False(t, resp.Timestamp.IsZero()) +} + +func TestCommandResponse_JSON(t *testing.T) { + resp := NewCommandResponse(true, "done", 2.5) + + data, err := json.Marshal(resp) + require.NoError(t, err) + + var decoded CommandResponse + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, resp.Success, decoded.Success) + assert.Equal(t, resp.Result, decoded.Result) + assert.Equal(t, resp.ElapsedTime, decoded.ElapsedTime) +} diff --git a/pkg/agent/context.go b/pkg/agent/context.go new file mode 100644 index 0000000..d2ef617 --- /dev/null +++ b/pkg/agent/context.go @@ -0,0 +1,75 @@ +// Package agent provides context management for the Alpamon agent. +package agent + +import ( + "context" + "sync" + "time" +) + +// ContextManager manages contexts for the agent, providing +// centralized context creation and cancellation. +type ContextManager struct { + root context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +// NewContextManager creates a new context manager with a root context. +func NewContextManager() *ContextManager { + ctx, cancel := context.WithCancel(context.Background()) + return &ContextManager{ + root: ctx, + cancel: cancel, + } +} + +// NewContext creates a new child context with an optional timeout. +// If timeout is 0 or negative, no timeout is applied. +func (m *ContextManager) NewContext(timeout time.Duration) (context.Context, context.CancelFunc) { + m.mu.Lock() + defer m.mu.Unlock() + + if timeout > 0 { + return context.WithTimeout(m.root, timeout) + } + return context.WithCancel(m.root) +} + +// NewContextWithDeadline creates a new child context with a specific deadline. +func (m *ContextManager) NewContextWithDeadline(deadline time.Time) (context.Context, context.CancelFunc) { + m.mu.Lock() + defer m.mu.Unlock() + + return context.WithDeadline(m.root, deadline) +} + +// Root returns the root context. +// This should be used sparingly, primarily for operations that need +// to outlive the normal shutdown process. +func (m *ContextManager) Root() context.Context { + m.mu.Lock() + defer m.mu.Unlock() + return m.root +} + +// Shutdown cancels the root context, triggering cancellation of all child contexts. +// This should be called during graceful shutdown. +func (m *ContextManager) Shutdown() { + m.mu.Lock() + defer m.mu.Unlock() + m.cancel() +} + +// IsShutdown returns true if the context manager has been shut down. +func (m *ContextManager) IsShutdown() bool { + m.mu.Lock() + defer m.mu.Unlock() + + select { + case <-m.root.Done(): + return true + default: + return false + } +} diff --git a/pkg/agent/context_benchmark_test.go b/pkg/agent/context_benchmark_test.go new file mode 100644 index 0000000..48980ab --- /dev/null +++ b/pkg/agent/context_benchmark_test.go @@ -0,0 +1,57 @@ +package agent + +import ( + "testing" + "time" +) + +// BenchmarkContextManager_NewContext measures context creation performance +func BenchmarkContextManager_NewContext(b *testing.B) { + cm := NewContextManager() + defer cm.Shutdown() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, cancel := cm.NewContext(0) + cancel() + } +} + +// BenchmarkContextManager_NewContextWithTimeout measures context creation with timeout +func BenchmarkContextManager_NewContextWithTimeout(b *testing.B) { + cm := NewContextManager() + defer cm.Shutdown() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, cancel := cm.NewContext(30 * time.Second) + cancel() + } +} + +// BenchmarkContextManager_ConcurrentNewContext measures concurrent context creation +func BenchmarkContextManager_ConcurrentNewContext(b *testing.B) { + cm := NewContextManager() + defer cm.Shutdown() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, cancel := cm.NewContext(0) + cancel() + } + }) +} + +// BenchmarkContextManager_NewContextWithDeadline measures deadline context creation +func BenchmarkContextManager_NewContextWithDeadline(b *testing.B) { + cm := NewContextManager() + defer cm.Shutdown() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + deadline := time.Now().Add(30 * time.Second) + _, cancel := cm.NewContextWithDeadline(deadline) + cancel() + } +} diff --git a/pkg/agent/context_leak_test.go b/pkg/agent/context_leak_test.go new file mode 100644 index 0000000..e916ba9 --- /dev/null +++ b/pkg/agent/context_leak_test.go @@ -0,0 +1,180 @@ +package agent + +import ( + "runtime" + "sync" + "testing" + "time" +) + +// TestContextManager_NoGoroutineLeak verifies that goroutines are properly cleaned up +func TestContextManager_NoGoroutineLeak(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + cm := NewContextManager() + + // Create many contexts + var cancels []func() + for i := 0; i < 100; i++ { + _, cancel := cm.NewContext(0) + cancels = append(cancels, cancel) + } + + // Create contexts with timeout + for i := 0; i < 50; i++ { + _, cancel := cm.NewContext(100 * time.Millisecond) + cancels = append(cancels, cancel) + } + + // Cancel all contexts + for _, cancel := range cancels { + cancel() + } + + // Shutdown manager + cm.Shutdown() + + runtime.GC() + time.Sleep(200 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak detected: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestContextManager_RapidCreateCancel verifies no leak with rapid create/cancel cycles +func TestContextManager_RapidCreateCancel(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + // Create and shutdown multiple managers rapidly + for i := 0; i < 20; i++ { + cm := NewContextManager() + + // Create and immediately cancel contexts + for j := 0; j < 50; j++ { + ctx, cancel := cm.NewContext(time.Duration(j) * time.Millisecond) + _ = ctx + cancel() + } + + cm.Shutdown() + } + + runtime.GC() + time.Sleep(200 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after rapid create/cancel: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestContextManager_ChildCleanup verifies child context cleanup on parent shutdown +func TestContextManager_ChildCleanup(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + cm := NewContextManager() + + // Create child contexts that simulate long-running operations + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := cm.NewContext(5 * time.Second) + defer cancel() + + // Simulate work + select { + case <-ctx.Done(): + // Context was cancelled + case <-time.After(100 * time.Millisecond): + // Work completed normally + } + }() + } + + // Give some time for goroutines to start + time.Sleep(20 * time.Millisecond) + + // Shutdown manager - should cancel all child contexts + cm.Shutdown() + + // Wait for all goroutines to exit + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // All goroutines exited + case <-time.After(2 * time.Second): + t.Error("child goroutines did not exit after parent shutdown") + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after child cleanup: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} + +// TestContextManager_ConcurrentOperations verifies thread safety and no leaks under concurrent access +func TestContextManager_ConcurrentOperations(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + + initial := runtime.NumGoroutine() + + cm := NewContextManager() + + var wg sync.WaitGroup + + // Concurrent context creation + for i := 0; i < 50; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 10; j++ { + ctx, cancel := cm.NewContext(time.Duration(id+j) * time.Millisecond) + select { + case <-ctx.Done(): + case <-time.After(50 * time.Millisecond): + } + cancel() + } + }(i) + } + + // Wait for all operations + wg.Wait() + + // Shutdown + cm.Shutdown() + + runtime.GC() + time.Sleep(200 * time.Millisecond) + + final := runtime.NumGoroutine() + + if final > initial+3 { + t.Errorf("goroutine leak after concurrent ops: initial=%d, final=%d (delta=%d)", initial, final, final-initial) + } +} diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go new file mode 100644 index 0000000..9ad0b91 --- /dev/null +++ b/pkg/agent/context_test.go @@ -0,0 +1,200 @@ +package agent + +import ( + "context" + "testing" + "time" +) + +// TestContextManagerCreation verifies context manager initialization +func TestContextManagerCreation(t *testing.T) { + cm := NewContextManager() + if cm == nil { + t.Fatal("NewContextManager returned nil") + } + + if cm.IsShutdown() { + t.Error("new context manager should not be shutdown") + } + + // Root context should be active + select { + case <-cm.Root().Done(): + t.Error("root context should not be cancelled") + default: + // Expected + } +} + +// TestContextCancellation verifies that shutdown cancels all child contexts +func TestContextCancellation(t *testing.T) { + cm := NewContextManager() + + // Create multiple child contexts + ctx1, cancel1 := cm.NewContext(0) + defer cancel1() + + ctx2, cancel2 := cm.NewContext(5 * time.Second) + defer cancel2() + + ctx3, cancel3 := cm.NewContext(0) + defer cancel3() + + // Shutdown the manager + cm.Shutdown() + + // All contexts should be cancelled + for i, ctx := range []context.Context{ctx1, ctx2, ctx3} { + select { + case <-ctx.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Errorf("context %d not cancelled after shutdown", i+1) + } + } + + // Manager should report as shutdown + if !cm.IsShutdown() { + t.Error("IsShutdown() should return true after Shutdown()") + } +} + +// TestContextTimeout verifies timeout context creation +func TestContextTimeout(t *testing.T) { + cm := NewContextManager() + defer cm.Shutdown() + + // Create context with timeout + ctx, cancel := cm.NewContext(50 * time.Millisecond) + defer cancel() + + // Should timeout + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("expected DeadlineExceeded, got %v", ctx.Err()) + } + case <-time.After(100 * time.Millisecond): + t.Error("context did not timeout") + } +} + +// TestContextDeadline verifies deadline context creation +func TestContextDeadline(t *testing.T) { + cm := NewContextManager() + defer cm.Shutdown() + + deadline := time.Now().Add(50 * time.Millisecond) + ctx, cancel := cm.NewContextWithDeadline(deadline) + defer cancel() + + // Should respect deadline + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("expected DeadlineExceeded, got %v", ctx.Err()) + } + case <-time.After(100 * time.Millisecond): + t.Error("context did not respect deadline") + } +} + +// TestContextNoTimeout verifies context without timeout +func TestContextNoTimeout(t *testing.T) { + cm := NewContextManager() + defer cm.Shutdown() + + // Create context without timeout (0 duration) + ctx, cancel := cm.NewContext(0) + defer cancel() + + // Should not timeout on its own + select { + case <-ctx.Done(): + t.Error("context should not be cancelled without explicit cancellation") + case <-time.After(100 * time.Millisecond): + // Expected + } + + // Manual cancellation should work + cancel() + select { + case <-ctx.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("context not cancelled after explicit cancel") + } +} + +// TestConcurrentContextCreation verifies thread-safe context creation +func TestConcurrentContextCreation(t *testing.T) { + cm := NewContextManager() + defer cm.Shutdown() + + done := make(chan bool) + + // Create contexts concurrently + for i := 0; i < 100; i++ { + go func(id int) { + ctx, cancel := cm.NewContext(time.Duration(id) * time.Millisecond) + defer cancel() + + // Do some work + select { + case <-ctx.Done(): + // Timeout expected for non-zero durations + case <-time.After(200 * time.Millisecond): + // Maximum wait + } + + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 100; i++ { + select { + case <-done: + // Continue + case <-time.After(5 * time.Second): + t.Fatal("concurrent operations timed out") + } + } + + // Context manager should still be functional + ctx, cancel := cm.NewContext(0) + cancel() + select { + case <-ctx.Done(): + // Expected + default: + t.Error("context manager not functional after concurrent operations") + } +} + +// TestShutdownIdempotency verifies that Shutdown can be called multiple times +func TestShutdownIdempotency(t *testing.T) { + cm := NewContextManager() + + // Create a context + ctx, cancel := cm.NewContext(0) + defer cancel() + + // Shutdown multiple times + cm.Shutdown() + cm.Shutdown() // Should not panic + cm.Shutdown() // Should not panic + + // Context should still be cancelled + select { + case <-ctx.Done(): + // Expected + default: + t.Error("context not cancelled after shutdown") + } + + // IsShutdown should still return true + if !cm.IsShutdown() { + t.Error("IsShutdown() should return true after multiple shutdowns") + } +} diff --git a/pkg/collector/collector.go b/pkg/collector/collector.go index d57c02a..6317b56 100644 --- a/pkg/collector/collector.go +++ b/pkg/collector/collector.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/collector/check" "github.com/alpacax/alpamon/pkg/collector/check/base" "github.com/alpacax/alpamon/pkg/collector/scheduler" @@ -32,6 +33,7 @@ type Collector struct { wg sync.WaitGroup ctx context.Context cancel context.CancelFunc + ctxManager *agent.ContextManager } type collectConf struct { @@ -47,7 +49,7 @@ type collectorArgs struct { transportFactory transporter.TransporterFactory } -func InitCollector(session *session.Session, client *ent.Client) *Collector { +func InitCollector(session *session.Session, client *ent.Client, ctxManager *agent.ContextManager) *Collector { conf, err := fetchConfig(session) if err != nil { log.Error().Err(err).Msg("Failed to fetch collector config.") @@ -65,7 +67,7 @@ func InitCollector(session *session.Session, client *ent.Client) *Collector { transportFactory: transporterFactory, } - collector, err := NewCollector(args) + collector, err := NewCollector(args, ctxManager) if err != nil { log.Error().Err(err).Msg("Failed to create collector.") return nil @@ -92,7 +94,7 @@ func fetchConfig(session *session.Session) ([]collectConf, error) { return conf, nil } -func NewCollector(args collectorArgs) (*Collector, error) { +func NewCollector(args collectorArgs, ctxManager *agent.ContextManager) (*Collector, error) { metricTransporter, err := args.transportFactory.CreateTransporter(args.session) if err != nil { return nil, err @@ -104,6 +106,7 @@ func NewCollector(args collectorArgs) (*Collector, error) { scheduler: scheduler.NewScheduler(), buffer: checkBuffer, errorChan: make(chan error, 10), + ctxManager: ctxManager, } err = metricCollector.initTasks(args) @@ -136,7 +139,8 @@ func (c *Collector) initTasks(args collectorArgs) error { func (c *Collector) Start() { log.Debug().Msg("Started collector") - c.ctx, c.cancel = context.WithCancel(context.Background()) + // Use context from global ContextManager instead of creating local context + c.ctx, c.cancel = c.ctxManager.NewContext(0) // 0 means no timeout go c.scheduler.Start(c.ctx, c.buffer.Capacity) diff --git a/pkg/config/config.go b/pkg/config/config.go index a1f9485..38981ae 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -19,13 +19,25 @@ var ( ) const ( - MinConnectInterval = 5 * time.Second - MaxConnectInterval = 300 * time.Second + MinConnectInterval = 5 * time.Second + MaxConnectInterval = 300 * time.Second + + // Smux configuration for tunnel connections SmuxKeepAliveInterval = 10 * time.Second SmuxKeepAliveTimeout = 30 * time.Second SmuxMaxFrameSize = 32768 // 32KB SmuxMaxReceiveBuffer = 4194304 // 4MB SmuxMaxStreamBuffer = 65536 // 64KB per stream + + // Pool configuration defaults + DefaultPoolMaxWorkers = 20 + DefaultPoolQueueSize = 200 + DefaultPoolDefaultTimeout = 30 + + // Pool configuration limits for warnings + MaxReasonableWorkers = 1000 + MaxReasonableQueueSize = 10000 + MaxReasonableTimeoutSeconds = 3600 ) // GetSmuxConfig returns optimized smux configuration for tunnel connections. @@ -105,11 +117,14 @@ func validateConfig(config Config, wsPath string) (bool, Settings) { log.Debug().Msg("Validating configuration fields...") settings := Settings{ - WSPath: wsPath, - UseSSL: false, - SSLVerify: true, - SSLOpt: make(map[string]interface{}), - HTTPThreads: 4, + WSPath: wsPath, + UseSSL: false, + SSLVerify: true, + SSLOpt: make(map[string]interface{}), + HTTPThreads: 4, + PoolMaxWorkers: DefaultPoolMaxWorkers, + PoolQueueSize: DefaultPoolQueueSize, + PoolDefaultTimeout: DefaultPoolDefaultTimeout, } valid := true @@ -153,6 +168,46 @@ func validateConfig(config Config, wsPath string) (bool, Settings) { } } + // Validate and set worker pool configuration + if config.Pool.MaxWorkers > 0 { + settings.PoolMaxWorkers = config.Pool.MaxWorkers + log.Debug().Msgf("Using configured pool max workers: %d", settings.PoolMaxWorkers) + } else { + log.Debug().Msgf("Using default pool max workers: %d", settings.PoolMaxWorkers) + } + + if config.Pool.QueueSize > 0 { + settings.PoolQueueSize = config.Pool.QueueSize + log.Debug().Msgf("Using configured pool queue size: %d", settings.PoolQueueSize) + } else { + log.Debug().Msgf("Using default pool queue size: %d", settings.PoolQueueSize) + } + + // Validate and set default timeout for pool tasks + // Use pointer type to distinguish "not configured" (nil) from "explicitly set to 0" + if config.Pool.DefaultTimeout != nil { + settings.PoolDefaultTimeout = *config.Pool.DefaultTimeout + if settings.PoolDefaultTimeout == 0 { + log.Debug().Msg("Using configured pool default timeout: 0 (no timeout)") + } else { + log.Debug().Msgf("Using configured pool default timeout: %d seconds", settings.PoolDefaultTimeout) + } + } else { + // Keep the default value that was set during Settings initialization + log.Debug().Msgf("Using default pool timeout: %d seconds", settings.PoolDefaultTimeout) + } + + // Validate pool settings are reasonable + if settings.PoolMaxWorkers > MaxReasonableWorkers { + log.Warn().Msgf("Pool max workers (%d) seems very high, consider reducing it", settings.PoolMaxWorkers) + } + if settings.PoolQueueSize > MaxReasonableQueueSize { + log.Warn().Msgf("Pool queue size (%d) seems very high, consider reducing it", settings.PoolQueueSize) + } + if settings.PoolDefaultTimeout > MaxReasonableTimeoutSeconds { + log.Warn().Msgf("Pool default timeout (%d seconds) seems very high, consider reducing it", settings.PoolDefaultTimeout) + } + return valid, settings } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..ec2d521 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,78 @@ +package config + +import ( + "os" + "testing" +) + +func TestPoolConfigDefaults(t *testing.T) { + // Test that default pool values are set correctly when not configured + config := Config{} + _, settings := validateConfig(config, "/ws/test/") + + if settings.PoolMaxWorkers != DefaultPoolMaxWorkers { + t.Errorf("Expected default PoolMaxWorkers to be %d, got %d", DefaultPoolMaxWorkers, settings.PoolMaxWorkers) + } + + if settings.PoolQueueSize != DefaultPoolQueueSize { + t.Errorf("Expected default PoolQueueSize to be %d, got %d", DefaultPoolQueueSize, settings.PoolQueueSize) + } + + if settings.PoolDefaultTimeout != DefaultPoolDefaultTimeout { + t.Errorf("Expected default PoolDefaultTimeout to be %d, got %d", DefaultPoolDefaultTimeout, settings.PoolDefaultTimeout) + } +} + +func TestPoolConfigCustomValues(t *testing.T) { + // Test that custom pool values are applied correctly + config := Config{} + config.Pool.MaxWorkers = 50 + config.Pool.QueueSize = 500 + + _, settings := validateConfig(config, "/ws/test/") + + if settings.PoolMaxWorkers != 50 { + t.Errorf("Expected PoolMaxWorkers to be 50, got %d", settings.PoolMaxWorkers) + } + + if settings.PoolQueueSize != 500 { + t.Errorf("Expected PoolQueueSize to be 500, got %d", settings.PoolQueueSize) + } +} + +func TestPoolConfigFromINI(t *testing.T) { + // Create a temporary config file + content := `[server] +url = http://test.com +id = testid +key = testkey + +[pool] +max_workers = 30 +queue_size = 300 +` + + tmpfile, err := os.CreateTemp("", "alpamon-test-*.conf") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write([]byte(content)); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + // Load the config + settings := LoadConfig([]string{tmpfile.Name()}, "/ws/test/") + + if settings.PoolMaxWorkers != 30 { + t.Errorf("Expected PoolMaxWorkers to be 30 from INI, got %d", settings.PoolMaxWorkers) + } + + if settings.PoolQueueSize != 300 { + t.Errorf("Expected PoolQueueSize to be 300 from INI, got %d", settings.PoolQueueSize) + } +} diff --git a/pkg/config/types.go b/pkg/config/types.go index a8d5a0b..cb3a2e1 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -1,15 +1,18 @@ package config type Settings struct { - ServerURL string - WSPath string - UseSSL bool - CaCert string // CA certificate file path - SSLVerify bool - SSLOpt map[string]interface{} - HTTPThreads int - ID string - Key string + ServerURL string + WSPath string + UseSSL bool + CaCert string // CA certificate file path + SSLVerify bool + SSLOpt map[string]interface{} + HTTPThreads int + ID string + Key string + PoolMaxWorkers int // Maximum number of workers in the global worker pool + PoolQueueSize int // Size of the job queue for the global worker pool + PoolDefaultTimeout int // Default timeout in seconds for pool tasks (0 = no timeout) } type Config struct { @@ -25,4 +28,9 @@ type Config struct { Logging struct { Debug bool `ini:"debug"` } `ini:"logging"` + Pool struct { + MaxWorkers int `ini:"max_workers"` + QueueSize int `ini:"queue_size"` + DefaultTimeout *int `ini:"default_timeout"` + } `ini:"pool"` } diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go new file mode 100644 index 0000000..1cc951c --- /dev/null +++ b/pkg/executor/adapter.go @@ -0,0 +1,39 @@ +package executor + +import ( + "github.com/alpacax/alpamon/pkg/scheduler" +) + +// SystemInfoAdapter implements the common.SystemInfoManager interface for handlers +type SystemInfoAdapter struct { + session *scheduler.Session + commitFunc func() + syncFunc func(*scheduler.Session, []string) +} + +// NewSystemInfoAdapter creates a new system info adapter with function callbacks +func NewSystemInfoAdapter( + session *scheduler.Session, + commitFunc func(), + syncFunc func(*scheduler.Session, []string), +) *SystemInfoAdapter { + return &SystemInfoAdapter{ + session: session, + commitFunc: commitFunc, + syncFunc: syncFunc, + } +} + +// CommitSystemInfo implements common.SystemInfoManager +func (a *SystemInfoAdapter) CommitSystemInfo() { + if a.commitFunc != nil { + a.commitFunc() + } +} + +// SyncSystemInfo implements common.SystemInfoManager +func (a *SystemInfoAdapter) SyncSystemInfo(keys []string) { + if a.syncFunc != nil { + a.syncFunc(a.session, keys) + } +} diff --git a/pkg/executor/dispatcher.go b/pkg/executor/dispatcher.go new file mode 100644 index 0000000..87d4909 --- /dev/null +++ b/pkg/executor/dispatcher.go @@ -0,0 +1,126 @@ +package executor + +import ( + "context" + "fmt" + "time" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/scheduler" + "github.com/rs/zerolog/log" +) + +// CommandDispatcher manages command execution through registered handlers +type CommandDispatcher struct { + registry *Registry + pool *pool.Pool + ctxManager *agent.ContextManager +} + +// NewCommandDispatcher creates a new command dispatcher +func NewCommandDispatcher(pool *pool.Pool, ctxManager *agent.ContextManager) *CommandDispatcher { + return &CommandDispatcher{ + registry: NewRegistry(), + pool: pool, + ctxManager: ctxManager, + } +} + +// RegisterHandler registers a command handler +func (e *CommandDispatcher) RegisterHandler(h common.Handler) error { + return e.registry.Register(h) +} + +// Execute runs a command with the appropriate handler +func (e *CommandDispatcher) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + // Handle nil args + if args == nil { + args = &common.CommandArgs{} + } + + // Find the appropriate handler + handler, err := e.registry.Get(cmd) + if err != nil { + log.Warn().Err(err).Msgf("No handler found for command: %s", cmd) + return 1, "", fmt.Errorf("no handler found for command: %s", cmd) + } + + // Validate arguments before execution + if err := handler.Validate(cmd, args); err != nil { + log.Error().Err(err).Msgf("Command %s validation failed", cmd) + return 1, "", fmt.Errorf("validation failed: %w", err) + } + + // Execute the command + startTime := time.Now() + exitCode, output, err := handler.Execute(ctx, cmd, args) + duration := time.Since(startTime) + + // Log execution result + if err != nil { + log.Error(). + Str("command", cmd). + Int("exitCode", exitCode). + Dur("duration", duration). + Err(err). + Msg("Command execution failed") + } else { + log.Info(). + Str("command", cmd). + Int("exitCode", exitCode). + Dur("duration", duration). + Msg("Command executed successfully") + } + + return exitCode, output, err +} + +// HasHandler checks if a handler exists for the given command +func (e *CommandDispatcher) HasHandler(cmd string) bool { + return e.registry.IsCommandRegistered(cmd) +} + +// Shutdown gracefully shuts down the executor +func (e *CommandDispatcher) Shutdown(timeout time.Duration) error { + log.Info().Msg("Shutting down executor") + + // Cancel all contexts + e.ctxManager.Shutdown() + + // Shutdown the pool + if err := e.pool.Shutdown(timeout); err != nil { + log.Error().Err(err).Msg("Failed to shutdown pool gracefully") + return err + } + + log.Info().Msg("Executor shutdown complete") + return nil +} + +// InitDispatcher initializes and configures the command dispatching system with all handlers +func InitDispatcher( + pool *pool.Pool, + ctxManager *agent.ContextManager, + session *scheduler.Session, + wsClient common.WSClient, + callbacks SystemInfoCallbacks, +) (*CommandDispatcher, error) { + // Create the main command dispatcher + dispatcher := NewCommandDispatcher(pool, ctxManager) + + // Create command executor for system commands + cmdExecutor := NewExecutor() + + // Create and register all handlers using the handler factory pattern + factory := NewHandlerFactory(dispatcher, cmdExecutor) + err := factory.RegisterAll(pool, ctxManager, session, wsClient, callbacks) + if err != nil { + return nil, err + } + + log.Info().Msg("Dispatcher initialized with handlers") + + return dispatcher, nil +} diff --git a/pkg/executor/dispatcher_benchmark_test.go b/pkg/executor/dispatcher_benchmark_test.go new file mode 100644 index 0000000..0f0d7f1 --- /dev/null +++ b/pkg/executor/dispatcher_benchmark_test.go @@ -0,0 +1,115 @@ +package executor + +import ( + "context" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// BenchmarkRegistry_Get measures registry lookup performance +func BenchmarkRegistry_Get(b *testing.B) { + registry := NewRegistry() + + // Register a mock handler + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2", "cmd3"}, + } + _ = registry.Register(handler) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = registry.Get("cmd1") + } +} + +// BenchmarkRegistry_ListCommands measures command listing performance +func BenchmarkRegistry_ListCommands(b *testing.B) { + registry := NewRegistry() + + // Register multiple handlers + for i := 0; i < 10; i++ { + handler := &MockHandler{ + name: "handler" + string(rune('A'+i)), + commands: []string{"cmd" + string(rune('A'+i))}, + } + _ = registry.Register(handler) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = registry.ListCommands() + } +} + +// BenchmarkRegistry_IsCommandRegistered measures registration check performance +func BenchmarkRegistry_IsCommandRegistered(b *testing.B) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2", "cmd3"}, + } + _ = registry.Register(handler) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = registry.IsCommandRegistered("cmd1") + } +} + +// BenchmarkRegistry_ConcurrentGet measures concurrent lookup performance +func BenchmarkRegistry_ConcurrentGet(b *testing.B) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2", "cmd3"}, + } + _ = registry.Register(handler) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = registry.Get("cmd1") + } + }) +} + +// MockHandlerWithExecute is a handler that can be used for execution benchmarks +type MockHandlerWithExecute struct { + name string + commands []string +} + +func (h *MockHandlerWithExecute) Name() string { + return h.name +} + +func (h *MockHandlerWithExecute) Commands() []string { + return h.commands +} + +func (h *MockHandlerWithExecute) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + return 0, "executed", nil +} + +func (h *MockHandlerWithExecute) Validate(cmd string, args *common.CommandArgs) error { + return nil +} + +// BenchmarkHandler_Execute measures basic handler execution overhead +func BenchmarkHandler_Execute(b *testing.B) { + handler := &MockHandlerWithExecute{ + name: "test", + commands: []string{"test_cmd"}, + } + ctx := context.Background() + args := &common.CommandArgs{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = handler.Execute(ctx, "test_cmd", args) + } +} diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go new file mode 100644 index 0000000..7e48228 --- /dev/null +++ b/pkg/executor/executor.go @@ -0,0 +1,216 @@ +package executor + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "time" + + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// Executor provides system command execution with privilege management +type Executor struct{} + +// NewExecutor creates a new system command executor +func NewExecutor() *Executor { + return &Executor{} +} + +// Execute runs a command with full control over execution parameters +func (e *Executor) Execute(ctx context.Context, opts CommandOptions) (int, string, error) { + // Apply environment variable substitution + args := e.substituteEnvVars(opts.Args, opts.Env) + + // Setup context with timeout if specified + if opts.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, opts.Timeout) + defer cancel() + } + + // Create command + // codeql[go/command-injection]: Intentional - Alpamon executes admin commands from Alpacon console + cmd := exec.CommandContext(ctx, args[0], args[1:]...) // lgtm[go/command-injection] + + // Set up privilege demotion if username specified + if opts.Username != "" && opts.Username != "root" { + sysProcAttr, err := e.demotePrivileges(opts.Username, opts.Groupname) + if err != nil { + log.Error().Err(err).Msg("Failed to demote privileges") + return 1, err.Error(), err + } + if sysProcAttr != nil { + cmd.SysProcAttr = sysProcAttr + } + } + + // Set environment variables + for key, value := range opts.Env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value)) + } + + // Set working directory + if opts.WorkingDir != "" { + cmd.Dir = opts.WorkingDir + } else if opts.Username != "" { + usr, err := utils.GetSystemUser(opts.Username) + if err == nil { + cmd.Dir = usr.HomeDir + } + } + + // Set stdin if provided + if opts.Input != "" { + cmd.Stdin = bytes.NewReader([]byte(opts.Input)) + } + + log.Debug(). + Str("command", strings.Join(args, " ")). + Str("user", opts.Username). + Str("group", opts.Groupname). + Str("dir", cmd.Dir). + Msg("Executor execute command") + + // Execute command + output, err := cmd.CombinedOutput() + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + exitCode = exitError.ExitCode() + } else { + exitCode = 1 + } + } + + return exitCode, string(output), err +} + +// CommandOptions defines options for command execution +type CommandOptions struct { + Args []string // Command and arguments + Username string // Username to run as (empty = current user) + Groupname string // Group to run as + Env map[string]string // Environment variables + WorkingDir string // Working directory + Timeout time.Duration // Command timeout + Input string // Input to provide via stdin +} + +// substituteEnvVars replaces environment variables in arguments +func (e *Executor) substituteEnvVars(args []string, env map[string]string) []string { + if env == nil { + return args + } + + // Add default environment variables + defaultEnv := e.getDefaultEnv() + for key, value := range defaultEnv { + if _, exists := env[key]; !exists { + env[key] = value + } + } + + // Substitute variables in arguments + result := make([]string, len(args)) + for i, arg := range args { + result[i] = e.expandEnvVar(arg, env) + } + return result +} + +// expandEnvVar expands environment variables in a string +func (e *Executor) expandEnvVar(s string, env map[string]string) string { + // Handle ${VAR} format + if strings.HasPrefix(s, "${") && strings.HasSuffix(s, "}") { + varName := s[2 : len(s)-1] + if val, ok := env[varName]; ok { + return val + } + } + // Handle $VAR format + if strings.HasPrefix(s, "$") { + varName := s[1:] + if val, ok := env[varName]; ok { + return val + } + } + return s +} + +// getDefaultEnv returns default environment variables +func (e *Executor) getDefaultEnv() map[string]string { + return map[string]string{ + "PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "HOME": os.Getenv("HOME"), + "USER": os.Getenv("USER"), + "SHELL": "/bin/bash", + "TERM": "xterm-256color", + "LANG": "en_US.UTF-8", + "LS_COLORS": `rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.Z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.jpg=01;35:*.jpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=01;35:*.mpg=01;35:*.mpeg=01;35:*.m2v=01;35:*.mkv=01;35:*.webm=01;35:*.ogm=01;35:*.mp4=01;35:*.m4v=01;35:*.mp4v=01;35:*.vob=01;35:*.qt=01;35:*.nuv=01;35:*.wmv=01;35:*.asf=01;35:*.rm=01;35:*.rmvb=01;35:*.flc=01;35:*.avi=01;35:*.fli=01;35:*.flv=01;35:*.gl=01;35:*.dl=01;35:*.xcf=01;35:*.xwd=01;35:*.yuv=01;35:*.cgm=01;35:*.emf=01;35:*.ogv=01;35:*.ogx=01;35:*.aac=00;36:*.au=00;36:*.flac=00;36:*.m4a=00;36:*.mid=00;36:*.midi=00;36:*.mka=00;36:*.mp3=00;36:*.mpc=00;36:*.ogg=00;36:*.ra=00;36:*.wav=00;36:*.oga=00;36:*.opus=00;36:*.spx=00;36:*.xspf=00;36:`, + } +} + +// demotePrivileges creates syscall attributes for privilege demotion +func (e *Executor) demotePrivileges(username, groupname string) (*syscall.SysProcAttr, error) { + result, err := utils.Demote(username, groupname, utils.DemoteOptions{ValidateGroup: true}) + if err != nil { + return nil, err + } + if result == nil { + return nil, nil + } + return result.SysProcAttr, nil +} + +// Implement CommandExecutor interface methods + +// Run executes a command with the given arguments +func (e *Executor) Run(ctx context.Context, name string, args ...string) (int, string, error) { + allArgs := append([]string{name}, args...) + return e.Execute(ctx, CommandOptions{Args: allArgs}) +} + +// RunAsUser executes a command as a specific user +func (e *Executor) RunAsUser(ctx context.Context, username string, name string, args ...string) (int, string, error) { + allArgs := append([]string{name}, args...) + return e.Execute(ctx, CommandOptions{ + Args: allArgs, + Username: username, + Groupname: username, + }) +} + +// RunWithInput executes a command with stdin input +func (e *Executor) RunWithInput(ctx context.Context, input string, name string, args ...string) (int, string, error) { + allArgs := append([]string{name}, args...) + return e.Execute(ctx, CommandOptions{ + Args: allArgs, + Input: input, + }) +} + +// RunWithTimeout executes a command with a timeout +func (e *Executor) RunWithTimeout(ctx context.Context, timeout time.Duration, name string, args ...string) (int, string, error) { + allArgs := append([]string{name}, args...) + return e.Execute(ctx, CommandOptions{ + Args: allArgs, + Timeout: timeout, + }) +} + +// Exec executes a command with all options +func (e *Executor) Exec(ctx context.Context, args []string, username, groupname string, env map[string]string, timeout time.Duration) (int, string, error) { + return e.Execute(ctx, CommandOptions{ + Args: args, + Username: username, + Groupname: groupname, + Env: env, + Timeout: timeout, + }) +} diff --git a/pkg/executor/factory.go b/pkg/executor/factory.go new file mode 100644 index 0000000..a25d37c --- /dev/null +++ b/pkg/executor/factory.go @@ -0,0 +1,78 @@ +package executor + +import ( + "fmt" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/executor/handlers/file" + "github.com/alpacax/alpamon/pkg/executor/handlers/firewall" + "github.com/alpacax/alpamon/pkg/executor/handlers/group" + "github.com/alpacax/alpamon/pkg/executor/handlers/info" + "github.com/alpacax/alpamon/pkg/executor/handlers/shell" + "github.com/alpacax/alpamon/pkg/executor/handlers/system" + "github.com/alpacax/alpamon/pkg/executor/handlers/terminal" + "github.com/alpacax/alpamon/pkg/executor/handlers/tunnel" + "github.com/alpacax/alpamon/pkg/executor/handlers/user" + "github.com/alpacax/alpamon/pkg/executor/services" + "github.com/alpacax/alpamon/pkg/scheduler" +) + +// SystemInfoCallbacks contains function callbacks for system info operations +type SystemInfoCallbacks struct { + CommitFunc func() + SyncFunc func(*scheduler.Session, []string) +} + +// HandlerFactory encapsulates handler instantiation and registration +type HandlerFactory struct { + dispatcher *CommandDispatcher + cmdExec common.CommandExecutor +} + +// NewHandlerFactory creates a new handler factory +func NewHandlerFactory(dispatcher *CommandDispatcher, cmdExec common.CommandExecutor) *HandlerFactory { + return &HandlerFactory{ + dispatcher: dispatcher, + cmdExec: cmdExec, + } +} + +// RegisterAll registers all handlers with the provided callbacks +func (f *HandlerFactory) RegisterAll( + pool *pool.Pool, + ctxManager *agent.ContextManager, + session *scheduler.Session, + wsClient common.WSClient, + callbacks SystemInfoCallbacks, +) error { + // Create group service for dependency injection + groupService := services.NewDefaultGroupService(f.cmdExec) + + // Create system info adapter for info handler with function callbacks + infoAdapter := NewSystemInfoAdapter(session, callbacks.CommitFunc, callbacks.SyncFunc) + + // Define all handlers in a slice for streamlined registration + handlers := []common.Handler{ + system.NewSystemHandler(f.cmdExec, wsClient, ctxManager, pool), + group.NewGroupHandler(f.cmdExec, infoAdapter), + info.NewInfoHandler(infoAdapter), + shell.NewShellHandler(f.cmdExec), + user.NewUserHandler(f.cmdExec, groupService, infoAdapter), + firewall.NewFirewallHandler(f.cmdExec), + file.NewFileHandler(f.cmdExec, session), + terminal.NewTerminalHandler(f.cmdExec, session), + tunnel.NewTunnelHandler(f.cmdExec), + } + + // Register all handlers + for _, handler := range handlers { + if err := f.dispatcher.RegisterHandler(handler); err != nil { + // Return error with context about which handler failed + return fmt.Errorf("failed to register handler: %w", err) + } + } + + return nil +} diff --git a/pkg/executor/handlers/common/args.go b/pkg/executor/handlers/common/args.go new file mode 100644 index 0000000..f2ec9fb --- /dev/null +++ b/pkg/executor/handlers/common/args.go @@ -0,0 +1,117 @@ +package common + +import "time" + +// CommandArgs is a strongly-typed struct for all command arguments +type CommandArgs struct { + // Common fields + SessionID string + URL string + Command string + Timeout time.Duration + + // User management + Username string + Groupname string + Groupnames []string + HomeDirectory string + HomeDirectoryPermission string + PurgeHomeDirectory bool + UID uint64 + GID uint64 + Comment string + Shell string + Groups []uint64 + + // File operations + Type string + Content string + Path string + Paths []string + Files []File + AllowOverwrite bool + AllowUnzip bool + UseBlob bool + + // Terminal operations + Rows uint16 + Cols uint16 + Input string + + // Tunnel operations + TargetPort int + + // Environment + Env map[string]string + + // Firewall operations + Keys []string + ChainName string + Method string + Chain string + Protocol string + PortStart int + PortEnd int + DPorts []int + ICMPType string + Source string + Destination string + Target string + Description string + Priority int + RuleType string + Rules []FirewallRule + Operation string + RuleID string + OldRuleID string + AssignmentID string + ServerID string + ChainNames []string + + // Backend information + Backend string // Backend type: iptables, nftables, firewalld, ufw + Table string // iptables/nftables table: filter, nat, mangle, raw, security + Family string // IP family: ip (IPv4), ip6 (IPv6), inet, arp, bridge, netdev + + // Firewalld specific + Zone string // Firewalld zone (default, public, etc.) + Service string // Firewalld service name + FirewalldRuleType string // Firewalld rule type: service, port, rich + + // UFW specific + Direction string // UFW direction: in, out + Interface string // UFW interface name +} + +// File represents a file transfer operation +type File struct { + Username string `json:"username"` + Groupname string `json:"groupname"` + Type string `json:"type"` + Content string `json:"content"` + Path string `json:"path"` + AllowOverwrite bool `json:"allow_overwrite"` + AllowUnzip bool `json:"allow_unzip"` + URL string `json:"url"` +} + +// FirewallRule represents a single firewall rule +type FirewallRule struct { + ChainName string `json:"chain_name"` + Method string `json:"method"` + Chain string `json:"chain"` + Protocol string `json:"protocol"` + PortStart int `json:"port_start"` + PortEnd int `json:"port_end"` + DPorts []int `json:"dports"` + ICMPType string `json:"icmp_type"` + Source string `json:"source"` + Destination string `json:"destination"` + Target string `json:"target"` + Description string `json:"description"` + Priority int `json:"priority"` + RuleType string `json:"rule_type"` + RuleID string `json:"rule_id"` + OldRuleID string `json:"old_rule_id"` + Operation string `json:"operation"` +} diff --git a/pkg/executor/handlers/common/base.go b/pkg/executor/handlers/common/base.go new file mode 100644 index 0000000..45fec58 --- /dev/null +++ b/pkg/executor/handlers/common/base.go @@ -0,0 +1,98 @@ +package common + +import ( + "strconv" + + "gopkg.in/go-playground/validator.v9" +) + +// BaseHandler provides common functionality for all handlers +type BaseHandler struct { + name string + commands []string + validator *validator.Validate + Executor CommandExecutor // Made public +} + +// NewBaseHandler creates a new base handler +func NewBaseHandler(name HandlerType, commands []CommandType, executor CommandExecutor) *BaseHandler { + return &BaseHandler{ + name: name.String(), + commands: CommandsToStrings(commands), + validator: validator.New(), + Executor: executor, + } +} + +// Name returns the handler name +func (h *BaseHandler) Name() string { + return h.name +} + +// Commands returns the list of supported commands +func (h *BaseHandler) Commands() []string { + return h.commands +} + +// ValidateStruct validates a struct using struct tags +func (h *BaseHandler) ValidateStruct(s interface{}) error { + return h.validator.Struct(s) +} + +// Helper functions for common operations + +// GetStringArg retrieves a string argument from the args map +func GetStringArg(args map[string]interface{}, key string, defaultValue string) string { + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return defaultValue +} + +// GetIntArg retrieves an integer argument from the args map +func GetIntArg(args map[string]interface{}, key string, defaultValue int) int { + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v + case float64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + } + return defaultValue +} + +// GetBoolArg retrieves a boolean argument from the args map +func GetBoolArg(args map[string]interface{}, key string, defaultValue bool) bool { + if val, ok := args[key]; ok { + if b, ok := val.(bool); ok { + return b + } + } + return defaultValue +} + +// GetStringSliceArg retrieves a string slice argument from the args map +func GetStringSliceArg(args map[string]interface{}, key string) []string { + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v + case []interface{}: + var result []string + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + } + } + return nil +} diff --git a/pkg/executor/handlers/common/interfaces.go b/pkg/executor/handlers/common/interfaces.go new file mode 100644 index 0000000..cbd955a --- /dev/null +++ b/pkg/executor/handlers/common/interfaces.go @@ -0,0 +1,63 @@ +package common + +import ( + "bytes" + "context" + "time" +) + +// Handler defines the interface for command handlers. +// Each handler is responsible for executing a specific set of commands. +type Handler interface { + // Name returns the handler name (e.g., "system", "user", "firewall") + Name() string + + // Commands returns the list of commands this handler supports + Commands() []string + + // Execute runs the specified command with the given arguments. + // Returns exit code, output string, and error if command fails. + Execute(ctx context.Context, cmd string, args *CommandArgs) (exitCode int, output string, err error) + + // Validate checks if the provided arguments are valid for the command. + // This allows pre-execution validation without running the command. + Validate(cmd string, args *CommandArgs) error +} + +// CommandExecutor defines the interface for executing system commands. +// This interface is defined within the handlers package to avoid circular dependencies. +// The concrete implementation is in the executor package. +type CommandExecutor interface { + // Run executes a command with the given arguments + Run(ctx context.Context, name string, args ...string) (int, string, error) + + // RunAsUser executes a command as a specific user + RunAsUser(ctx context.Context, username string, name string, args ...string) (int, string, error) + + // RunWithInput executes a command with stdin input + RunWithInput(ctx context.Context, input string, name string, args ...string) (int, string, error) + + // RunWithTimeout executes a command with a timeout + RunWithTimeout(ctx context.Context, timeout time.Duration, name string, args ...string) (int, string, error) + + // Exec executes a command with all options (user, group, env, timeout) + Exec(ctx context.Context, args []string, username, groupname string, env map[string]string, timeout time.Duration) (int, string, error) +} + +// WSClient interface for WebSocket client operations +type WSClient interface { + Restart() + ShutDown() + RestartCollector() +} + +// SystemInfoManager interface for system info operations +type SystemInfoManager interface { + CommitSystemInfo() + SyncSystemInfo(keys []string) +} + +// APISession interface for API operations (file upload) +type APISession interface { + MultipartRequest(url string, body bytes.Buffer, contentType string, timeout time.Duration) ([]byte, int, error) +} diff --git a/pkg/executor/handlers/common/testing.go b/pkg/executor/handlers/common/testing.go new file mode 100644 index 0000000..ad0e374 --- /dev/null +++ b/pkg/executor/handlers/common/testing.go @@ -0,0 +1,77 @@ +package common + +import ( + "context" + "strings" + "testing" + "time" +) + +// MockCommandExecutor is a mock implementation of CommandExecutor for testing. +// It is the single source of truth for mocking in this package. +type MockCommandExecutor struct { + t *testing.T + commands []ExecutedCommand + results map[string]CommandResult +} + +// ExecutedCommand represents a command that was executed by the mock. +type ExecutedCommand struct { + Name string + Args []string + User string +} + +// CommandResult represents the result of a mocked command execution. +type CommandResult struct { + ExitCode int + Output string + Err error +} + +func NewMockCommandExecutor(t *testing.T) *MockCommandExecutor { + return &MockCommandExecutor{ + t: t, + commands: []ExecutedCommand{}, + results: make(map[string]CommandResult), + } +} + +func (m *MockCommandExecutor) Run(ctx context.Context, name string, args ...string) (int, string, error) { + m.commands = append(m.commands, ExecutedCommand{Name: name, Args: args}) + key := name + " " + strings.Join(args, " ") + if result, ok := m.results[key]; ok { + return result.ExitCode, result.Output, result.Err + } + // Default behavior: return success for unknown commands to prevent unintended test failures. + return 0, "Mock success", nil +} + +func (m *MockCommandExecutor) RunAsUser(ctx context.Context, username string, name string, args ...string) (int, string, error) { + m.commands = append(m.commands, ExecutedCommand{Name: name, Args: args, User: username}) + return m.Run(ctx, name, args...) +} + +func (m *MockCommandExecutor) RunWithInput(ctx context.Context, input string, name string, args ...string) (int, string, error) { + return m.Run(ctx, name, args...) +} + +func (m *MockCommandExecutor) RunWithTimeout(ctx context.Context, timeout time.Duration, name string, args ...string) (int, string, error) { + return m.Run(ctx, name, args...) +} + +func (m *MockCommandExecutor) Exec(ctx context.Context, args []string, username, groupname string, env map[string]string, timeout time.Duration) (int, string, error) { + if len(args) == 0 { + return 0, "", nil + } + m.commands = append(m.commands, ExecutedCommand{Name: args[0], Args: args[1:], User: username}) + return m.Run(ctx, args[0], args[1:]...) +} + +func (m *MockCommandExecutor) SetResult(command string, exitCode int, output string, err error) { + m.results[command] = CommandResult{ExitCode: exitCode, Output: output, Err: err} +} + +func (m *MockCommandExecutor) GetExecutedCommands() []ExecutedCommand { + return m.commands +} diff --git a/pkg/executor/handlers/common/types.go b/pkg/executor/handlers/common/types.go new file mode 100644 index 0000000..fff647c --- /dev/null +++ b/pkg/executor/handlers/common/types.go @@ -0,0 +1,103 @@ +package common + +// HandlerType represents the type of a handler +type HandlerType string + +// CommandType represents the type of a command +type CommandType string + +// Handler type constants +const ( + System HandlerType = "system" + Group HandlerType = "group" + Info HandlerType = "info" + Shell HandlerType = "shell" + User HandlerType = "user" + Firewall HandlerType = "firewall" + FileTransfer HandlerType = "file" + Terminal HandlerType = "terminal" + Tunnel HandlerType = "tunnel" +) + +// Command type constants +const ( + // System commands + Upgrade CommandType = "upgrade" + Restart CommandType = "restart" + Quit CommandType = "quit" + Reboot CommandType = "reboot" + Shutdown CommandType = "shutdown" + Update CommandType = "update" + ByeBye CommandType = "byebye" + + // Group commands + AddGroup CommandType = "addgroup" + DelGroup CommandType = "delgroup" + + // Info commands + Ping CommandType = "ping" + Help CommandType = "help" + Commit CommandType = "commit" + Sync CommandType = "sync" + + // Shell commands + ShellCmd CommandType = "shell" + Exec CommandType = "exec" + + // User commands + AddUser CommandType = "adduser" + DelUser CommandType = "deluser" + ModUser CommandType = "moduser" + + // Firewall commands + FirewallCmd CommandType = "firewall" + FirewallRollback CommandType = "firewall-rollback" + FirewallReorderChains CommandType = "firewall-reorder-chains" + FirewallReorderRules CommandType = "firewall-reorder-rules" + + // Firewall sub-operations (used within firewall command) + FirewallOpBatch string = "batch" + FirewallOpFlush string = "flush" + FirewallOpDelete string = "delete" + FirewallOpAdd string = "add" + FirewallOpUpdate string = "update" + + // File commands + Upload CommandType = "upload" + Download CommandType = "download" + + // Terminal commands + OpenPty CommandType = "openpty" + OpenFtp CommandType = "openftp" + ResizePty CommandType = "resizepty" + + // Tunnel commands + OpenTunnel CommandType = "opentunnel" + CloseTunnel CommandType = "closetunnel" +) + +// Shell operators for command parsing +const ( + ShellAndOperator = "&&" // Execute next command only if previous succeeds + ShellOrOperator = "||" // Execute next command only if previous fails + ShellSemicolon = ";" // Execute next command regardless of previous result +) + +// String returns the string representation of HandlerType +func (h HandlerType) String() string { + return string(h) +} + +// String returns the string representation of CommandType +func (c CommandType) String() string { + return string(c) +} + +// CommandsToStrings converts a slice of CommandType to a slice of strings +func CommandsToStrings(commands []CommandType) []string { + result := make([]string, len(commands)) + for i, cmd := range commands { + result[i] = cmd.String() + } + return result +} diff --git a/pkg/executor/handlers/file/file.go b/pkg/executor/handlers/file/file.go new file mode 100644 index 0000000..ad0ab36 --- /dev/null +++ b/pkg/executor/handlers/file/file.go @@ -0,0 +1,454 @@ +package file + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/alpacax/alpamon/pkg/config" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/scheduler" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +// FileHandler handles file transfer commands +type FileHandler struct { + *common.BaseHandler + apiSession common.APISession +} + +// NewFileHandler creates a new file handler +func NewFileHandler(cmdExecutor common.CommandExecutor, apiSession common.APISession) *FileHandler { + h := &FileHandler{ + BaseHandler: common.NewBaseHandler( + common.FileTransfer, + []common.CommandType{ + common.Upload, + common.Download, + }, + cmdExecutor, + ), + apiSession: apiSession, + } + return h +} + +// Execute runs the file transfer command +func (h *FileHandler) Execute(_ context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.Upload.String(): + code, message := h.handleUpload(args) + h.statFileTransfer(code, download, message, args) + return code, message, nil + case common.Download.String(): + return h.handleDownload(args) + default: + return 1, "", fmt.Errorf("unknown file command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *FileHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.Upload.String(): + if args.Username == "" { + return fmt.Errorf("upload: username is required") + } + if len(args.Paths) == 0 { + return fmt.Errorf("upload: at least one path is required") + } + return nil + + case common.Download.String(): + if args.Username == "" { + return fmt.Errorf("download: username is required") + } + // Either Files array or single Path/Content should be provided + if len(args.Files) == 0 && args.Path == "" && args.Content == "" { + return fmt.Errorf("download: either Files array or Path/Content is required") + } + return nil + + default: + return fmt.Errorf("unknown file command: %s", cmd) + } +} + +// handleUpload handles the upload command +func (h *FileHandler) handleUpload(args *common.CommandArgs) (int, string) { + log.Debug(). + Str("username", args.Username). + Str("groupname", args.Groupname). + Int("pathCount", len(args.Paths)). + Msg("Uploading files") + + sysProcAttr, homeDirectory, err := h.demoteWithHomeDir(args.Username, args.Groupname) + if err != nil { + log.Error().Err(err).Msg("Failed to demote user.") + return 1, err.Error() + } + + if len(args.Paths) == 0 { + return 1, "No paths provided" + } + + paths, bulk, recursive, err := h.parsePaths(homeDirectory, args.Paths) + if err != nil { + log.Error().Err(err).Msg("Failed to parse paths") + return 1, err.Error() + } + + name, err := h.makeArchive(paths, bulk, recursive, sysProcAttr) + if err != nil { + log.Error().Err(err).Msg("Failed to create archive") + return 1, err.Error() + } + + // codeql[go/path-injection]: Intentional - Admin-specified file path for download + if bulk || recursive { + defer func() { _ = os.Remove(name) }() // lgtm[go/path-injection] + } + + cmd := exec.Command("cat", name) + cmd.SysProcAttr = sysProcAttr + + output, err := cmd.Output() + if err != nil { + log.Error().Err(err).Msgf("Failed to cat file: %s", output) + return 1, err.Error() + } + + requestBody, contentType, err := h.createMultipartBody(output, filepath.Base(name), args.UseBlob, recursive) + if err != nil { + log.Error().Err(err).Msgf("Failed to make request body") + return 1, err.Error() + } + + _, statusCode, err := h.fileUpload(args.Content, args.UseBlob, requestBody, contentType) + if err != nil { + log.Error().Err(err).Msg("Failed to upload file") + return 1, err.Error() + } + + if statusCode == http.StatusOK { + return 0, fmt.Sprintf("Successfully uploaded %d file(s).", len(paths)) + } + + return 1, "You do not have permission to read on the directory. or directory does not exist" +} + +// handleDownload handles the download command +func (h *FileHandler) handleDownload(args *common.CommandArgs) (int, string, error) { + log.Debug(). + Str("username", args.Username). + Str("groupname", args.Groupname). + Str("path", args.Path). + Msg("Downloading file") + + var code int + var message string + + sysProcAttr, err := h.demote(args.Username, args.Groupname) + if err != nil { + log.Error().Err(err).Msg("Failed to demote user.") + return 1, err.Error(), nil + } + + if len(args.Files) == 0 { + code, message = h.fileDownload(args, sysProcAttr) + h.statFileTransfer(code, upload, message, args) + } else { + for _, file := range args.Files { + cmdArgs := &common.CommandArgs{ + Username: file.Username, + Groupname: file.Groupname, + Type: file.Type, + Content: file.Content, + Path: file.Path, + AllowOverwrite: file.AllowOverwrite, + AllowUnzip: file.AllowUnzip, + URL: file.URL, + } + code, message = h.fileDownload(cmdArgs, sysProcAttr) + h.statFileTransfer(code, upload, message, cmdArgs) + } + } + + if code != 0 { + return code, message, nil + } + + return 0, "Successfully downloaded files.", nil +} + +// fileDownload handles single file download +func (h *FileHandler) fileDownload(args *common.CommandArgs, sysProcAttr *syscall.SysProcAttr) (int, string) { + var cmd *exec.Cmd + content, err := h.getFileData(args) + if err != nil { + return 1, err.Error() + } + + if !args.AllowOverwrite && utils.FileExists(args.Path) { + return 1, fmt.Sprintf("%s already exists.", args.Path) + } + + isZip := utils.IsZipFile(content, filepath.Ext(args.Path)) + if isZip && args.AllowUnzip { + escapePath := utils.Quote(args.Path) + escapeDirPath := utils.Quote(filepath.Dir(args.Path)) + command := fmt.Sprintf("tee %s > /dev/null && unzip -n %s -d %s; rm %s", + escapePath, + escapePath, + escapeDirPath, + escapePath) + cmd = exec.Command("sh", "-c", command) + } else { + cmd = exec.Command("sh", "-c", fmt.Sprintf("tee %s > /dev/null", utils.Quote(args.Path))) + } + + cmd.SysProcAttr = sysProcAttr + cmd.Stdin = bytes.NewReader(content) + + output, err := cmd.Output() + if err != nil { + log.Error().Err(err).Msgf("Failed to write file: %s", output) + return 1, "You do not have permission to read on the directory. or directory does not exist" + } + + return 0, fmt.Sprintf("Successfully downloaded %s.", args.Path) +} + +// demote demotes privilege to the specified user/group +func (h *FileHandler) demote(username, groupname string) (*syscall.SysProcAttr, error) { + result, err := utils.Demote(username, groupname, utils.DemoteOptions{ValidateGroup: true}) + if err != nil { + return nil, err + } + if result == nil { + return nil, nil + } + return result.SysProcAttr, nil +} + +// demoteWithHomeDir demotes privilege and returns home directory +func (h *FileHandler) demoteWithHomeDir(username, groupname string) (*syscall.SysProcAttr, string, error) { + result, err := utils.Demote(username, groupname, utils.DemoteOptions{ValidateGroup: false}) + if err != nil { + return nil, "", err + } + if result == nil { + return nil, "", nil + } + return result.SysProcAttr, result.User.HomeDir, nil +} + +// parsePaths parses and validates the path list +func (h *FileHandler) parsePaths(homeDirectory string, pathList []string) ([]string, bool, bool, error) { + paths := make([]string, len(pathList)) + for i, path := range pathList { + if strings.HasPrefix(path, "~") { + path = strings.Replace(path, "~", homeDirectory, 1) + } + + if !filepath.IsAbs(path) { + path = filepath.Join(homeDirectory, path) + } + + absPath, err := filepath.Abs(path) + if err != nil { + return nil, false, false, err + } + paths[i] = absPath + } + + isBulk := len(pathList) > 1 + isRecursive := false + + // codeql[go/path-injection]: Intentional - Admin-specified file path for upload + if !isBulk { + fileInfo, err := os.Stat(paths[0]) // lgtm[go/path-injection] + if err != nil { + return nil, false, false, err + } + isRecursive = fileInfo.IsDir() + } + + return paths, isBulk, isRecursive, nil +} + +// makeArchive creates a zip archive from the specified paths +func (h *FileHandler) makeArchive(paths []string, bulk, recursive bool, sysProcAttr *syscall.SysProcAttr) (string, error) { + var archiveName string + var cmd *exec.Cmd + path := paths[0] + + if bulk { + archiveName = filepath.Dir(path) + "/" + uuid.New().String() + ".zip" + dirPath := filepath.Dir(path) + basePaths := make([]string, len(paths)) + for i, p := range paths { + basePaths[i] = filepath.Base(p) + } + + cmd = exec.Command("zip", "-r", archiveName) + cmd.SysProcAttr = sysProcAttr + cmd.Args = append(cmd.Args, basePaths...) + cmd.Dir = dirPath + } else { + if recursive { + archiveName = path + ".zip" + cmd = exec.Command("zip", "-r", archiveName, filepath.Base(path)) + cmd.SysProcAttr = sysProcAttr + cmd.Dir = filepath.Dir(path) + } else { + archiveName = path + } + } + + if bulk || recursive { + err := cmd.Run() + if err != nil { + return "", err + } + } + + return archiveName, nil +} + +// createMultipartBody creates a multipart form body for file upload +func (h *FileHandler) createMultipartBody(output []byte, filePath string, useBlob, isRecursive bool) (bytes.Buffer, string, error) { + if useBlob { + return *bytes.NewBuffer(output), "", nil + } + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + fileWriter, err := writer.CreateFormFile("content", filePath) + if err != nil { + return bytes.Buffer{}, "", err + } + + _, err = fileWriter.Write(output) + if err != nil { + return bytes.Buffer{}, "", err + } + + if isRecursive { + err = writer.WriteField("name", filePath) + if err != nil { + return bytes.Buffer{}, "", err + } + } + + _ = writer.Close() + + return requestBody, writer.FormDataContentType(), nil +} + +// fileUpload uploads the file to the server +func (h *FileHandler) fileUpload(uploadURL string, useBlob bool, body bytes.Buffer, contentType string) ([]byte, int, error) { + if useBlob { + return utils.Put(uploadURL, body, 0) + } + + if h.apiSession == nil { + return nil, 0, errors.New("API session not available") + } + + return h.apiSession.MultipartRequest(uploadURL, body, contentType, time.Duration(fileUploadTimeout)*time.Second) +} + +// getFileData fetches file content from URL, text, or base64 +func (h *FileHandler) getFileData(args *common.CommandArgs) ([]byte, error) { + switch args.Type { + case "url": + return h.fetchFromURL(args.Content) + case "text": + return []byte(args.Content), nil + case "base64": + content, err := base64.StdEncoding.DecodeString(args.Content) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 content: %w", err) + } + return content, nil + default: + return nil, fmt.Errorf("unknown file type: %s", args.Type) + } +} + +// fetchFromURL downloads content from a URL +func (h *FileHandler) fetchFromURL(contentURL string) ([]byte, error) { + parsedRequestURL, err := url.Parse(contentURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL '%s': %w", contentURL, err) + } + + req, err := http.NewRequest(http.MethodGet, parsedRequestURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + parsedServerURL, err := url.Parse(config.GlobalSettings.ServerURL) + if err != nil { + return nil, fmt.Errorf("failed to parse url: %w", err) + } + + if parsedRequestURL.Host == parsedServerURL.Host && parsedRequestURL.Scheme == parsedServerURL.Scheme { + req.Header.Set("Authorization", fmt.Sprintf(`id="%s", key="%s"`, + config.GlobalSettings.ID, config.GlobalSettings.Key)) + } + + // codeql[go/request-forgery]: Intentional - Admin-specified URL for file content + client := utils.NewHTTPClient() + resp, err := client.Do(req) // lgtm[go/request-forgery] + if err != nil { + return nil, fmt.Errorf("failed to download content from URL: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + log.Error().Msgf("Failed to download content from URL: %d %s", resp.StatusCode, parsedRequestURL) + return nil, errors.New("downloading content failed") + } + + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return content, nil +} + +// statFileTransfer reports the file transfer status +func (h *FileHandler) statFileTransfer(code int, transferType transferType, message string, args *common.CommandArgs) { + if scheduler.Rqueue == nil { + log.Warn().Msg("Request queue not initialized, skipping stat") + return + } + + statURL := fmt.Sprint(args.URL + "stat/") + isSuccess := code == 0 + + payload := &commandStat{ + Success: isSuccess, + Message: message, + Type: transferType, + } + scheduler.Rqueue.Post(statURL, payload, 10, time.Time{}) +} diff --git a/pkg/executor/handlers/file/file_test.go b/pkg/executor/handlers/file/file_test.go new file mode 100644 index 0000000..11122ad --- /dev/null +++ b/pkg/executor/handlers/file/file_test.go @@ -0,0 +1,282 @@ +package file + +import ( + "context" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" +) + +func TestFileHandler_Validate(t *testing.T) { + handler := NewFileHandler(common.NewMockCommandExecutor(t), nil) + + tests := []struct { + name string + cmd string + args *common.CommandArgs + wantErr bool + }{ + { + name: "upload valid", + cmd: "upload", + args: &common.CommandArgs{ + Username: "testuser", + Paths: []string{"/tmp/file.txt"}, + }, + wantErr: false, + }, + { + name: "upload missing username", + cmd: "upload", + args: &common.CommandArgs{ + Paths: []string{"/tmp/file.txt"}, + }, + wantErr: true, + }, + { + name: "upload missing paths", + cmd: "upload", + args: &common.CommandArgs{ + Username: "testuser", + }, + wantErr: true, + }, + { + name: "download valid with content", + cmd: "download", + args: &common.CommandArgs{ + Username: "testuser", + Path: "/tmp/file.txt", + Content: "test", + }, + wantErr: false, + }, + { + name: "download valid with files", + cmd: "download", + args: &common.CommandArgs{ + Username: "testuser", + Files: []common.File{ + { + Path: "/tmp/file.txt", + Content: "test", + }, + }, + }, + wantErr: false, + }, + { + name: "download missing username", + cmd: "download", + args: &common.CommandArgs{ + Path: "/tmp/file.txt", + Content: "test", + }, + wantErr: true, + }, + { + name: "download missing content and files", + cmd: "download", + args: &common.CommandArgs{ + Username: "testuser", + }, + wantErr: true, + }, + { + name: "unknown command", + cmd: "unknown", + args: &common.CommandArgs{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handler.Validate(tt.cmd, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestFileHandler_Execute_UnknownCommand(t *testing.T) { + handler := NewFileHandler(common.NewMockCommandExecutor(t), nil) + ctx := context.Background() + + exitCode, _, err := handler.Execute(ctx, "unknown", &common.CommandArgs{}) + + if err == nil { + t.Error("Execute() expected error for unknown command") + } + if exitCode != 1 { + t.Errorf("Execute() exitCode = %v, want 1", exitCode) + } +} + +func TestFileHandler_Execute_UploadNoPaths(t *testing.T) { + handler := NewFileHandler(common.NewMockCommandExecutor(t), nil) + ctx := context.Background() + + args := &common.CommandArgs{ + Username: "testuser", + Groupname: "testgroup", + Paths: []string{}, + } + + exitCode, output, err := handler.Execute(ctx, "upload", args) + + if err != nil { + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != 1 { + t.Errorf("Execute() exitCode = %v, want 1", exitCode) + } + if output != "No paths provided" { + t.Errorf("Execute() output = %v, want 'No paths provided'", output) + } +} + +func TestFileHandler_Execute_DownloadUnknownType(t *testing.T) { + handler := NewFileHandler(common.NewMockCommandExecutor(t), nil) + ctx := context.Background() + + args := &common.CommandArgs{ + Username: "testuser", + Groupname: "testgroup", + Path: "/tmp/file.txt", + Content: "test content", + Type: "unknown_type", + } + + exitCode, output, err := handler.Execute(ctx, "download", args) + + if err != nil { + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != 1 { + t.Errorf("Execute() exitCode = %v, want 1", exitCode) + } + if output == "" { + t.Error("Execute() expected error message in output") + } +} + +func TestIsZipFile(t *testing.T) { + tests := []struct { + name string + content []byte + ext string + want bool + }{ + { + name: "jar file extension", + content: []byte("PK\x03\x04"), // zip magic bytes + ext: ".jar", + want: false, // Should be excluded + }, + { + name: "war file extension", + content: []byte("PK\x03\x04"), + ext: ".war", + want: false, // Should be excluded + }, + { + name: "regular zip content", + content: []byte("PK\x03\x04"), + ext: ".zip", + want: false, // Invalid zip (too short) + }, + { + name: "non-zip content", + content: []byte("hello world"), + ext: ".txt", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := utils.IsZipFile(tt.content, tt.ext) + if got != tt.want { + t.Errorf("IsZipFile() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileExists(t *testing.T) { + // Test with non-existent file + if utils.FileExists("/nonexistent/path/file.txt") { + t.Error("FileExists() should return false for non-existent file") + } + + // Test with existing file (current file) + if !utils.FileExists("file_test.go") { + t.Error("FileExists() should return true for existing file") + } +} + +func TestFileHandler_parsePaths(t *testing.T) { + handler := NewFileHandler(common.NewMockCommandExecutor(t), nil) + + tests := []struct { + name string + homeDirectory string + pathList []string + wantBulk bool + wantErr bool + }{ + { + name: "single absolute path", + homeDirectory: "/home/user", + pathList: []string{"/tmp/file.txt"}, + wantBulk: false, + wantErr: true, // File doesn't exist + }, + { + name: "multiple paths", + homeDirectory: "/home/user", + pathList: []string{"/tmp/file1.txt", "/tmp/file2.txt"}, + wantBulk: true, + wantErr: false, // Bulk mode doesn't check file existence in parsePaths + }, + { + name: "tilde path", + homeDirectory: "/home/testuser", + pathList: []string{"~/file.txt"}, + wantBulk: false, + wantErr: true, // File doesn't exist + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, bulk, _, err := handler.parsePaths(tt.homeDirectory, tt.pathList) + + if (err != nil) != tt.wantErr { + t.Errorf("parsePaths() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil && bulk != tt.wantBulk { + t.Errorf("parsePaths() bulk = %v, want %v", bulk, tt.wantBulk) + } + }) + } +} + +func TestNonZipExtensions(t *testing.T) { + // Test that zip-like extensions are excluded from IsZipFile + zipContent := []byte("PK\x03\x04") // zip magic bytes (but invalid/short) + excludedExtensions := []string{ + ".jar", ".war", ".ear", ".apk", ".xpi", + ".vsix", ".crx", ".egg", ".whl", ".appx", + ".msix", ".ipk", ".nupkg", ".kmz", + } + + for _, ext := range excludedExtensions { + if utils.IsZipFile(zipContent, ext) { + t.Errorf("Expected extension %s to be excluded from IsZipFile", ext) + } + } +} diff --git a/pkg/executor/handlers/file/types.go b/pkg/executor/handlers/file/types.go new file mode 100644 index 0000000..c7629e9 --- /dev/null +++ b/pkg/executor/handlers/file/types.go @@ -0,0 +1,45 @@ +package file + +const ( + fileUploadTimeout = 60 * 10 // 600 seconds +) + +type transferType string + +const ( + download transferType = "download" + upload transferType = "upload" +) + +// commandStat represents the file transfer status payload +type commandStat struct { + Success bool `json:"success"` + Message string `json:"message"` + Type transferType `json:"type"` +} + +// FileData contains data for file operations +type FileData struct { + Username string `json:"username"` + Groupname string `json:"groupname"` + Paths []string `json:"paths,omitempty"` // For upload + Files []FileInfo `json:"files,omitempty"` // For batch download + Path string `json:"path,omitempty"` // Single file path + Content []byte `json:"content,omitempty"` // File content for download + Type string `json:"type,omitempty"` // File type + AllowOverwrite bool `json:"allow_overwrite,omitempty"` + AllowUnzip bool `json:"allow_unzip,omitempty"` + URL string `json:"url,omitempty"` +} + +// FileInfo contains information about a file for batch operations +type FileInfo struct { + Username string `json:"username"` + Groupname string `json:"groupname"` + Path string `json:"path"` + Type string `json:"type"` + Content []byte `json:"content"` + AllowOverwrite bool `json:"allow_overwrite"` + AllowUnzip bool `json:"allow_unzip"` + URL string `json:"url"` +} diff --git a/pkg/executor/handlers/firewall/backend.go b/pkg/executor/handlers/firewall/backend.go new file mode 100644 index 0000000..aebfda6 --- /dev/null +++ b/pkg/executor/handlers/firewall/backend.go @@ -0,0 +1,64 @@ +package firewall + +import ( + "context" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// FirewallBackend defines the interface for firewall backend implementations +type FirewallBackend interface { + // Name returns the backend name (e.g., "iptables", "nftables") + Name() string + + // Detect checks if this firewall backend is available and active + Detect(ctx context.Context) bool + + // AddRule adds a single firewall rule + AddRule(ctx context.Context, rule *common.FirewallRule) error + + // DeleteRule removes a single firewall rule by ID + DeleteRule(ctx context.Context, ruleID, chainName string) error + + // FlushChain removes all rules from a chain + FlushChain(ctx context.Context, chainName string) error + + // DeleteChain deletes a chain entirely + DeleteChain(ctx context.Context, chainName string) error + + // ListRules returns all rules in a chain + ListRules(ctx context.Context, chainName string) ([]common.FirewallRule, error) + + // BatchApply applies multiple rules atomically + // Returns: applied count, failed rule descriptions, error + BatchApply(ctx context.Context, chainName string, rules []common.FirewallRule) (applied int, failed []string, err error) + + // ReorderChains reorders jump rules in INPUT chain + ReorderChains(ctx context.Context, chainNames []string) (map[string]interface{}, error) + + // ReorderRules reorders rules within a chain + ReorderRules(ctx context.Context, chainName string, rules []common.FirewallRule) error + + // Backup creates a backup of current firewall state + Backup(ctx context.Context) (string, error) + + // Restore restores firewall state from backup + Restore(ctx context.Context, backup string) error +} + +// BatchResult represents the result of a batch operation +type BatchResult struct { + Success bool `json:"success"` + AppliedRules int `json:"applied_rules"` + FailedRules []string `json:"failed_rules"` + RolledBack bool `json:"rolled_back"` + Message string `json:"message,omitempty"` +} + +// ReorderResult represents the result of a reorder operation +type ReorderResult struct { + Success bool `json:"success"` + ReorderedCount int `json:"reordered_count,omitempty"` + DeletedRules int `json:"deleted_rules,omitempty"` + Message string `json:"message,omitempty"` +} diff --git a/pkg/executor/handlers/firewall/backup.go b/pkg/executor/handlers/firewall/backup.go new file mode 100644 index 0000000..da1b542 --- /dev/null +++ b/pkg/executor/handlers/firewall/backup.go @@ -0,0 +1,130 @@ +package firewall + +import ( + "context" + "fmt" + "sync" + "time" +) + +// BackupManager manages firewall state backups for atomic operations with rollback +type BackupManager struct { + backend FirewallBackend + lastBackup string + backupTime time.Time + mu sync.RWMutex +} + +// NewBackupManager creates a new backup manager +func NewBackupManager(backend FirewallBackend) *BackupManager { + return &BackupManager{ + backend: backend, + } +} + +// SetBackend updates the backup manager's backend +func (bm *BackupManager) SetBackend(backend FirewallBackend) { + bm.mu.Lock() + defer bm.mu.Unlock() + bm.backend = backend +} + +// CreateBackup creates a backup of current firewall state +func (bm *BackupManager) CreateBackup(ctx context.Context) error { + bm.mu.Lock() + defer bm.mu.Unlock() + + if bm.backend == nil { + return fmt.Errorf("no firewall backend configured") + } + + backup, err := bm.backend.Backup(ctx) + if err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + + bm.lastBackup = backup + bm.backupTime = time.Now() + + return nil +} + +// Rollback restores the last backup +func (bm *BackupManager) Rollback(ctx context.Context) error { + bm.mu.RLock() + backup := bm.lastBackup + backend := bm.backend + bm.mu.RUnlock() + + if backend == nil { + return fmt.Errorf("no firewall backend configured") + } + + if backup == "" { + return fmt.Errorf("no backup available for rollback") + } + + if err := backend.Restore(ctx, backup); err != nil { + return fmt.Errorf("failed to restore backup: %w", err) + } + + return nil +} + +// HasBackup checks if a backup exists +func (bm *BackupManager) HasBackup() bool { + bm.mu.RLock() + defer bm.mu.RUnlock() + return bm.lastBackup != "" +} + +// GetLastBackup returns the last backup and its timestamp +func (bm *BackupManager) GetLastBackup() (backup string, backupTime time.Time, exists bool) { + bm.mu.RLock() + defer bm.mu.RUnlock() + + if bm.lastBackup == "" { + return "", time.Time{}, false + } + + return bm.lastBackup, bm.backupTime, true +} + +// ClearBackup clears the stored backup +func (bm *BackupManager) ClearBackup() { + bm.mu.Lock() + defer bm.mu.Unlock() + bm.lastBackup = "" + bm.backupTime = time.Time{} +} + +// BackupAge returns the age of the current backup +func (bm *BackupManager) BackupAge() time.Duration { + bm.mu.RLock() + defer bm.mu.RUnlock() + + if bm.lastBackup == "" { + return 0 + } + + return time.Since(bm.backupTime) +} + +// WithBackup executes an operation with automatic backup and rollback on failure +func (bm *BackupManager) WithBackup(ctx context.Context, operation func() error) error { + // Create backup before operation + if err := bm.CreateBackup(ctx); err != nil { + return fmt.Errorf("pre-operation backup failed: %w", err) + } + + // Execute the operation + if err := operation(); err != nil { + // Attempt rollback on failure + if rollbackErr := bm.Rollback(ctx); rollbackErr != nil { + return fmt.Errorf("operation failed: %w; rollback also failed: %v", err, rollbackErr) + } + return fmt.Errorf("operation failed (rolled back): %w", err) + } + + return nil +} diff --git a/pkg/executor/handlers/firewall/detection.go b/pkg/executor/handlers/firewall/detection.go new file mode 100644 index 0000000..4cb62bd --- /dev/null +++ b/pkg/executor/handlers/firewall/detection.go @@ -0,0 +1,230 @@ +package firewall + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/rs/zerolog/log" +) + +// BackendType represents the detected firewall backend type +type BackendType string + +const ( + BackendNone BackendType = "none" + BackendIptables BackendType = "iptables" + BackendNftables BackendType = "nftables" +) + +// HighLevelFirewall represents high-level firewall management tools +type HighLevelFirewall string + +const ( + HighLevelNone HighLevelFirewall = "" + HighLevelUFW HighLevelFirewall = "ufw" + HighLevelFirewalld HighLevelFirewall = "firewalld" +) + +// DetectionResult holds the cached detection results +type DetectionResult struct { + Backend BackendType + HighLevel HighLevelFirewall + NftablesAvailable bool + IptablesAvailable bool + Disabled bool + DetectedAt time.Time +} + +// FirewallDetector detects available firewall backends and high-level tools +type FirewallDetector struct { + executor common.CommandExecutor + result *DetectionResult + mu sync.RWMutex +} + +// NewFirewallDetector creates a new firewall detector +func NewFirewallDetector(executor common.CommandExecutor) *FirewallDetector { + return &FirewallDetector{ + executor: executor, + } +} + +// Detect performs full firewall detection +// Returns cached result if already detected, otherwise performs detection +func (d *FirewallDetector) Detect(ctx context.Context) *DetectionResult { + d.mu.Lock() + defer d.mu.Unlock() + + // Return cached result if available + if d.result != nil { + return d.result + } + + result := &DetectionResult{ + DetectedAt: time.Now(), + } + + // First check for high-level firewall tools + result.HighLevel = d.detectHighLevelFirewall(ctx) + if result.HighLevel != HighLevelNone { + result.Disabled = true + result.Backend = BackendNone + d.result = result + return result + } + + // Detect backend based on existing rules + result.Backend = d.detectBackend(ctx) + result.NftablesAvailable = result.Backend == BackendNftables + result.IptablesAvailable = result.Backend == BackendIptables + result.Disabled = result.Backend == BackendNone + + d.result = result + return result +} + +// GetResult returns the cached detection result without re-detecting +func (d *FirewallDetector) GetResult() *DetectionResult { + d.mu.RLock() + defer d.mu.RUnlock() + return d.result +} + +// Reset clears the cached detection result, forcing re-detection on next call +func (d *FirewallDetector) Reset() { + d.mu.Lock() + defer d.mu.Unlock() + d.result = nil +} + +// IsDisabled returns true if firewall management is disabled +func (d *FirewallDetector) IsDisabled() bool { + d.mu.RLock() + defer d.mu.RUnlock() + if d.result == nil { + return false + } + return d.result.Disabled +} + +// GetBackendType returns the detected backend type +func (d *FirewallDetector) GetBackendType() BackendType { + d.mu.RLock() + defer d.mu.RUnlock() + if d.result == nil { + return BackendNone + } + return d.result.Backend +} + +// detectHighLevelFirewall detects if high-level firewall management tools are active +func (d *FirewallDetector) detectHighLevelFirewall(ctx context.Context) HighLevelFirewall { + // Check ufw via systemctl (most reliable) + exitCode, output, _ := d.executor.RunWithTimeout(ctx, 5*time.Second, "systemctl", "is-active", "ufw") + if exitCode == 0 && strings.TrimSpace(output) == "active" { + log.Info().Msg("Detected active ufw firewall - alpacon firewall management will be disabled") + return HighLevelUFW + } + + // Fallback: Check ufw via direct command + exitCode, output, _ = d.executor.RunWithTimeout(ctx, 5*time.Second, "ufw", "status") + if exitCode == 0 && strings.Contains(strings.ToLower(output), "status: active") { + log.Info().Msg("Detected active ufw firewall - alpacon firewall management will be disabled") + return HighLevelUFW + } + + // Check firewalld via systemctl + exitCode, output, _ = d.executor.RunWithTimeout(ctx, 5*time.Second, "systemctl", "is-active", "firewalld") + if exitCode == 0 && strings.TrimSpace(output) == "active" { + log.Info().Msg("Detected active firewalld - alpacon firewall management will be disabled") + return HighLevelFirewalld + } + + // Fallback: Check firewalld via firewall-cmd + exitCode, output, _ = d.executor.RunWithTimeout(ctx, 5*time.Second, "firewall-cmd", "--state") + if exitCode == 0 && strings.Contains(strings.ToLower(output), "running") { + log.Info().Msg("Detected active firewalld - alpacon firewall management will be disabled") + return HighLevelFirewalld + } + + log.Debug().Msg("No high-level firewall detected - alpacon firewall management enabled") + return HighLevelNone +} + +// detectBackend detects which firewall backend to use based on existing rules +func (d *FirewallDetector) detectBackend(ctx context.Context) BackendType { + // Try iptables-save to check for existing iptables rules + exitCode, output, _ := d.executor.RunWithTimeout(ctx, 10*time.Second, "iptables-save") + + if exitCode == 0 { + // Count actual rules (lines starting with -A or -I) + ruleCount := 0 + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "-A ") || strings.HasPrefix(line, "-I ") { + ruleCount++ + } + } + + if ruleCount > 0 { + log.Debug().Msgf("Found %d iptables rules", ruleCount) + return BackendIptables + } + + // iptables-save succeeded but no rules - check if nft is available + exitCode, _, _ := d.executor.RunWithTimeout(ctx, 5*time.Second, "which", "nft") + if exitCode == 0 { + log.Debug().Msg("No iptables rules, nft available") + return BackendNftables + } + + // Only iptables available, no nft + log.Debug().Msg("No iptables rules, nft not available, defaulting to iptables") + return BackendIptables + } + + // iptables-save failed, try fallback with iptables -S + exitCode, output, _ = d.executor.RunWithTimeout(ctx, 10*time.Second, "iptables", "-S") + if exitCode == 0 { + // Check for rules (iptables -S output starts with -P, -A, -I, etc) + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "-A ") || strings.HasPrefix(line, "-I ") { + log.Debug().Msg("Found iptables rules via iptables -S") + return BackendIptables + } + } + } + + // No iptables rules found, check if nft is available + exitCode, _, _ = d.executor.RunWithTimeout(ctx, 5*time.Second, "which", "nft") + if exitCode == 0 { + log.Debug().Msg("No iptables rules, using nftables") + return BackendNftables + } + + // Neither iptables nor nft available + log.Warn().Msg("No firewall backend available") + return BackendNone +} + +// CreateBackend creates the appropriate backend based on detection result +func (d *FirewallDetector) CreateBackend(ctx context.Context) FirewallBackend { + result := d.Detect(ctx) + + if result.Disabled { + return nil + } + + switch result.Backend { + case BackendIptables: + return NewIptablesBackend(d.executor) + case BackendNftables: + return NewNftablesBackend(d.executor) + default: + return nil + } +} diff --git a/pkg/executor/handlers/firewall/firewall.go b/pkg/executor/handlers/firewall/firewall.go new file mode 100644 index 0000000..9e77721 --- /dev/null +++ b/pkg/executor/handlers/firewall/firewall.go @@ -0,0 +1,561 @@ +package firewall + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/rs/zerolog/log" +) + +// FirewallHandler handles firewall management commands +type FirewallHandler struct { + *common.BaseHandler + detector *FirewallDetector + backend FirewallBackend + backup *BackupManager + validator *Validator + mu sync.RWMutex +} + +// NewFirewallHandler creates a new firewall handler +func NewFirewallHandler(cmdExecutor common.CommandExecutor) *FirewallHandler { + detector := NewFirewallDetector(cmdExecutor) + validator := NewValidator() + + h := &FirewallHandler{ + BaseHandler: common.NewBaseHandler( + common.Firewall, + []common.CommandType{ + common.FirewallCmd, + common.FirewallRollback, + common.FirewallReorderChains, + common.FirewallReorderRules, + }, + cmdExecutor, + ), + detector: detector, + validator: validator, + } + return h +} + +// initBackend initializes the firewall backend based on detection +func (h *FirewallHandler) initBackend(ctx context.Context) error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.backend != nil { + return nil + } + + h.backend = h.detector.CreateBackend(ctx) + if h.backend == nil { + return fmt.Errorf("no firewall backend available") + } + + h.backup = NewBackupManager(h.backend) + log.Info().Str("backend", h.backend.Name()).Msg("Firewall backend initialized") + return nil +} + +// getBackend returns the current backend, initializing if needed +func (h *FirewallHandler) getBackend(ctx context.Context) (FirewallBackend, error) { + h.mu.RLock() + if h.backend != nil { + defer h.mu.RUnlock() + return h.backend, nil + } + h.mu.RUnlock() + + if err := h.initBackend(ctx); err != nil { + return nil, err + } + + h.mu.RLock() + defer h.mu.RUnlock() + return h.backend, nil +} + +// Execute runs the firewall management command +func (h *FirewallHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + // Check for high-level firewall tools + result := h.detector.Detect(ctx) + if result.HighLevel != HighLevelNone { + return 1, fmt.Sprintf("Alpacon firewall management is disabled because %s is active. Please use %s to manage firewall rules.", + result.HighLevel, result.HighLevel), nil + } + + if result.Disabled { + return 1, "Firewall functionality is not available - no backend detected", nil + } + + // Initialize backend if not already done + if _, err := h.getBackend(ctx); err != nil { + return 1, fmt.Sprintf("Failed to initialize firewall: %v", err), err + } + + switch cmd { + case common.FirewallCmd.String(): + return h.handleFirewall(ctx, args) + case common.FirewallRollback.String(): + return h.handleFirewallRollback(ctx) + case common.FirewallReorderChains.String(): + return h.handleFirewallReorderChains(ctx, args) + case common.FirewallReorderRules.String(): + return h.handleFirewallReorderRules(ctx, args) + default: + return 1, "", fmt.Errorf("unknown firewall command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *FirewallHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.FirewallCmd.String(): + operation := args.Operation + if operation == "" { + return fmt.Errorf("firewall: operation is required") + } + switch operation { + case common.FirewallOpBatch: + return h.validator.ValidateBatchRules(args.Rules) + case common.FirewallOpFlush: + if args.ChainName == "" { + return fmt.Errorf("firewall flush: chain name is required") + } + return h.validator.ValidateChainName(args.ChainName) + case common.FirewallOpDelete: + if args.RuleID == "" { + return fmt.Errorf("firewall delete: rule ID is required") + } + return nil + case common.FirewallOpAdd: + if len(args.Rules) == 0 { + return fmt.Errorf("firewall add: at least one rule is required") + } + return h.validator.ValidateBatchRules(args.Rules) + case common.FirewallOpUpdate: + if args.RuleID == "" { + return fmt.Errorf("firewall update: rule ID is required") + } + return nil + default: + return fmt.Errorf("firewall: unknown operation '%s'", operation) + } + + case common.FirewallRollback.String(): + return nil + + case common.FirewallReorderChains.String(): + if len(args.ChainNames) == 0 { + return fmt.Errorf("firewall-reorder-chains: chain names are required") + } + for _, name := range args.ChainNames { + if err := h.validator.ValidateChainName(name); err != nil { + return err + } + } + return nil + + case common.FirewallReorderRules.String(): + if args.ChainName == "" { + return fmt.Errorf("firewall-reorder-rules: chain name is required") + } + return h.validator.ValidateChainName(args.ChainName) + + default: + return fmt.Errorf("unknown firewall command: %s", cmd) + } +} + +// handleFirewall handles the main firewall command +func (h *FirewallHandler) handleFirewall(ctx context.Context, args *common.CommandArgs) (int, string, error) { + operation := args.Operation + + log.Info(). + Str("operation", operation). + Msg("Executing firewall operation") + + switch operation { + case common.FirewallOpBatch: + return h.handleBatchOperation(ctx, args) + case common.FirewallOpFlush: + return h.handleFlushOperation(ctx, args) + case common.FirewallOpDelete: + return h.handleDeleteOperation(ctx, args) + case common.FirewallOpAdd: + return h.handleAddOperation(ctx, args) + case common.FirewallOpUpdate: + return h.handleUpdateOperation(ctx, args) + default: + return 1, fmt.Sprintf("firewall: Unknown operation '%s'", operation), nil + } +} + +// handleBatchOperation handles batch firewall operations +func (h *FirewallHandler) handleBatchOperation(ctx context.Context, args *common.CommandArgs) (int, string, error) { + chainName := args.ChainName + rules := args.Rules + + log.Info(). + Str("chain", chainName). + Int("ruleCount", len(rules)). + Msg("Firewall batch operation") + + if len(rules) == 0 { + result := BatchResult{ + Success: true, + AppliedRules: 0, + FailedRules: []string{}, + RolledBack: false, + Message: "No rules to apply", + } + resultJSON, _ := json.Marshal(result) + return 0, string(resultJSON), nil + } + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before batch operation + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before batch operation") + } + + // Apply rules + applied, failed, applyErr := backend.BatchApply(ctx, chainName, rules) + + result := BatchResult{ + Success: applyErr == nil, + AppliedRules: applied, + FailedRules: failed, + RolledBack: false, + } + + // If there were failures and we have a backup, offer rollback info + if applyErr != nil && h.backup.HasBackup() { + result.Message = fmt.Sprintf("Batch operation completed with errors: %v. Rollback available.", applyErr) + } else if applyErr != nil { + result.Message = fmt.Sprintf("Batch operation completed with errors: %v", applyErr) + } else { + result.Message = fmt.Sprintf("Successfully applied %d rules", applied) + } + + resultJSON, _ := json.Marshal(result) + return 0, string(resultJSON), nil +} + +// handleFlushOperation handles flush firewall operations +func (h *FirewallHandler) handleFlushOperation(ctx context.Context, args *common.CommandArgs) (int, string, error) { + chainName := args.ChainName + + log.Info(). + Str("chain", chainName). + Msg("Firewall flush operation") + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before flush + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before flush operation") + } + + if err := backend.FlushChain(ctx, chainName); err != nil { + return 1, fmt.Sprintf("Failed to flush chain '%s': %v", chainName, err), err + } + + return 0, fmt.Sprintf("Successfully flushed chain '%s'", chainName), nil +} + +// handleDeleteOperation handles delete firewall operations +func (h *FirewallHandler) handleDeleteOperation(ctx context.Context, args *common.CommandArgs) (int, string, error) { + ruleID := args.RuleID + chainName := args.ChainName + + log.Info(). + Str("ruleID", ruleID). + Str("chain", chainName). + Msg("Firewall delete operation") + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before delete + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before delete operation") + } + + if err := backend.DeleteRule(ctx, ruleID, chainName); err != nil { + return 1, fmt.Sprintf("Failed to delete rule '%s': %v", ruleID, err), err + } + + return 0, fmt.Sprintf("Successfully deleted rule '%s'", ruleID), nil +} + +// handleAddOperation handles add firewall operations +func (h *FirewallHandler) handleAddOperation(ctx context.Context, args *common.CommandArgs) (int, string, error) { + chainName := args.ChainName + rules := args.Rules + + log.Info(). + Str("chain", chainName). + Int("ruleCount", len(rules)). + Msg("Firewall add operation") + + if len(rules) == 0 { + return 0, "No rules to add", nil + } + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before add + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before add operation") + } + + // Add each rule + var addedCount int + var lastErr error + for i, rule := range rules { + ruleCopy := rule + if ruleCopy.Chain == "" { + ruleCopy.Chain = chainName + } + if err := backend.AddRule(ctx, &ruleCopy); err != nil { + log.Error().Err(err).Int("ruleIndex", i).Msg("Failed to add rule") + lastErr = err + continue + } + addedCount++ + } + + if lastErr != nil { + return 1, fmt.Sprintf("Added %d of %d rules, last error: %v", addedCount, len(rules), lastErr), lastErr + } + + return 0, fmt.Sprintf("Successfully added %d rules", addedCount), nil +} + +// handleUpdateOperation handles update firewall operations +func (h *FirewallHandler) handleUpdateOperation(ctx context.Context, args *common.CommandArgs) (int, string, error) { + ruleID := args.RuleID + oldRuleID := args.OldRuleID + chainName := args.ChainName + + log.Info(). + Str("ruleID", ruleID). + Str("oldRuleID", oldRuleID). + Str("chain", chainName). + Msg("Firewall update operation") + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before update + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before update operation") + } + + // Delete old rule if specified + deleteID := oldRuleID + if deleteID == "" { + deleteID = ruleID + } + + if deleteID != "" { + if err := backend.DeleteRule(ctx, deleteID, chainName); err != nil { + log.Warn().Err(err).Str("ruleID", deleteID).Msg("Failed to delete old rule during update") + } + } + + // Add new rule if provided + if len(args.Rules) > 0 { + rule := args.Rules[0] + if rule.Chain == "" { + rule.Chain = chainName + } + if err := backend.AddRule(ctx, &rule); err != nil { + return 1, fmt.Sprintf("Failed to add updated rule: %v", err), err + } + } + + return 0, fmt.Sprintf("Successfully updated rule '%s'", ruleID), nil +} + +// handleFirewallRollback handles firewall rollback command +func (h *FirewallHandler) handleFirewallRollback(ctx context.Context) (int, string, error) { + log.Info().Msg("Executing firewall rollback") + + h.mu.RLock() + backup := h.backup + h.mu.RUnlock() + + if backup == nil || !backup.HasBackup() { + return 1, "No backup available for rollback", fmt.Errorf("no backup available") + } + + if err := backup.Rollback(ctx); err != nil { + return 1, fmt.Sprintf("Failed to rollback firewall: %v", err), err + } + + return 0, "Firewall rules rolled back successfully", nil +} + +// handleFirewallReorderChains handles firewall chain reordering +func (h *FirewallHandler) handleFirewallReorderChains(ctx context.Context, args *common.CommandArgs) (int, string, error) { + chainNames := args.ChainNames + + log.Info(). + Strs("chains", chainNames). + Msg("Reordering firewall chains") + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before reorder + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before chain reorder") + } + + result, err := backend.ReorderChains(ctx, chainNames) + if err != nil { + // Attempt rollback on failure + if h.backup.HasBackup() { + if rollbackErr := h.backup.Rollback(ctx); rollbackErr != nil { + log.Error().Err(rollbackErr).Msg("Failed to rollback after reorder failure") + } + } + return 1, fmt.Sprintf("Failed to reorder chains: %v", err), err + } + + reorderResult := ReorderResult{ + Success: true, + ReorderedCount: len(chainNames), + Message: fmt.Sprintf("Successfully reordered %d chains", len(chainNames)), + } + + if deletedRules, ok := result["deleted_rules"].(int); ok { + reorderResult.DeletedRules = deletedRules + } + + resultJSON, _ := json.Marshal(reorderResult) + return 0, string(resultJSON), nil +} + +// handleFirewallReorderRules handles firewall rule reordering within a chain +func (h *FirewallHandler) handleFirewallReorderRules(ctx context.Context, args *common.CommandArgs) (int, string, error) { + chainName := args.ChainName + rules := args.Rules + + log.Info(). + Str("chain", chainName). + Int("ruleCount", len(rules)). + Msg("Reordering firewall rules") + + backend, err := h.getBackend(ctx) + if err != nil { + return 1, fmt.Sprintf("Failed to get firewall backend: %v", err), err + } + + // Create backup before reorder + if err := h.backup.CreateBackup(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to create backup before rule reorder") + } + + if err := backend.ReorderRules(ctx, chainName, rules); err != nil { + // Attempt rollback on failure + if h.backup.HasBackup() { + if rollbackErr := h.backup.Rollback(ctx); rollbackErr != nil { + log.Error().Err(rollbackErr).Msg("Failed to rollback after reorder failure") + } + } + return 1, fmt.Sprintf("Failed to reorder rules in chain '%s': %v", chainName, err), err + } + + reorderResult := ReorderResult{ + Success: true, + ReorderedCount: len(rules), + Message: fmt.Sprintf("Successfully reordered %d rules in chain '%s'", len(rules), chainName), + } + + resultJSON, _ := json.Marshal(reorderResult) + return 0, string(resultJSON), nil +} + +// GetDetector returns the firewall detector for external use +func (h *FirewallHandler) GetDetector() *FirewallDetector { + return h.detector +} + +// GetBackend returns the current backend for external use +func (h *FirewallHandler) GetBackend() FirewallBackend { + h.mu.RLock() + defer h.mu.RUnlock() + return h.backend +} + +// GetBackupManager returns the backup manager for external use +func (h *FirewallHandler) GetBackupManager() *BackupManager { + h.mu.RLock() + defer h.mu.RUnlock() + return h.backup +} + +// CollectAllRules collects all firewall rules from the system +// Returns rules in the format compatible with utils.FirewallSyncPayload +func (h *FirewallHandler) CollectAllRules(ctx context.Context) (map[string][]common.FirewallRule, error) { + backend, err := h.getBackend(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get firewall backend: %w", err) + } + + // Get all chains by listing rules without filter + rules, err := backend.ListRules(ctx, "") + if err != nil { + return nil, fmt.Errorf("failed to list firewall rules: %w", err) + } + + // Group rules by chain + chains := make(map[string][]common.FirewallRule) + for _, rule := range rules { + chainName := rule.Chain + if chainName == "" { + chainName = "INPUT" + } + chains[chainName] = append(chains[chainName], rule) + } + + return chains, nil +} + +// IsFirewallAvailable checks if firewall functionality is available +func (h *FirewallHandler) IsFirewallAvailable(ctx context.Context) bool { + result := h.detector.Detect(ctx) + return !result.Disabled && result.HighLevel == HighLevelNone +} + +// GetHighLevelFirewall returns the detected high-level firewall tool name if any +func (h *FirewallHandler) GetHighLevelFirewall(ctx context.Context) (bool, string) { + result := h.detector.Detect(ctx) + if result.HighLevel != HighLevelNone { + return true, string(result.HighLevel) + } + return false, "" +} diff --git a/pkg/executor/handlers/firewall/firewall_test.go b/pkg/executor/handlers/firewall/firewall_test.go new file mode 100644 index 0000000..664ea91 --- /dev/null +++ b/pkg/executor/handlers/firewall/firewall_test.go @@ -0,0 +1,285 @@ +package firewall + +import ( + "context" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" +) + +func TestFirewallHandler_Execute(t *testing.T) { + // Temporarily enable firewall functionality for this test + utils.FirewallFunctionalityDisabled = false + t.Cleanup(func() { + utils.FirewallFunctionalityDisabled = true + }) + + tests := []struct { + name string + cmd string + args *common.CommandArgs + wantCode int + wantErr bool + }{ + { + name: "firewall batch operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "batch", + ChainName: "INPUT", + Rules: []common.FirewallRule{ + { + Protocol: "tcp", + }, + }, + }, + wantCode: 0, + wantErr: false, + }, + { + name: "firewall flush operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "flush", + ChainName: "FORWARD", + }, + wantCode: 0, + wantErr: false, + }, + { + name: "firewall delete operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "delete", + RuleID: "rule123", + }, + wantCode: 1, // Delete fails because rule doesn't exist in mock + wantErr: true, + }, + { + name: "firewall add operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "add", + ChainName: "OUTPUT", + Protocol: "tcp", + Target: "ACCEPT", + }, + wantCode: 0, + wantErr: false, + }, + { + name: "firewall update operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "update", + RuleID: "rule123", + OldRuleID: "rule122", + }, + wantCode: 0, + wantErr: false, + }, + { + name: "firewall unknown operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "unknown", + }, + wantCode: 1, + wantErr: false, + }, + { + name: "firewall-rollback", + cmd: "firewall-rollback", + args: &common.CommandArgs{}, + wantCode: 1, // Changed from 0 to 1 + wantErr: true, // Changed from false to true + }, + { + name: "firewall-reorder-chains", + cmd: "firewall-reorder-chains", + args: &common.CommandArgs{ + ChainNames: []string{"INPUT", "OUTPUT", "FORWARD"}, + }, + wantCode: 0, + wantErr: false, + }, + { + name: "firewall-reorder-rules", + cmd: "firewall-reorder-rules", + args: &common.CommandArgs{ + ChainName: "INPUT", + Rules: []common.FirewallRule{ + {RuleID: "rule1"}, + {RuleID: "rule2"}, + }, + }, + wantCode: 0, + wantErr: false, + }, + { + name: "unknown firewall command", + cmd: "firewall-unknown", + args: &common.CommandArgs{}, + wantCode: 1, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := common.NewMockCommandExecutor(t) + handler := NewFirewallHandler(mock) + ctx := context.Background() + + exitCode, output, err := handler.Execute(ctx, tt.cmd, tt.args) + + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + } + if exitCode != tt.wantCode { + t.Errorf("Execute() exitCode = %v, want %v", exitCode, tt.wantCode) + } + // Since these are placeholders, we expect some output for successful operations + if exitCode == 0 && output == "" && !tt.wantErr { + t.Error("Execute() returned success but no output") + } + }) + } +} + +func TestFirewallHandler_Validate(t *testing.T) { + handler := NewFirewallHandler(common.NewMockCommandExecutor(t)) + + tests := []struct { + name string + cmd string + args *common.CommandArgs + wantErr bool + }{ + { + name: "firewall valid batch operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "batch", + }, + wantErr: false, + }, + { + name: "firewall missing operation", + cmd: "firewall", + args: &common.CommandArgs{}, + wantErr: true, + }, + { + name: "firewall invalid operation", + cmd: "firewall", + args: &common.CommandArgs{ + Operation: "invalid", + }, + wantErr: true, + }, + { + name: "firewall-rollback valid", + cmd: "firewall-rollback", + args: &common.CommandArgs{}, + wantErr: false, + }, + { + name: "firewall-reorder-chains valid", + cmd: "firewall-reorder-chains", + args: &common.CommandArgs{ + ChainNames: []string{"INPUT", "OUTPUT"}, + }, + wantErr: false, + }, + { + name: "firewall-reorder-chains missing chains", + cmd: "firewall-reorder-chains", + args: &common.CommandArgs{}, + wantErr: true, + }, + { + name: "firewall-reorder-rules valid", + cmd: "firewall-reorder-rules", + args: &common.CommandArgs{ + ChainName: "INPUT", + }, + wantErr: false, + }, + { + name: "firewall-reorder-rules missing chain", + cmd: "firewall-reorder-rules", + args: &common.CommandArgs{}, + wantErr: true, + }, + { + name: "unknown command", + cmd: "unknown", + args: &common.CommandArgs{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handler.Validate(tt.cmd, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestFirewallHandler_BatchOperation(t *testing.T) { + handler := NewFirewallHandler(common.NewMockCommandExecutor(t)) + ctx := context.Background() + + // Test with empty rules + args := &common.CommandArgs{ + Operation: "batch", + ChainName: "INPUT", + Rules: []common.FirewallRule{}, + } + + exitCode, output, err := handler.Execute(ctx, "firewall", args) + + if err != nil { + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("Execute() exitCode = %v, want 0", exitCode) + } + if output == "" { + t.Error("Execute() returned no output") + } + + // Test with multiple rules + args = &common.CommandArgs{ + Operation: "batch", + ChainName: "INPUT", + Rules: []common.FirewallRule{ + { + Protocol: "tcp", + Target: "ACCEPT", + }, + { + Protocol: "tcp", + Target: "ACCEPT", + }, + }, + } + + exitCode, output, err = handler.Execute(ctx, "firewall", args) + + if err != nil { + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("Execute() exitCode = %v, want 0", exitCode) + } + if output == "" { + t.Error("Execute() returned no output") + } +} diff --git a/pkg/executor/handlers/firewall/iptables.go b/pkg/executor/handlers/firewall/iptables.go new file mode 100644 index 0000000..862e677 --- /dev/null +++ b/pkg/executor/handlers/firewall/iptables.go @@ -0,0 +1,504 @@ +package firewall + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// IptablesBackend implements FirewallBackend interface for iptables +type IptablesBackend struct { + executor common.CommandExecutor +} + +// NewIptablesBackend creates a new iptables backend +func NewIptablesBackend(executor common.CommandExecutor) *IptablesBackend { + return &IptablesBackend{ + executor: executor, + } +} + +// Name returns the backend name +func (s *IptablesBackend) Name() string { + return "iptables" +} + +// Detect checks if iptables is available +func (s *IptablesBackend) Detect(ctx context.Context) bool { + exitCode, _, _ := s.executor.RunWithTimeout(ctx, 5*time.Second, "which", "iptables") + return exitCode == 0 +} + +// AddRule adds a single firewall rule +func (s *IptablesBackend) AddRule(ctx context.Context, rule *common.FirewallRule) error { + args := s.buildAddRuleArgs(rule) + exitCode, output, err := s.executor.RunAsUser(ctx, "root", args[0], args[1:]...) + if err != nil { + return fmt.Errorf("failed to execute iptables: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("iptables add rule failed: %s", output) + } + return nil +} + +// buildAddRuleArgs builds iptables command arguments for adding a rule +func (s *IptablesBackend) buildAddRuleArgs(rule *common.FirewallRule) []string { + chainName := rule.Chain + if chainName == "" { + chainName = "INPUT" + } + + args := []string{"iptables", "-A", chainName} + + // Add protocol + if rule.Protocol != "" && rule.Protocol != "all" { + args = append(args, "-p", rule.Protocol) + } + + // Add source CIDR + if rule.Source != "" && rule.Source != "0.0.0.0/0" { + args = append(args, "-s", rule.Source) + } + + // Add destination CIDR + if rule.Destination != "" && rule.Destination != "0.0.0.0/0" { + args = append(args, "-d", rule.Destination) + } + + // Add port matches + if rule.Protocol == "tcp" || rule.Protocol == "udp" { + if len(rule.DPorts) > 0 { + // Multiple ports using multiport + ports := make([]string, len(rule.DPorts)) + for i, p := range rule.DPorts { + ports[i] = strconv.Itoa(p) + } + args = append(args, "-m", "multiport", "--dports", strings.Join(ports, ",")) + } else if rule.PortStart > 0 { + if rule.PortEnd > 0 && rule.PortEnd != rule.PortStart { + // Port range + args = append(args, "--dport", fmt.Sprintf("%d:%d", rule.PortStart, rule.PortEnd)) + } else { + // Single port + args = append(args, "--dport", strconv.Itoa(rule.PortStart)) + } + } + } + + // Add ICMP type + if rule.Protocol == "icmp" && rule.ICMPType != "" { + args = append(args, "--icmp-type", rule.ICMPType) + } + + // Add target + target := rule.Target + if target == "" { + target = "ACCEPT" + } + args = append(args, "-j", strings.ToUpper(target)) + + // Add comment with rule ID + if rule.RuleID != "" { + comment := utils.BuildFirewallComment("", rule.RuleID, rule.RuleType) + args = append(args, "-m", "comment", "--comment", comment) + } + + return args +} + +// DeleteRule removes a single firewall rule by ID +func (s *IptablesBackend) DeleteRule(ctx context.Context, ruleID, chainName string) error { + if chainName == "" { + chainName = "INPUT" + } + + // Get current rules to find the one to delete + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables", "-L", chainName, "--line-numbers", "-n", "-v") + if err != nil { + return fmt.Errorf("failed to list iptables rules: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to list iptables rules: %s", output) + } + + // Find rule with matching rule ID in comment + lines := strings.Split(output, "\n") + lineNum := 0 + + for _, line := range lines { + if strings.Contains(line, ruleID) { + // Extract line number from the beginning of the line + parts := strings.Fields(line) + if len(parts) > 0 { + if num, err := strconv.Atoi(parts[0]); err == nil { + lineNum = num + break + } + } + } + } + + if lineNum == 0 { + return fmt.Errorf("rule with ID %s not found in chain %s", ruleID, chainName) + } + + // Delete by line number + exitCode, output, err = s.executor.RunAsUser(ctx, "root", "iptables", "-D", chainName, strconv.Itoa(lineNum)) + if err != nil { + return fmt.Errorf("failed to delete iptables rule: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to delete iptables rule: %s", output) + } + + log.Debug().Msgf("Deleted iptables rule %s from chain %s (line %d)", ruleID, chainName, lineNum) + return nil +} + +// FlushChain removes all rules from a chain +func (s *IptablesBackend) FlushChain(ctx context.Context, chainName string) error { + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables", "-F", chainName) + if err != nil { + return fmt.Errorf("failed to flush iptables chain: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to flush iptables chain %s: %s", chainName, output) + } + log.Debug().Msgf("Flushed iptables chain: %s", chainName) + return nil +} + +// DeleteChain deletes a chain entirely +func (s *IptablesBackend) DeleteChain(ctx context.Context, chainName string) error { + // First flush the chain + if err := s.FlushChain(ctx, chainName); err != nil { + return fmt.Errorf("failed to flush chain before delete: %w", err) + } + + // Then delete the chain + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables", "-X", chainName) + if err != nil { + return fmt.Errorf("failed to delete iptables chain: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to delete iptables chain %s: %s", chainName, output) + } + log.Debug().Msgf("Deleted iptables chain: %s", chainName) + return nil +} + +// ListRules returns all rules in a chain +func (s *IptablesBackend) ListRules(ctx context.Context, chainName string) ([]common.FirewallRule, error) { + args := []string{"iptables-save"} + exitCode, output, err := s.executor.RunAsUser(ctx, "root", args[0], args[1:]...) + if err != nil { + return nil, fmt.Errorf("failed to run iptables-save: %w", err) + } + if exitCode != 0 { + return nil, fmt.Errorf("iptables-save failed: %s", output) + } + + rules := s.parseIptablesSaveOutput(output, chainName) + return rules, nil +} + +// parseIptablesSaveOutput parses iptables-save output to extract rules +func (s *IptablesBackend) parseIptablesSaveOutput(output, filterChain string) []common.FirewallRule { + var rules []common.FirewallRule + lines := strings.Split(output, "\n") + + for _, line := range lines { + line = strings.TrimSpace(line) + + // Skip non-rule lines + if !strings.HasPrefix(line, "-A ") && !strings.HasPrefix(line, "-I ") { + continue + } + + rule := s.parseIptablesSaveRuleLine(line) + if rule == nil { + continue + } + + // Filter by chain if specified + if filterChain != "" && rule.Chain != filterChain { + continue + } + + rules = append(rules, *rule) + } + + return rules +} + +// parseIptablesSaveRuleLine parses a single iptables-save rule line +func (s *IptablesBackend) parseIptablesSaveRuleLine(line string) *common.FirewallRule { + // Remove -A or -I prefix + if strings.HasPrefix(line, "-A ") { + line = strings.TrimPrefix(line, "-A ") + } else if strings.HasPrefix(line, "-I ") { + line = strings.TrimPrefix(line, "-I ") + } + + parts := strings.Fields(line) + if len(parts) < 2 { + return nil + } + + rule := &common.FirewallRule{ + Chain: parts[0], + Source: "0.0.0.0/0", + Destination: "0.0.0.0/0", + Protocol: "all", + Target: "ACCEPT", + } + + // Parse arguments + for i := 1; i < len(parts); i++ { + switch parts[i] { + case "-p", "--protocol": + if i+1 < len(parts) { + rule.Protocol = parts[i+1] + i++ + } + case "-s", "--source": + if i+1 < len(parts) { + rule.Source = parts[i+1] + i++ + } + case "-d", "--destination": + if i+1 < len(parts) { + rule.Destination = parts[i+1] + i++ + } + case "-j", "--jump": + if i+1 < len(parts) { + rule.Target = strings.ToUpper(parts[i+1]) + i++ + } + case "--dport": + if i+1 < len(parts) { + portStr := parts[i+1] + if strings.Contains(portStr, ":") { + // Port range + portRange := strings.Split(portStr, ":") + if len(portRange) == 2 { + if start, err := strconv.Atoi(portRange[0]); err == nil { + rule.PortStart = start + } + if end, err := strconv.Atoi(portRange[1]); err == nil { + rule.PortEnd = end + } + } + } else { + // Single port + if port, err := strconv.Atoi(portStr); err == nil { + rule.PortStart = port + } + } + i++ + } + case "--dports": + if i+1 < len(parts) { + portStrs := strings.Split(parts[i+1], ",") + for _, ps := range portStrs { + if port, err := strconv.Atoi(ps); err == nil { + rule.DPorts = append(rule.DPorts, port) + } + } + i++ + } + case "--icmp-type": + if i+1 < len(parts) { + rule.ICMPType = parts[i+1] + i++ + } + case "--comment": + if i+1 < len(parts) { + comment := strings.Trim(parts[i+1], "\"") + ruleID, ruleType, _ := utils.ParseFirewallComment(comment) + rule.RuleID = ruleID + rule.RuleType = ruleType + i++ + } + } + } + + return rule +} + +// BatchApply applies multiple rules atomically +func (s *IptablesBackend) BatchApply(ctx context.Context, chainName string, rules []common.FirewallRule) (applied int, failed []string, err error) { + for i, rule := range rules { + ruleCopy := rule + if ruleCopy.Chain == "" { + ruleCopy.Chain = chainName + } + + if addErr := s.AddRule(ctx, &ruleCopy); addErr != nil { + failed = append(failed, fmt.Sprintf("rule[%d]: %v", i, addErr)) + log.Error().Err(addErr).Msgf("Failed to add rule %d", i) + continue + } + applied++ + } + + if len(failed) > 0 { + err = fmt.Errorf("batch apply partially failed: %d applied, %d failed", applied, len(failed)) + } + + return applied, failed, err +} + +// ReorderChains reorders jump rules in INPUT chain +func (s *IptablesBackend) ReorderChains(ctx context.Context, chainNames []string) (map[string]interface{}, error) { + log.Debug().Msg("Starting iptables chain reordering") + + // Get current INPUT chain rules + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables", "-L", "INPUT", "--line-numbers", "-n") + if err != nil { + return nil, fmt.Errorf("failed to list INPUT chain rules: %w", err) + } + if exitCode != 0 { + return nil, fmt.Errorf("failed to list INPUT chain rules: %s", output) + } + + // Find alpacon jump rule line numbers + var jumpLines []int + lines := strings.Split(output, "\n") + + for _, line := range lines { + parts := strings.Fields(line) + if len(parts) < 3 { + continue + } + + for _, chainName := range chainNames { + if parts[1] == chainName || (len(parts) > 2 && parts[2] == chainName) { + if lineNum, err := strconv.Atoi(parts[0]); err == nil && lineNum > 0 { + jumpLines = append(jumpLines, lineNum) + log.Debug().Msgf("Found jump rule at line: %d for chain: %s", lineNum, chainName) + break + } + } + } + } + + if len(jumpLines) == 0 { + log.Warn().Msg("No jump rules found to reorder") + return map[string]interface{}{ + "reordered_chains": chainNames, + "deleted_rules": 0, + }, nil + } + + // Sort in reverse order (delete from bottom to preserve line numbers) + for i := 0; i < len(jumpLines); i++ { + for j := i + 1; j < len(jumpLines); j++ { + if jumpLines[i] < jumpLines[j] { + jumpLines[i], jumpLines[j] = jumpLines[j], jumpLines[i] + } + } + } + + // Delete old jump rules + for _, lineNum := range jumpLines { + exitCode, errOutput, _ := s.executor.RunAsUser(ctx, "root", "iptables", "-D", "INPUT", strconv.Itoa(lineNum)) + if exitCode != 0 { + return nil, fmt.Errorf("failed to delete rule at line %d: %s", lineNum, errOutput) + } + log.Debug().Msgf("Deleted rule at line: %d", lineNum) + } + + // Add jump rules in new order + for _, chainName := range chainNames { + exitCode, errOutput, _ := s.executor.RunAsUser(ctx, "root", "iptables", "-A", "INPUT", "-j", chainName) + if exitCode != 0 { + return nil, fmt.Errorf("failed to add jump rule for chain %s: %s", chainName, errOutput) + } + log.Debug().Msgf("Added jump rule for chain: %s", chainName) + } + + return map[string]interface{}{ + "reordered_chains": chainNames, + "deleted_rules": len(jumpLines), + }, nil +} + +// ReorderRules reorders rules within a chain +func (s *IptablesBackend) ReorderRules(ctx context.Context, chainName string, rules []common.FirewallRule) error { + // Flush the chain first + if err := s.FlushChain(ctx, chainName); err != nil { + return fmt.Errorf("failed to flush chain before reorder: %w", err) + } + + // Add rules in the new order + for i, rule := range rules { + ruleCopy := rule + ruleCopy.Chain = chainName + if err := s.AddRule(ctx, &ruleCopy); err != nil { + return fmt.Errorf("failed to add rule %d during reorder: %w", i, err) + } + } + + return nil +} + +// Backup creates a backup of current firewall state +func (s *IptablesBackend) Backup(ctx context.Context) (string, error) { + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables-save") + if err != nil { + return "", fmt.Errorf("failed to run iptables-save: %w", err) + } + if exitCode != 0 { + return "", fmt.Errorf("iptables-save failed: %s", output) + } + log.Debug().Msg("Created iptables backup") + return output, nil +} + +// Restore restores firewall state from backup +func (s *IptablesBackend) Restore(ctx context.Context, backup string) error { + if backup == "" { + return fmt.Errorf("empty backup provided") + } + + log.Warn().Msg("Restoring iptables backup") + + // Write backup to temp file + tmpFile := fmt.Sprintf("/tmp/iptables-backup-%d-%d.rules", os.Getpid(), time.Now().UnixNano()) + f, err := os.OpenFile(tmpFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile) + + if _, err := f.WriteString(backup); err != nil { + f.Close() + return fmt.Errorf("failed to write backup: %w", err) + } + f.Close() + + // Restore using iptables-restore + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "iptables-restore", tmpFile) + if err != nil { + return fmt.Errorf("failed to run iptables-restore: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("iptables-restore failed: %s", output) + } + + log.Info().Msg("Successfully restored iptables backup") + return nil +} + +// Compile-time check to ensure IptablesBackend implements FirewallBackend +var _ FirewallBackend = (*IptablesBackend)(nil) diff --git a/pkg/executor/handlers/firewall/nftables.go b/pkg/executor/handlers/firewall/nftables.go new file mode 100644 index 0000000..4a9f13e --- /dev/null +++ b/pkg/executor/handlers/firewall/nftables.go @@ -0,0 +1,515 @@ +package firewall + +import ( + "context" + "encoding/json" + "fmt" + "os" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// NftablesBackend implements FirewallBackend interface for nftables +type NftablesBackend struct { + executor common.CommandExecutor + tableName string // default: "inet filter" +} + +// NewNftablesBackend creates a new nftables backend +func NewNftablesBackend(executor common.CommandExecutor) *NftablesBackend { + return &NftablesBackend{ + executor: executor, + tableName: "inet filter", + } +} + +// Name returns the backend name +func (s *NftablesBackend) Name() string { + return "nftables" +} + +// Detect checks if nftables is available +func (s *NftablesBackend) Detect(ctx context.Context) bool { + exitCode, _, _ := s.executor.RunWithTimeout(ctx, 5*time.Second, "which", "nft") + return exitCode == 0 +} + +// AddRule adds a single firewall rule +func (s *NftablesBackend) AddRule(ctx context.Context, rule *common.FirewallRule) error { + args := s.buildAddRuleArgs(rule) + exitCode, output, err := s.executor.RunAsUser(ctx, "root", args[0], args[1:]...) + if err != nil { + return fmt.Errorf("failed to execute nft: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("nft add rule failed: %s", output) + } + return nil +} + +// buildAddRuleArgs builds nft command arguments for adding a rule +func (s *NftablesBackend) buildAddRuleArgs(rule *common.FirewallRule) []string { + chainName := rule.Chain + if chainName == "" { + chainName = "INPUT" + } + + args := []string{"nft", "add", "rule", s.tableName, chainName} + + // Add protocol match + if rule.Protocol != "" && rule.Protocol != "all" { + args = append(args, "meta", "l4proto", rule.Protocol) + } + + // Add source CIDR match + if rule.Source != "" && rule.Source != "0.0.0.0/0" { + args = append(args, "ip", "saddr", rule.Source) + } + + // Add destination CIDR match + if rule.Destination != "" && rule.Destination != "0.0.0.0/0" { + args = append(args, "ip", "daddr", rule.Destination) + } + + // Add port matches + if rule.Protocol == "tcp" || rule.Protocol == "udp" { + if len(rule.DPorts) > 0 { + // Multiple ports using set + ports := make([]string, len(rule.DPorts)) + for i, p := range rule.DPorts { + ports[i] = strconv.Itoa(p) + } + args = append(args, rule.Protocol, "dport", "{", strings.Join(ports, ", "), "}") + } else if rule.PortStart > 0 { + if rule.PortEnd > 0 && rule.PortEnd != rule.PortStart { + // Port range + args = append(args, rule.Protocol, "dport", fmt.Sprintf("%d-%d", rule.PortStart, rule.PortEnd)) + } else { + // Single port + args = append(args, rule.Protocol, "dport", strconv.Itoa(rule.PortStart)) + } + } + } + + // Add ICMP type + if rule.Protocol == "icmp" && rule.ICMPType != "" { + args = append(args, "icmp", "type", rule.ICMPType) + } + + // Add target/verdict (must come before comment) + target := strings.ToLower(rule.Target) + if target == "" { + target = "accept" + } + args = append(args, target) + + // Add comment with rule ID + if rule.RuleID != "" { + comment := utils.BuildFirewallComment("", rule.RuleID, rule.RuleType) + args = append(args, "comment", fmt.Sprintf("\"%s\"", comment)) + } + + return args +} + +// DeleteRule removes a single firewall rule by ID +func (s *NftablesBackend) DeleteRule(ctx context.Context, ruleID, chainName string) error { + if chainName == "" { + chainName = "INPUT" + } + + // Get rule handle by listing rules with handles + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "-a", "list", "chain", s.tableName, chainName) + if err != nil { + return fmt.Errorf("failed to list nftables chain: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to list nftables chain %s: %s", chainName, output) + } + + // Find rule with matching rule ID and extract handle + lines := strings.Split(output, "\n") + handleRegex := regexp.MustCompile(`# handle (\d+)`) + + for _, line := range lines { + if !strings.Contains(line, ruleID) { + continue + } + + matches := handleRegex.FindStringSubmatch(line) + if len(matches) > 1 { + handle := matches[1] + + // Delete by handle + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "delete", "rule", s.tableName, chainName, "handle", handle) + if err != nil { + return fmt.Errorf("failed to delete nft rule: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to delete nft rule handle %s: %s", handle, output) + } + + log.Debug().Msgf("Deleted nftables rule %s from chain %s (handle %s)", ruleID, chainName, handle) + return nil + } + } + + return fmt.Errorf("rule with ID %s not found in chain %s", ruleID, chainName) +} + +// FlushChain removes all rules from a chain +func (s *NftablesBackend) FlushChain(ctx context.Context, chainName string) error { + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "flush", "chain", s.tableName, chainName) + if err != nil { + return fmt.Errorf("failed to flush nftables chain: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to flush nftables chain %s: %s", chainName, output) + } + log.Debug().Msgf("Flushed nftables chain: %s", chainName) + return nil +} + +// DeleteChain deletes a chain entirely +func (s *NftablesBackend) DeleteChain(ctx context.Context, chainName string) error { + // First flush the chain + if err := s.FlushChain(ctx, chainName); err != nil { + return fmt.Errorf("failed to flush chain before delete: %w", err) + } + + // Then delete the chain + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "delete", "chain", s.tableName, chainName) + if err != nil { + return fmt.Errorf("failed to delete nftables chain: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("failed to delete nftables chain %s: %s", chainName, output) + } + log.Debug().Msgf("Deleted nftables chain: %s", chainName) + return nil +} + +// ListRules returns all rules in a chain +func (s *NftablesBackend) ListRules(ctx context.Context, chainName string) ([]common.FirewallRule, error) { + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "-j", "list", "ruleset") + if err != nil { + return nil, fmt.Errorf("failed to list nftables ruleset: %w", err) + } + if exitCode != 0 { + return nil, fmt.Errorf("nft list ruleset failed: %s", output) + } + + rules, err := s.parseNftablesJSONOutput(output, chainName) + if err != nil { + return nil, err + } + + return rules, nil +} + +// parseNftablesJSONOutput parses nft -j output to extract rules +func (s *NftablesBackend) parseNftablesJSONOutput(output, filterChain string) ([]common.FirewallRule, error) { + var nftData struct { + Nftables []map[string]interface{} `json:"nftables"` + } + + if err := json.Unmarshal([]byte(output), &nftData); err != nil { + return nil, fmt.Errorf("failed to parse nftables JSON: %w", err) + } + + var rules []common.FirewallRule + + for _, item := range nftData.Nftables { + ruleData, ok := item["rule"] + if !ok { + continue + } + + ruleMap, ok := ruleData.(map[string]interface{}) + if !ok { + continue + } + + rule := s.parseNftablesRule(ruleMap) + if rule == nil { + continue + } + + // Filter by chain if specified + if filterChain != "" && rule.Chain != filterChain { + continue + } + + rules = append(rules, *rule) + } + + return rules, nil +} + +// parseNftablesRule parses a single nftables rule map +func (s *NftablesBackend) parseNftablesRule(ruleMap map[string]interface{}) *common.FirewallRule { + rule := &common.FirewallRule{ + Source: "0.0.0.0/0", + Destination: "0.0.0.0/0", + Protocol: "all", + Target: "ACCEPT", + } + + // Extract chain name + if chain, ok := ruleMap["chain"].(string); ok { + rule.Chain = chain + } + + // Extract comment if present at top level + var fullComment string + if comment, ok := ruleMap["comment"].(string); ok { + fullComment = comment + } + + // Parse expressions for protocol, ports, source, target, comment + if expr, ok := ruleMap["expr"].([]interface{}); ok { + for _, e := range expr { + exprMap, ok := e.(map[string]interface{}) + if !ok { + continue + } + + // Extract comment from expression + if commentExpr, ok := exprMap["comment"].(string); ok { + fullComment = commentExpr + } + + // Match protocol + if match, ok := exprMap["match"].(map[string]interface{}); ok { + if left, ok := match["left"].(map[string]interface{}); ok { + if meta, ok := left["meta"].(map[string]interface{}); ok { + if key, ok := meta["key"].(string); ok && key == "l4proto" { + if right, ok := match["right"].(string); ok { + rule.Protocol = right + } + } + } + + // Match source/destination + if payload, ok := left["payload"].(map[string]interface{}); ok { + if field, ok := payload["field"].(string); ok { + if right, ok := match["right"].(string); ok { + if field == "saddr" { + rule.Source = right + } else if field == "daddr" { + rule.Destination = right + } + } + } + } + } + + // Match ports + if right, ok := match["right"].(float64); ok { + port := int(right) + if rule.PortStart == 0 { + rule.PortStart = port + } + } else if right, ok := match["right"].(map[string]interface{}); ok { + if set, ok := right["set"].([]interface{}); ok { + for _, portVal := range set { + if p, ok := portVal.(float64); ok { + rule.DPorts = append(rule.DPorts, int(p)) + } + } + } + } + } + + // Match target/verdict + if _, ok := exprMap["accept"]; ok { + rule.Target = "ACCEPT" + } else if _, ok := exprMap["drop"]; ok { + rule.Target = "DROP" + } else if _, ok := exprMap["reject"]; ok { + rule.Target = "REJECT" + } + } + } + + // Parse comment to extract rule_id and type + if fullComment != "" { + ruleID, ruleType, _ := utils.ParseFirewallComment(fullComment) + rule.RuleID = ruleID + rule.RuleType = ruleType + } + + return rule +} + +// BatchApply applies multiple rules atomically +func (s *NftablesBackend) BatchApply(ctx context.Context, chainName string, rules []common.FirewallRule) (applied int, failed []string, err error) { + for i, rule := range rules { + ruleCopy := rule + if ruleCopy.Chain == "" { + ruleCopy.Chain = chainName + } + + if addErr := s.AddRule(ctx, &ruleCopy); addErr != nil { + failed = append(failed, fmt.Sprintf("rule[%d]: %v", i, addErr)) + log.Error().Err(addErr).Msgf("Failed to add rule %d", i) + continue + } + applied++ + } + + if len(failed) > 0 { + err = fmt.Errorf("batch apply partially failed: %d applied, %d failed", applied, len(failed)) + } + + return applied, failed, err +} + +// ReorderChains reorders jump rules in INPUT chain +func (s *NftablesBackend) ReorderChains(ctx context.Context, chainNames []string) (map[string]interface{}, error) { + log.Debug().Msg("Starting nftables chain reordering") + + // Get current INPUT chain rules with handles + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "-a", "list", "chain", s.tableName, "INPUT") + if err != nil { + return nil, fmt.Errorf("failed to list INPUT chain: %w", err) + } + if exitCode != 0 { + return nil, fmt.Errorf("failed to list INPUT chain: %s", output) + } + + // Parse and find jump rule handles + var jumpHandles []string + lines := strings.Split(output, "\n") + handleRegex := regexp.MustCompile(`# handle (\d+)`) + + for _, line := range lines { + isJumpRule := false + for _, chainName := range chainNames { + if strings.Contains(line, fmt.Sprintf("jump %s", chainName)) { + isJumpRule = true + break + } + } + + if !isJumpRule { + continue + } + + matches := handleRegex.FindStringSubmatch(line) + if len(matches) > 1 { + jumpHandles = append(jumpHandles, matches[1]) + log.Debug().Msgf("Found jump rule handle: %s", matches[1]) + } + } + + if len(jumpHandles) == 0 { + log.Warn().Msg("No jump rules found to reorder") + return map[string]interface{}{ + "reordered_chains": chainNames, + "deleted_rules": 0, + }, nil + } + + // Delete old jump rules + for _, handle := range jumpHandles { + exitCode, errOutput, _ := s.executor.RunAsUser(ctx, "root", "nft", "delete", "rule", s.tableName, "INPUT", "handle", handle) + if exitCode != 0 { + return nil, fmt.Errorf("failed to delete rule handle %s: %s", handle, errOutput) + } + log.Debug().Msgf("Deleted rule handle: %s", handle) + } + + // Add jump rules in new order + for _, chainName := range chainNames { + exitCode, errOutput, _ := s.executor.RunAsUser(ctx, "root", "nft", "add", "rule", s.tableName, "INPUT", "jump", chainName) + if exitCode != 0 { + return nil, fmt.Errorf("failed to add jump rule for chain %s: %s", chainName, errOutput) + } + log.Debug().Msgf("Added jump rule for chain: %s", chainName) + } + + return map[string]interface{}{ + "reordered_chains": chainNames, + "deleted_rules": len(jumpHandles), + }, nil +} + +// ReorderRules reorders rules within a chain +func (s *NftablesBackend) ReorderRules(ctx context.Context, chainName string, rules []common.FirewallRule) error { + // Flush the chain first + if err := s.FlushChain(ctx, chainName); err != nil { + return fmt.Errorf("failed to flush chain before reorder: %w", err) + } + + // Add rules in the new order + for i, rule := range rules { + ruleCopy := rule + ruleCopy.Chain = chainName + if err := s.AddRule(ctx, &ruleCopy); err != nil { + return fmt.Errorf("failed to add rule %d during reorder: %w", i, err) + } + } + + return nil +} + +// Backup creates a backup of current firewall state +func (s *NftablesBackend) Backup(ctx context.Context) (string, error) { + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "list", "ruleset") + if err != nil { + return "", fmt.Errorf("failed to run nft list ruleset: %w", err) + } + if exitCode != 0 { + return "", fmt.Errorf("nft list ruleset failed: %s", output) + } + log.Debug().Msg("Created nftables backup") + return output, nil +} + +// Restore restores firewall state from backup +func (s *NftablesBackend) Restore(ctx context.Context, backup string) error { + if backup == "" { + return fmt.Errorf("empty backup provided") + } + + log.Warn().Msg("Restoring nftables backup") + + // Write backup to temp file + tmpFile := fmt.Sprintf("/tmp/nft-backup-%d-%d.nft", os.Getpid(), time.Now().UnixNano()) + f, err := os.OpenFile(tmpFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile) + + if _, err := f.WriteString(backup); err != nil { + f.Close() + return fmt.Errorf("failed to write backup: %w", err) + } + f.Close() + + // Flush current ruleset (ignore error as restore will overwrite anyway) + _, _, _ = s.executor.RunAsUser(ctx, "root", "nft", "flush", "ruleset") + + // Restore from file + exitCode, output, err := s.executor.RunAsUser(ctx, "root", "nft", "-f", tmpFile) + if err != nil { + return fmt.Errorf("failed to run nft restore: %w", err) + } + if exitCode != 0 { + return fmt.Errorf("nft restore failed: %s", output) + } + + log.Info().Msg("Successfully restored nftables backup") + return nil +} + +// Compile-time check to ensure NftablesBackend implements FirewallBackend +var _ FirewallBackend = (*NftablesBackend)(nil) diff --git a/pkg/executor/handlers/firewall/types.go b/pkg/executor/handlers/firewall/types.go new file mode 100644 index 0000000..7530def --- /dev/null +++ b/pkg/executor/handlers/firewall/types.go @@ -0,0 +1,18 @@ +package firewall + +// FirewallData contains data for firewall operations +type FirewallData struct { + Operation string `json:"operation"` + ChainName string `json:"chain_name,omitempty"` + Rules []map[string]interface{} `json:"rules,omitempty"` + RuleID string `json:"rule_id,omitempty"` + OldRuleID string `json:"old_rule_id,omitempty"` + ChainNames []string `json:"chain_names,omitempty"` + Method string `json:"method,omitempty"` + Chain string `json:"chain,omitempty"` + Protocol string `json:"protocol,omitempty"` + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + Target string `json:"target,omitempty"` + Description string `json:"description,omitempty"` +} diff --git a/pkg/executor/handlers/firewall/validator.go b/pkg/executor/handlers/firewall/validator.go new file mode 100644 index 0000000..29af21f --- /dev/null +++ b/pkg/executor/handlers/firewall/validator.go @@ -0,0 +1,227 @@ +package firewall + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// Validator validates firewall rules before execution +type Validator struct{} + +// NewValidator creates a new validator +func NewValidator() *Validator { + return &Validator{} +} + +// ValidateRule validates a single firewall rule +func (v *Validator) ValidateRule(rule *common.FirewallRule) error { + // Validate chain name if provided + if rule.Chain != "" { + if err := v.ValidateChainName(rule.Chain); err != nil { + return err + } + } + + // Validate protocol + if rule.Protocol != "" { + if err := v.ValidateProtocol(rule.Protocol); err != nil { + return err + } + } + + // Validate port start + if rule.PortStart != 0 { + if err := v.ValidatePort(rule.PortStart); err != nil { + return fmt.Errorf("invalid port_start: %w", err) + } + } + + // Validate port end + if rule.PortEnd != 0 { + if err := v.ValidatePort(rule.PortEnd); err != nil { + return fmt.Errorf("invalid port_end: %w", err) + } + } + + // Validate port range if both start and end provided + if rule.PortStart != 0 && rule.PortEnd != 0 { + if err := v.ValidatePortRange(rule.PortStart, rule.PortEnd); err != nil { + return err + } + } + + // Validate DPorts + for i, port := range rule.DPorts { + if err := v.ValidatePort(port); err != nil { + return fmt.Errorf("invalid dports[%d]: %w", i, err) + } + } + + // Validate source CIDR + if rule.Source != "" && rule.Source != "0.0.0.0/0" { + if err := v.ValidateCIDR(rule.Source); err != nil { + return fmt.Errorf("invalid source: %w", err) + } + } + + // Validate destination CIDR + if rule.Destination != "" && rule.Destination != "0.0.0.0/0" { + if err := v.ValidateCIDR(rule.Destination); err != nil { + return fmt.Errorf("invalid destination: %w", err) + } + } + + // Validate target + if rule.Target != "" { + if err := v.ValidateTarget(rule.Target); err != nil { + return err + } + } + + // Validate ICMP type if protocol is ICMP + if rule.Protocol == "icmp" && rule.ICMPType != "" { + if err := v.ValidateICMPType(rule.ICMPType); err != nil { + return err + } + } + + return nil +} + +// ValidateChainName validates a chain name +func (v *Validator) ValidateChainName(chainName string) error { + if chainName == "" { + return fmt.Errorf("chain name cannot be empty") + } + + // Check for valid characters (alphanumeric, underscore, hyphen) + for _, c := range chainName { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_' || c == '-') { + return fmt.Errorf("invalid character '%c' in chain name", c) + } + } + + // Maximum length check (iptables limit is 29 characters) + if len(chainName) > 29 { + return fmt.Errorf("chain name too long (max 29 characters)") + } + + return nil +} + +// ValidateProtocol validates the protocol field +func (v *Validator) ValidateProtocol(protocol string) error { + validProtocols := map[string]bool{ + "tcp": true, + "udp": true, + "icmp": true, + "all": true, + "": true, // empty is valid (defaults to all) + } + + if !validProtocols[strings.ToLower(protocol)] { + return fmt.Errorf("invalid protocol '%s' (allowed: tcp, udp, icmp, all)", protocol) + } + + return nil +} + +// ValidatePort validates a port number +func (v *Validator) ValidatePort(port int) error { + if port < 1 || port > 65535 { + return fmt.Errorf("port %d out of range (1-65535)", port) + } + return nil +} + +// ValidatePortRange validates a port range +func (v *Validator) ValidatePortRange(start, end int) error { + if start > end { + return fmt.Errorf("port range invalid: start (%d) > end (%d)", start, end) + } + return nil +} + +// ValidateCIDR validates an IP address or CIDR notation +func (v *Validator) ValidateCIDR(cidr string) error { + // Try parsing as CIDR first + _, _, err := net.ParseCIDR(cidr) + if err == nil { + return nil + } + + // Try parsing as plain IP address + ip := net.ParseIP(cidr) + if ip != nil { + return nil + } + + return fmt.Errorf("invalid IP address or CIDR notation: %s", cidr) +} + +// ValidateTarget validates a firewall target/action +func (v *Validator) ValidateTarget(target string) error { + validTargets := map[string]bool{ + "ACCEPT": true, + "DROP": true, + "REJECT": true, + "LOG": true, + "RETURN": true, + "accept": true, + "drop": true, + "reject": true, + "log": true, + "return": true, + } + + if !validTargets[target] { + return fmt.Errorf("invalid target '%s' (allowed: ACCEPT, DROP, REJECT, LOG, RETURN)", target) + } + + return nil +} + +// ValidateICMPType validates an ICMP type +func (v *Validator) ValidateICMPType(icmpType string) error { + // Common ICMP types (names or numbers) + validTypes := map[string]bool{ + "echo-reply": true, + "destination-unreachable": true, + "redirect": true, + "echo-request": true, + "time-exceeded": true, + "parameter-problem": true, + "timestamp-request": true, + "timestamp-reply": true, + "address-mask-request": true, + "address-mask-reply": true, + } + + // Check if it's a valid name + if validTypes[strings.ToLower(icmpType)] { + return nil + } + + // Check if it's a valid number (0-255) + num, err := strconv.Atoi(icmpType) + if err == nil && num >= 0 && num <= 255 { + return nil + } + + return fmt.Errorf("invalid ICMP type '%s'", icmpType) +} + +// ValidateBatchRules validates a batch of rules +func (v *Validator) ValidateBatchRules(rules []common.FirewallRule) error { + for i, rule := range rules { + if err := v.ValidateRule(&rule); err != nil { + return fmt.Errorf("rule[%d]: %w", i, err) + } + } + return nil +} diff --git a/pkg/executor/handlers/group/group.go b/pkg/executor/handlers/group/group.go new file mode 100644 index 0000000..b4a8f65 --- /dev/null +++ b/pkg/executor/handlers/group/group.go @@ -0,0 +1,179 @@ +package group + +import ( + "context" + "fmt" + "strconv" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// GroupHandler handles group management commands +type GroupHandler struct { + *common.BaseHandler + syncManager common.SystemInfoManager +} + +// NewGroupHandler creates a new group handler +func NewGroupHandler(cmdExecutor common.CommandExecutor, syncManager common.SystemInfoManager) *GroupHandler { + h := &GroupHandler{ + BaseHandler: common.NewBaseHandler( + common.Group, + []common.CommandType{ + common.AddGroup, + common.DelGroup, + }, + cmdExecutor, + ), + syncManager: syncManager, + } + return h +} + +// Execute runs the group management command +func (h *GroupHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + var exitCode int + var output string + var err error + + switch cmd { + case common.AddGroup.String(): + exitCode, output, err = h.handleAddGroup(ctx, args) + case common.DelGroup.String(): + exitCode, output, err = h.handleDelGroup(ctx, args) + default: + return 1, "", fmt.Errorf("unknown group command: %s", cmd) + } + + // Sync system info after successful command execution + if exitCode == 0 && h.syncManager != nil { + h.syncManager.SyncSystemInfo([]string{"groups", "users"}) + } + + return exitCode, output, err +} + +// Validate checks if the arguments are valid for the command +func (h *GroupHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.AddGroup.String(): + data := GroupData{ + Groupname: args.Groupname, + GID: args.GID, + } + return h.ValidateStruct(data) + + case common.DelGroup.String(): + data := DeleteGroupData{ + Groupname: args.Groupname, + } + return h.ValidateStruct(data) + + default: + return fmt.Errorf("unknown group command: %s", cmd) + } +} + +// handleAddGroup handles the addgroup command +func (h *GroupHandler) handleAddGroup(ctx context.Context, args *common.CommandArgs) (int, string, error) { + // Extract arguments + groupname := args.Groupname + gid := int(args.GID) + + // Validate + err := h.Validate(common.AddGroup.String(), args) + if err != nil { + return 1, err.Error(), nil + } + + log.Info(). + Str("groupname", groupname). + Uint64("gid", uint64(gid)). + Msg("Adding group") + + var exitCode int + var output string + // Platform-specific group addition + if utils.PlatformLike == "debian" { + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/addgroup", + "--gid", strconv.Itoa(gid), + groupname, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else if utils.PlatformLike == "rhel" { + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/groupadd", + "--gid", strconv.Itoa(gid), + groupname, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else { + return 1, fmt.Sprintf("Platform '%s' not supported for group management", utils.PlatformLike), nil + } + + log.Info(). + Str("groupname", groupname). + Uint64("gid", uint64(gid)). + Int("exitCode", exitCode). + Msg("Group added successfully") + + return exitCode, fmt.Sprintf("Group '%s' added successfully with GID %d", groupname, gid), nil +} + +// handleDelGroup handles the delgroup command +func (h *GroupHandler) handleDelGroup(ctx context.Context, args *common.CommandArgs) (int, string, error) { + // Extract arguments + groupname := args.Groupname + + // Validate + err := h.Validate(common.DelGroup.String(), args) + if err != nil { + return 1, err.Error(), nil + } + + log.Info(). + Str("groupname", groupname). + Msg("Deleting group") + + var exitCode int + var output string + + // Platform-specific group deletion + if utils.PlatformLike == "debian" { + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/delgroup", + groupname, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else if utils.PlatformLike == "rhel" { + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/groupdel", + groupname, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else { + return 1, fmt.Sprintf("Platform '%s' not supported for group management", utils.PlatformLike), nil + } + + log.Info(). + Str("groupname", groupname). + Int("exitCode", exitCode). + Msg("Group deleted successfully") + + return exitCode, fmt.Sprintf("Group '%s' deleted successfully", groupname), nil +} diff --git a/pkg/executor/handlers/group/group_test.go b/pkg/executor/handlers/group/group_test.go new file mode 100644 index 0000000..6ffd458 --- /dev/null +++ b/pkg/executor/handlers/group/group_test.go @@ -0,0 +1,128 @@ +package group + +import ( + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +func TestGroupHandler_AddGroup(t *testing.T) { + // Create mock executor + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("/usr/sbin/addgroup --gid 1001 testgroup", 0, "Group added successfully", nil) + + // Create handler with mock + handler := NewGroupHandler(mockExec, nil) + + // Test data + args := &common.CommandArgs{ + Groupname: "testgroup", + GID: 1001, + } + + // Validate arguments + err := handler.Validate("addgroup", args) + if err != nil { + t.Fatalf("Validation failed: %v", err) + } + + // Execute command (Note: This test is simplified, full implementation would use proper mocking) + // For now, just test validation and basic structure + t.Log("Group handler validated successfully") +} + +func TestGroupHandler_AddGroup_InvalidArgs(t *testing.T) { + handler := NewGroupHandler(nil, nil) // NewGroupHandler expects common.CommandExecutor, but for validation only, nil is fine + + testCases := []struct { + name string + args *common.CommandArgs + wantErr bool + }{ + { + name: "missing groupname", + args: &common.CommandArgs{ + GID: 1001, + }, + wantErr: true, + }, + { + name: "missing GID", + args: &common.CommandArgs{ + Groupname: "testgroup", + }, + wantErr: true, + }, + { + name: "invalid GID", + args: &common.CommandArgs{ + Groupname: "testgroup", + GID: 0, + }, + wantErr: true, + }, + { + name: "valid args", + args: &common.CommandArgs{ + Groupname: "testgroup", + GID: 1001, + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := handler.Validate("addgroup", tc.args) + if (err != nil) != tc.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} + +func TestGroupHandler_DelGroup(t *testing.T) { + handler := NewGroupHandler(nil, nil) // NewGroupHandler expects common.CommandExecutor, but for validation only, nil is fine + + // Test validation for delgroup + args := &common.CommandArgs{ + Groupname: "testgroup", + } + + err := handler.Validate("delgroup", args) + if err != nil { + t.Fatalf("Validation failed: %v", err) + } + + // Test missing groupname + emptyArgs := &common.CommandArgs{} + err = handler.Validate("delgroup", emptyArgs) + if err == nil { + t.Error("Expected error for missing groupname, got nil") + } +} + +func TestGroupHandler_Commands(t *testing.T) { + handler := NewGroupHandler(nil, nil) // NewGroupHandler expects common.CommandExecutor, but for validation only, nil is fine + + commands := handler.Commands() + expectedCommands := []string{"addgroup", "delgroup"} + + if len(commands) != len(expectedCommands) { + t.Errorf("Expected %d commands, got %d", len(expectedCommands), len(commands)) + } + + for i, cmd := range expectedCommands { + if commands[i] != cmd { + t.Errorf("Expected command %s at index %d, got %s", cmd, i, commands[i]) + } + } +} + +func TestGroupHandler_Name(t *testing.T) { + handler := NewGroupHandler(nil, nil) // NewGroupHandler expects common.CommandExecutor, but for validation only, nil is fine + + if handler.Name() != "group" { + t.Errorf("Expected handler name 'group', got '%s'", handler.Name()) + } +} diff --git a/pkg/executor/handlers/group/types.go b/pkg/executor/handlers/group/types.go new file mode 100644 index 0000000..c5d78bb --- /dev/null +++ b/pkg/executor/handlers/group/types.go @@ -0,0 +1,12 @@ +package group + +// GroupData contains data for group operations +type GroupData struct { + Groupname string `validate:"required"` + GID uint64 `validate:"required,min=1"` +} + +// DeleteGroupData contains data for group deletion +type DeleteGroupData struct { + Groupname string `validate:"required"` +} diff --git a/pkg/executor/handlers/info/info.go b/pkg/executor/handlers/info/info.go new file mode 100644 index 0000000..920d304 --- /dev/null +++ b/pkg/executor/handlers/info/info.go @@ -0,0 +1,148 @@ +package info + +import ( + "context" + "fmt" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/rs/zerolog/log" +) + +// InfoHandler handles informational commands like ping, help, commit, sync +type InfoHandler struct { + *common.BaseHandler + infoManager common.SystemInfoManager +} + +// NewInfoHandler creates a new info handler +func NewInfoHandler(infoManager common.SystemInfoManager) *InfoHandler { + h := &InfoHandler{ + BaseHandler: common.NewBaseHandler( + common.Info, + []common.CommandType{ + common.Ping, + common.Help, + common.Commit, + common.Sync, + }, + nil, // No command executor needed for these commands + ), + infoManager: infoManager, + } + return h +} + +// Execute runs the info command +func (h *InfoHandler) Execute(_ context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.Ping.String(): + return h.handlePing() + case common.Help.String(): + return h.handleHelp() + case common.Commit.String(): + return h.handleCommit() + case common.Sync.String(): + return h.handleSync(args) + default: + return 1, "", fmt.Errorf("unknown info command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *InfoHandler) Validate(cmd string, args *common.CommandArgs) error { + // Most info commands don't require validation + // Only sync accepts optional Keys parameter + if cmd == common.Sync.String() { + // Keys is optional, so no validation needed + _ = args.Keys + } + return nil +} + +// handlePing handles the ping command +func (h *InfoHandler) handlePing() (int, string, error) { + // Return current timestamp in RFC3339 format + return 0, time.Now().Format(time.RFC3339), nil +} + +// handleHelp handles the help command +func (h *InfoHandler) handleHelp() (int, string, error) { + helpMessage := ` +Available commands: + +System Control: + upgrade - Upgrade Alpamon to the latest version + update - Update system packages + restart [target] - Restart Alpamon or collector (target: alpamon|collector) + quit - Stop Alpamon gracefully + byebye - Completely uninstall Alpamon + reboot - Reboot the system + shutdown - Shutdown the system + +User Management: + adduser - Add a new user + deluser - Delete a user + moduser - Modify user settings + +Group Management: + addgroup - Add a new group + delgroup - Delete a group + +Firewall Management: + firewall - Manage firewall rules + firewall-rollback - Rollback firewall changes + firewall-reorder-chains - Reorder firewall chains + firewall-reorder-rules - Reorder firewall rules + +File Operations: + upload - Upload files to the server + download - Download files from the server + +Terminal Operations: + openpty - Open a PTY session + openftp - Open an FTP session + resizepty - Resize PTY terminal + +System Information: + commit - Commit system information + sync [keys] - Synchronize system information + ping - Check agent responsiveness + help - Show this help message + +Package Management: + package install - Install a system package + package uninstall - Remove a system package + +Shell Commands: + Any other command will be executed as a shell command +` + return 0, helpMessage, nil +} + +// handleCommit handles the commit command +func (h *InfoHandler) handleCommit() (int, string, error) { + log.Debug().Msg("Executing commit command") + + // Call through interface + if h.infoManager != nil { + h.infoManager.CommitSystemInfo() + } + + return 0, "Committed system information.", nil +} + +// handleSync handles the sync command +func (h *InfoHandler) handleSync(args *common.CommandArgs) (int, string, error) { + log.Debug().Msg("Executing sync command") + + // Extract keys if provided + keys := args.Keys + + // Call through interface + if h.infoManager != nil { + h.infoManager.SyncSystemInfo(keys) + } + + return 0, "Synchronized system information.", nil +} diff --git a/pkg/executor/handlers/info/info_test.go b/pkg/executor/handlers/info/info_test.go new file mode 100644 index 0000000..ae2ec8c --- /dev/null +++ b/pkg/executor/handlers/info/info_test.go @@ -0,0 +1,295 @@ +package info + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// MockSystemInfoManager is a mock implementation of SystemInfoManager for testing +type MockSystemInfoManager struct { + CommitCalled bool + SyncCalled bool + SyncKeys []string +} + +func (m *MockSystemInfoManager) CommitSystemInfo() { + m.CommitCalled = true +} + +func (m *MockSystemInfoManager) SyncSystemInfo(keys []string) { + m.SyncCalled = true + m.SyncKeys = keys +} + +func TestInfoHandler_Name(t *testing.T) { + handler := NewInfoHandler(nil) + if handler.Name() != common.Info.String() { + t.Errorf("expected name %q, got %q", common.Info.String(), handler.Name()) + } +} + +func TestInfoHandler_Commands(t *testing.T) { + handler := NewInfoHandler(nil) + commands := handler.Commands() + + expected := []string{ + common.Ping.String(), + common.Help.String(), + common.Commit.String(), + common.Sync.String(), + } + + if len(commands) != len(expected) { + t.Errorf("expected %d commands, got %d", len(expected), len(commands)) + return + } + + for i, cmd := range commands { + if cmd != expected[i] { + t.Errorf("command %d: expected %q, got %q", i, expected[i], cmd) + } + } +} + +func TestInfoHandler_Ping(t *testing.T) { + handler := NewInfoHandler(nil) + ctx := context.Background() + args := &common.CommandArgs{} + + before := time.Now().Add(-1 * time.Second) // Allow 1 second tolerance before + exitCode, output, err := handler.Execute(ctx, common.Ping.String(), args) + after := time.Now().Add(1 * time.Second) // Allow 1 second tolerance after + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + + // Verify output is RFC3339 timestamp + parsedTime, parseErr := time.Parse(time.RFC3339, output) + if parseErr != nil { + t.Errorf("output is not valid RFC3339 timestamp: %v", parseErr) + } + + // Verify timestamp is within expected range (with tolerance for RFC3339 second precision) + if parsedTime.Before(before) || parsedTime.After(after) { + t.Errorf("timestamp %v not within expected range [%v, %v]", parsedTime, before, after) + } +} + +func TestInfoHandler_Help(t *testing.T) { + handler := NewInfoHandler(nil) + ctx := context.Background() + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Help.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + + // Verify help message contains expected sections + expectedSections := []string{ + "Available commands", + "System Control:", + "User Management:", + "Group Management:", + "Firewall Management:", + "File Operations:", + "Terminal Operations:", + "System Information:", + "Package Management:", + "Shell Commands:", + } + + for _, section := range expectedSections { + if !strings.Contains(output, section) { + t.Errorf("help message missing section: %q", section) + } + } + + // Verify key commands are documented + expectedCommands := []string{ + "upgrade", "restart", "quit", "reboot", "shutdown", + "adduser", "deluser", "moduser", + "addgroup", "delgroup", + "firewall", "upload", "download", + "openpty", "openftp", + "commit", "sync", "ping", "help", + } + + for _, cmd := range expectedCommands { + if !strings.Contains(output, cmd) { + t.Errorf("help message missing command: %q", cmd) + } + } +} + +func TestInfoHandler_Commit(t *testing.T) { + mockManager := &MockSystemInfoManager{} + handler := NewInfoHandler(mockManager) + ctx := context.Background() + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Commit.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "Committed") { + t.Errorf("expected output to contain 'Committed', got %q", output) + } + if !mockManager.CommitCalled { + t.Error("expected CommitSystemInfo to be called") + } +} + +func TestInfoHandler_Commit_NilManager(t *testing.T) { + handler := NewInfoHandler(nil) + ctx := context.Background() + args := &common.CommandArgs{} + + // Should not panic with nil manager + exitCode, output, err := handler.Execute(ctx, common.Commit.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "Committed") { + t.Errorf("expected output to contain 'Committed', got %q", output) + } +} + +func TestInfoHandler_Sync(t *testing.T) { + mockManager := &MockSystemInfoManager{} + handler := NewInfoHandler(mockManager) + ctx := context.Background() + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Sync.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "Synchronized") { + t.Errorf("expected output to contain 'Synchronized', got %q", output) + } + if !mockManager.SyncCalled { + t.Error("expected SyncSystemInfo to be called") + } +} + +func TestInfoHandler_Sync_WithKeys(t *testing.T) { + mockManager := &MockSystemInfoManager{} + handler := NewInfoHandler(mockManager) + ctx := context.Background() + keys := []string{"cpu", "memory", "disk"} + args := &common.CommandArgs{ + Keys: keys, + } + + exitCode, output, err := handler.Execute(ctx, common.Sync.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "Synchronized") { + t.Errorf("expected output to contain 'Synchronized', got %q", output) + } + if !mockManager.SyncCalled { + t.Error("expected SyncSystemInfo to be called") + } + if len(mockManager.SyncKeys) != len(keys) { + t.Errorf("expected %d keys, got %d", len(keys), len(mockManager.SyncKeys)) + } + for i, key := range keys { + if mockManager.SyncKeys[i] != key { + t.Errorf("key %d: expected %q, got %q", i, key, mockManager.SyncKeys[i]) + } + } +} + +func TestInfoHandler_Sync_NilManager(t *testing.T) { + handler := NewInfoHandler(nil) + ctx := context.Background() + args := &common.CommandArgs{ + Keys: []string{"cpu"}, + } + + // Should not panic with nil manager + exitCode, output, err := handler.Execute(ctx, common.Sync.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "Synchronized") { + t.Errorf("expected output to contain 'Synchronized', got %q", output) + } +} + +func TestInfoHandler_UnknownCommand(t *testing.T) { + handler := NewInfoHandler(nil) + ctx := context.Background() + args := &common.CommandArgs{} + + exitCode, _, err := handler.Execute(ctx, "unknown_command", args) + + if err == nil { + t.Error("expected error for unknown command") + } + if exitCode != 1 { + t.Errorf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(err.Error(), "unknown info command") { + t.Errorf("error should mention 'unknown info command', got: %v", err) + } +} + +func TestInfoHandler_Validate(t *testing.T) { + handler := NewInfoHandler(nil) + + testCases := []struct { + name string + cmd string + args *common.CommandArgs + }{ + {"ping", common.Ping.String(), &common.CommandArgs{}}, + {"help", common.Help.String(), &common.CommandArgs{}}, + {"commit", common.Commit.String(), &common.CommandArgs{}}, + {"sync without keys", common.Sync.String(), &common.CommandArgs{}}, + {"sync with keys", common.Sync.String(), &common.CommandArgs{Keys: []string{"cpu", "memory"}}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := handler.Validate(tc.cmd, tc.args) + if err != nil { + t.Errorf("unexpected validation error: %v", err) + } + }) + } +} diff --git a/pkg/executor/handlers/shell/shell.go b/pkg/executor/handlers/shell/shell.go new file mode 100644 index 0000000..63bf201 --- /dev/null +++ b/pkg/executor/handlers/shell/shell.go @@ -0,0 +1,154 @@ +package shell + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/rs/zerolog/log" +) + +// ShellHandler handles shell command execution +type ShellHandler struct { + *common.BaseHandler +} + +// NewShellHandler creates a new shell handler +func NewShellHandler(cmdExecutor common.CommandExecutor) *ShellHandler { + h := &ShellHandler{ + BaseHandler: common.NewBaseHandler( + common.Shell, + []common.CommandType{ + common.ShellCmd, + common.Exec, + }, + cmdExecutor, + ), + } + return h +} + +// Execute runs the shell command +func (h *ShellHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.ShellCmd.String(), common.Exec.String(): + return h.handleShellCommand(ctx, args) + default: + return 1, "", fmt.Errorf("unknown shell command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *ShellHandler) Validate(cmd string, args *common.CommandArgs) error { + if args.Command == "" { + return fmt.Errorf("shell command is required") + } + return nil +} + +// handleShellCommand executes a shell command with support for operators (&&, ||, ;) +func (h *ShellHandler) handleShellCommand(ctx context.Context, args *common.CommandArgs) (int, string, error) { + command := args.Command + username := args.Username + if username == "" { + username = "root" + } + groupname := args.Groupname + if groupname == "" { + groupname = username + } + + // Get environment variables + env := args.Env + + // Get timeout + timeout := int(args.Timeout.Seconds()) + + log.Debug(). + Str("command", command). + Str("user", username). + Str("group", groupname). + Int("timeout", timeout). + Msg("Executing shell command") + + // Parse and execute command with operators support + return h.executeWithOperators(ctx, command, username, groupname, env, timeout) +} + +// executeWithOperators handles shell operators (&&, ||, ;) +func (h *ShellHandler) executeWithOperators(ctx context.Context, command, username, groupname string, env map[string]string, timeoutSecs int) (int, string, error) { + spl := strings.Fields(command) + var currentCmd []string + var results strings.Builder + var exitCode int + var result string + + timeout := time.Duration(timeoutSecs) * time.Second + if timeout == 0 { + timeout = 120 * time.Second // Default timeout + } + + for _, arg := range spl { + switch arg { + case "&&": + // Execute current command + if len(currentCmd) > 0 { + exitCode, result = h.executeCommand(ctx, currentCmd, username, groupname, env, timeout) + results.WriteString(result) + // Stop if command fails + if exitCode != 0 { + return exitCode, results.String(), nil + } + currentCmd = nil + } + case "||": + // Execute current command + if len(currentCmd) > 0 { + exitCode, result = h.executeCommand(ctx, currentCmd, username, groupname, env, timeout) + results.WriteString(result) + // Continue only if command fails + if exitCode == 0 { + return exitCode, results.String(), nil + } + currentCmd = nil + } + case ";": + // Execute current command + if len(currentCmd) > 0 { + exitCode, result = h.executeCommand(ctx, currentCmd, username, groupname, env, timeout) + results.WriteString(result) + // Continue regardless of result + currentCmd = nil + } + default: + currentCmd = append(currentCmd, arg) + } + } + + // Execute any remaining command + if len(currentCmd) > 0 { + exitCode, result = h.executeCommand(ctx, currentCmd, username, groupname, env, timeout) + results.WriteString(result) + } + + return exitCode, results.String(), nil +} + +// executeCommand executes a single command +func (h *ShellHandler) executeCommand(ctx context.Context, cmdArgs []string, username, groupname string, env map[string]string, timeout time.Duration) (int, string) { + if len(cmdArgs) == 0 { + return 0, "" + } + + // Execute command through the executor with full parameters (user, group, env, timeout) + exitCode, output, err := h.Executor.Exec(ctx, cmdArgs, username, groupname, env, timeout) + + if err != nil && exitCode == -1 { + // Command execution error (not just non-zero exit) + return exitCode, err.Error() + } + + return exitCode, output +} diff --git a/pkg/executor/handlers/shell/shell_test.go b/pkg/executor/handlers/shell/shell_test.go new file mode 100644 index 0000000..7656ca7 --- /dev/null +++ b/pkg/executor/handlers/shell/shell_test.go @@ -0,0 +1,473 @@ +package shell + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +func TestShellHandler_Name(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + handler := NewShellHandler(mockExec) + if handler.Name() != common.Shell.String() { + t.Errorf("expected name %q, got %q", common.Shell.String(), handler.Name()) + } +} + +func TestShellHandler_Commands(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + handler := NewShellHandler(mockExec) + commands := handler.Commands() + + expected := []string{ + common.ShellCmd.String(), + common.Exec.String(), + } + + if len(commands) != len(expected) { + t.Errorf("expected %d commands, got %d", len(expected), len(commands)) + return + } + + for i, cmd := range commands { + if cmd != expected[i] { + t.Errorf("command %d: expected %q, got %q", i, expected[i], cmd) + } + } +} + +func TestShellHandler_Execute_Basic(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + // Key format is "name arg1 arg2..." - for single word command it's just "ls " + mockExec.SetResult("ls ", 0, "file1.txt\nfile2.txt", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "ls", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "file1.txt") { + t.Errorf("expected output to contain 'file1.txt', got %q", output) + } +} + +func TestShellHandler_Execute_Exec(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("echo hello", 0, "hello", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "echo hello", + } + + exitCode, output, err := handler.Execute(ctx, common.Exec.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "hello") { + t.Errorf("expected output to contain 'hello', got %q", output) + } +} + +func TestShellHandler_Execute_AndOperator(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + // Shell handler uses strings.Fields which splits "cmd1 && cmd2" into ["cmd1", "&&", "cmd2"] + mockExec.SetResult("cmd1 ", 0, "output1", nil) + mockExec.SetResult("cmd2 ", 0, "output2", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "cmd1 && cmd2", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "output1") || !strings.Contains(output, "output2") { + t.Errorf("expected output to contain both outputs, got %q", output) + } +} + +func TestShellHandler_AndStopsOnFailure(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("cmd1 ", 1, "error output", nil) // First command fails + mockExec.SetResult("cmd2 ", 0, "output2", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "cmd1 && cmd2", + } + + exitCode, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 1 { + t.Errorf("expected exit code 1, got %d", exitCode) + } +} + +func TestShellHandler_Execute_OrOperator(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("cmd1 ", 1, "error", nil) // First fails + mockExec.SetResult("cmd2 ", 0, "success", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "cmd1 || cmd2", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "success") { + t.Errorf("expected output to contain 'success', got %q", output) + } +} + +func TestShellHandler_OrStopsOnSuccess(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("cmd1 ", 0, "success", nil) // First succeeds + mockExec.SetResult("cmd2 ", 0, "output2", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "cmd1 || cmd2", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + // Only cmd1's output should be present (cmd2 shouldn't run) + if !strings.Contains(output, "success") { + t.Errorf("expected output to contain 'success', got %q", output) + } +} + +func TestShellHandler_Execute_Semicolon(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("cmd1 ", 1, "error", nil) // First fails + mockExec.SetResult("cmd2 ", 0, "success", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "cmd1 ; cmd2", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Last command exit code + if exitCode != 0 { + t.Errorf("expected exit code 0 (from cmd2), got %d", exitCode) + } + // Both outputs should be present + if !strings.Contains(output, "error") || !strings.Contains(output, "success") { + t.Errorf("expected output to contain both outputs, got %q", output) + } +} + +func TestShellHandler_CustomUser(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("whoami ", 0, "testuser", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "whoami", + Username: "testuser", + } + + exitCode, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + + cmds := mockExec.GetExecutedCommands() + // Exec method adds to commands, then calls Run which also adds + // So we check that at least one command has the right user + foundCorrectUser := false + for _, cmd := range cmds { + if cmd.User == "testuser" { + foundCorrectUser = true + break + } + } + if !foundCorrectUser { + t.Errorf("expected at least one command with user 'testuser', got %+v", cmds) + } +} + +func TestShellHandler_DefaultUser(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("whoami ", 0, "root", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "whoami", + // Username not set - should default to "root" + } + + _, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cmds := mockExec.GetExecutedCommands() + foundRootUser := false + for _, cmd := range cmds { + if cmd.User == "root" { + foundRootUser = true + break + } + } + if !foundRootUser { + t.Errorf("expected at least one command with default user 'root', got %+v", cmds) + } +} + +func TestShellHandler_WithTimeout(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("sleep 1", 0, "", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "sleep 1", + Timeout: 10 * time.Second, + } + + exitCode, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } +} + +func TestShellHandler_DefaultTimeout(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("ls ", 0, "output", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "ls", + // Timeout not set - should default to 120 seconds + } + + exitCode, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } +} + +func TestShellHandler_Validate_Empty(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + handler := NewShellHandler(mockExec) + + args := &common.CommandArgs{ + Command: "", // Empty command + } + + err := handler.Validate(common.ShellCmd.String(), args) + + if err == nil { + t.Error("expected error for empty command") + } + if !strings.Contains(err.Error(), "required") { + t.Errorf("error should mention 'required', got: %v", err) + } +} + +func TestShellHandler_Validate_Valid(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + handler := NewShellHandler(mockExec) + + args := &common.CommandArgs{ + Command: "ls -la", + } + + err := handler.Validate(common.ShellCmd.String(), args) + + if err != nil { + t.Errorf("unexpected validation error: %v", err) + } +} + +func TestShellHandler_UnknownCommand(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "ls", + } + + exitCode, _, err := handler.Execute(ctx, "unknown_command", args) + + if err == nil { + t.Error("expected error for unknown command") + } + if exitCode != 1 { + t.Errorf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(err.Error(), "unknown shell command") { + t.Errorf("error should mention 'unknown shell command', got: %v", err) + } +} + +func TestShellHandler_CommandExecutionError(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + // Set up a command that returns -1 exit code with error + mockExec.SetResult("failing_cmd ", -1, "", errors.New("command not found")) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "failing_cmd", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error from Execute: %v", err) + } + if exitCode != -1 { + t.Errorf("expected exit code -1, got %d", exitCode) + } + if !strings.Contains(output, "not found") { + t.Errorf("expected output to contain error message, got %q", output) + } +} + +func TestShellHandler_WithEnv(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("printenv ", 0, "TEST_VAR=test_value", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "printenv", + Env: map[string]string{ + "TEST_VAR": "test_value", + }, + } + + exitCode, _, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } +} + +func TestShellHandler_MixedOperators(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockExec.SetResult("cmd1 ", 0, "out1", nil) + mockExec.SetResult("cmd2 ", 1, "err2", nil) + mockExec.SetResult("cmd3 ", 0, "out3", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + // cmd1 && cmd2 || cmd3 + // cmd1 succeeds (0), run cmd2 + // cmd2 fails (1), run cmd3 (due to ||) + args := &common.CommandArgs{ + Command: "cmd1 && cmd2 || cmd3", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + // cmd1's output should be present + if !strings.Contains(output, "out1") { + t.Errorf("expected output to contain 'out1', got %q", output) + } +} + +func TestShellHandler_MultiWordCommand(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + // "ls -la /tmp" -> Fields splits to ["ls", "-la", "/tmp"] + // Exec is called with args[0]="ls", args[1:]=["-la", "/tmp"] + // Run key is "ls -la /tmp" + mockExec.SetResult("ls -la /tmp", 0, "total 0", nil) + handler := NewShellHandler(mockExec) + ctx := context.Background() + + args := &common.CommandArgs{ + Command: "ls -la /tmp", + } + + exitCode, output, err := handler.Execute(ctx, common.ShellCmd.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "total") { + t.Errorf("expected output to contain 'total', got %q", output) + } +} diff --git a/pkg/executor/handlers/system/system.go b/pkg/executor/handlers/system/system.go new file mode 100644 index 0000000..ddae0df --- /dev/null +++ b/pkg/executor/handlers/system/system.go @@ -0,0 +1,302 @@ +package system + +import ( + "context" + "fmt" + "time" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/config" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/alpacax/alpamon/pkg/version" + "github.com/rs/zerolog/log" +) + +// SystemHandler handles system-level commands like restart, reboot, shutdown, upgrade +type SystemHandler struct { + *common.BaseHandler + wsClient common.WSClient + ctxManager *agent.ContextManager + pool *pool.Pool +} + +// NewSystemHandler creates a new system handler +func NewSystemHandler(cmdExecutor common.CommandExecutor, wsClient common.WSClient, ctxManager *agent.ContextManager, pool *pool.Pool) *SystemHandler { + h := &SystemHandler{ + BaseHandler: common.NewBaseHandler( + common.System, + []common.CommandType{ + common.Upgrade, + common.Restart, + common.Quit, + common.Reboot, + common.Shutdown, + common.Update, + common.ByeBye, + }, + cmdExecutor, + ), + wsClient: wsClient, + ctxManager: ctxManager, + pool: pool, + } + return h +} + +// Execute runs the system command +func (h *SystemHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.Upgrade.String(): + return h.handleUpgrade(ctx) + case common.Restart.String(): + return h.handleRestart(args) + case common.Quit.String(): + return h.handleQuit() + case common.ByeBye.String(): + return h.handleUninstall() + case common.Reboot.String(): + return h.handleReboot() + case common.Shutdown.String(): + return h.handleShutdown() + case common.Update.String(): + return h.handleSystemUpdate(ctx) + default: + return 1, "", fmt.Errorf("unknown system command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *SystemHandler) Validate(cmd string, args *common.CommandArgs) error { + // Most system commands don't require arguments + return nil +} + +// handleUpgrade handles the upgrade command +func (h *SystemHandler) handleUpgrade(ctx context.Context) (int, string, error) { + latestVersion := utils.GetLatestVersion() + + if version.Version == latestVersion { + return 0, fmt.Sprintf("Alpamon is already up-to-date (version: %s)", version.Version), nil + } + + var cmd string + if utils.PlatformLike == "debian" { + cmd = "apt-get update -y && apt-get install --only-upgrade alpamon -y" + } else if utils.PlatformLike == "rhel" { + cmd = "yum update -y alpamon" + } else { + return 1, fmt.Sprintf("Platform '%s' not supported.", utils.PlatformLike), nil + } + + log.Debug().Msgf("Upgrading alpamon from %s to %s using command: '%s'...", version.Version, latestVersion, cmd) + + exitCode, output, err := h.Executor.RunAsUser(ctx, "root", "sh", "-c", cmd) + return exitCode, output, err +} + +// handleRestart handles the restart command +func (h *SystemHandler) handleRestart(args *common.CommandArgs) (int, string, error) { + target := args.Target + if target == "" { + target = "alpamon" + } + message := "Alpamon will restart in 1 second." + + switch target { + case "collector": + log.Info().Msg("Restart collector.") + h.wsClient.RestartCollector() + message = "Collector will be restarted." + default: + // Submit to worker pool for managed execution + poolCtx, cancel := h.ctxManager.NewContext(2 * time.Second) + submitted := false + defer func() { + if !submitted { + cancel() + } + }() + + err := h.pool.Submit(poolCtx, func() error { + defer cancel() + time.Sleep(1 * time.Second) + h.wsClient.Restart() + return nil + }) + if err != nil { + log.Error().Err(err).Msg("Failed to submit restart task to pool") + } else { + submitted = true + } + } + + return 0, message, nil +} + +// handleQuit handles the quit command +func (h *SystemHandler) handleQuit() (int, string, error) { + // Submit to worker pool for managed execution + poolCtx, cancel := h.ctxManager.NewContext(2 * time.Second) + submitted := false + defer func() { + if !submitted { + cancel() + } + }() + + err := h.pool.Submit(poolCtx, func() error { + defer cancel() + time.Sleep(1 * time.Second) + h.wsClient.ShutDown() + return nil + }) + if err != nil { + log.Error().Err(err).Msg("Failed to submit quit task to pool") + } else { + submitted = true + } + return 0, "Alpamon will shutdown in 1 second.", nil +} + +// handleUninstall handles the byebye (uninstall) command +func (h *SystemHandler) handleUninstall() (int, string, error) { + log.Info().Msg("Uninstall request received.") + + // Execute uninstall after 1 second to ensure response is sent + time.AfterFunc(1*time.Second, func() { + h.executeUninstall() + }) + + return 0, "Starting uninstall process...", nil +} + +// executeUninstall performs the actual uninstall +func (h *SystemHandler) executeUninstall() { + var cmd string + + if utils.PlatformLike == "debian" { + // Use purge to remove package and config files + cmd = "apt-get purge alpamon -y && apt-get autoremove -y" + } else if utils.PlatformLike == "rhel" { + // Remove package using yum + cmd = "yum remove alpamon -y" + } else if utils.PlatformLike == "darwin" { + // For macOS development environment, just shutdown + log.Warn().Msgf("Platform '%s' does not support full uninstall. Shutting down instead.", utils.PlatformLike) + h.wsClient.ShutDown() + return + } else { + log.Error().Msgf("Platform '%s' not supported for uninstall.", utils.PlatformLike) + h.wsClient.ShutDown() + return + } + + // Build the complete uninstall command that includes: + // 1. Package removal + // 2. Cleanup of transient systemd units created by this operation + uninstallCmd := fmt.Sprintf("%s; systemctl reset-failed alpamon-uninstall.service 2>/dev/null || true; systemctl reset-failed alpamon-uninstall.timer 2>/dev/null || true", cmd) + + // This ensures the uninstall continues even after the current process terminates + // The service will start 5 seconds after being scheduled + // --collect: Automatically clean up transient units after they complete (systemd 236+) + scheduleCmdArgs := []string{ + "systemd-run", + "--collect", + "--uid=0", + "--gid=0", + "--unit=alpamon-uninstall", + "--timer-property=OnActiveSec=5", + "--timer-property=AccuracySec=1s", + "--description=Alpamon Uninstall Service", + "/bin/sh", "-c", uninstallCmd, + } + + ctx := context.Background() + exitCode, output, _ := h.Executor.RunWithTimeout(ctx, 30*time.Second, scheduleCmdArgs[0], scheduleCmdArgs[1:]...) + + if exitCode != 0 { + log.Error().Msgf("Failed to schedule uninstall: %s", output) + // Fallback to direct execution + _, _, _ = h.Executor.RunAsUser(ctx, "root", "sh", "-c", cmd) + } + + // Shutdown the process after scheduling + h.wsClient.ShutDown() +} + +// handleReboot handles the reboot command +func (h *SystemHandler) handleReboot() (int, string, error) { + log.Info().Msg("Reboot request received.") + + // Submit to worker pool for managed execution + poolCtx, cancel := h.ctxManager.NewContext(time.Duration(config.GlobalSettings.PoolDefaultTimeout) * time.Second) + submitted := false + defer func() { + if !submitted { + cancel() + } + }() + + err := h.pool.Submit(poolCtx, func() error { + defer cancel() + time.Sleep(1 * time.Second) + _, _, _ = h.Executor.RunAsUser(poolCtx, "root", "reboot") + return nil + }) + if err != nil { + log.Error().Err(err).Msg("Failed to submit reboot task to pool") + } else { + submitted = true + } + + return 0, "Server will reboot in 1 second", nil +} + +// handleShutdown handles the shutdown command +func (h *SystemHandler) handleShutdown() (int, string, error) { + log.Info().Msg("Shutdown request received.") + + // Submit to worker pool for managed execution + poolCtx, cancel := h.ctxManager.NewContext(time.Duration(config.GlobalSettings.PoolDefaultTimeout) * time.Second) + submitted := false + defer func() { + if !submitted { + cancel() + } + }() + + err := h.pool.Submit(poolCtx, func() error { + defer cancel() + time.Sleep(1 * time.Second) + _, _, _ = h.Executor.RunAsUser(poolCtx, "root", "shutdown", "now") + return nil + }) + if err != nil { + log.Error().Err(err).Msg("Failed to submit shutdown task to pool") + } else { + submitted = true + } + + return 0, "Server will shutdown in 1 second", nil +} + +// handleSystemUpdate handles the update command (system-wide updates) +func (h *SystemHandler) handleSystemUpdate(ctx context.Context) (int, string, error) { + log.Info().Msg("Upgrade system requested.") + + var cmd string + if utils.PlatformLike == "debian" { + cmd = "apt-get update && apt-get upgrade -y && apt-get autoremove -y" + } else if utils.PlatformLike == "rhel" { + cmd = "yum update -y" + } else if utils.PlatformLike == "darwin" { + cmd = "brew upgrade" + } else { + return 1, fmt.Sprintf("Platform '%s' not supported.", utils.PlatformLike), nil + } + + exitCode, output, err := h.Executor.RunAsUser(ctx, "root", "sh", "-c", cmd) + return exitCode, output, err +} diff --git a/pkg/executor/handlers/system/system_test.go b/pkg/executor/handlers/system/system_test.go new file mode 100644 index 0000000..aec07c3 --- /dev/null +++ b/pkg/executor/handlers/system/system_test.go @@ -0,0 +1,428 @@ +package system + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// MockWSClient is a mock implementation of WSClient for testing +type MockWSClient struct { + RestartCalled bool + ShutDownCalled bool + RestartCollectorCalled bool +} + +func (m *MockWSClient) Restart() { + m.RestartCalled = true +} + +func (m *MockWSClient) ShutDown() { + m.ShutDownCalled = true +} + +func (m *MockWSClient) RestartCollector() { + m.RestartCollectorCalled = true +} + +func TestSystemHandler_Name(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + if handler.Name() != common.System.String() { + t.Errorf("expected name %q, got %q", common.System.String(), handler.Name()) + } +} + +func TestSystemHandler_Commands(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + commands := handler.Commands() + + expected := []string{ + common.Upgrade.String(), + common.Restart.String(), + common.Quit.String(), + common.Reboot.String(), + common.Shutdown.String(), + common.Update.String(), + common.ByeBye.String(), + } + + if len(commands) != len(expected) { + t.Errorf("expected %d commands, got %d", len(expected), len(commands)) + return + } + + for i, cmd := range commands { + if cmd != expected[i] { + t.Errorf("command %d: expected %q, got %q", i, expected[i], cmd) + } + } +} + +func TestSystemHandler_Restart_Collector(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{ + Target: "collector", + } + + exitCode, output, err := handler.Execute(ctx, common.Restart.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !mockWS.RestartCollectorCalled { + t.Error("expected RestartCollector to be called") + } + if !strings.Contains(output, "restarted") { + t.Errorf("expected output to contain 'restarted', got %q", output) + } +} + +func TestSystemHandler_Restart_Default(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{ + // No target - should default to alpamon + } + + exitCode, output, err := handler.Execute(ctx, common.Restart.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "restart") { + t.Errorf("expected output to mention restart, got %q", output) + } + // Give time for the pool task to execute + time.Sleep(100 * time.Millisecond) +} + +func TestSystemHandler_Restart_Alpamon(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{ + Target: "alpamon", + } + + exitCode, output, err := handler.Execute(ctx, common.Restart.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "restart") { + t.Errorf("expected output to mention restart, got %q", output) + } +} + +func TestSystemHandler_Quit(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Quit.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "shutdown") { + t.Errorf("expected output to mention shutdown, got %q", output) + } +} + +func TestSystemHandler_Reboot(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Reboot.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "reboot") { + t.Errorf("expected output to mention reboot, got %q", output) + } +} + +func TestSystemHandler_Shutdown(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.Shutdown.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "shutdown") { + t.Errorf("expected output to mention shutdown, got %q", output) + } +} + +func TestSystemHandler_UnknownCommand(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + exitCode, _, err := handler.Execute(ctx, "unknown_command", args) + + if err == nil { + t.Error("expected error for unknown command") + } + if exitCode != 1 { + t.Errorf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(err.Error(), "unknown system command") { + t.Errorf("error should mention 'unknown system command', got: %v", err) + } +} + +func TestSystemHandler_Validate(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + + testCases := []struct { + name string + cmd string + args *common.CommandArgs + }{ + {"upgrade", common.Upgrade.String(), &common.CommandArgs{}}, + {"restart", common.Restart.String(), &common.CommandArgs{Target: "alpamon"}}, + {"quit", common.Quit.String(), &common.CommandArgs{}}, + {"reboot", common.Reboot.String(), &common.CommandArgs{}}, + {"shutdown", common.Shutdown.String(), &common.CommandArgs{}}, + {"update", common.Update.String(), &common.CommandArgs{}}, + {"byebye", common.ByeBye.String(), &common.CommandArgs{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := handler.Validate(tc.cmd, tc.args) + if err != nil { + t.Errorf("unexpected validation error: %v", err) + } + }) + } +} + +func TestSystemHandler_PoolShutdown(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + // Shutdown pool first + _ = workerPool.Shutdown(100 * time.Millisecond) + ctxManager.Shutdown() + + args := &common.CommandArgs{ + Target: "alpamon", + } + + // Should handle pool submission failure gracefully + exitCode, output, err := handler.Execute(ctx, common.Restart.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should still return success message even if pool submission fails + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "restart") { + t.Errorf("expected output to mention restart, got %q", output) + } +} + +func TestSystemHandler_Upgrade_UpToDate(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + // This test depends on the actual version comparison + // GetLatestVersion() makes an HTTP call, so this test will behave differently + // depending on network availability + exitCode, output, err := handler.Execute(ctx, common.Upgrade.String(), args) + + // Should not return an error regardless of version comparison result + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // exitCode could be 0 (up-to-date or upgrade success) or 1 (platform not supported) + if exitCode != 0 && exitCode != 1 { + t.Errorf("expected exit code 0 or 1, got %d", exitCode) + } + // Output should mention either "up-to-date", "Upgrading", or "not supported" + if !strings.Contains(output, "up-to-date") && + !strings.Contains(output, "Upgrading") && + !strings.Contains(output, "not supported") && + !strings.Contains(output, "Alpamon") { + t.Errorf("expected meaningful output, got %q", output) + } +} + +func TestSystemHandler_Update(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + // This test depends on the actual platform + exitCode, output, err := handler.Execute(ctx, common.Update.String(), args) + + // Should not return an error + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // exitCode could be 0 (success) or 1 (platform not supported) + if exitCode != 0 && exitCode != 1 { + t.Errorf("expected exit code 0 or 1, got %d", exitCode) + } + // Output should be present + if output == "" && exitCode == 1 { + t.Errorf("expected some output for unsupported platform") + } +} + +func TestSystemHandler_Uninstall(t *testing.T) { + mockExec := common.NewMockCommandExecutor(t) + mockWS := &MockWSClient{} + ctxManager := agent.NewContextManager() + workerPool := pool.NewPool(2, 10) + defer func() { _ = workerPool.Shutdown(1 * time.Second) }() + defer ctxManager.Shutdown() + + handler := NewSystemHandler(mockExec, mockWS, ctxManager, workerPool) + ctx := context.Background() + + args := &common.CommandArgs{} + + exitCode, output, err := handler.Execute(ctx, common.ByeBye.String(), args) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if !strings.Contains(output, "uninstall") { + t.Errorf("expected output to mention uninstall, got %q", output) + } +} diff --git a/pkg/executor/handlers/terminal/terminal.go b/pkg/executor/handlers/terminal/terminal.go new file mode 100644 index 0000000..297f4e9 --- /dev/null +++ b/pkg/executor/handlers/terminal/terminal.go @@ -0,0 +1,199 @@ +package terminal + +import ( + "context" + "fmt" + "os" + "os/exec" + "syscall" + + "github.com/alpacax/alpamon/internal/protocol" + "github.com/alpacax/alpamon/pkg/config" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/runner" + "github.com/alpacax/alpamon/pkg/scheduler" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// TerminalHandler handles PTY and FTP terminal commands +type TerminalHandler struct { + *common.BaseHandler + apiSession *scheduler.Session +} + +// NewTerminalHandler creates a new terminal handler +func NewTerminalHandler(cmdExecutor common.CommandExecutor, apiSession *scheduler.Session) *TerminalHandler { + h := &TerminalHandler{ + BaseHandler: common.NewBaseHandler( + common.Terminal, + []common.CommandType{ + common.OpenPty, + common.OpenFtp, + common.ResizePty, + }, + cmdExecutor, + ), + apiSession: apiSession, + } + return h +} + +// Execute runs the terminal command +func (h *TerminalHandler) Execute(_ context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.OpenPty.String(): + return h.handleOpenPTY(args) + case common.OpenFtp.String(): + return h.handleOpenFTP(args) + case common.ResizePty.String(): + return h.handleResizePTY(args) + default: + return 1, "", fmt.Errorf("unknown terminal command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *TerminalHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.OpenPty.String(): + data := PTYData{ + SessionID: args.SessionID, + URL: args.URL, + Username: args.Username, + Groupname: args.Groupname, + HomeDirectory: args.HomeDirectory, + Rows: int(args.Rows), + Cols: int(args.Cols), + } + return h.ValidateStruct(data) + + case common.OpenFtp.String(): + data := FTPData{ + SessionID: args.SessionID, + URL: args.URL, + Username: args.Username, + Groupname: args.Groupname, + HomeDirectory: args.HomeDirectory, + } + return h.ValidateStruct(data) + + case common.ResizePty.String(): + data := ResizePTYData{ + SessionID: args.SessionID, + Rows: int(args.Rows), + Cols: int(args.Cols), + } + return h.ValidateStruct(data) + + default: + return fmt.Errorf("unknown terminal command: %s", cmd) + } +} + +// handleOpenPTY opens a new PTY terminal session +func (h *TerminalHandler) handleOpenPTY(args *common.CommandArgs) (int, string, error) { + err := h.Validate(common.OpenPty.String(), args) + if err != nil { + return 1, fmt.Sprintf("openpty: Not enough information. %s", err.Error()), nil + } + + data := protocol.CommandData{ + SessionID: args.SessionID, + URL: args.URL, + Username: args.Username, + Groupname: args.Groupname, + HomeDirectory: args.HomeDirectory, + Rows: uint16(args.Rows), + Cols: uint16(args.Cols), + } + + log.Info(). + Str("sessionID", data.SessionID). + Str("username", data.Username). + Uint16("rows", data.Rows). + Uint16("cols", data.Cols). + Msg("Opening PTY terminal") + + ptyClient := runner.NewPtyClient(data, h.apiSession) + go ptyClient.RunPtyBackground() + + return 0, "Spawned a pty terminal.", nil +} + +// handleOpenFTP opens a new FTP session +func (h *TerminalHandler) handleOpenFTP(args *common.CommandArgs) (int, string, error) { + err := h.Validate(common.OpenFtp.String(), args) + if err != nil { + return 1, fmt.Sprintf("openftp: Not enough information. %s", err.Error()), nil + } + + log.Info(). + Str("sessionID", args.SessionID). + Str("username", args.Username). + Str("url", args.URL). + Msg("Opening FTP session") + + result, err := utils.Demote(args.Username, args.Groupname, utils.DemoteOptions{ValidateGroup: false}) + if err != nil { + log.Debug().Err(err).Msg("Failed to get demote permission") + return 1, fmt.Sprintf("openftp: Failed to get demoted permission. %v", err), nil + } + + var sysProcAttr *syscall.SysProcAttr + var homeDirectory string + if result != nil { + sysProcAttr = result.SysProcAttr + if result.User != nil { + homeDirectory = result.User.HomeDir + } + } + + executable, err := os.Executable() + if err != nil { + log.Debug().Err(err).Msg("Failed to get executable path") + return 1, fmt.Sprintf("openftp: Failed to get executable path. %v", err), nil + } + + cmd := exec.Command( + executable, + "ftp", + args.URL, + config.GlobalSettings.ServerURL, + homeDirectory, + ) + cmd.SysProcAttr = sysProcAttr + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err = cmd.Start() + if err != nil { + log.Debug().Err(err).Msg("Failed to start ftp worker process") + return 1, fmt.Sprintf("openftp: Failed to start ftp worker process. %v", err), nil + } + + go func() { _ = cmd.Wait() }() + + return 0, "Spawned a ftp terminal.", nil +} + +// handleResizePTY resizes a PTY terminal +func (h *TerminalHandler) handleResizePTY(args *common.CommandArgs) (int, string, error) { + log.Info(). + Str("sessionID", args.SessionID). + Int("rows", int(args.Rows)). + Int("cols", int(args.Cols)). + Msg("Resizing PTY") + + terminal := runner.GetTerminal(args.SessionID) + if terminal == nil { + return 1, "Invalid session ID", nil + } + + err := terminal.Resize(uint16(args.Rows), uint16(args.Cols)) + if err != nil { + return 1, err.Error(), nil + } + + return 0, fmt.Sprintf("Resized terminal for %s to %dx%d.", args.SessionID, args.Cols, args.Rows), nil +} diff --git a/pkg/executor/handlers/terminal/terminal_test.go b/pkg/executor/handlers/terminal/terminal_test.go new file mode 100644 index 0000000..95d47b7 --- /dev/null +++ b/pkg/executor/handlers/terminal/terminal_test.go @@ -0,0 +1,131 @@ +package terminal + +import ( + "context" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +func TestTerminalHandler_Validate(t *testing.T) { + handler := NewTerminalHandler(common.NewMockCommandExecutor(t), nil) + + tests := []struct { + name string + cmd string + args *common.CommandArgs + wantErr bool + }{ + { + name: "openpty valid", + cmd: "openpty", + args: &common.CommandArgs{ + SessionID: "session123", + URL: "ws://localhost:8080", + Username: "testuser", + Groupname: "testgroup", + HomeDirectory: "/home/testuser", + Rows: 24, + Cols: 80, + }, + wantErr: false, + }, + { + name: "openpty missing required fields", + cmd: "openpty", + args: &common.CommandArgs{ + SessionID: "session123", + // Missing URL and Username + }, + wantErr: true, + }, + { + name: "openftp valid", + cmd: "openftp", + args: &common.CommandArgs{ + SessionID: "ftp123", + URL: "ftp://localhost", + Username: "testuser", + }, + wantErr: false, + }, + { + name: "openftp missing username", + cmd: "openftp", + args: &common.CommandArgs{ + SessionID: "ftp123", + URL: "ftp://localhost", + }, + wantErr: true, + }, + { + name: "resizepty valid", + cmd: "resizepty", + args: &common.CommandArgs{ + SessionID: "session123", + Rows: 40, + Cols: 120, + }, + wantErr: false, + }, + { + name: "resizepty missing session ID", + cmd: "resizepty", + args: &common.CommandArgs{ + Rows: 40, + Cols: 120, + }, + wantErr: true, + }, + { + name: "unknown command", + cmd: "unknown", + args: &common.CommandArgs{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handler.Validate(tt.cmd, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTerminalHandler_Execute_UnknownCommand(t *testing.T) { + handler := NewTerminalHandler(common.NewMockCommandExecutor(t), nil) + + exitCode, _, err := handler.Execute(context.TODO(), "unknown", &common.CommandArgs{}) + + if err == nil { + t.Error("Execute() expected error for unknown command") + } + if exitCode != 1 { + t.Errorf("Execute() exitCode = %v, want 1", exitCode) + } +} + +func TestTerminalHandler_ResizePTY_InvalidSession(t *testing.T) { + handler := NewTerminalHandler(common.NewMockCommandExecutor(t), nil) + + args := &common.CommandArgs{ + SessionID: "nonexistent", + Rows: 40, + Cols: 120, + } + + exitCode, output, err := handler.Execute(context.TODO(), "resizepty", args) + + if err != nil { + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != 1 { + t.Errorf("Execute() exitCode = %v, want 1", exitCode) + } + if output != "Invalid session ID" { + t.Errorf("Execute() output = %v, want 'Invalid session ID'", output) + } +} diff --git a/pkg/executor/handlers/terminal/types.go b/pkg/executor/handlers/terminal/types.go new file mode 100644 index 0000000..96b597d --- /dev/null +++ b/pkg/executor/handlers/terminal/types.go @@ -0,0 +1,28 @@ +package terminal + +// PTYData contains data for PTY operations +type PTYData struct { + SessionID string `json:"session_id" validate:"required"` + URL string `json:"url" validate:"required"` + Username string `json:"username" validate:"required"` + Groupname string `json:"groupname"` + HomeDirectory string `json:"home_directory"` + Rows int `json:"rows"` + Cols int `json:"cols"` +} + +// FTPData contains data for FTP operations +type FTPData struct { + SessionID string `json:"session_id" validate:"required"` + URL string `json:"url" validate:"required"` + Username string `json:"username" validate:"required"` + Groupname string `json:"groupname"` + HomeDirectory string `json:"home_directory"` +} + +// ResizePTYData contains data for resizing PTY +type ResizePTYData struct { + SessionID string `json:"session_id" validate:"required"` + Rows int `json:"rows" validate:"required"` + Cols int `json:"cols" validate:"required"` +} diff --git a/pkg/executor/handlers/tunnel/tunnel.go b/pkg/executor/handlers/tunnel/tunnel.go new file mode 100644 index 0000000..0d3e1e7 --- /dev/null +++ b/pkg/executor/handlers/tunnel/tunnel.go @@ -0,0 +1,109 @@ +package tunnel + +import ( + "context" + "fmt" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/runner" + "github.com/rs/zerolog/log" +) + +// TunnelHandler handles tunnel connection commands (opentunnel, closetunnel) +type TunnelHandler struct { + *common.BaseHandler +} + +// NewTunnelHandler creates a new tunnel handler +func NewTunnelHandler(cmdExecutor common.CommandExecutor) *TunnelHandler { + h := &TunnelHandler{ + BaseHandler: common.NewBaseHandler( + common.Tunnel, + []common.CommandType{ + common.OpenTunnel, + common.CloseTunnel, + }, + cmdExecutor, + ), + } + return h +} + +// Execute runs the tunnel command +func (h *TunnelHandler) Execute(_ context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + switch cmd { + case common.OpenTunnel.String(): + return h.handleOpenTunnel(args) + case common.CloseTunnel.String(): + return h.handleCloseTunnel(args) + default: + return 1, "", fmt.Errorf("unknown tunnel command: %s", cmd) + } +} + +// Validate checks if the arguments are valid for the command +func (h *TunnelHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.OpenTunnel.String(): + data := OpenTunnelData{ + SessionID: args.SessionID, + URL: args.URL, + TargetPort: args.TargetPort, + } + if err := h.ValidateStruct(data); err != nil { + return err + } + // Check for duplicate tunnel to prevent process leak + if _, exists := runner.GetActiveTunnel(args.SessionID); exists { + return fmt.Errorf("tunnel with session ID %s already exists", args.SessionID) + } + return nil + + case common.CloseTunnel.String(): + data := CloseTunnelData{ + SessionID: args.SessionID, + } + return h.ValidateStruct(data) + + default: + return fmt.Errorf("unknown tunnel command: %s", cmd) + } +} + +// handleOpenTunnel opens a new tunnel connection +func (h *TunnelHandler) handleOpenTunnel(args *common.CommandArgs) (int, string, error) { + err := h.Validate(common.OpenTunnel.String(), args) + if err != nil { + return 1, fmt.Sprintf("opentunnel: Not enough information. %s", err.Error()), nil + } + + log.Info(). + Str("sessionID", args.SessionID). + Int("targetPort", args.TargetPort). + Str("url", args.URL). + Msg("Opening tunnel connection") + + tunnelClient := runner.NewTunnelClient(args.SessionID, args.TargetPort, args.URL) + go tunnelClient.RunTunnelBackground() + + return 0, fmt.Sprintf("Tunnel opened for session %s to port %d.", args.SessionID, args.TargetPort), nil +} + +// handleCloseTunnel closes an existing tunnel connection +func (h *TunnelHandler) handleCloseTunnel(args *common.CommandArgs) (int, string, error) { + err := h.Validate(common.CloseTunnel.String(), args) + if err != nil { + return 1, fmt.Sprintf("closetunnel: Not enough information. %s", err.Error()), nil + } + + log.Info(). + Str("sessionID", args.SessionID). + Msg("Closing tunnel connection") + + err = runner.CloseTunnel(args.SessionID) + if err != nil { + return 1, fmt.Sprintf("closetunnel: Failed to close tunnel. %s", err.Error()), nil + } + + return 0, fmt.Sprintf("Tunnel closed for session %s.", args.SessionID), nil +} diff --git a/pkg/executor/handlers/tunnel/types.go b/pkg/executor/handlers/tunnel/types.go new file mode 100644 index 0000000..2c67a6d --- /dev/null +++ b/pkg/executor/handlers/tunnel/types.go @@ -0,0 +1,13 @@ +package tunnel + +// OpenTunnelData contains data for opening a tunnel +type OpenTunnelData struct { + SessionID string `json:"session_id" validate:"required"` + URL string `json:"url" validate:"required"` + TargetPort int `json:"target_port" validate:"required,min=1,max=65535"` +} + +// CloseTunnelData contains data for closing a tunnel +type CloseTunnelData struct { + SessionID string `json:"session_id" validate:"required"` +} diff --git a/pkg/executor/handlers/user/types.go b/pkg/executor/handlers/user/types.go new file mode 100644 index 0000000..96f99fb --- /dev/null +++ b/pkg/executor/handlers/user/types.go @@ -0,0 +1,27 @@ +package user + +// UserData contains data for user operations +type UserData struct { + Username string `validate:"required"` + UID uint64 `validate:"required"` + GID uint64 `validate:"required"` + Comment string `validate:"required"` + HomeDirectory string `validate:"required"` + HomeDirectoryPermission string `validate:"omitempty"` + Shell string `validate:"required"` + Groupname string `validate:"required"` + Groups []uint64 `validate:"omitempty"` +} + +// DeleteUserData contains data for user deletion +type DeleteUserData struct { + Username string `validate:"required"` + PurgeHomeDirectory bool `validate:"omitempty"` +} + +// ModUserData contains data for modifying user +type ModUserData struct { + Username string `validate:"required"` + Groupnames []string `validate:"required"` + Comment string `validate:"required"` +} diff --git a/pkg/executor/handlers/user/user.go b/pkg/executor/handlers/user/user.go new file mode 100644 index 0000000..8aa688b --- /dev/null +++ b/pkg/executor/handlers/user/user.go @@ -0,0 +1,359 @@ +package user + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/executor/services" + "github.com/alpacax/alpamon/pkg/utils" + "github.com/rs/zerolog/log" +) + +// UserHandler handles user management commands +type UserHandler struct { + *common.BaseHandler + groupService services.GroupService + syncManager common.SystemInfoManager +} + +// NewUserHandler creates a new user handler +func NewUserHandler(cmdExecutor common.CommandExecutor, groupService services.GroupService, syncManager common.SystemInfoManager) *UserHandler { + h := &UserHandler{ + BaseHandler: common.NewBaseHandler( + common.User, + []common.CommandType{ + common.AddUser, + common.DelUser, + common.ModUser, + }, + cmdExecutor, + ), + groupService: groupService, + syncManager: syncManager, + } + return h +} + +// Execute runs the user management command +func (h *UserHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + var exitCode int + var output string + var err error + + switch cmd { + case common.AddUser.String(): + exitCode, output, err = h.handleAddUser(ctx, args) + case common.DelUser.String(): + exitCode, output, err = h.handleDelUser(ctx, args) + case common.ModUser.String(): + exitCode, output, err = h.handleModUser(ctx, args) + default: + return 1, "", fmt.Errorf("unknown user command: %s", cmd) + } + + // Sync system info after successful command execution + if exitCode == 0 && h.syncManager != nil { + h.syncManager.SyncSystemInfo([]string{"groups", "users"}) + } + + return exitCode, output, err +} + +// Validate checks if the arguments are valid for the command +func (h *UserHandler) Validate(cmd string, args *common.CommandArgs) error { + switch cmd { + case common.AddUser.String(): + data := UserData{ + Username: args.Username, + UID: args.UID, + GID: args.GID, + Comment: args.Comment, + HomeDirectory: args.HomeDirectory, + HomeDirectoryPermission: args.HomeDirectoryPermission, + Shell: args.Shell, + Groupname: args.Groupname, + Groups: args.Groups, + } + if data.HomeDirectoryPermission == "" { + data.HomeDirectoryPermission = "0755" + } + if data.Shell == "" { + data.Shell = "/bin/bash" + } + return h.ValidateStruct(data) + + case common.DelUser.String(): + data := DeleteUserData{ + Username: args.Username, + PurgeHomeDirectory: args.PurgeHomeDirectory, + } + return h.ValidateStruct(data) + + case common.ModUser.String(): + data := ModUserData{ + Username: args.Username, + Groupnames: args.Groupnames, + Comment: args.Comment, + } + return h.ValidateStruct(data) + + default: + return fmt.Errorf("unknown user command: %s", cmd) + } +} + +// handleAddUser handles the adduser command +func (h *UserHandler) handleAddUser(ctx context.Context, args *common.CommandArgs) (int, string, error) { + // Extract and validate arguments + data := UserData{ + Username: args.Username, + UID: args.UID, + GID: args.GID, + Comment: args.Comment, + HomeDirectory: args.HomeDirectory, + HomeDirectoryPermission: args.HomeDirectoryPermission, + Shell: args.Shell, + Groupname: args.Groupname, + Groups: args.Groups, + } + if data.HomeDirectoryPermission == "" { + data.HomeDirectoryPermission = "0755" + } + if data.Shell == "" { + data.Shell = "/bin/bash" + } + + err := h.Validate(common.AddUser.String(), args) + if err != nil { + return 1, err.Error(), nil + } + + log.Info(). + Str("username", data.Username). + Uint64("uid", data.UID). + Uint64("gid", data.GID). + Str("home", data.HomeDirectory). + Msg("Adding user") + + var exitCode int + var output string + + // Platform-specific user addition + if utils.PlatformLike == "debian" { + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/adduser", + "--home", data.HomeDirectory, + "--shell", data.Shell, + "--uid", strconv.FormatUint(data.UID, 10), + "--gid", strconv.FormatUint(data.GID, 10), + "--gecos", data.Comment, + "--disabled-password", + data.Username, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else if utils.PlatformLike == "rhel" { + // Create primary group first if needed + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/groupadd", + "--gid", strconv.FormatUint(data.GID, 10), + data.Groupname, + ) + // Ignore if group already exists + if exitCode != 0 && !strings.Contains(output, "already exists") { + return exitCode, output, err + } + + // Create user + exitCode, output, err = h.Executor.Run( + ctx, + "/usr/sbin/useradd", + "--home-dir", data.HomeDirectory, + "--shell", data.Shell, + "--uid", strconv.FormatUint(data.UID, 10), + "--gid", strconv.FormatUint(data.GID, 10), + "--comment", data.Comment, + "--create-home", + data.Username, + ) + if exitCode != 0 { + return exitCode, output, err + } + } else { + return 1, fmt.Sprintf("Platform '%s' not supported for user management", utils.PlatformLike), nil + } + + // Set home directory permissions if specified + // codeql[go/path-injection]: Intentional - Admin-specified home directory permission + if data.HomeDirectoryPermission != "" && data.HomeDirectoryPermission != "0755" { + mode, err := strconv.ParseUint(data.HomeDirectoryPermission, 8, 32) + if err == nil { + _ = os.Chmod(data.HomeDirectory, os.FileMode(mode)) // lgtm[go/path-injection] + } + } + + // Add user to additional groups if specified + if len(data.Groups) > 0 && h.groupService != nil { + if err := h.groupService.AddUserToGroups(ctx, data.Username, data.Groups); err != nil { + log.Warn().Err(err).Msg("Failed to add user to additional groups") + return 0, fmt.Sprintf("User '%s' created but failed to add to groups: %v", data.Username, err), nil + } + } + + log.Info(). + Str("username", data.Username). + Int("exitCode", exitCode). + Msg("User added successfully") + + return exitCode, fmt.Sprintf("User '%s' added successfully", data.Username), nil +} + +// backupHomeDirectory backs up the user's home directory before deletion +func (h *UserHandler) backupHomeDirectory(username string) error { + homeDir := fmt.Sprintf("/home/%s", username) + timestamp := time.Now().UTC().Format(time.RFC3339) + backupDir := fmt.Sprintf("/home/deleted_users/%s_%s", username, timestamp) + + // Create backup parent directory + if err := os.MkdirAll("/home/deleted_users", 0700); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Check if home directory exists + // codeql[go/path-injection]: Intentional - User home directory for backup + if _, err := os.Stat(homeDir); err != nil { // lgtm[go/path-injection] + return fmt.Errorf("%s not exist: %w", homeDir, err) + } + + // Move home directory to backup location + // codeql[go/path-injection]: Intentional - Backup destination path + if err := os.Rename(homeDir, backupDir); err != nil { // lgtm[go/path-injection] + return fmt.Errorf("failed to move home directory: %w", err) + } + + // Change ownership to root + if err := utils.ChownRecursive(backupDir, 0, 0); err != nil { + return fmt.Errorf("failed to chown backup directory: %w", err) + } + + log.Info(). + Str("username", username). + Str("backupDir", backupDir). + Msg("Home directory backed up") + + return nil +} + +// handleDelUser handles the deluser command +func (h *UserHandler) handleDelUser(ctx context.Context, args *common.CommandArgs) (int, string, error) { + // Extract and validate arguments + data := DeleteUserData{ + Username: args.Username, + PurgeHomeDirectory: args.PurgeHomeDirectory, + } + + err := h.Validate(common.DelUser.String(), args) + if err != nil { + return 1, err.Error(), nil + } + + log.Info(). + Str("username", data.Username). + Bool("purge", data.PurgeHomeDirectory). + Msg("Deleting user") + + // Backup home directory if not purging + if !data.PurgeHomeDirectory { + if err := h.backupHomeDirectory(data.Username); err != nil { + return 1, err.Error(), nil + } + } + + var exitCode int + var output string + cmdArgs := []string{} + + // Platform-specific user deletion + if utils.PlatformLike == "debian" { + cmdArgs = append(cmdArgs, "/usr/sbin/deluser") + if data.PurgeHomeDirectory { + cmdArgs = append(cmdArgs, "--remove-home") + } + cmdArgs = append(cmdArgs, data.Username) + } else if utils.PlatformLike == "rhel" { + cmdArgs = append(cmdArgs, "/usr/sbin/userdel") + if data.PurgeHomeDirectory { + cmdArgs = append(cmdArgs, "-r") + } + cmdArgs = append(cmdArgs, data.Username) + } else { + return 1, fmt.Sprintf("Platform '%s' not supported for user management", utils.PlatformLike), nil + } + + exitCode, output, err = h.Executor.Run( + ctx, + cmdArgs[0], cmdArgs[1:]..., + ) + if exitCode != 0 { + return exitCode, output, err + } + + log.Info(). + Str("username", data.Username). + Int("exitCode", exitCode). + Msg("User deleted successfully") + + return exitCode, fmt.Sprintf("User '%s' deleted successfully", data.Username), nil +} + +// handleModUser handles the moduser command +func (h *UserHandler) handleModUser(ctx context.Context, args *common.CommandArgs) (int, string, error) { + // Extract and validate arguments + data := ModUserData{ + Username: args.Username, + Groupnames: args.Groupnames, + Comment: args.Comment, + } + + err := h.Validate(common.ModUser.String(), args) + if err != nil { + return 1, err.Error(), nil + } + + log.Info(). + Str("username", data.Username). + Strs("groups", data.Groupnames). + Str("comment", data.Comment). + Msg("Modifying user") + + // Build usermod arguments + cmdArgs := []string{"/usr/sbin/usermod"} + if data.Comment != "" { + cmdArgs = append(cmdArgs, "--comment", data.Comment) + } + if len(data.Groupnames) > 0 { + cmdArgs = append(cmdArgs, "-G", strings.Join(data.Groupnames, ",")) + } + cmdArgs = append(cmdArgs, data.Username) + + // Execute usermod + exitCode, output, err := h.Executor.Run(ctx, cmdArgs[0], cmdArgs[1:]...) + if exitCode != 0 { + return exitCode, output, err + } + + log.Info(). + Str("username", data.Username). + Int("exitCode", exitCode). + Msg("User modified successfully") + + return exitCode, fmt.Sprintf("User '%s' modified successfully", data.Username), nil +} diff --git a/pkg/executor/handlers/user/user_test.go b/pkg/executor/handlers/user/user_test.go new file mode 100644 index 0000000..b35d401 --- /dev/null +++ b/pkg/executor/handlers/user/user_test.go @@ -0,0 +1,298 @@ +package user + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/alpacax/alpamon/pkg/utils" +) + +// MockGroupService implements services.GroupService for testing +type MockGroupService struct { + AddUserToGroupsCalled bool + AddUserToGroupsError error +} + +func (m *MockGroupService) AddUserToGroups(ctx context.Context, username string, gids []uint64) error { + m.AddUserToGroupsCalled = true + return m.AddUserToGroupsError +} + +func TestUserHandler_Execute(t *testing.T) { + tests := []struct { + name string + cmd string + args *common.CommandArgs + setupMock func(*common.MockCommandExecutor) + groupService *MockGroupService + wantCode int + wantErr bool + }{ + { + name: "adduser success debian", + cmd: "adduser", + args: &common.CommandArgs{ + Username: "testuser", + UID: 1001, + GID: 1001, + Comment: "Test User", + HomeDirectory: "/home/testuser", + Shell: "/bin/bash", + Groupname: "testgroup", + Groups: []uint64{1002, 1003}, + }, + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult(fmt.Sprintf("/usr/sbin/adduser --home /home/testuser --shell /bin/bash --uid %d --gid %d --gecos Test User --disabled-password testuser", 1001, 1001), 0, "User created", nil) + }, + groupService: &MockGroupService{}, + wantCode: 0, + wantErr: false, + }, + { + name: "deluser success", + cmd: "deluser", + args: &common.CommandArgs{ + Username: "testuser", + PurgeHomeDirectory: true, + }, + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult("/usr/sbin/deluser --remove-home testuser", 0, "User deleted", nil) + }, + groupService: &MockGroupService{}, + wantCode: 0, + wantErr: false, + }, + { + name: "moduser success", + cmd: "moduser", + args: &common.CommandArgs{ + Username: "testuser", + Groupnames: []string{"sudo", "docker"}, + Comment: "Updated comment", + }, + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult("/usr/sbin/usermod -c Updated comment testuser", 0, "User comment updated", nil) + mock.SetResult("/usr/sbin/usermod -G sudo,docker testuser", 0, "User groups updated", nil) + }, + groupService: &MockGroupService{}, + wantCode: 0, + wantErr: false, + }, + { + name: "unknown command", + cmd: "unknownuser", + args: &common.CommandArgs{}, + groupService: &MockGroupService{}, + wantCode: 1, + wantErr: true, + }, + { + name: "adduser missing username", + cmd: "adduser", + args: &common.CommandArgs{ + UID: 1001, + GID: 1001, + }, + groupService: &MockGroupService{}, + wantCode: 1, + wantErr: false, + }, + { + name: "adduser failure", + cmd: "adduser", + args: &common.CommandArgs{ + Username: "testuser", + UID: 1001, + GID: 1001, + Comment: "Test User", + HomeDirectory: "/home/testuser", + Shell: "/bin/bash", + Groupname: "testgroup", + }, + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult(fmt.Sprintf("/usr/sbin/adduser --home /home/testuser --shell /bin/bash --uid %d --gid %d --gecos Test User --disabled-password testuser", 1001, 1001), 1, "User add failed", errors.New("user add error")) + }, + groupService: &MockGroupService{}, + wantCode: 1, + wantErr: true, // error is returned as part of output, not actual go error + }, + } + + for _, tt := range tests { + tt := tt // Capture range variable. + t.Run(tt.name, func(t *testing.T) { + // Set platform like to debian for the test + originalPlatformLike := utils.PlatformLike + utils.SetPlatformLike("debian") + t.Cleanup(func() { + utils.SetPlatformLike(originalPlatformLike) + }) + + mock := common.NewMockCommandExecutor(t) + if tt.setupMock != nil { + tt.setupMock(mock) + } + + handler := NewUserHandler(mock, tt.groupService, nil) + ctx := context.Background() + + exitCode, output, err := handler.Execute(ctx, tt.cmd, tt.args) + + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + } + if exitCode != tt.wantCode { + t.Errorf("Execute() exitCode = %v, want %v", exitCode, tt.wantCode) + } + if exitCode == 0 && output == "" && !tt.wantErr { + t.Error("Execute() returned success but no output") + } + }) + } +} + +func TestUserHandler_Validate(t *testing.T) { + handler := NewUserHandler(common.NewMockCommandExecutor(t), &MockGroupService{}, nil) + + tests := []struct { + name string + cmd string + args *common.CommandArgs + wantErr bool + }{ + { + name: "adduser valid", + cmd: "adduser", + args: &common.CommandArgs{ + Username: "testuser", + UID: 1001, + GID: 1001, + Comment: "Test User", + HomeDirectory: "/home/testuser", + Shell: "/bin/bash", + Groupname: "testgroup", + }, + wantErr: false, + }, + { + name: "adduser missing required fields", + cmd: "adduser", + args: &common.CommandArgs{ + Username: "testuser", + // Missing other required fields + }, + wantErr: true, + }, + { + name: "deluser valid", + cmd: "deluser", + args: &common.CommandArgs{ + Username: "testuser", + }, + wantErr: false, + }, + { + name: "moduser valid", + cmd: "moduser", + args: &common.CommandArgs{ + Username: "testuser", + Groupnames: []string{"sudo"}, + Comment: "Updated", + }, + wantErr: false, + }, + { + name: "unknown command", + cmd: "unknown", + args: &common.CommandArgs{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handler.Validate(tt.cmd, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestUserHandler_AddUserWithGroups(t *testing.T) { + tests := []struct { + name string + setupMock func(*common.MockCommandExecutor) + groupService *MockGroupService + wantCode int + calledGroups bool + }{ + { + name: "add user to groups success", + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult(fmt.Sprintf("/usr/sbin/adduser --home /home/testuser --shell /bin/bash --uid %d --gid %d --gecos Test User --disabled-password testuser", 1001, 1001), 0, "User created", nil) + }, + groupService: &MockGroupService{}, + wantCode: 0, + calledGroups: true, + }, + { + name: "add user to groups failure", + setupMock: func(mock *common.MockCommandExecutor) { + mock.SetResult(fmt.Sprintf("/usr/sbin/adduser --home /home/testuser --shell /bin/bash --uid %d --gid %d --gecos Test User --disabled-password testuser", 1001, 1001), 0, "User created", nil) + }, + groupService: &MockGroupService{AddUserToGroupsError: errors.New("failed to add to groups")}, + wantCode: 0, // Still want 0 for user creation, group error is logged + calledGroups: true, + }, + } + + for _, tt := range tests { + tt := tt // Capture range variable. + t.Run(tt.name, func(t *testing.T) { + // Set platform like to debian for the test + originalPlatformLike := utils.PlatformLike + utils.SetPlatformLike("debian") + t.Cleanup(func() { + utils.SetPlatformLike(originalPlatformLike) + }) + + mock := common.NewMockCommandExecutor(t) + if tt.setupMock != nil { + tt.setupMock(mock) + } + + handler := NewUserHandler(mock, tt.groupService, nil) + ctx := context.Background() + + args := &common.CommandArgs{ + Username: "testuser", + UID: 1001, + GID: 1001, + Comment: "Test User", + HomeDirectory: "/home/testuser", + Shell: "/bin/bash", + Groupname: "testgroup", + Groups: []uint64{1002, 1003}, + } + + exitCode, _, err := handler.Execute(ctx, "adduser", args) + + if err != nil && !tt.groupService.AddUserToGroupsCalled { // only expect error if AddUserToGroups not called + t.Errorf("Execute() unexpected error: %v", err) + } + if exitCode != tt.wantCode { + t.Errorf("Execute() exitCode = %v, want %v", exitCode, tt.wantCode) + } + if tt.calledGroups && !tt.groupService.AddUserToGroupsCalled { + t.Error("Execute() did not call AddUserToGroups on group service") + } + if !tt.calledGroups && tt.groupService.AddUserToGroupsCalled { + t.Error("Execute() unexpectedly called AddUserToGroups on group service") + } + }) + } +} diff --git a/pkg/executor/integration_test.go b/pkg/executor/integration_test.go new file mode 100644 index 0000000..177b69a --- /dev/null +++ b/pkg/executor/integration_test.go @@ -0,0 +1,292 @@ +package executor + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// IntegrationMockHandler is a more complete mock handler for integration testing +type IntegrationMockHandler struct { + name string + commands []string + executeCount int + validateCount int + executionDelay time.Duration + mu sync.Mutex +} + +func (h *IntegrationMockHandler) Name() string { + return h.name +} + +func (h *IntegrationMockHandler) Commands() []string { + return h.commands +} + +func (h *IntegrationMockHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + h.mu.Lock() + h.executeCount++ + h.mu.Unlock() + + if h.executionDelay > 0 { + select { + case <-ctx.Done(): + return 1, "", ctx.Err() + case <-time.After(h.executionDelay): + } + } + + return 0, "executed: " + cmd, nil +} + +func (h *IntegrationMockHandler) Validate(cmd string, args *common.CommandArgs) error { + h.mu.Lock() + h.validateCount++ + h.mu.Unlock() + return nil +} + +func (h *IntegrationMockHandler) GetExecuteCount() int { + h.mu.Lock() + defer h.mu.Unlock() + return h.executeCount +} + +func (h *IntegrationMockHandler) GetValidateCount() int { + h.mu.Lock() + defer h.mu.Unlock() + return h.validateCount +} + +// TestIntegration_RegistryWithHandlers tests that handlers can be registered and retrieved +func TestIntegration_RegistryWithHandlers(t *testing.T) { + registry := NewRegistry() + + handler1 := &IntegrationMockHandler{ + name: "handler1", + commands: []string{"cmd1", "cmd2"}, + } + handler2 := &IntegrationMockHandler{ + name: "handler2", + commands: []string{"cmd3", "cmd4"}, + } + + // Register handlers + if err := registry.Register(handler1); err != nil { + t.Fatalf("failed to register handler1: %v", err) + } + if err := registry.Register(handler2); err != nil { + t.Fatalf("failed to register handler2: %v", err) + } + + // Verify all commands are accessible + for _, cmd := range []string{"cmd1", "cmd2", "cmd3", "cmd4"} { + if !registry.IsCommandRegistered(cmd) { + t.Errorf("command %q should be registered", cmd) + } + } + + // Get handlers and verify names + h1, err := registry.Get("cmd1") + if err != nil { + t.Fatalf("failed to get handler for cmd1: %v", err) + } + if h1.Name() != "handler1" { + t.Errorf("expected handler1, got %q", h1.Name()) + } + + h2, err := registry.Get("cmd3") + if err != nil { + t.Fatalf("failed to get handler for cmd3: %v", err) + } + if h2.Name() != "handler2" { + t.Errorf("expected handler2, got %q", h2.Name()) + } +} + +// TestIntegration_HandlerExecution tests handler execution through registry +func TestIntegration_HandlerExecution(t *testing.T) { + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "test_handler", + commands: []string{"test_cmd"}, + } + _ = registry.Register(handler) + + // Get handler and execute + h, err := registry.Get("test_cmd") + if err != nil { + t.Fatalf("failed to get handler: %v", err) + } + + ctx := context.Background() + args := &common.CommandArgs{} + + // Validate first + if err := h.Validate("test_cmd", args); err != nil { + t.Fatalf("validation failed: %v", err) + } + + // Execute + exitCode, output, err := h.Execute(ctx, "test_cmd", args) + if err != nil { + t.Fatalf("execution failed: %v", err) + } + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } + if output == "" { + t.Error("expected non-empty output") + } + + // Verify counts + if handler.GetExecuteCount() != 1 { + t.Errorf("expected execute count 1, got %d", handler.GetExecuteCount()) + } + if handler.GetValidateCount() != 1 { + t.Errorf("expected validate count 1, got %d", handler.GetValidateCount()) + } +} + +// TestIntegration_ContextCancellation tests that context cancellation is propagated +func TestIntegration_ContextCancellation(t *testing.T) { + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "slow_handler", + commands: []string{"slow_cmd"}, + executionDelay: 2 * time.Second, + } + _ = registry.Register(handler) + + h, _ := registry.Get("slow_cmd") + + // Create context with short timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + args := &common.CommandArgs{} + + // Execute - should timeout + exitCode, _, err := h.Execute(ctx, "slow_cmd", args) + + if err == nil { + t.Error("expected context cancellation error") + } + if exitCode != 1 { + t.Errorf("expected exit code 1 on cancellation, got %d", exitCode) + } +} + +// TestIntegration_ConcurrentExecution tests concurrent handler execution +func TestIntegration_ConcurrentExecution(t *testing.T) { + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "concurrent_handler", + commands: []string{"concurrent_cmd"}, + executionDelay: 10 * time.Millisecond, + } + _ = registry.Register(handler) + + h, _ := registry.Get("concurrent_cmd") + ctx := context.Background() + args := &common.CommandArgs{} + + var wg sync.WaitGroup + concurrency := 50 + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _, _ = h.Execute(ctx, "concurrent_cmd", args) + }() + } + + wg.Wait() + + if handler.GetExecuteCount() != concurrency { + t.Errorf("expected %d executions, got %d", concurrency, handler.GetExecuteCount()) + } +} + +// TestIntegration_PoolWithRegistry tests pool integration with registry +func TestIntegration_PoolWithRegistry(t *testing.T) { + workerPool := pool.NewPool(5, 100) + defer func() { _ = workerPool.Shutdown(5 * time.Second) }() + + ctxManager := agent.NewContextManager() + defer ctxManager.Shutdown() + + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "pool_handler", + commands: []string{"pool_cmd"}, + } + _ = registry.Register(handler) + + h, _ := registry.Get("pool_cmd") + args := &common.CommandArgs{} + + var wg sync.WaitGroup + taskCount := 20 + + for i := 0; i < taskCount; i++ { + wg.Add(1) + ctx, cancel := ctxManager.NewContext(5 * time.Second) + + err := workerPool.Submit(ctx, func() error { + defer wg.Done() + defer cancel() + _, _, _ = h.Execute(ctx, "pool_cmd", args) + return nil + }) + if err != nil { + wg.Done() + cancel() + t.Logf("failed to submit task: %v", err) + } + } + + wg.Wait() + + // Allow for some tasks to fail due to pool dynamics + if handler.GetExecuteCount() < taskCount/2 { + t.Errorf("expected at least %d executions, got %d", taskCount/2, handler.GetExecuteCount()) + } +} + +// TestIntegration_UnregisterHandler tests handler unregistration +func TestIntegration_UnregisterHandler(t *testing.T) { + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "removable", + commands: []string{"remove_cmd"}, + } + _ = registry.Register(handler) + + // Verify registered + if !registry.IsCommandRegistered("remove_cmd") { + t.Error("command should be registered") + } + + // Unregister + if err := registry.Unregister("removable"); err != nil { + t.Fatalf("failed to unregister: %v", err) + } + + // Verify unregistered + if registry.IsCommandRegistered("remove_cmd") { + t.Error("command should not be registered after unregister") + } +} diff --git a/pkg/executor/registry.go b/pkg/executor/registry.go new file mode 100644 index 0000000..c9835b2 --- /dev/null +++ b/pkg/executor/registry.go @@ -0,0 +1,145 @@ +// Package executor provides the command execution framework for Alpamon +package executor + +import ( + "fmt" + "sync" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" + "github.com/rs/zerolog/log" +) + +// Registry manages the registration and lookup of command handlers +type Registry struct { + handlers map[string]common.Handler // handler name -> handler + cmdToHandler map[string]common.Handler // command -> handler + mu sync.RWMutex +} + +// NewRegistry creates a new handler registry +func NewRegistry() *Registry { + return &Registry{ + handlers: make(map[string]common.Handler), + cmdToHandler: make(map[string]common.Handler), + } +} + +// Register adds a new handler to the registry +func (r *Registry) Register(h common.Handler) error { + r.mu.Lock() + defer r.mu.Unlock() + + name := h.Name() + + // Check if handler already exists + if _, exists := r.handlers[name]; exists { + return fmt.Errorf("handler already registered: %s", name) + } + + // Register the handler + r.handlers[name] = h + log.Debug().Msgf("Registered handler: %s", name) + + // Map each command to this handler + for _, cmd := range h.Commands() { + if existing, exists := r.cmdToHandler[cmd]; exists { + return fmt.Errorf("command %s already registered to handler %s", cmd, existing.Name()) + } + r.cmdToHandler[cmd] = h + log.Debug().Msgf("Registered command %s to handler %s", cmd, name) + } + + return nil +} + +// Get retrieves a handler by command name +func (r *Registry) Get(cmd string) (common.Handler, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + handler, exists := r.cmdToHandler[cmd] + if !exists { + return nil, fmt.Errorf("no handler for command: %s", cmd) + } + + return handler, nil +} + +// GetHandler retrieves a handler by handler name +func (r *Registry) GetHandler(name string) (common.Handler, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + handler, exists := r.handlers[name] + if !exists { + return nil, fmt.Errorf("handler not found: %s", name) + } + + return handler, nil +} + +// List returns all registered handlers +func (r *Registry) List() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.handlers)) + for name := range r.handlers { + names = append(names, name) + } + return names +} + +// ListCommands returns all registered commands +func (r *Registry) ListCommands() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + commands := make([]string, 0, len(r.cmdToHandler)) + for cmd := range r.cmdToHandler { + commands = append(commands, cmd) + } + return commands +} + +// IsCommandRegistered checks if a command is registered +func (r *Registry) IsCommandRegistered(cmd string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, exists := r.cmdToHandler[cmd] + return exists +} + +// Unregister removes a handler from the registry +func (r *Registry) Unregister(name string) error { + r.mu.Lock() + defer r.mu.Unlock() + + handler, exists := r.handlers[name] + if !exists { + return fmt.Errorf("handler not found: %s", name) + } + + // Remove command mappings + for _, cmd := range handler.Commands() { + delete(r.cmdToHandler, cmd) + log.Debug().Msgf("Unregistered command %s from handler %s", cmd, name) + } + + // Remove handler + delete(r.handlers, name) + log.Debug().Msgf("Unregistered handler %s", name) + + return nil +} + +// Clear removes all handlers from the registry +func (r *Registry) Clear() { + r.mu.Lock() + defer r.mu.Unlock() + + r.handlers = make(map[string]common.Handler) + r.cmdToHandler = make(map[string]common.Handler) + log.Debug().Msg("Cleared all handlers from registry") +} diff --git a/pkg/executor/registry_test.go b/pkg/executor/registry_test.go new file mode 100644 index 0000000..aea6cc5 --- /dev/null +++ b/pkg/executor/registry_test.go @@ -0,0 +1,288 @@ +package executor + +import ( + "context" + "testing" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// MockHandler is a mock implementation of Handler interface for testing +type MockHandler struct { + name string + commands []string +} + +func (h *MockHandler) Name() string { + return h.name +} + +func (h *MockHandler) Commands() []string { + return h.commands +} + +func (h *MockHandler) Execute(ctx context.Context, cmd string, args *common.CommandArgs) (int, string, error) { + return 0, "mock execution", nil +} + +func (h *MockHandler) Validate(cmd string, args *common.CommandArgs) error { + return nil +} + +func TestRegistry_Register(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + // Test successful registration + err := registry.Register(handler) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + // Test duplicate handler registration + err = registry.Register(handler) + if err == nil { + t.Error("Expected error for duplicate handler registration") + } + + // Test duplicate command registration + handler2 := &MockHandler{ + name: "test2", + commands: []string{"cmd1"}, // cmd1 is already registered + } + err = registry.Register(handler2) + if err == nil { + t.Error("Expected error for duplicate command registration") + } +} + +func TestRegistry_Get(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + _ = registry.Register(handler) + + // Test getting existing command + h, err := registry.Get("cmd1") + if err != nil { + t.Fatalf("Failed to get handler for cmd1: %v", err) + } + if h.Name() != "test" { + t.Errorf("Expected handler name 'test', got '%s'", h.Name()) + } + + // Test getting non-existent command + _, err = registry.Get("nonexistent") + if err == nil { + t.Error("Expected error for non-existent command") + } +} + +func TestRegistry_GetHandler(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + _ = registry.Register(handler) + + // Test getting existing handler + h, err := registry.GetHandler("test") + if err != nil { + t.Fatalf("Failed to get handler 'test': %v", err) + } + if h.Name() != "test" { + t.Errorf("Expected handler name 'test', got '%s'", h.Name()) + } + + // Test getting non-existent handler + _, err = registry.GetHandler("nonexistent") + if err == nil { + t.Error("Expected error for non-existent handler") + } +} + +func TestRegistry_List(t *testing.T) { + registry := NewRegistry() + + handler1 := &MockHandler{ + name: "handler1", + commands: []string{"cmd1"}, + } + handler2 := &MockHandler{ + name: "handler2", + commands: []string{"cmd2"}, + } + + _ = registry.Register(handler1) + _ = registry.Register(handler2) + + handlers := registry.List() + if len(handlers) != 2 { + t.Errorf("Expected 2 handlers, got %d", len(handlers)) + } + + // Check that both handlers are in the list + foundHandler1 := false + foundHandler2 := false + for _, name := range handlers { + if name == "handler1" { + foundHandler1 = true + } + if name == "handler2" { + foundHandler2 = true + } + } + if !foundHandler1 || !foundHandler2 { + t.Error("Not all handlers found in list") + } +} + +func TestRegistry_ListCommands(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2", "cmd3"}, + } + + _ = registry.Register(handler) + + commands := registry.ListCommands() + if len(commands) != 3 { + t.Errorf("Expected 3 commands, got %d", len(commands)) + } + + // Check that all commands are in the list + commandMap := make(map[string]bool) + for _, cmd := range commands { + commandMap[cmd] = true + } + for _, expected := range []string{"cmd1", "cmd2", "cmd3"} { + if !commandMap[expected] { + t.Errorf("Command %s not found in list", expected) + } + } +} + +func TestRegistry_IsCommandRegistered(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + _ = registry.Register(handler) + + if !registry.IsCommandRegistered("cmd1") { + t.Error("Expected cmd1 to be registered") + } + if !registry.IsCommandRegistered("cmd2") { + t.Error("Expected cmd2 to be registered") + } + if registry.IsCommandRegistered("nonexistent") { + t.Error("Expected 'nonexistent' to not be registered") + } +} + +func TestRegistry_Unregister(t *testing.T) { + registry := NewRegistry() + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + _ = registry.Register(handler) + + // Verify handler is registered + if !registry.IsCommandRegistered("cmd1") { + t.Fatal("Handler not registered properly") + } + + // Unregister handler + err := registry.Unregister("test") + if err != nil { + t.Fatalf("Failed to unregister handler: %v", err) + } + + // Verify handler is no longer registered + if registry.IsCommandRegistered("cmd1") { + t.Error("Command still registered after unregistering handler") + } + + // Test unregistering non-existent handler + err = registry.Unregister("nonexistent") + if err == nil { + t.Error("Expected error when unregistering non-existent handler") + } +} + +func TestRegistry_Clear(t *testing.T) { + registry := NewRegistry() + + // Register multiple handlers + handler1 := &MockHandler{ + name: "handler1", + commands: []string{"cmd1"}, + } + handler2 := &MockHandler{ + name: "handler2", + commands: []string{"cmd2"}, + } + + _ = registry.Register(handler1) + _ = registry.Register(handler2) + + // Verify handlers are registered + if len(registry.List()) != 2 { + t.Fatal("Handlers not registered properly") + } + + // Clear registry + registry.Clear() + + // Verify registry is empty + if len(registry.List()) != 0 { + t.Error("Registry not cleared properly") + } + if len(registry.ListCommands()) != 0 { + t.Error("Commands not cleared properly") + } +} + +func TestRegistry_ThreadSafety(t *testing.T) { + registry := NewRegistry() + + // Test concurrent registration + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(n int) { + handler := &MockHandler{ + name: "handler" + string(rune('A'+n)), + commands: []string{"cmd" + string(rune('A'+n))}, + } + _ = registry.Register(handler) + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // The test passes if there's no race condition or panic + t.Log("Thread safety test passed") +} diff --git a/pkg/executor/regression_test.go b/pkg/executor/regression_test.go new file mode 100644 index 0000000..8efcacc --- /dev/null +++ b/pkg/executor/regression_test.go @@ -0,0 +1,311 @@ +package executor + +import ( + "testing" + "time" + + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// TestRegression_AllHandlerTypes verifies all expected handler types are available +func TestRegression_AllHandlerTypes(t *testing.T) { + expectedTypes := []common.HandlerType{ + common.System, + common.User, + common.Group, + common.Firewall, + common.FileTransfer, + common.Shell, + common.Terminal, + common.Info, + } + + for _, handlerType := range expectedTypes { + if handlerType.String() == "" { + t.Errorf("handler type %v has empty string representation", handlerType) + } + } +} + +// TestRegression_AllCommandTypes verifies all expected command types are available +func TestRegression_AllCommandTypes(t *testing.T) { + expectedCommands := []common.CommandType{ + // System commands + common.Upgrade, + common.Restart, + common.Quit, + common.Reboot, + common.Shutdown, + common.Update, + common.ByeBye, + + // User commands + common.AddUser, + common.DelUser, + common.ModUser, + + // Group commands + common.AddGroup, + common.DelGroup, + + // Firewall commands + common.FirewallCmd, + common.FirewallRollback, + + // File commands + common.Upload, + common.Download, + + // Shell commands + common.ShellCmd, + common.Exec, + + // Terminal commands + common.OpenPty, + common.OpenFtp, + common.ResizePty, + + // Info commands + common.Ping, + common.Help, + common.Commit, + common.Sync, + } + + for _, cmd := range expectedCommands { + if cmd.String() == "" { + t.Errorf("command type %v has empty string representation", cmd) + } + } +} + +// TestRegression_RegistryOperations verifies registry basic operations work +func TestRegression_RegistryOperations(t *testing.T) { + registry := NewRegistry() + + // Test empty registry + if len(registry.List()) != 0 { + t.Error("new registry should be empty") + } + + // Test registration + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1", "cmd2"}, + } + + if err := registry.Register(handler); err != nil { + t.Fatalf("registration failed: %v", err) + } + + // Test listing + if len(registry.List()) != 1 { + t.Error("should have 1 handler after registration") + } + + // Test command check + if !registry.IsCommandRegistered("cmd1") { + t.Error("cmd1 should be registered") + } + + // Test get + h, err := registry.Get("cmd1") + if err != nil { + t.Fatalf("get failed: %v", err) + } + if h.Name() != "test" { + t.Errorf("expected handler name 'test', got '%s'", h.Name()) + } + + // Test unregister + if err := registry.Unregister("test"); err != nil { + t.Fatalf("unregister failed: %v", err) + } + + if registry.IsCommandRegistered("cmd1") { + t.Error("cmd1 should not be registered after unregister") + } + + // Test clear + _ = registry.Register(handler) + registry.Clear() + if len(registry.List()) != 0 { + t.Error("registry should be empty after clear") + } +} + +// TestRegression_CommandArgsFields verifies all CommandArgs fields exist +func TestRegression_CommandArgsFields(t *testing.T) { + args := &common.CommandArgs{ + // User/Group management + Username: "test", + Groupname: "test", + Shell: "/bin/bash", + UID: 1000, + GID: 1000, + + // Shell execution + Command: "ls", + Env: map[string]string{"KEY": "VALUE"}, + Timeout: 30 * time.Second, + + // Firewall + Rules: []common.FirewallRule{}, + + // File transfer + Path: "/test/path", + URL: "http://example.com", + + // Terminal + SessionID: "session-123", + Rows: 24, + Cols: 80, + + // System + Target: "alpamon", + + // Info + Keys: []string{"cpu", "memory"}, + } + + // Verify all fields are accessible + if args.Username == "" { + t.Error("Username field not accessible") + } + if args.Groupname == "" { + t.Error("Groupname field not accessible") + } + if args.Shell == "" { + t.Error("Shell field not accessible") + } + if args.UID == 0 { + t.Error("UID field not accessible") + } + if args.GID == 0 { + t.Error("GID field not accessible") + } + if args.Command == "" { + t.Error("Command field not accessible") + } + if args.Env == nil { + t.Error("Env field not accessible") + } + if args.Timeout == 0 { + t.Error("Timeout field not accessible") + } + if args.Rules == nil { + t.Error("Rules field not accessible") + } + if args.Path == "" { + t.Error("Path field not accessible") + } + if args.URL == "" { + t.Error("URL field not accessible") + } + if args.SessionID == "" { + t.Error("SessionID field not accessible") + } + if args.Rows == 0 { + t.Error("Rows field not accessible") + } + if args.Cols == 0 { + t.Error("Cols field not accessible") + } + if args.Target == "" { + t.Error("Target field not accessible") + } + if len(args.Keys) == 0 { + t.Error("Keys field not accessible") + } +} + +// TestRegression_HandlerInterface verifies Handler interface contract +func TestRegression_HandlerInterface(t *testing.T) { + var _ common.Handler = (*MockHandler)(nil) + + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1"}, + } + + // Name() should return non-empty string + if handler.Name() == "" { + t.Error("Name() should not return empty string") + } + + // Commands() should return non-empty slice + if len(handler.Commands()) == 0 { + t.Error("Commands() should not return empty slice") + } +} + +// TestRegression_CommandExecutorInterface verifies CommandExecutor interface exists +func TestRegression_CommandExecutorInterface(t *testing.T) { + // Verify MockCommandExecutor implements CommandExecutor + mockExec := common.NewMockCommandExecutor(t) + + var _ common.CommandExecutor = mockExec + + // Test all methods exist + mockExec.SetResult("test ", 0, "output", nil) + cmds := mockExec.GetExecutedCommands() + if cmds == nil { + t.Error("GetExecutedCommands should not return nil") + } +} + +// TestRegression_FirewallRule verifies FirewallRule structure +func TestRegression_FirewallRule(t *testing.T) { + rule := common.FirewallRule{ + ChainName: "INPUT", + Method: "append", + Chain: "INPUT", + Protocol: "tcp", + PortStart: 22, + PortEnd: 22, + Source: "0.0.0.0/0", + Destination: "0.0.0.0/0", + Target: "ACCEPT", + Description: "Allow SSH", + Priority: 0, + RuleType: "port", + RuleID: "rule-1", + } + + if rule.ChainName == "" { + t.Error("ChainName field not accessible") + } + if rule.Method == "" { + t.Error("Method field not accessible") + } + if rule.Chain == "" { + t.Error("Chain field not accessible") + } + if rule.Protocol == "" { + t.Error("Protocol field not accessible") + } + if rule.PortStart == 0 { + t.Error("PortStart field not accessible") + } + if rule.PortEnd == 0 { + t.Error("PortEnd field not accessible") + } + if rule.Source == "" { + t.Error("Source field not accessible") + } + if rule.Destination == "" { + t.Error("Destination field not accessible") + } + if rule.Target == "" { + t.Error("Target field not accessible") + } + if rule.Description == "" { + t.Error("Description field not accessible") + } + if rule.RuleType == "" { + t.Error("RuleType field not accessible") + } + if rule.RuleID == "" { + t.Error("RuleID field not accessible") + } +} diff --git a/pkg/executor/resource_test.go b/pkg/executor/resource_test.go new file mode 100644 index 0000000..19f1ee3 --- /dev/null +++ b/pkg/executor/resource_test.go @@ -0,0 +1,302 @@ +package executor + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" +) + +// TestPerformance_MemoryUsageIdle verifies idle memory usage is within limits +func TestPerformance_MemoryUsageIdle(t *testing.T) { + // Force garbage collection before measuring + runtime.GC() + time.Sleep(50 * time.Millisecond) + + // Create typical components + registry := NewRegistry() + workerPool := pool.NewPool(10, 100) + ctxManager := agent.NewContextManager() + + // Register a mock handler + handler := &MockHandler{ + name: "test", + commands: []string{"cmd1"}, + } + _ = registry.Register(handler) + + // Force garbage collection and let things settle + runtime.GC() + time.Sleep(100 * time.Millisecond) + + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // Calculate total memory in use (HeapAlloc is current heap usage) + memUsedMB := float64(m.HeapAlloc) / 1024 / 1024 + + t.Logf("Heap memory in use: %.2f MB (HeapAlloc: %d bytes)", memUsedMB, m.HeapAlloc) + + // Cleanup + _ = workerPool.Shutdown(5 * time.Second) + ctxManager.Shutdown() + registry.Clear() + + // Allow significant margin - components should use well under 50MB + // This is a sanity check, not a strict performance requirement + // Note: This measures total heap, not just our components + if memUsedMB > 50 { + t.Errorf("Idle heap memory %.2f MB exceeds 50MB limit", memUsedMB) + } +} + +// TestPerformance_StartupTime verifies component startup is fast +func TestPerformance_StartupTime(t *testing.T) { + start := time.Now() + + // Create components + registry := NewRegistry() + workerPool := pool.NewPool(10, 100) + ctxManager := agent.NewContextManager() + + // Register some handlers + for i := 0; i < 10; i++ { + handler := &MockHandler{ + name: "handler" + string(rune('A'+i)), + commands: []string{"cmd" + string(rune('A'+i))}, + } + _ = registry.Register(handler) + } + + startupTime := time.Since(start) + + t.Logf("Startup time: %v", startupTime) + + // Cleanup + _ = workerPool.Shutdown(5 * time.Second) + ctxManager.Shutdown() + registry.Clear() + + // Startup should be under 1 second + if startupTime > 1*time.Second { + t.Errorf("Startup time %v exceeds 1 second limit", startupTime) + } +} + +// TestPerformance_CommandOverhead measures command execution overhead +func TestPerformance_CommandOverhead(t *testing.T) { + registry := NewRegistry() + + handler := &IntegrationMockHandler{ + name: "perf_handler", + commands: []string{"perf_cmd"}, + executionDelay: 0, // No delay - measure pure overhead + } + _ = registry.Register(handler) + + h, _ := registry.Get("perf_cmd") + ctx := context.Background() + args := &common.CommandArgs{} + + // Warm up + for i := 0; i < 10; i++ { + _, _, _ = h.Execute(ctx, "perf_cmd", args) + } + + // Measure execution time + iterations := 1000 + start := time.Now() + + for i := 0; i < iterations; i++ { + _, _, _ = h.Execute(ctx, "perf_cmd", args) + } + + elapsed := time.Since(start) + avgOverhead := elapsed / time.Duration(iterations) + + t.Logf("Average command overhead: %v (total: %v for %d iterations)", avgOverhead, elapsed, iterations) + + // Each command execution should have minimal overhead (< 1ms) + if avgOverhead > 1*time.Millisecond { + t.Errorf("Average command overhead %v exceeds 1ms limit", avgOverhead) + } +} + +// TestPerformance_ConcurrentCommandScaling tests performance under concurrent load +func TestPerformance_ConcurrentCommandScaling(t *testing.T) { + workerPool := pool.NewPool(10, 200) + defer func() { _ = workerPool.Shutdown(5 * time.Second) }() + + ctxManager := agent.NewContextManager() + defer ctxManager.Shutdown() + + registry := NewRegistry() + handler := &IntegrationMockHandler{ + name: "scale_handler", + commands: []string{"scale_cmd"}, + executionDelay: 1 * time.Millisecond, // Small delay to simulate work + } + _ = registry.Register(handler) + + h, _ := registry.Get("scale_cmd") + args := &common.CommandArgs{} + + // Test with different concurrency levels + concurrencyLevels := []int{1, 5, 10} + + for _, concurrency := range concurrencyLevels { + taskCount := 100 + var wg sync.WaitGroup + var completed int32 + + start := time.Now() + + for i := 0; i < taskCount; i++ { + wg.Add(1) + ctx, cancel := ctxManager.NewContext(5 * time.Second) + + err := workerPool.Submit(ctx, func() error { + defer wg.Done() + defer cancel() + _, _, err := h.Execute(ctx, "scale_cmd", args) + if err == nil { + completed++ + } + return err + }) + + if err != nil { + wg.Done() + cancel() + } + } + + wg.Wait() + elapsed := time.Since(start) + + t.Logf("Concurrency %d: completed %d/%d tasks in %v (%.2f tasks/sec)", + concurrency, completed, taskCount, elapsed, float64(completed)/elapsed.Seconds()) + } +} + +// TestPerformance_RegistryLookupSpeed tests registry lookup performance +func TestPerformance_RegistryLookupSpeed(t *testing.T) { + registry := NewRegistry() + + // Register many handlers + for i := 0; i < 100; i++ { + handler := &MockHandler{ + name: "handler" + string(rune(i)), + commands: []string{"cmd" + string(rune(i))}, + } + _ = registry.Register(handler) + } + + // Warm up + for i := 0; i < 10; i++ { + _, _ = registry.Get("cmd" + string(rune(50))) + } + + // Measure lookup time + iterations := 10000 + start := time.Now() + + for i := 0; i < iterations; i++ { + cmdIdx := i % 100 + _, _ = registry.Get("cmd" + string(rune(cmdIdx))) + } + + elapsed := time.Since(start) + avgLookup := elapsed / time.Duration(iterations) + + t.Logf("Average registry lookup: %v (total: %v for %d lookups)", avgLookup, elapsed, iterations) + + // Lookup should be very fast (< 100µs) + if avgLookup > 100*time.Microsecond { + t.Errorf("Average lookup time %v exceeds 100µs limit", avgLookup) + } + + registry.Clear() +} + +// TestPerformance_GoroutineLimit verifies goroutine limits are enforced +func TestPerformance_GoroutineLimit(t *testing.T) { + maxWorkers := 10 + workerPool := pool.NewPool(maxWorkers, 100) + defer func() { _ = workerPool.Shutdown(5 * time.Second) }() + + ctx := context.Background() + + // Track concurrent goroutines + var maxConcurrent int32 + var current int32 + var mu sync.Mutex + + // Submit tasks that take some time + var wg sync.WaitGroup + taskCount := 50 + + for i := 0; i < taskCount; i++ { + wg.Add(1) + err := workerPool.Submit(ctx, func() error { + defer wg.Done() + + mu.Lock() + current++ + if current > maxConcurrent { + maxConcurrent = current + } + mu.Unlock() + + time.Sleep(20 * time.Millisecond) + + mu.Lock() + current-- + mu.Unlock() + + return nil + }) + if err != nil { + wg.Done() + } + } + + wg.Wait() + + t.Logf("Max concurrent goroutines: %d (limit: %d)", maxConcurrent, maxWorkers) + + if int(maxConcurrent) > maxWorkers { + t.Errorf("Max concurrent goroutines %d exceeded worker limit %d", maxConcurrent, maxWorkers) + } +} + +// TestPerformance_ContextCancellationSpeed tests context cancellation overhead +func TestPerformance_ContextCancellationSpeed(t *testing.T) { + ctxManager := agent.NewContextManager() + defer ctxManager.Shutdown() + + // Measure context creation and cancellation + iterations := 1000 + start := time.Now() + + for i := 0; i < iterations; i++ { + ctx, cancel := ctxManager.NewContext(5 * time.Second) + _ = ctx + cancel() + } + + elapsed := time.Since(start) + avgTime := elapsed / time.Duration(iterations) + + t.Logf("Average context create/cancel: %v (total: %v for %d iterations)", avgTime, elapsed, iterations) + + // Context operations should be fast (< 100µs) + if avgTime > 100*time.Microsecond { + t.Errorf("Average context operation time %v exceeds 100µs limit", avgTime) + } +} diff --git a/pkg/executor/services/group_service.go b/pkg/executor/services/group_service.go new file mode 100644 index 0000000..0de8183 --- /dev/null +++ b/pkg/executor/services/group_service.go @@ -0,0 +1,75 @@ +package services + +import ( + "context" + "fmt" + "os/user" + "strconv" + "strings" + + "github.com/rs/zerolog/log" +) + +// CommandExecutor interface defines the contract for executing system commands +// This is duplicated here to avoid circular dependency with handlers package +type CommandExecutor interface { + Run(ctx context.Context, name string, args ...string) (exitCode int, output string, err error) + RunAsUser(ctx context.Context, username string, name string, args ...string) (exitCode int, output string, err error) + RunWithInput(ctx context.Context, input string, name string, args ...string) (exitCode int, output string, err error) +} + +// GroupService provides group management operations for use by other handlers +type GroupService interface { + // AddUserToGroups adds a user to one or more groups by GID + AddUserToGroups(ctx context.Context, username string, gids []uint64) error +} + +// DefaultGroupService is the default implementation of GroupService +type DefaultGroupService struct { + executor CommandExecutor +} + +// NewDefaultGroupService creates a new DefaultGroupService +func NewDefaultGroupService(executor CommandExecutor) *DefaultGroupService { + return &DefaultGroupService{ + executor: executor, + } +} + +// AddUserToGroups adds a user to one or more groups by GID +func (s *DefaultGroupService) AddUserToGroups(ctx context.Context, username string, gids []uint64) error { + if len(gids) == 0 { + return nil + } + + log.Info(). + Str("user", username). + Uints64("gids", gids). + Msg("Adding user to groups") + + // Convert GIDs to group names + var groups []string + for _, gid := range gids { + group, err := user.LookupGroupId(strconv.FormatUint(gid, 10)) + if err != nil { + log.Warn().Uint64("gid", gid).Err(err).Msg("Failed to lookup group by GID") + continue + } + groups = append(groups, group.Name) + } + + if len(groups) == 0 { + return nil + } + + // Use usermod -a -G to add user to groups + // Join groups with comma for single command + groupList := strings.Join(groups, ",") + + exitCode, output, _ := s.executor.RunAsUser(ctx, "root", "usermod", "-a", "-G", groupList, username) + if exitCode != 0 { + return fmt.Errorf("failed to add user %s to groups %v: %s", username, groups, output) + } + + return nil +} diff --git a/pkg/logger/server.go b/pkg/logger/server.go index b47caf9..b67c875 100644 --- a/pkg/logger/server.go +++ b/pkg/logger/server.go @@ -8,6 +8,8 @@ import ( "net" "time" + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/scheduler" "github.com/rs/zerolog/log" ) @@ -19,9 +21,11 @@ const ( type LogServer struct { listener net.Listener shutDownChan chan struct{} + workerPool *pool.Pool + ctxManager *agent.ContextManager } -func NewLogServer() *LogServer { +func NewLogServer(workerPool *pool.Pool, ctxManager *agent.ContextManager) *LogServer { listener, err := net.Listen("tcp", address) if err != nil { log.Error().Err(err).Msgf("Log server startup failed: cannot bind to %s.", address) @@ -31,6 +35,8 @@ func NewLogServer() *LogServer { return &LogServer{ listener: listener, shutDownChan: make(chan struct{}), + workerPool: workerPool, + ctxManager: ctxManager, } } @@ -50,7 +56,18 @@ func (ls *LogServer) StartLogServer() { log.Error().Err(err).Msg("Failed to accept socket.") continue } - go ls.handleConnection(conn) + // Submit connection handler to worker pool + ctx, cancel := ls.ctxManager.NewContext(0) // No timeout for connection handlers + err = ls.workerPool.Submit(ctx, func() error { + defer cancel() + ls.handleConnection(conn) + return nil + }) + if err != nil { + cancel() + log.Error().Err(err).Msg("Failed to submit connection handler to pool") + conn.Close() + } } } } diff --git a/pkg/runner/client.go b/pkg/runner/client.go index c7a1d3f..a163bf5 100644 --- a/pkg/runner/client.go +++ b/pkg/runner/client.go @@ -3,12 +3,14 @@ package runner import ( "context" "crypto/tls" - "encoding/json" "fmt" "net/http" "os" "time" + "github.com/alpacax/alpamon/internal/pool" + "github.com/alpacax/alpamon/internal/protocol" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/config" "github.com/alpacax/alpamon/pkg/scheduler" "github.com/alpacax/alpamon/pkg/utils" @@ -34,9 +36,12 @@ type WebsocketClient struct { RestartChan chan struct{} ShutDownChan chan struct{} CollectorRestartChan chan struct{} + pool *pool.Pool + ctxManager *agent.ContextManager + dispatcher CommandDispatcher } -func NewWebsocketClient(session *scheduler.Session) *WebsocketClient { +func NewWebsocketClient(session *scheduler.Session, ctxManager *agent.ContextManager, workerPool *pool.Pool) *WebsocketClient { headers := http.Header{ "Authorization": {fmt.Sprintf(`id="%s", key="%s"`, config.GlobalSettings.ID, config.GlobalSettings.Key)}, "Origin": {config.GlobalSettings.ServerURL}, @@ -49,9 +54,16 @@ func NewWebsocketClient(session *scheduler.Session) *WebsocketClient { RestartChan: make(chan struct{}), ShutDownChan: make(chan struct{}), CollectorRestartChan: make(chan struct{}, 1), + pool: workerPool, + ctxManager: ctxManager, } } +// SetDispatcher sets the dispatcher for handling commands with dispatcher +func (wc *WebsocketClient) SetDispatcher(dispatcher CommandDispatcher) { + wc.dispatcher = dispatcher +} + func (wc *WebsocketClient) RunForever(ctx context.Context) { wc.Connect() @@ -205,48 +217,61 @@ func (wc *WebsocketClient) RestartCollector() { } func (wc *WebsocketClient) CommandRequestHandler(message []byte) { - var content Content - var data CommandData - if len(message) == 0 { return } - err := json.Unmarshal(message, &content) + msg, err := protocol.ParseMessage(message) if err != nil { log.Warn().Err(err).Msgf("Inappropriate message: %s.", string(message)) return } - if content.Command.Data != "" { - err = json.Unmarshal([]byte(content.Command.Data), &data) - if err != nil { - log.Warn().Err(err).Msgf("Inappropriate message: %s.", string(message)) - return - } - } - - switch content.Query { - case "ping": + switch msg.Query { + case protocol.MessageTypePing: + // Respond to ping with pong if err := wc.SendPongResponse(); err != nil { log.Debug().Err(err).Msg("Failed to send pong response.") } - case "command": - scheduler.Rqueue.Post(fmt.Sprintf(eventCommandAckURL, content.Command.ID), + case protocol.MessageTypeCommand: + if msg.Command == nil { + log.Warn().Msg("Command message without command data") + return + } + + scheduler.Rqueue.Post(fmt.Sprintf(eventCommandAckURL, msg.Command.ID), nil, 10, time.Time{}, ) - commandRunner := NewCommandRunner(wc, wc.apiSession, content.Command, data) - go commandRunner.Run() - case "quit": - log.Debug().Msgf("Quit requested for reason: %s.", content.Reason) + + data, err := msg.Command.ParseCommandData() + if err != nil { + log.Warn().Err(err).Msgf("Failed to parse command data: %s.", string(message)) + return + } + + // Use modular handler system + if wc.dispatcher != nil { + wc.handleCommand(*msg.Command, *data) + } else { + log.Error().Msg("Dispatcher not initialized") + // Send failure notification + payload := protocol.NewCommandResponse(false, "Internal error: dispatcher not initialized", 0) + scheduler.Rqueue.Post(fmt.Sprintf(eventCommandFinURL, msg.Command.ID), + payload, + 10, + time.Time{}, + ) + } + case protocol.MessageTypeQuit: + log.Debug().Msgf("Quit requested for reason: %s.", msg.Reason) wc.ShutDown() - case "reconnect": - log.Debug().Msgf("Reconnect requested for reason: %s.", content.Reason) + case protocol.MessageTypeReconnect: + log.Debug().Msgf("Reconnect requested for reason: %s.", msg.Reason) wc.Close() default: - log.Warn().Msgf("Not implemented query: %s.", content.Query) + log.Warn().Msgf("Not implemented query: %s.", msg.Query) } } @@ -258,3 +283,36 @@ func (wc *WebsocketClient) WriteJSON(data interface{}) error { } return nil } + +func (wc *WebsocketClient) handleCommand(command protocol.Command, data protocol.CommandData) { + // Create CommandRunner with dispatcher for direct execution + commandRunner := NewCommandRunner(wc, wc.apiSession, command, data, wc.dispatcher) + + // Submit to pool with context + var ctx context.Context + var cancel context.CancelFunc + if config.GlobalSettings.PoolDefaultTimeout > 0 { + ctx, cancel = wc.ctxManager.NewContext(time.Duration(config.GlobalSettings.PoolDefaultTimeout) * time.Second) + } else { + ctx, cancel = wc.ctxManager.NewContext(0) + } + + err := wc.pool.Submit(ctx, func() error { + defer cancel() + // Run the command - it handles result notification internally via defer + return commandRunner.Run(ctx) + }) + + if err != nil { + cancel() + log.Error().Err(err).Msgf("Failed to submit command %s to pool", command.ID) + // Send failure notification + start := time.Now() + payload := protocol.NewCommandResponse(false, fmt.Sprintf("Failed to submit command: %v", err), time.Since(start).Seconds()) + scheduler.Rqueue.Post(fmt.Sprintf(eventCommandFinURL, command.ID), + payload, + 10, + time.Time{}, + ) + } +} diff --git a/pkg/runner/command.go b/pkg/runner/command.go index 411f14b..dd354bf 100644 --- a/pkg/runner/command.go +++ b/pkg/runner/command.go @@ -1,47 +1,34 @@ package runner import ( - "archive/zip" - "bytes" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "errors" + "context" "fmt" - "io" - "mime/multipart" - "net/http" - "net/url" - "os" - "os/exec" - "os/user" - "path/filepath" - "strconv" "strings" - "syscall" "time" - "github.com/alpacax/alpamon/pkg/config" + "github.com/alpacax/alpamon/internal/protocol" + "github.com/alpacax/alpamon/pkg/executor/handlers/common" "github.com/alpacax/alpamon/pkg/scheduler" - "github.com/alpacax/alpamon/pkg/utils" - "github.com/alpacax/alpamon/pkg/version" - "github.com/google/uuid" "github.com/rs/zerolog/log" - "gopkg.in/go-playground/validator.v9" ) -const ( - fileUploadTimeout = 60 * 10 - serverUnregisterURL = "/api/servers/servers/-/unregister/" -) +// CommandDispatcher interface to avoid circular import with executor package +type CommandDispatcher interface { + Execute(ctx context.Context, command string, args *common.CommandArgs) (int, string, error) + HasHandler(command string) bool +} -func init() { - // Inject runCmdWithOutput function into utils.firewall package - utils.SetFirewallCommandExecutor(runCmdWithOutput) +// CommandRunner executes commands received from the server +type CommandRunner struct { + name string + command protocol.Command + wsClient *WebsocketClient + apiSession *scheduler.Session + data protocol.CommandData + dispatcher CommandDispatcher } -func NewCommandRunner(wsClient *WebsocketClient, apiSession *scheduler.Session, command Command, data CommandData) *CommandRunner { +func NewCommandRunner(wsClient *WebsocketClient, apiSession *scheduler.Session, command protocol.Command, data protocol.CommandData, dispatcher CommandDispatcher) *CommandRunner { var name string if command.ID != "" { name = fmt.Sprintf("CommandRunner-%s", strings.Split(command.ID, "-")[0]) @@ -53,2296 +40,82 @@ func NewCommandRunner(wsClient *WebsocketClient, apiSession *scheduler.Session, data: data, wsClient: wsClient, apiSession: apiSession, - validator: validator.New(), + dispatcher: dispatcher, } } -func (cr *CommandRunner) Run() { +func (cr *CommandRunner) Run(ctx context.Context) error { var exitCode int var result string - - log.Debug().Msgf("Received command: %s > %s", cr.command.Shell, cr.command.Line) - start := time.Now() - switch cr.command.Shell { - case "internal": - exitCode, result = cr.handleInternalCmd() - case "system": - exitCode, result = cr.handleShellCmd(cr.command.Line, cr.command.User, cr.command.Group, cr.command.Env) - default: - exitCode = 1 - result = "Invalid command shell argument." - } - - if cr.command.ID != "" { - finURL := fmt.Sprintf(eventCommandFinURL, cr.command.ID) - - payload := &commandFin{ - Success: exitCode == 0, - Result: result, - ElapsedTime: time.Since(start).Seconds(), - } - scheduler.Rqueue.Post(finURL, payload, 10, time.Time{}) - } -} - -func (cr *CommandRunner) handleInternalCmd() (int, string) { - args := strings.Fields(cr.command.Line) - if len(args) == 0 { - return 1, "No command provided" - } - - for i, arg := range args { - unquotedArg, err := strconv.Unquote(arg) - if err == nil { - args[i] = unquotedArg - } - } - - var cmd string - switch args[0] { - case "upgrade": - latestVersion := utils.GetLatestVersion() - - if version.Version == latestVersion { - return 0, fmt.Sprintf("Alpamon is already up-to-date (version: %s)", version.Version) - } - - if utils.PlatformLike == "debian" { - cmd = "apt-get update -y && " + - "apt-get install --only-upgrade alpamon -y" - } else if utils.PlatformLike == "rhel" { - cmd = "yum update -y alpamon" - } else { - return 1, fmt.Sprintf("Platform '%s' not supported.", utils.PlatformLike) - } - log.Debug().Msgf("Upgrading alpamon from %s to %s using command: '%s'...", version.Version, latestVersion, cmd) - return cr.handleShellCmd(cmd, "root", "root", nil) - case "commit": - cr.commit() - return 0, "Committed system information." - case "sync": - cr.sync(cr.data.Keys) - return 0, "Synchronized system information." - case "adduser": - return cr.addUser() - case "addgroup": - return cr.addGroup() - case "deluser": - return cr.delUser() - case "delgroup": - return cr.delGroup() - case "moduser": - return cr.modUser() - case "ping": - _ = cr.wsClient.SendPongResponse() - return 0, time.Now().Format(time.RFC3339) - case "download": - return cr.runFileDownload(args[1]) - case "upload": - code, message := cr.runFileUpload(args[1]) - statFileTransfer(code, DOWNLOAD, message, cr.data) - - return code, message - case "openpty": - data := openPtyData{ - SessionID: cr.data.SessionID, - URL: cr.data.URL, - Username: cr.data.Username, - Groupname: cr.data.Groupname, - HomeDirectory: cr.data.HomeDirectory, - Rows: cr.data.Rows, - Cols: cr.data.Cols, - } - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("openpty: Not enough information. %s", err.Error()) - } - - ptyClient := NewPtyClient(cr.data, cr.apiSession) - go ptyClient.RunPtyBackground() - - return 0, "Spawned a pty terminal." - case "openftp": - data := openFtpData{ - SessionID: cr.data.SessionID, - URL: cr.data.URL, - Username: cr.data.Username, - Groupname: cr.data.Groupname, - HomeDirectory: cr.data.HomeDirectory, - } - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("openftp: Not enough information. %s", err.Error()) - } - - err = cr.openFtp(data) - if err != nil { - return 1, fmt.Sprintf("%v", err) - } - - return 0, "Spawned a ftp terminal." - case "opentunnel": - log.Debug(). - Str("sessionID", cr.data.SessionID). - Int("targetPort", cr.data.TargetPort). - Str("url", cr.data.URL). - Msg("Received opentunnel command") - - // Validate port range (1-65535, 0 is reserved) - if cr.data.TargetPort < 1 || cr.data.TargetPort > 65535 { - return 1, fmt.Sprintf("opentunnel: Invalid target port %d. Must be between 1 and 65535.", cr.data.TargetPort) - } - - data := openTunnelData{ - SessionID: cr.data.SessionID, - TargetPort: cr.data.TargetPort, - URL: cr.data.URL, - } - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("opentunnel: Not enough information. %s", err.Error()) - } - - // Check if tunnel already exists - if _, exists := GetActiveTunnel(cr.data.SessionID); exists { - return 1, fmt.Sprintf("opentunnel: Tunnel session %s already exists.", cr.data.SessionID) - } - - tunnelClient := NewTunnelClient( - cr.data.SessionID, - cr.data.TargetPort, - cr.data.URL, - ) - go tunnelClient.RunTunnelBackground() - - return 0, fmt.Sprintf("Spawned a tunnel for session %s, target port %d.", cr.data.SessionID, cr.data.TargetPort) - case "closetunnel": - data := closeTunnelData{ - SessionID: cr.data.SessionID, - } - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("closetunnel: Not enough information. %s", err.Error()) - } - if err := CloseTunnel(cr.data.SessionID); err != nil { - return 1, fmt.Sprintf("closetunnel: %s", err.Error()) - } - - return 0, fmt.Sprintf("Closed tunnel session %s.", cr.data.SessionID) - case "resizepty": - if terminals[cr.data.SessionID] != nil { - err := terminals[cr.data.SessionID].resize(cr.data.Rows, cr.data.Cols) - if err != nil { - return 1, err.Error() - } - return 0, fmt.Sprintf("Resized terminal for %s to %dx%d.", cr.data.SessionID, cr.data.Cols, cr.data.Rows) - } - return 1, "Invalid session ID" - case "restart": - target := "alpamon" - message := "Alpamon will restart in 1 second." - if len(args) >= 2 { - target = args[1] - } - - switch target { - case "collector": - log.Info().Msg("Restart collector.") - cr.wsClient.RestartCollector() - message = "Collector will be restarted." - default: - time.AfterFunc(1*time.Second, func() { - cr.wsClient.Restart() - }) - } - - return 0, message - case "quit": - time.AfterFunc(1*time.Second, func() { - cr.wsClient.ShutDown() - }) - return 0, "Alpamon will shutdown in 1 second." - case "byebye": - log.Info().Msg("Uninstall request received.") - - // Execute uninstall after 1 second to ensure response is sent - time.AfterFunc(1*time.Second, func() { - cr.executeUninstall() - }) - - return 0, "Alpamon will be completely uninstalled in 1 second. Goodbye!" - case "reboot": - log.Info().Msg("Reboot request received.") - time.AfterFunc(1*time.Second, func() { - cr.handleShellCmd("reboot", "root", "root", nil) - }) - - return 0, "Server will reboot in 1 second" - case "shutdown": - log.Info().Msg("Shutdown request received.") - time.AfterFunc(1*time.Second, func() { - cr.handleShellCmd("shutdown", "root", "root", nil) - }) - - return 0, "Server will shutdown in 1 second" - case "update": - log.Info().Msg("Upgrade system requested.") - if utils.PlatformLike == "debian" { - cmd = "apt-get update && apt-get upgrade -y && apt-get autoremove -y" - } else if utils.PlatformLike == "rhel" { - cmd = "yum update -y" - } else if utils.PlatformLike == "darwin" { - cmd = "brew upgrade" - } else { - return 1, fmt.Sprintf("Platform '%s' not supported.", utils.PlatformLike) - } - - return cr.handleShellCmd(cmd, "root", "root", nil) - case "restartcoll": - log.Info().Msg("Restart collector.") - cr.wsClient.RestartCollector() - - return 0, "Collector will be restarted." - case "firewall": - if utils.IsFirewallDisabled() { - log.Warn().Msg("Firewall command ignored - firewall functionality is temporarily disabled") - return 0, "Firewall functionality is temporarily disabled" - } - if detected, toolName := utils.DetectHighLevelFirewall(); detected { - return 1, fmt.Sprintf("Alpacon firewall management is disabled because %s is active. Please use %s to manage firewall rules.", toolName, toolName) - } - return cr.firewall() - case "firewall-rollback": - if utils.IsFirewallDisabled() { - log.Warn().Msg("Firewall rollback command ignored - firewall functionality is temporarily disabled") - return 0, "Firewall functionality is temporarily disabled" - } - if detected, toolName := utils.DetectHighLevelFirewall(); detected { - return 1, fmt.Sprintf("Alpacon firewall management is disabled because %s is active. Please use %s to manage firewall rules.", toolName, toolName) - } - return cr.firewallRollback() - case "firewall-reorder-chains": - if utils.IsFirewallDisabled() { - log.Warn().Msg("Firewall reorder-chains command ignored - firewall functionality is temporarily disabled") - return 0, "Firewall functionality is temporarily disabled" - } - if detected, toolName := utils.DetectHighLevelFirewall(); detected { - return 1, fmt.Sprintf("Alpacon firewall management is disabled because %s is active. Please use %s to manage firewall rules.", toolName, toolName) - } - return cr.firewallReorderChains() - case "firewall-reorder-rules": - if utils.IsFirewallDisabled() { - log.Warn().Msg("Firewall reorder-rules command ignored - firewall functionality is temporarily disabled") - return 0, "Firewall functionality is temporarily disabled" - } - if detected, toolName := utils.DetectHighLevelFirewall(); detected { - return 1, fmt.Sprintf("Alpacon firewall management is disabled because %s is active. Please use %s to manage firewall rules.", toolName, toolName) + defer func() { + if cr.command.ID != "" { + finURL := fmt.Sprintf(eventCommandFinURL, cr.command.ID) + payload := protocol.NewCommandResponse(exitCode == 0, result, time.Since(start).Seconds()) + scheduler.Rqueue.Post(finURL, payload, 10, time.Time{}) } - return cr.firewallReorderRules() - case "help": - helpMessage := ` - Available commands: - package install : install a system package - package uninstall : remove a system package - upgrade: upgrade alpamon - restart: restart alpamon - quit: stop alpamon - byebye: completely uninstall alpamon - update: update system - reboot: reboot system - shutdown: shutdown system - ` - return 0, helpMessage + }() - case "sudo_approval_response": - return cr.handleSudoApprovalResponse() + log.Debug().Msgf("Received command: %s > %s", cr.command.Shell, cr.command.Line) + // Check if context is already cancelled before starting + select { + case <-ctx.Done(): + result = fmt.Sprintf("Command cancelled before execution: %v", ctx.Err()) + exitCode = 1 + return fmt.Errorf("command failed with exit code %d: %s", exitCode, result) default: - return 1, fmt.Sprintf("Invalid command %s", args[0]) - } -} - -func (cr *CommandRunner) handleShellCmd(command, user, group string, env map[string]string) (exitCode int, result string) { - spl := strings.Fields(command) - args := []string{} - results := "" - - if group == "" { - group = user - } - - for _, arg := range spl { - switch arg { - case "&&": - exitCode, result = runCmdWithOutput(args, user, group, env, 0) - results += result - // stop executing if command fails - if exitCode != 0 { - return exitCode, results - } - args = []string{} - case "||": - exitCode, result = runCmdWithOutput(args, user, group, env, 0) - results += result - // execute next only if command fails - if exitCode == 0 { - return exitCode, results - } - args = []string{} - case ";": - exitCode, result = runCmdWithOutput(args, user, group, env, 0) - results += result - args = []string{} - default: - if strings.HasSuffix(arg, ";") { - args = append(args, strings.TrimSuffix(arg, ";")) - exitCode, result = runCmdWithOutput(args, user, group, env, 0) - results += result - args = []string{} - } else { - args = append(args, arg) - } - } - } - - if len(args) > 0 { - exitCode, result = runCmdWithOutput(args, user, group, env, 0) - results += result - } - - return exitCode, results -} - -func (cr *CommandRunner) commit() { - commitSystemInfo() -} - -func (cr *CommandRunner) sync(keys []string) { - syncSystemInfo(cr.wsClient.apiSession, keys) -} - -func (cr *CommandRunner) addUser() (exitCode int, result string) { - data := addUserData{ - Username: cr.data.Username, - UID: cr.data.UID, - GID: cr.data.GID, - Comment: cr.data.Comment, - HomeDirectory: cr.data.HomeDirectory, - HomeDirectoryPermission: cr.data.HomeDirectoryPermission, - Shell: cr.data.Shell, - Groupname: cr.data.Groupname, - } - - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("adduser: Not enough information. %s", err) - } - - if utils.PlatformLike == "debian" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/adduser", - "--home", data.HomeDirectory, - "--shell", data.Shell, - "--uid", strconv.FormatUint(data.UID, 10), - "--gid", strconv.FormatUint(data.GID, 10), - "--gecos", data.Comment, - "--disabled-password", - data.Username, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - - for _, gid := range cr.data.Groups { - if gid == data.GID { - continue - } - // get groupname from gid - group, err := user.LookupGroupId(strconv.FormatUint(gid, 10)) - if err != nil { - return 1, err.Error() - } - - // invoke adduser - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/adduser", - data.Username, - group.Name, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } - } else if utils.PlatformLike == "rhel" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/useradd", - "--home-dir", data.HomeDirectory, - "--shell", data.Shell, - "--uid", strconv.FormatUint(data.UID, 10), - "--gid", strconv.FormatUint(data.GID, 10), - "--groups", utils.JoinUint64s(cr.data.Groups), - "--comment", data.Comment, - data.Username, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else { - return 1, "Not implemented 'adduser' command for this platform." - } - - // Set default permission for home directory if not provided - if data.HomeDirectoryPermission == "" { - data.HomeDirectoryPermission = "700" - } - - exitCode, result = runCmdWithOutput( - []string{ - "chmod", data.HomeDirectoryPermission, data.HomeDirectory, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - - cr.sync([]string{"groups", "users"}) - return 0, "Successfully added new user." -} - -func (cr *CommandRunner) addGroup() (exitCode int, result string) { - data := addGroupData{ - Groupname: cr.data.Groupname, - GID: cr.data.GID, - } - - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("addgroup: Not enough information. %s", err) - } - - if utils.PlatformLike == "debian" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/addgroup", - "--gid", strconv.FormatUint(data.GID, 10), - data.Groupname, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else if utils.PlatformLike == "rhel" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/groupadd", - "--gid", strconv.FormatUint(data.GID, 10), - data.Groupname, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else { - return 1, "Not implemented 'addgroup' command for this platform." - } - - cr.sync([]string{"groups", "users"}) - return 0, "Successfully added new group." -} - -func (cr *CommandRunner) delUser() (exitCode int, result string) { - data := deleteUserData{ - Username: cr.data.Username, - PurgeHomeDirectory: cr.data.PurgeHomeDirectory, } - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("deluser: Not enough information. %s", err) + // Check if dispatcher is available + if cr.dispatcher == nil { + exitCode = 1 + result = "Internal error: dispatcher not initialized" + return nil } - cmd := "/usr/sbin/userdel" - args := []string{} + var command string + var args *common.CommandArgs - switch utils.PlatformLike { - case "debian": - cmd = "/usr/sbin/deluser" - if data.PurgeHomeDirectory { - args = append(args, "--remove-home") - } - case "rhel": - if data.PurgeHomeDirectory { - args = append(args, "--remove") + switch cr.command.Shell { + case "internal": + fields := strings.Fields(cr.command.Line) + if len(fields) == 0 { + exitCode = 1 + result = "No command provided" + return nil + } + command = fields[0] + args = cr.data.ToArgs() + case "system": + command = common.ShellCmd.String() + args = &common.CommandArgs{ + Command: cr.command.Line, + Username: cr.command.User, + Groupname: cr.command.Group, + Env: cr.command.Env, } default: - return 1, "Not implemented 'deluser' command for this platform." - } - - if !data.PurgeHomeDirectory { - homeDir := fmt.Sprintf("/home/%s", data.Username) - timestamp := time.Now().UTC().Format(time.RFC3339) - backupDir := fmt.Sprintf("/home/deleted_users/%s_%s", data.Username, timestamp) - - err = os.MkdirAll("/home/deleted_users", 0700) - if err != nil { - return 1, fmt.Sprintf("Failed to create backup directory: %v", err) - } - - _, err = os.Stat(homeDir) - if err != nil { - return 1, fmt.Sprintf("%s not exist: %v", homeDir, err) - } - - err = os.Rename(homeDir, backupDir) - if err != nil { - return 1, fmt.Sprintf("Failed to move home directory: %v", err) - } - - err = utils.ChownRecursive(backupDir, 0, 0) - if err != nil { - return 1, fmt.Sprintf("Failed to chown backup directory: %v", err) - } - } - - args = append(args, data.Username) - cmdString := append([]string{cmd}, args...) - - exitCode, result = runCmdWithOutput( - cmdString, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - - cr.sync([]string{"groups", "users"}) - return 0, "Successfully deleted the user." -} - -func (cr *CommandRunner) delGroup() (exitCode int, result string) { - data := deleteGroupData{ - Groupname: cr.data.Groupname, - } - - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("delgroup: Not enough information. %s", err) - } - - if utils.PlatformLike == "debian" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/delgroup", - data.Groupname, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else if utils.PlatformLike == "rhel" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/groupdel", - data.Groupname, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else { - return 1, "Not implemented 'delgroup' command for this platform." - } - - cr.sync([]string{"groups", "users"}) - return 0, "Successfully deleted the group." -} - -func (cr *CommandRunner) modUser() (exitCode int, result string) { - data := modUserData{ - Username: cr.data.Username, - Groupnames: cr.data.Groupnames, - Comment: cr.data.Comment, - } - - err := cr.validateData(data) - if err != nil { - return 1, fmt.Sprintf("moduser: Not enough information. %s", err) - } - - if utils.PlatformLike == "debian" || utils.PlatformLike == "rhel" { - exitCode, result = runCmdWithOutput( - []string{ - "/usr/sbin/usermod", - "--comment", data.Comment, - "-G", strings.Join(data.Groupnames, ","), - data.Username, - }, - "root", "", nil, 60, - ) - if exitCode != 0 { - return exitCode, result - } - } else { - return 1, "Not implemented 'moduser' command for this platform." - } - - cr.sync([]string{"groups", "users"}) - return 0, "Successfully modified user information." -} - -func (cr *CommandRunner) runFileUpload(fileName string) (exitCode int, result string) { - log.Debug().Msgf("Uploading file to %s. (username: %s, groupname: %s)", fileName, cr.data.Username, cr.data.Groupname) - - sysProcAttr, homeDirectory, err := demoteFtp(cr.data.Username, cr.data.Groupname) - if err != nil { - log.Error().Err(err).Msg("Failed to demote user.") - return 1, err.Error() - } - - if len(cr.data.Paths) == 0 { - return 1, "No paths provided" - } - - paths, bulk, recursive, err := parsePaths(homeDirectory, cr.data.Paths) - if err != nil { - log.Error().Err(err).Msg("Failed to parse paths") - return 1, err.Error() - } - - name, err := makeArchive(paths, bulk, recursive, sysProcAttr) - if err != nil { - log.Error().Err(err).Msg("Failed to create archive") - return 1, err.Error() - } - - if bulk || recursive { - defer func() { _ = os.Remove(name) }() - } - - cmd := exec.Command("cat", name) - cmd.SysProcAttr = sysProcAttr - - output, err := cmd.Output() - if err != nil { - log.Error().Err(err).Msgf("Failed to cat file: %s", output) - return 1, err.Error() - } - - requestBody, contentType, err := createMultipartBody(output, filepath.Base(name), cr.data.UseBlob, recursive) - if err != nil { - log.Error().Err(err).Msgf("Failed to make request body") - return 1, err.Error() - } - - _, statusCode, err := cr.fileUpload(requestBody, contentType) - if err != nil { - log.Error().Err(err).Msgf("Failed to upload file: %s", fileName) - return 1, err.Error() - } - - if statusCode == http.StatusOK { - return 0, fmt.Sprintf("Successfully uploaded %s.", fileName) - } - - return 1, "You do not have permission to read on the directory. or directory does not exist" -} - -func (cr *CommandRunner) fileUpload(body bytes.Buffer, contentType string) ([]byte, int, error) { - if cr.data.UseBlob { - return utils.Put(cr.data.Content, body, 0) - } - - return cr.wsClient.apiSession.MultipartRequest(cr.data.Content, body, contentType, fileUploadTimeout) -} - -func (cr *CommandRunner) runFileDownload(fileName string) (exitCode int, result string) { - log.Debug().Msgf("Downloading file to %s. (username: %s, groupname: %s)", fileName, cr.data.Username, cr.data.Groupname) - - var code int - var message string - sysProcAttr, err := demote(cr.data.Username, cr.data.Groupname) - if err != nil { - log.Error().Err(err).Msg("Failed to demote user.") - return 1, err.Error() - } - - if len(cr.data.Files) == 0 { - code, message = fileDownload(cr.data, sysProcAttr) - statFileTransfer(code, UPLOAD, message, cr.data) - } else { - for _, file := range cr.data.Files { - cmdData := CommandData{ - Username: file.Username, - Groupname: file.Groupname, - Type: file.Type, - Content: file.Content, - Path: file.Path, - AllowOverwrite: file.AllowOverwrite, - AllowUnzip: file.AllowUnzip, - URL: file.URL, - } - code, message = fileDownload(cmdData, sysProcAttr) - statFileTransfer(code, UPLOAD, message, cmdData) - } - } - - if code != 0 { - return code, message - } - - return 0, fmt.Sprintf("Successfully downloaded %s.", fileName) -} - -func (cr *CommandRunner) validateData(data interface{}) error { - err := cr.validator.Struct(data) - if err != nil { - return err - } - return nil -} - -func (cr *CommandRunner) openFtp(data openFtpData) error { - sysProcAttr, homeDirectory, err := demoteFtp(data.Username, data.Groupname) - if err != nil { - log.Debug().Err(err).Msg("Failed to get demote permission") - - return fmt.Errorf("openftp: Failed to get demoted permission. %w", err) + exitCode = 1 + result = "Invalid command shell argument." + return nil } - executable, err := os.Executable() - if err != nil { - log.Debug().Err(err).Msg("Failed to get executable path") - - return fmt.Errorf("openftp: Failed to get executable path. %w", err) + // Check if handler exists for the command + if !cr.dispatcher.HasHandler(command) { + exitCode = 1 + result = fmt.Sprintf("Unknown command: %s", command) + return nil } - cmd := exec.Command( - executable, - "ftp", - data.URL, - config.GlobalSettings.ServerURL, - homeDirectory, - ) - cmd.SysProcAttr = sysProcAttr - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + log.Debug().Msgf("Executing %s command: %s", cr.command.Shell, command) - err = cmd.Start() + var err error + exitCode, result, err = cr.dispatcher.Execute(ctx, command, args) if err != nil { - log.Debug().Err(err).Msg("Failed to start ftp worker process") - - return fmt.Errorf("openftp: Failed to start ftp worker process. %w", err) + log.Error().Err(err).Str("command", command).Msg("Command execution failed") } - go func() { _ = cmd.Wait() }() - return nil } - -func (cr *CommandRunner) firewall() (exitCode int, result string) { - log.Info().Msgf("Firewall operation: %s, ChainName: %s", cr.data.Operation, cr.data.ChainName) - - // Validate required fields based on operation - if cr.data.ChainName == "" { - return 1, "firewall: chain_name is required" - } - if cr.data.Operation == "" { - return 1, "firewall: operation is required" - } - - // Route to appropriate operation handler - switch cr.data.Operation { - case "batch": - return cr.handleBatchOperation() - case "flush": - return cr.handleFlushOperation() - case "delete": - return cr.handleDeleteOperation() - case "add": - return cr.handleAddOperation() - case "update": - return cr.handleUpdateOperation() - default: - return 1, fmt.Sprintf("firewall: Unknown operation '%s'. Supported: batch, flush, delete, add, update", cr.data.Operation) - } -} - -// handleBatchOperation handles batch application of firewall rules -func (cr *CommandRunner) handleBatchOperation() (exitCode int, result string) { - log.Info().Msgf("Firewall batch operation - ChainName: %s, RuleCount: %d", - cr.data.ChainName, len(cr.data.Rules)) - - if len(cr.data.Rules) == 0 { - // Empty batch is considered successful (no-op) - log.Warn().Msgf("Firewall batch operation with no rules - treating as no-op for chain: %s", cr.data.ChainName) - return 0, `{"success": true, "applied_rules": 0, "failed_rules": [], "rolled_back": false, "rollback_reason": null, "message": "No rules to apply"}` - } - - // Use the common batch apply logic with rollback on failure - appliedRules, failedRules, rolledBack, rollbackReason := cr.applyRulesBatchWithFlush() - - // Prepare response in batch format - if rolledBack { - return 1, fmt.Sprintf(`{"success": false, "error": "Failed to apply rules", "applied_rules": %d, "failed_rules": %d, "rolled_back": true, "rollback_reason": "%s"}`, - appliedRules, len(failedRules), rollbackReason) - } - - return 0, fmt.Sprintf(`{"success": true, "applied_rules": %d, "failed_rules": [], "rolled_back": false, "rollback_reason": null}`, appliedRules) -} - -// handleFlushOperation handles flushing a firewall chain -func (cr *CommandRunner) handleFlushOperation() (exitCode int, result string) { - log.Info().Msgf("Firewall flush operation - ChainName: %s", cr.data.ChainName) - - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall flush: Failed to check firewall tools. %s", err) - } - - if nftablesInstalled { - return cr.performNftablesRollback(cr.data.ChainName, "flush") - } else if iptablesInstalled { - return cr.performIptablesRollback(cr.data.ChainName, "flush") - } - - return 1, "firewall flush: No firewall management tool installed" -} - -// handleDeleteOperation handles deleting a specific firewall rule by rule_id -func (cr *CommandRunner) handleDeleteOperation() (exitCode int, result string) { - log.Info().Msgf("Firewall delete operation - ChainName: %s, RuleID: %s", cr.data.ChainName, cr.data.RuleID) - - // Validate required fields - if cr.data.RuleID == "" { - return 1, "firewall delete: rule_id is required for delete operation" - } - - // Create backup before deleting - backup, err := utils.BackupFirewallRules() - if err != nil { - return 1, fmt.Sprintf("firewall delete: Failed to create backup: %v", err) - } - - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall delete: Failed to check firewall tools. %s", err) - } - - var deleteExitCode int - var deleteResult string - - if nftablesInstalled { - deleteExitCode, deleteResult = cr.deleteNftablesRuleByID(cr.data.ChainName, cr.data.RuleID) - } else if iptablesInstalled { - deleteExitCode, deleteResult = cr.deleteIptablesRuleByID(cr.data.ChainName, cr.data.RuleID) - } else { - return 1, "firewall delete: No firewall management tool installed" - } - - // If deletion failed, restore backup - if deleteExitCode != 0 { - log.Error().Msgf("Failed to delete rule, restoring backup: %s", deleteResult) - if restoreErr := utils.RestoreFirewallRules(backup); restoreErr != nil { - log.Error().Err(restoreErr).Msg("Failed to restore backup after delete failure") - return deleteExitCode, fmt.Sprintf("firewall delete: Failed and restore failed: %s", deleteResult) - } - return deleteExitCode, fmt.Sprintf("firewall delete: Failed, backup restored: %s", deleteResult) - } - - return deleteExitCode, deleteResult -} - -// handleAddOperation handles adding a single firewall rule -func (cr *CommandRunner) handleAddOperation() (exitCode int, result string) { - log.Info().Msgf("Firewall add operation - ChainName: %s", cr.data.ChainName) - - // Log all received data for debugging - log.Debug().Msgf("Received firewall add data: ChainName=%s, Method=%s, Chain=%s, Protocol=%s, PortStart=%d, PortEnd=%d, DPorts=%v, ICMPType=%s, Source=%s, Destination=%s, Target=%s, Priority=%d, RuleID=%s, RuleType=%s", - cr.data.ChainName, cr.data.Method, cr.data.Chain, cr.data.Protocol, - cr.data.PortStart, cr.data.PortEnd, cr.data.DPorts, cr.data.ICMPType, - cr.data.Source, cr.data.Destination, cr.data.Target, cr.data.Priority, - cr.data.RuleID, cr.data.RuleType) - - // Validate required fields for rule addition - if err := cr.validateFirewallRuleData(); err != nil { - return 1, fmt.Sprintf("firewall add: Validation failed. %s", err) - } - - return cr.executeSingleFirewallRule() -} - -// handleUpdateOperation handles updating a firewall rule -func (cr *CommandRunner) handleUpdateOperation() (exitCode int, result string) { - log.Info().Msgf("Firewall update operation - ChainName: %s, OldRuleID: %s, NewRuleID: %s", - cr.data.ChainName, cr.data.OldRuleID, cr.data.RuleID) - - // Validate required fields for rule update - if err := cr.validateFirewallRuleData(); err != nil { - return 1, fmt.Sprintf("firewall update: Validation failed. %s", err) - } - - // For update operation: delete old rule first, then add new one with new ID - // old_rule_id: the rule to delete - // rule_id: the new rule to add - // TODO: Consider changing order to add-then-delete for better safety - - if cr.data.OldRuleID == "" { - return 1, "firewall update: old_rule_id is required for update operation" - } - - // Create backup before updating - backup, err := utils.BackupFirewallRules() - if err != nil { - return 1, fmt.Sprintf("firewall update: Failed to create backup: %v", err) - } - - // Step 1: Check firewall tools - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall update: Failed to check firewall tools. %s", err) - } - - // Step 2: Delete the old rule using old_rule_id - var deleteExitCode int - var deleteResult string - - if nftablesInstalled { - deleteExitCode, deleteResult = cr.deleteNftablesRuleByID(cr.data.ChainName, cr.data.OldRuleID) - } else if iptablesInstalled { - deleteExitCode, deleteResult = cr.deleteIptablesRuleByID(cr.data.ChainName, cr.data.OldRuleID) - } else { - return 1, "firewall update: No firewall tool available" - } - - if deleteExitCode != 0 { - // If deletion fails, restore backup - log.Error().Msgf("Failed to delete old rule during update, restoring backup: %s", deleteResult) - if restoreErr := utils.RestoreFirewallRules(backup); restoreErr != nil { - log.Error().Err(restoreErr).Msg("Failed to restore backup after delete failure") - return deleteExitCode, fmt.Sprintf("firewall update: Failed to delete old rule and restore failed: %s", deleteResult) - } - return deleteExitCode, fmt.Sprintf("firewall update: Failed to delete old rule, backup restored: %s", deleteResult) - } - - // Step 3: Add the new rule with new rule_id (stored in cr.data.RuleID) - addExitCode, addResult := cr.executeSingleFirewallRule() - - if addExitCode != 0 { - // Adding new rule failed, restore backup (old rule was deleted) - log.Error().Msgf("Failed to add new rule during update, restoring backup: %s", addResult) - if restoreErr := utils.RestoreFirewallRules(backup); restoreErr != nil { - log.Error().Err(restoreErr).Msg("Failed to restore backup after add failure") - return addExitCode, fmt.Sprintf("firewall update: Failed to add new rule and restore failed: %s", addResult) - } - return addExitCode, fmt.Sprintf("firewall update: Failed to add new rule, backup restored: %s", addResult) - } - - log.Info().Msgf("Successfully updated firewall rule: deleted %s, added %s", cr.data.OldRuleID, cr.data.RuleID) - return 0, fmt.Sprintf("Successfully updated rule: deleted %s, added %s", cr.data.OldRuleID, cr.data.RuleID) -} - -// validateFirewallRuleData performs validation for single rule operations -func (cr *CommandRunner) validateFirewallRuleData() error { - // Set default rule type if not provided - if cr.data.RuleType == "" { - cr.data.RuleType = "alpacon" - } - - // Generate rule ID if not provided - if cr.data.RuleID == "" { - cr.data.RuleID = uuid.New().String() - } - - data := firewallData{ - ChainName: cr.data.ChainName, - Method: cr.data.Method, - Chain: cr.data.Chain, - Protocol: cr.data.Protocol, - PortStart: cr.data.PortStart, - PortEnd: cr.data.PortEnd, - DPorts: cr.data.DPorts, - ICMPType: cr.data.ICMPType, - Source: cr.data.Source, - Destination: cr.data.Destination, - Target: cr.data.Target, - Description: cr.data.Description, - Priority: cr.data.Priority, - RuleType: cr.data.RuleType, - RuleID: cr.data.RuleID, - Operation: cr.data.Operation, - } - - return cr.validateFirewallData(data) -} - -// executeSingleFirewallRule executes a single firewall rule operation -func (cr *CommandRunner) executeSingleFirewallRule() (exitCode int, result string) { - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall: Failed to check firewall tools. %s", err) - } - - if nftablesInstalled { - return cr.executeNftablesRule() - } else if iptablesInstalled { - return cr.executeIptablesRule() - } - - return 1, "firewall: No firewall management tool installed" -} - -// executeNftablesRule executes nftables rule -func (cr *CommandRunner) executeNftablesRule() (exitCode int, result string) { - log.Info().Msg("Using nftables for firewall management.") - - // Create table dynamically - tableCmdArgs := []string{"nft", "add", "table", "inet", cr.data.ChainName} - _, _ = runCmdWithOutput(tableCmdArgs, "root", "", nil, 60) - - // Create chain in the new table - chainCmdArgs := []string{"nft", "add", "chain", "inet", cr.data.ChainName, strings.ToLower(cr.data.Chain)} - switch strings.ToUpper(cr.data.Chain) { - case "INPUT": - chainCmdArgs = append(chainCmdArgs, "{", "type", "filter", "hook", "input", "priority", strconv.Itoa(cr.data.Priority), ";", "policy", "accept;", "}") - case "OUTPUT": - chainCmdArgs = append(chainCmdArgs, "{", "type", "filter", "hook", "output", "priority", strconv.Itoa(cr.data.Priority), ";", "policy", "accept;", "}") - case "FORWARD": - chainCmdArgs = append(chainCmdArgs, "{", "type", "filter", "hook", "forward", "priority", strconv.Itoa(cr.data.Priority), ";", "policy", "accept;", "}") - default: - chainCmdArgs = append(chainCmdArgs, "{", "type", "filter", "hook", "prerouting", "priority", strconv.Itoa(cr.data.Priority), ";", "policy", "accept;", "}") - } - _, _ = runCmdWithOutput(chainCmdArgs, "root", "", nil, 60) - - // Add rule to the dynamic table/chain - args := []string{"nft"} - switch cr.data.Method { - case "-A": - args = append(args, "add") - case "-I": - args = append(args, "insert") - case "-R": - args = append(args, "replace") - case "-D": - args = append(args, "delete") - } - args = append(args, "rule", "inet", cr.data.ChainName, strings.ToLower(cr.data.Chain)) - - if cr.data.Source != "" && cr.data.Source != "0.0.0.0/0" { - args = append(args, "ip", "saddr", cr.data.Source) - } - - if cr.data.Destination != "" && cr.data.Destination != "0.0.0.0/0" { - args = append(args, "ip", "daddr", cr.data.Destination) - } - - if cr.data.Protocol != "all" { - if cr.data.Protocol == "icmp" { - args = append(args, "ip", "protocol", cr.data.Protocol) - if cr.data.ICMPType != "" { - args = append(args, "icmp", "type", cr.data.ICMPType) - } - } else if cr.data.Protocol == "tcp" || cr.data.Protocol == "udp" { - // For TCP/UDP, use proper nftables protocol syntax - if len(cr.data.DPorts) > 0 { - args = append(args, cr.data.Protocol) - var portList []string - for _, port := range cr.data.DPorts { - portList = append(portList, strconv.Itoa(port)) - } - args = append(args, "dport", "{", strings.Join(portList, ","), "}") - } else if cr.data.PortStart != 0 { - args = append(args, cr.data.Protocol) - // Handle single port or port range - if cr.data.PortEnd != 0 && cr.data.PortEnd != cr.data.PortStart { - portStr := fmt.Sprintf("%d-%d", cr.data.PortStart, cr.data.PortEnd) - args = append(args, "dport", portStr) - } else { - args = append(args, "dport", strconv.Itoa(cr.data.PortStart)) - } - } else { - // No port specified, use ip protocol syntax - args = append(args, "ip", "protocol", cr.data.Protocol) - } - } else { - // For other protocols - args = append(args, "ip", "protocol", cr.data.Protocol) - } - } - - // Add target action (accept/drop/reject) - targetAction := strings.ToLower(cr.data.Target) - if targetAction == "accept" || targetAction == "drop" || targetAction == "reject" { - args = append(args, targetAction) - } else { - // Default action if target is not specified or invalid - args = append(args, "accept") - } - - // Add comment with rule_id and rule_type - if cr.data.RuleID != "" || cr.data.RuleType != "" { - var commentParts []string - if cr.data.RuleID != "" { - commentParts = append(commentParts, fmt.Sprintf("rule_id:%s", cr.data.RuleID)) - } - if cr.data.RuleType != "" { - commentParts = append(commentParts, fmt.Sprintf("type:%s", cr.data.RuleType)) - } - ruleComment := strings.Join(commentParts, ",") - args = append(args, "comment", fmt.Sprintf("\"%s\"", ruleComment)) - } - - // Log the final nftables command - log.Info().Msgf("Executing nftables command: %s", strings.Join(args, " ")) - - exitCode, result = runCmdWithOutput(args, "root", "", nil, 60) - - if exitCode != 0 { - log.Error().Msgf("nftables command failed (exit code %d): %s", exitCode, result) - return exitCode, fmt.Sprintf("nftables error: %s", result) - } - - log.Info().Msgf("Successfully executed nftables rule for table %s", cr.data.ChainName) - return 0, fmt.Sprintf("Successfully executed rule for security group table %s.", cr.data.ChainName) -} - -// executeIptablesRule executes iptables rule -func (cr *CommandRunner) executeIptablesRule() (exitCode int, result string) { - log.Info().Msg("Using iptables for firewall management.") - - chainName := cr.data.ChainName + "_" + strings.ToLower(cr.data.Chain) - - // Create chain dynamically in filter table - chainCreateCmdArgs := []string{"iptables", "-N", chainName} - _, _ = runCmdWithOutput(chainCreateCmdArgs, "root", "", nil, 60) - - // Add rule to the dynamic chain - args := []string{"iptables", cr.data.Method, chainName} - - // Add protocol - if cr.data.Protocol != "all" { - args = append(args, "-p", cr.data.Protocol) - } - - // Add source if specified - if cr.data.Source != "" && cr.data.Source != "0.0.0.0/0" { - args = append(args, "-s", cr.data.Source) - } - - // Add destination if specified - if cr.data.Destination != "" && cr.data.Destination != "0.0.0.0/0" { - args = append(args, "-d", cr.data.Destination) - } - - // Handle ports based on protocol - if cr.data.Protocol == "icmp" { - if cr.data.ICMPType != "" { - args = append(args, "--icmp-type", cr.data.ICMPType) - } - } else if cr.data.Protocol == "tcp" || cr.data.Protocol == "udp" { - // Handle multiport - if len(cr.data.DPorts) > 0 { - var portList []string - for _, port := range cr.data.DPorts { - portList = append(portList, strconv.Itoa(port)) - } - args = append(args, "-m", "multiport", "--dports", strings.Join(portList, ",")) - } else if cr.data.PortStart != 0 { - // Handle single port or port range - if cr.data.PortEnd != 0 && cr.data.PortEnd != cr.data.PortStart { - portStr := fmt.Sprintf("%d:%d", cr.data.PortStart, cr.data.PortEnd) - args = append(args, "--dport", portStr) - } else { - args = append(args, "--dport", strconv.Itoa(cr.data.PortStart)) - } - } - } - - // Add target - args = append(args, "-j", cr.data.Target) - - // Add comment with rule_id and rule_type - if cr.data.RuleID != "" || cr.data.RuleType != "" { - var commentParts []string - if cr.data.RuleID != "" { - commentParts = append(commentParts, fmt.Sprintf("rule_id:%s", cr.data.RuleID)) - } - if cr.data.RuleType != "" { - commentParts = append(commentParts, fmt.Sprintf("type:%s", cr.data.RuleType)) - } - ruleComment := strings.Join(commentParts, ",") - args = append(args, "-m", "comment", "--comment", ruleComment) - } - - // Log the final iptables command - log.Info().Msgf("Executing iptables command: %s", strings.Join(args, " ")) - - exitCode, result = runCmdWithOutput(args, "root", "", nil, 60) - - if exitCode != 0 { - log.Error().Msgf("iptables command failed (exit code %d): %s", exitCode, result) - return exitCode, fmt.Sprintf("iptables error: %s", result) - } - - log.Info().Msgf("Successfully executed iptables rule for chain %s", chainName) - return 0, fmt.Sprintf("Successfully executed rule for security group chain %s.", chainName) -} - -// deleteNftablesRuleByID deletes a specific nftables rule by finding its handle using rule_id in comment -func (cr *CommandRunner) deleteNftablesRuleByID(chainName, ruleID string) (exitCode int, result string) { - log.Info().Msgf("Deleting nftables rule by ID: %s in chain %s", ruleID, chainName) - - // First, list rules with handles to find the target rule - listArgs := []string{"nft", "--handle", "list", "table", "inet", chainName} - listExitCode, listOutput := runCmdWithOutput(listArgs, "root", "", nil, 60) - - if listExitCode != 0 { - log.Error().Msgf("Failed to list nftables rules: %s", listOutput) - return listExitCode, fmt.Sprintf("Failed to list rules: %s", listOutput) - } - - // Parse the output to find rule handle and chain type with matching rule_id in comment - ruleHandle, chainType := cr.findNftablesRuleHandleAndChain(listOutput, ruleID) - if ruleHandle == "" { - log.Warn().Msgf("Rule with ID %s not found in table %s", ruleID, chainName) - return 1, fmt.Sprintf("Rule with ID %s not found", ruleID) - } - - // Delete the rule using its handle - // nftables syntax: nft delete rule inet handle - deleteArgs := []string{"nft", "delete", "rule", "inet", chainName, chainType, "handle", ruleHandle} - deleteExitCode, deleteOutput := runCmdWithOutput(deleteArgs, "root", "", nil, 60) - - if deleteExitCode != 0 { - log.Error().Msgf("Failed to delete nftables rule: %s", deleteOutput) - return deleteExitCode, fmt.Sprintf("Failed to delete rule: %s", deleteOutput) - } - - log.Info().Msgf("Successfully deleted nftables rule with ID %s (handle %s) from chain %s", ruleID, ruleHandle, chainType) - return 0, fmt.Sprintf("Successfully deleted rule with ID %s", ruleID) -} - -// findNftablesRuleHandleAndChain parses nft list output to find rule handle and chain by rule_id in comment -func (cr *CommandRunner) findNftablesRuleHandleAndChain(listOutput, ruleID string) (string, string) { - lines := strings.Split(listOutput, "\n") - targetComment := fmt.Sprintf("rule_id:%s", ruleID) - currentChain := "" - - for _, line := range lines { - // Check for chain declarations (e.g., "chain input {", "chain output {") - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "chain ") && strings.Contains(trimmed, "{") { - // Extract chain name from "chain {" - parts := strings.Fields(trimmed) - if len(parts) >= 2 { - currentChain = parts[1] - } - } - - // Look for lines containing the target comment and handle - if strings.Contains(line, targetComment) && strings.Contains(line, "# handle") { - // Extract handle number from the line - if handleIndex := strings.Index(line, "# handle"); handleIndex != -1 { - handlePart := line[handleIndex+9:] // Skip "# handle " - handle := "" - if spaceIndex := strings.Index(handlePart, " "); spaceIndex != -1 { - handle = strings.TrimSpace(handlePart[:spaceIndex]) - } else { - handle = strings.TrimSpace(handlePart) - } - return handle, currentChain - } - } - } - - return "", "" -} - -// deleteIptablesRuleByID deletes a specific iptables rule by matching rule specifications -func (cr *CommandRunner) deleteIptablesRuleByID(chainName, ruleID string) (exitCode int, result string) { - log.Info().Msgf("Deleting iptables rule - ChainName: %s, RuleID: %s", chainName, ruleID) - - fullChainName := chainName + "_" + strings.ToLower(cr.data.Chain) - - // Note: For iptables rule deletion with comment, we rely on rule specification matching - // since comment format may include additional type information - - // Build delete command with rule specifications - args := []string{"iptables", "-D", fullChainName} - - // Add protocol - if cr.data.Protocol != "" && cr.data.Protocol != "all" { - args = append(args, "-p", cr.data.Protocol) - } - - // Add source if specified - if cr.data.Source != "" && cr.data.Source != "0.0.0.0/0" { - args = append(args, "-s", cr.data.Source) - } - - // Add destination if specified - if cr.data.Destination != "" && cr.data.Destination != "0.0.0.0/0" { - args = append(args, "-d", cr.data.Destination) - } - - // Handle ports based on protocol - if cr.data.Protocol == "icmp" { - if cr.data.ICMPType != "" { - args = append(args, "--icmp-type", cr.data.ICMPType) - } - } else if cr.data.Protocol == "tcp" || cr.data.Protocol == "udp" { - // Handle multiport - if len(cr.data.DPorts) > 0 { - var portList []string - for _, port := range cr.data.DPorts { - portList = append(portList, strconv.Itoa(port)) - } - args = append(args, "-m", "multiport", "--dports", strings.Join(portList, ",")) - } else if cr.data.PortStart != 0 { - // Handle single port or port range - if cr.data.PortEnd != 0 && cr.data.PortEnd != cr.data.PortStart { - portStr := fmt.Sprintf("%d:%d", cr.data.PortStart, cr.data.PortEnd) - args = append(args, "--dport", portStr) - } else { - args = append(args, "--dport", strconv.Itoa(cr.data.PortStart)) - } - } - } - - // Add target - if cr.data.Target != "" { - args = append(args, "-j", cr.data.Target) - } - - // Skip comment matching for deletion since the comment format may have changed - // to include type information. Rule specification matching should be sufficient. - - // Execute delete command - deleteExitCode, deleteOutput := runCmdWithOutput(args, "root", "", nil, 60) - - if deleteExitCode != 0 { - log.Error().Msgf("Failed to delete iptables rule: %s", deleteOutput) - return deleteExitCode, fmt.Sprintf("Failed to delete rule: %s", deleteOutput) - } - - log.Info().Msgf("Successfully deleted iptables rule with ID %s", ruleID) - return 0, fmt.Sprintf("Successfully deleted rule with ID %s", ruleID) -} - -func getFileData(data CommandData) ([]byte, error) { - var content []byte - switch data.Type { - case "url": - parsedRequestURL, err := url.Parse(data.Content) - if err != nil { - return nil, fmt.Errorf("failed to parse URL '%s': %w", data.Content, err) - } - - req, err := http.NewRequest(http.MethodGet, parsedRequestURL.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - parsedServerURL, err := url.Parse(config.GlobalSettings.ServerURL) - if err != nil { - return nil, fmt.Errorf("failed to parse url: %w", err) - } - - if parsedRequestURL.Host == parsedServerURL.Host && parsedRequestURL.Scheme == parsedServerURL.Scheme { - req.Header.Set("Authorization", fmt.Sprintf(`id="%s", key="%s"`, - config.GlobalSettings.ID, config.GlobalSettings.Key)) - } - - client := http.Client{} - - tlsConfig := &tls.Config{} - if config.GlobalSettings.CaCert != "" { - caCertPool := x509.NewCertPool() - caCert, err := os.ReadFile(config.GlobalSettings.CaCert) - if err != nil { - log.Error().Err(err).Msg("Failed to read CA certificate.") - } - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig.RootCAs = caCertPool - } - - tlsConfig.InsecureSkipVerify = !config.GlobalSettings.SSLVerify - client.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, - } - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to download content from URL: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if (resp.StatusCode / 100) != 2 { - log.Error().Msgf("Failed to download content from URL: %d %s", resp.StatusCode, parsedRequestURL) - return nil, errors.New("downloading content failed") - } - content, err = io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - case "text": - content = []byte(data.Content) - case "base64": - var err error - content, err = base64.StdEncoding.DecodeString(data.Content) - if err != nil { - return nil, fmt.Errorf("failed to decode base64 content: %w", err) - } - default: - return nil, fmt.Errorf("unknown file type: %s", data.Type) - } - - if content == nil { - return nil, errors.New("content is nil") - } - - return content, nil -} - -func parsePaths(homeDirectory string, pathList []string) (parsedPaths []string, isBulk bool, isRecursive bool, err error) { - paths := make([]string, len(pathList)) - for i, path := range pathList { - if strings.HasPrefix(path, "~") { - path = strings.Replace(path, "~", homeDirectory, 1) - } - - if !filepath.IsAbs(path) { - path = filepath.Join(homeDirectory, path) - } - - absPath, err := filepath.Abs(path) - if err != nil { - return nil, false, false, err - } - paths[i] = absPath - } - - isBulk = len(pathList) > 1 - isRecursive = false - - if !isBulk { - fileInfo, err := os.Stat(paths[0]) - if err != nil { - return nil, false, false, err - } - isRecursive = fileInfo.IsDir() - } - - return paths, isBulk, isRecursive, nil -} - -func makeArchive(paths []string, bulk, recursive bool, sysProcAttr *syscall.SysProcAttr) (string, error) { - var archiveName string - var cmd *exec.Cmd - path := paths[0] - - if bulk { - archiveName = filepath.Dir(path) + "/" + uuid.New().String() + ".zip" - dirPath := filepath.Dir(path) - basePaths := make([]string, len(paths)) - for i, path := range paths { - basePaths[i] = filepath.Base(path) - } - - cmd = exec.Command("zip", "-r", archiveName) - cmd.SysProcAttr = sysProcAttr - cmd.Args = append(cmd.Args, basePaths...) - cmd.Dir = dirPath - } else { - if recursive { - archiveName = path + ".zip" - cmd = exec.Command("zip", "-r", archiveName, filepath.Base(path)) - cmd.SysProcAttr = sysProcAttr - cmd.Dir = filepath.Dir(path) - } else { - archiveName = path - } - } - - if bulk || recursive { - err := cmd.Run() - if err != nil { - return "", err - } - } - - return archiveName, nil -} - -func createMultipartBody(output []byte, filePath string, useBlob, isRecursive bool) (bytes.Buffer, string, error) { - if useBlob { - return *bytes.NewBuffer(output), "", nil - } - - var requestBody bytes.Buffer - writer := multipart.NewWriter(&requestBody) - - fileWriter, err := writer.CreateFormFile("content", filePath) - if err != nil { - return bytes.Buffer{}, "", err - } - - _, err = fileWriter.Write(output) - if err != nil { - return bytes.Buffer{}, "", err - } - - if isRecursive { - err = writer.WriteField("name", filePath) - if err != nil { - return bytes.Buffer{}, "", err - } - } - - _ = writer.Close() - - return requestBody, writer.FormDataContentType(), nil -} - -func fileDownload(data CommandData, sysProcAttr *syscall.SysProcAttr) (exitCode int, result string) { - var cmd *exec.Cmd - content, err := getFileData(data) - if err != nil { - return 1, err.Error() - } - - if !data.AllowOverwrite && isFileExist(data.Path) { - return 1, fmt.Sprintf("%s already exists.", data.Path) - } - - isZip := isZipFile(content, filepath.Ext(data.Path)) - if isZip && data.AllowUnzip { - escapePath := utils.Quote(data.Path) - escapeDirPath := utils.Quote(filepath.Dir(data.Path)) - // Use -o (overwrite) if AllowOverwrite is true, otherwise -n (never overwrite) - unzipOpt := "-n" - if data.AllowOverwrite { - unzipOpt = "-o" - } - command := fmt.Sprintf("tee %s > /dev/null && unzip %s %s -d %s; rm %s", - escapePath, - unzipOpt, - escapePath, - escapeDirPath, - escapePath) - cmd = exec.Command("sh", "-c", command) - } else { - cmd = exec.Command("sh", "-c", fmt.Sprintf("tee %s > /dev/null", utils.Quote(data.Path))) - } - - cmd.SysProcAttr = sysProcAttr - cmd.Stdin = bytes.NewReader(content) - - output, err := cmd.Output() - if err != nil { - log.Error().Err(err).Msgf("Failed to write file: %s", output) - return 1, "You do not have permission to read on the directory. or directory does not exist" - } - - return 0, fmt.Sprintf("Successfully downloaded %s.", data.Path) -} - -func isZipFile(content []byte, ext string) bool { - if _, found := nonZipExt[ext]; found { - return false - } - - _, err := zip.NewReader(bytes.NewReader(content), int64(len(content))) - - return err == nil -} - -func isFileExist(path string) bool { - _, err := os.Stat(path) - return !os.IsNotExist(err) -} - -func statFileTransfer(code int, transferType transferType, message string, data CommandData) { - statURL := fmt.Sprint(data.URL + "stat/") - isSuccess := code == 0 - - payload := &commandStat{ - Success: isSuccess, - Message: message, - Type: transferType, - } - scheduler.Rqueue.Post(statURL, payload, 10, time.Time{}) -} - -func (cr *CommandRunner) handleSudoApprovalResponse() (int, string) { - var sudoApprovalResponse SudoApprovalResponse - if cr.command.Data != "" { - err := json.Unmarshal([]byte(cr.command.Data), &sudoApprovalResponse) - if err != nil { - log.Error().Err(err).Msg("Failed to parse sudo_approval_response data") - return 1, "Invalid sudo_approval_response data format" - } - } else { - return 1, "No sudo_approval_response data provided" - } - - if authManager == nil { - log.Error().Msg("AuthManager not available") - return 1, "AuthManager not available" - } - - // SudoApprovalResponse - err := authManager.HandleSudoApprovalResponse(sudoApprovalResponse) - if err != nil { - log.Error().Err(err).Msg("Failed to handle sudo_approval_response") - return 1, fmt.Sprintf("Failed to handle sudo_approval_response: %v", err) - } - - return 0, "Sudo approval response processed successfully" -} - -// validateFirewallData performs enhanced validation for firewall data -func (cr *CommandRunner) validateFirewallData(data firewallData) error { - // Basic validation using struct tags - if err := cr.validateData(data); err != nil { - return fmt.Errorf("basic validation failed: %w", err) - } - - // Enhanced validation logic - validMethods := []string{"-A", "-I", "-R", "-D"} - found := false - for _, method := range validMethods { - if data.Method == method { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid method '%s', must be one of: %v", data.Method, validMethods) - } - - validProtocols := []string{"tcp", "udp", "icmp", "all"} - found = false - for _, protocol := range validProtocols { - if data.Protocol == protocol { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid protocol '%s', must be one of: %v", data.Protocol, validProtocols) - } - - validTargets := []string{"ACCEPT", "DROP", "REJECT", "LOG", "RETURN"} - found = false - for _, target := range validTargets { - if data.Target == target { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid target '%s', must be one of: %v", data.Target, validTargets) - } - - // Protocol-specific validation - if data.Protocol == "icmp" { - if data.PortStart != 0 || data.PortEnd != 0 || len(data.DPorts) > 0 { - return fmt.Errorf("ICMP protocol cannot have port specifications") - } - } - - // Port validation - if data.PortStart != 0 { - if data.PortStart < 1 || data.PortStart > 65535 { - return fmt.Errorf("PortStart must be between 1 and 65535, got %d", data.PortStart) - } - } - - if data.PortEnd != 0 { - if data.PortEnd < 1 || data.PortEnd > 65535 { - return fmt.Errorf("PortEnd must be between 1 and 65535, got %d", data.PortEnd) - } - if data.PortStart != 0 && data.PortEnd < data.PortStart { - return fmt.Errorf("PortEnd (%d) cannot be less than PortStart (%d)", data.PortEnd, data.PortStart) - } - } - - // DPorts validation - if len(data.DPorts) > 0 { - if len(data.DPorts) > 15 { - return fmt.Errorf("too many ports in multiport rule (max 15), got %d", len(data.DPorts)) - } - - // Check for duplicates and validate range - seen := make(map[int]bool) - for _, port := range data.DPorts { - if port < 1 || port > 65535 { - return fmt.Errorf("DPort must be between 1 and 65535, got %d", port) - } - if seen[port] { - return fmt.Errorf("duplicate port %d in DPorts", port) - } - seen[port] = true - } - - // Cannot have both DPorts and single port/range - if data.PortStart != 0 || data.PortEnd != 0 { - return fmt.Errorf("cannot specify both individual ports (PortStart/PortEnd) and multiport (DPorts)") - } - } - - // ICMP type validation - if data.Protocol == "icmp" && data.ICMPType != "" { - // Check if numeric - if icmpTypeNum, err := strconv.Atoi(data.ICMPType); err == nil { - if icmpTypeNum < 0 || icmpTypeNum > 255 { - return fmt.Errorf("ICMP type must be between 0 and 255, got %d", icmpTypeNum) - } - } else { - // Validate common ICMP type names - validICMPTypes := []string{ - "echo-request", "echo-reply", "destination-unreachable", - "source-quench", "redirect", "time-exceeded", - "parameter-problem", "timestamp-request", "timestamp-reply", - } - found := false - for _, validType := range validICMPTypes { - if data.ICMPType == validType { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid ICMP type '%s'", data.ICMPType) - } - } - } else if data.Protocol != "icmp" && data.ICMPType != "" { - return fmt.Errorf("ICMP type can only be specified for ICMP protocol") - } - - return nil -} - -// firewallRollback handles firewall rollback operations -func (cr *CommandRunner) firewallRollback() (exitCode int, result string) { - log.Info().Msgf("Firewall rollback command received - Operation: %s, ChainName: %s", - cr.data.Operation, cr.data.ChainName) - - // Handle both old and new field names for backward compatibility - if cr.data.ChainName == "" && cr.data.Operation == "" { - return 1, "firewall-rollback: ChainName or Operation is required" - } - - // Determine the action (flush or restore) - action := cr.data.Operation - if action == "" { - // Fallback to Method field for backward compatibility - if cr.data.Method != "" { - action = cr.data.Method - } else { - action = "flush" // Default action - } - } - - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - log.Error().Err(err).Msg("Failed to check firewall tools for rollback") - return 1, fmt.Sprintf("firewall-rollback: Failed to check firewall tools. %s", err) - } - - // Handle different rollback actions - switch action { - case "flush": - // Simple flush operation - remove all rules - if nftablesInstalled { - return cr.performNftablesRollback(cr.data.ChainName, "flush") - } else if iptablesInstalled { - return cr.performIptablesRollback(cr.data.ChainName, "flush") - } - - case "restore": - // Restore from snapshot - flush then apply new rules - if len(cr.data.Rules) == 0 { - return 1, "firewall-rollback: No rules provided for restore action" - } - - // First flush the chain - var flushExitCode int - var flushResult string - if nftablesInstalled { - flushExitCode, flushResult = cr.performNftablesRollback(cr.data.ChainName, "flush") - } else if iptablesInstalled { - flushExitCode, flushResult = cr.performIptablesRollback(cr.data.ChainName, "flush") - } - - if flushExitCode != 0 { - return flushExitCode, fmt.Sprintf("firewall-rollback: Failed to flush before restore - %s", flushResult) - } - - // Then apply each rule from the snapshot - successCount := 0 - failedRules := []string{} - - for i, ruleData := range cr.data.Rules { - // Convert rule data to CommandData fields using existing function with rule ID generation - cr.data = cr.convertRuleDataToCommandData(ruleData, cr.data) - - ruleExitCode, ruleResult := cr.executeSingleFirewallRule() - if ruleExitCode == 0 { - successCount++ - } else { - failedRules = append(failedRules, fmt.Sprintf("Rule %d: %s", i+1, ruleResult)) - } - } - - if len(failedRules) > 0 { - return 1, fmt.Sprintf("firewall-rollback: Restored %d/%d rules. Failed rules: %s", - successCount, len(cr.data.Rules), strings.Join(failedRules, "; ")) - } - - return 0, fmt.Sprintf("firewall-rollback: Successfully restored %d rules", successCount) - - case "delete": - // Delete entire table/chain structure - if nftablesInstalled { - return cr.performNftablesRollback(cr.data.ChainName, "delete") - } else if iptablesInstalled { - return cr.performIptablesRollback(cr.data.ChainName, "delete") - } - - default: - return 1, fmt.Sprintf("firewall-rollback: Unknown action '%s', use 'flush', 'restore', or 'delete'", action) - } - - return 1, "firewall-rollback: No firewall management tool installed" -} - -// performNftablesRollback performs rollback operations for nftables -func (cr *CommandRunner) performNftablesRollback(chainName, method string) (int, string) { - log.Info().Msgf("Performing nftables rollback for table: %s, method: %s", chainName, method) - - var exitCode int - var result string - - switch method { - case "flush": - // Flush all chains in the table - // For nftables, we flush all chains (INPUT, OUTPUT, FORWARD) in the security group table - chainTypes := []string{"input", "output", "forward"} - successCount := 0 - - for _, chainType := range chainTypes { - args := []string{"nft", "flush", "chain", "inet", chainName, chainType} - exitCode, result = runCmdWithOutput(args, "root", "", nil, 60) - - if exitCode == 0 { - successCount++ - log.Info().Msgf("Successfully flushed nftables chain: %s %s", chainName, chainType) - } else { - // Chain might not exist, which is OK - log.Debug().Msgf("Failed to flush nftables chain %s %s: %s (chain may not exist)", chainName, chainType, result) - } - } - - if successCount > 0 { - log.Info().Msgf("Successfully flushed %d chains in table %s", successCount, chainName) - return 0, fmt.Sprintf("Successfully flushed %d chains in table %s", successCount, chainName) - } - - log.Warn().Msgf("No chains flushed in table %s (table may not exist)", chainName) - return 0, fmt.Sprintf("No chains to flush in table %s", chainName) - - case "delete": - // Delete the entire table - args := []string{"nft", "delete", "table", "inet", chainName} - exitCode, result = runCmdWithOutput(args, "root", "", nil, 60) - - if exitCode != 0 { - log.Error().Msgf("Failed to delete nftables table %s: %s", chainName, result) - // If table doesn't exist, consider it success - if strings.Contains(result, "No such file or directory") { - log.Info().Msgf("nftables table %s already deleted", chainName) - return 0, fmt.Sprintf("Table %s was already deleted", chainName) - } - return exitCode, fmt.Sprintf("nftables delete error: %s", result) - } - - log.Info().Msgf("Successfully deleted nftables table: %s", chainName) - return 0, fmt.Sprintf("Successfully deleted table %s", chainName) - - default: - return 1, fmt.Sprintf("nftables rollback: unsupported method '%s', use 'flush' or 'delete'", method) - } -} - -// performIptablesRollback performs rollback operations for iptables -func (cr *CommandRunner) performIptablesRollback(chainName, method string) (int, string) { - log.Info().Msgf("Performing iptables rollback for chain: %s, method: %s", chainName, method) - - var exitCode int - var result string - - // For iptables, we need to handle chains differently - chainTypes := []string{"input", "output", "forward"} - - switch method { - case "flush": - successCount := 0 - for _, chainType := range chainTypes { - fullChainName := chainName + "_" + chainType - - // Flush the chain - args := []string{"iptables", "-F", fullChainName} - exitCode, result = runCmdWithOutput(args, "root", "", nil, 60) - - if exitCode == 0 { - successCount++ - log.Info().Msgf("Successfully flushed iptables chain: %s", fullChainName) - } else { - log.Warn().Msgf("Failed to flush iptables chain %s: %s", fullChainName, result) - } - } - - if successCount > 0 { - return 0, fmt.Sprintf("Successfully flushed %d chains for security group %s", successCount, chainName) - } else { - return 1, fmt.Sprintf("Failed to flush any chains for security group %s", chainName) - } - - case "delete": - successCount := 0 - for _, chainType := range chainTypes { - fullChainName := chainName + "_" + chainType - - // First flush the chain - flushArgs := []string{"iptables", "-F", fullChainName} - runCmdWithOutput(flushArgs, "root", "", nil, 60) - - // Then delete the chain - deleteArgs := []string{"iptables", "-X", fullChainName} - exitCode, result = runCmdWithOutput(deleteArgs, "root", "", nil, 60) - - if exitCode == 0 { - successCount++ - log.Info().Msgf("Successfully deleted iptables chain: %s", fullChainName) - } else { - log.Warn().Msgf("Failed to delete iptables chain %s: %s", fullChainName, result) - } - } - - if successCount > 0 { - return 0, fmt.Sprintf("Successfully deleted %d chains for security group %s", successCount, chainName) - } else { - return 1, fmt.Sprintf("Failed to delete any chains for security group %s", chainName) - } - - default: - return 1, fmt.Sprintf("iptables rollback: unsupported method '%s', use 'flush' or 'delete'", method) - } -} - -// applyRulesBatchWithFlush applies a batch of firewall rules with optional rollback on failure -// Returns: appliedRules, failedRules, rolledBack, rollbackReason -// flushBeforeApply: if true, flush all existing rules before applying new ones (full replacement) -// -// if false, only add new rules (incremental) -func (cr *CommandRunner) applyRulesBatchWithFlush() (int, []map[string]interface{}, bool, string) { - // Check firewall tools if needed - _, _, err := utils.CheckFirewallTool() - if err != nil { - log.Error().Err(err).Msg("Failed to check firewall tools") - return 0, nil, false, "" - } - - // Backup current rules before applying changes - backup, err := utils.BackupFirewallRules() - if err != nil { - log.Error().Err(err).Msg("Failed to create firewall backup") - return 0, nil, false, fmt.Sprintf("Failed to create backup before applying rules: %v", err) - } - log.Info().Msg("Created firewall backup before batch apply") - - // Apply rules - appliedRules := 0 - var failedRules []map[string]interface{} - var rollbackReason string - rolledBack := false - - // Store original data to restore later - originalChainName := cr.data.ChainName - originalMethod := cr.data.Method - originalChain := cr.data.Chain - originalProtocol := cr.data.Protocol - originalPortStart := cr.data.PortStart - originalPortEnd := cr.data.PortEnd - originalSource := cr.data.Source - originalTarget := cr.data.Target - originalDescription := cr.data.Description - originalPriority := cr.data.Priority - originalICMPType := cr.data.ICMPType - originalDPorts := cr.data.DPorts - - defer func() { - // Restore original data - cr.data.ChainName = originalChainName - cr.data.Method = originalMethod - cr.data.Chain = originalChain - cr.data.Protocol = originalProtocol - cr.data.PortStart = originalPortStart - cr.data.PortEnd = originalPortEnd - cr.data.Source = originalSource - cr.data.Target = originalTarget - cr.data.Description = originalDescription - cr.data.Priority = originalPriority - cr.data.ICMPType = originalICMPType - cr.data.DPorts = originalDPorts - }() - - for i, ruleData := range cr.data.Rules { - // Convert rule data to CommandData fields - cr.data = cr.convertRuleDataToCommandData(ruleData, cr.data) - - var ruleExitCode int - var ruleResult string - - // Check if rule has an operation field for UUID-based operations - if operation, ok := ruleData["operation"].(string); ok && operation != "" { - // Handle UUID-based operations (update/delete/add) - switch operation { - case "update": - // Update operation requires rule_id (new) and old_rule_id (to delete) - ruleID, hasRuleID := ruleData["rule_id"].(string) - oldRuleID, hasOldRuleID := ruleData["old_rule_id"].(string) - - if hasRuleID && ruleID != "" && hasOldRuleID && oldRuleID != "" { - cr.data.RuleID = ruleID - cr.data.OldRuleID = oldRuleID - ruleExitCode, ruleResult = cr.handleUpdateOperation() - } else { - ruleExitCode = 1 - ruleResult = "update operation requires both rule_id (new) and old_rule_id (to delete)" - } - case "delete": - // Delete operation requires rule_id - if ruleID, ok := ruleData["rule_id"].(string); ok && ruleID != "" { - cr.data.RuleID = ruleID - ruleExitCode, ruleResult = cr.handleDeleteOperation() - } else { - ruleExitCode = 1 - ruleResult = "delete operation requires rule_id" - } - case "add": - // Add operation - use handleAddOperation for proper validation and logging - log.Debug().Msgf("Batch add operation for rule %d/%d", i+1, len(cr.data.Rules)) - ruleExitCode, ruleResult = cr.handleAddOperation() - default: - ruleExitCode = 1 - ruleResult = fmt.Sprintf("unknown operation: %s", operation) - } - } else { - // Default: use method-based execution (-A, -I, -R, -D) - // This applies validation and logging via handleAddOperation - log.Debug().Msgf("Batch method-based operation for rule %d/%d (method: %s)", i+1, len(cr.data.Rules), cr.data.Method) - ruleExitCode, ruleResult = cr.handleAddOperation() - } - - if ruleExitCode == 0 { - appliedRules++ - log.Info().Msgf("Successfully applied rule %d/%d", i+1, len(cr.data.Rules)) - } else { - failedRule := map[string]interface{}{ - "rule": fmt.Sprintf("Rule %d: %s", i+1, cr.data.Description), - "error": ruleResult, - } - failedRules = append(failedRules, failedRule) - log.Error().Msgf("Failed to apply rule %d/%d: %s", i+1, len(cr.data.Rules), ruleResult) - - // Trigger rollback on first failure if enabled - if !rolledBack { - rollbackReason = fmt.Sprintf("Rule %d failed to apply", i+1) - rolledBack = true - - // Perform rollback to previous state - log.Info().Msg("Initiating rollback due to rule failure") - // Restore from backup - if restoreErr := utils.RestoreFirewallRules(backup); restoreErr != nil { - log.Error().Err(restoreErr).Msg("Failed to restore firewall rules from backup") - rollbackReason = fmt.Sprintf("Rule %d failed and restore failed: %v", i+1, restoreErr) - } else { - log.Info().Msg("Successfully restored firewall rules from backup") - } - - // Stop processing remaining rules - break - } - } - } - - return appliedRules, failedRules, rolledBack, rollbackReason -} - -// executeUninstall performs complete uninstallation of Alpamon -func (cr *CommandRunner) executeUninstall() { - // Only debian and rhel support systemd-run based uninstall - if utils.PlatformLike == "debian" { - cmd := "apt-get purge alpamon -y && apt-get autoremove -y" - cr.scheduleSystemdUninstall(cmd) - } else if utils.PlatformLike == "rhel" { - cmd := "yum remove alpamon -y" - cr.scheduleSystemdUninstall(cmd) - } else if utils.PlatformLike == "darwin" { - log.Warn().Msgf("Platform '%s' does not support full uninstall. Shutting down instead.", utils.PlatformLike) - } else { - log.Error().Msgf("Platform '%s' not supported for uninstall.", utils.PlatformLike) - } - - // Send deletion event to server synchronously before shutdown (ALL platforms) - _, statusCode, err := cr.apiSession.Delete(serverUnregisterURL, nil, 10) - if err != nil { - log.Error().Err(err).Msg("Failed to send unregister request to server") - } else if statusCode >= 200 && statusCode < 300 { - log.Info().Msgf("Successfully sent unregister request to server (status: %d)", statusCode) - } else { - log.Warn().Msgf("Unregister request returned status: %d", statusCode) - } - - // ShutDown asynchronously to ensure cleanup completes (ALL platforms) - go cr.wsClient.ShutDown() -} - -// scheduleSystemdUninstall schedules the uninstall command via systemd-run -func (cr *CommandRunner) scheduleSystemdUninstall(cmd string) { - // Build the complete uninstall command that includes: - // 1. Package removal - // 2. Cleanup of transient systemd units created by this operation - uninstallCmd := fmt.Sprintf("%s; systemctl reset-failed alpamon-uninstall.service 2>/dev/null || true; systemctl reset-failed alpamon-uninstall.timer 2>/dev/null || true", cmd) - - // This ensures the uninstall continues even after the current process terminates - // The service will start 5 seconds after being scheduled - // Use runCmdWithOutput directly to avoid shell parsing issues with handleShellCmd - // --collect: Automatically clean up transient units after they complete (systemd 236+) - scheduleCmdArgs := []string{ - "systemd-run", - "--on-active=5", - "--unit=alpamon-uninstall", - "--collect", - "/bin/sh", - "-c", - uninstallCmd, - } - - log.Debug().Msgf("Scheduling uninstall via systemd-run: %s", strings.Join(scheduleCmdArgs, " ")) - exitCode, result := runCmdWithOutput(scheduleCmdArgs, "root", "root", nil, 60) - - if exitCode != 0 { - log.Error().Msgf("Failed to schedule uninstall: %s", result) - } else { - log.Info().Msg("Alpamon uninstall scheduled via systemd, will execute in 5 seconds") - } -} - -// convertRuleDataToCommandData converts rule data map to CommandData fields -func (cr *CommandRunner) convertRuleDataToCommandData(ruleData map[string]interface{}, data CommandData) CommandData { - // Reset all optional fields to prevent conflicts between rules in batch operations - // This ensures each rule starts with a clean slate - data.Method = "-A" // Default to append - data.Chain = "" - data.Protocol = "" - data.PortStart = 0 - data.PortEnd = 0 - data.DPorts = nil - data.ICMPType = "" - data.Source = "" - data.Destination = "" - data.Target = "" - data.Description = "" - data.Priority = 0 - data.RuleType = "alpacon" // Default to alpacon type - data.RuleID = "" - data.OldRuleID = "" - - // Now set values from ruleData - if chainName, ok := ruleData["chain_name"].(string); ok { - data.ChainName = chainName - } - if method, ok := ruleData["method"].(string); ok { - data.Method = method - } - if chain, ok := ruleData["chain"].(string); ok { - data.Chain = chain - } - if protocol, ok := ruleData["protocol"].(string); ok { - data.Protocol = protocol - } - if portStart, ok := ruleData["port_start"].(float64); ok { - data.PortStart = int(portStart) - } - if portEnd, ok := ruleData["port_end"].(float64); ok { - data.PortEnd = int(portEnd) - } - if source, ok := ruleData["source"].(string); ok { - data.Source = source - } - if destination, ok := ruleData["destination"].(string); ok { - data.Destination = destination - } - if target, ok := ruleData["target"].(string); ok { - data.Target = target - } - if description, ok := ruleData["description"].(string); ok { - data.Description = description - } - if priority, ok := ruleData["priority"].(float64); ok { - data.Priority = int(priority) - } - if ruleType, ok := ruleData["rule_type"].(string); ok { - data.RuleType = ruleType - } - if icmpType, ok := ruleData["icmp_type"].(string); ok { - data.ICMPType = icmpType - } - if ruleID, ok := ruleData["rule_id"].(string); ok { - data.RuleID = ruleID - } else { - // Generate rule ID if not provided - data.RuleID = uuid.New().String() - } - - // Handle operation field for batch operations (add, update, delete) - if operation, ok := ruleData["operation"].(string); ok { - data.Operation = operation - } - - // Handle dports array - if dportsInterface, ok := ruleData["dports"].([]interface{}); ok { - dports := []int{} - for _, p := range dportsInterface { - if portStr, ok := p.(string); ok { - if port, err := strconv.Atoi(portStr); err == nil { - dports = append(dports, port) - } - } else if port, ok := p.(float64); ok { - dports = append(dports, int(port)) - } - } - data.DPorts = dports - } - - return data -} diff --git a/pkg/runner/command_types.go b/pkg/runner/command_types.go deleted file mode 100644 index 853bc8d..0000000 --- a/pkg/runner/command_types.go +++ /dev/null @@ -1,206 +0,0 @@ -package runner - -import ( - "github.com/alpacax/alpamon/pkg/scheduler" - "gopkg.in/go-playground/validator.v9" -) - -type Content struct { - Query string `json:"query"` - Command Command `json:"command,omitempty"` - Reason string `json:"reason,omitempty"` -} - -type Command struct { - ID string `json:"id"` - Shell string `json:"shell"` - Line string `json:"line"` - User string `json:"user"` - Group string `json:"group"` - Env map[string]string `json:"env"` - Data string `json:"data,omitempty"` -} - -type File struct { - Username string `json:"username"` - Groupname string `json:"groupname"` - Type string `json:"type"` - Content string `json:"content"` - Path string `json:"path"` - AllowOverwrite bool `json:"allow_overwrite"` - AllowUnzip bool `json:"allow_unzip"` - URL string `json:"url"` -} - -type CommandData struct { - SessionID string `json:"session_id"` - URL string `json:"url"` - Rows uint16 `json:"rows"` - Cols uint16 `json:"cols"` - Username string `json:"username"` - Groupname string `json:"groupname"` - Groupnames []string `json:"groupnames"` - HomeDirectory string `json:"home_directory"` - HomeDirectoryPermission string `json:"home_directory_permission"` - PurgeHomeDirectory bool `json:"purge_home"` - UID uint64 `json:"uid"` - GID uint64 `json:"gid"` - Comment string `json:"comment"` - Shell string `json:"shell"` - Groups []uint64 `json:"groups"` - Type string `json:"type"` - Content string `json:"content"` - Path string `json:"path"` - Paths []string `json:"paths"` - Files []File `json:"files,omitempty"` - AllowOverwrite bool `json:"allow_overwrite,omitempty"` - AllowUnzip bool `json:"allow_unzip,omitempty"` - UseBlob bool `json:"use_blob,omitempty"` - Keys []string `json:"keys"` - ChainName string `json:"chain_name"` - Method string `json:"method"` - Chain string `json:"chain"` - Protocol string `json:"protocol"` - PortStart int `json:"port_start"` - PortEnd int `json:"port_end"` - DPorts []int `json:"dports"` - ICMPType string `json:"icmp_type"` - Source string `json:"source"` - Destination string `json:"destination"` - Target string `json:"target"` - Description string `json:"description"` - Priority int `json:"priority"` - RuleType string `json:"rule_type"` - Rules []map[string]interface{} `json:"rules"` - Operation string `json:"operation"` // batch, flush, delete, add, update - RuleID string `json:"rule_id"` // for rule-specific operations (add/update: new rule ID) - OldRuleID string `json:"old_rule_id"` // for update operation: old rule ID to delete - AssignmentID string `json:"assignment_id"` - ServerID string `json:"server_id"` - ChainNames []string `json:"chain_names"` // for firewall-reorder-chains - TargetPort int `json:"target_port"` // for tunneling -} - -type firewallData struct { - ChainName string `validate:"required"` - Method string `validate:"omitempty"` - Chain string `validate:"omitempty"` - Protocol string `validate:"omitempty"` - PortStart int `validate:"omitempty"` - PortEnd int `validate:"omitempty"` - DPorts []int `validate:"omitempty"` - ICMPType string `validate:"omitempty"` - Source string `validate:"omitempty"` - Destination string `validate:"omitempty"` - Target string `validate:"omitempty"` - Description string `validate:"omitempty"` - Priority int `validate:"omitempty"` - RuleType string `validate:"omitempty,oneof=alpacon server"` - RuleID string `validate:"omitempty"` - Operation string `validate:"required"` // batch, flush, delete, add, update -} -type CommandRunner struct { - name string - command Command - wsClient *WebsocketClient - apiSession *scheduler.Session - data CommandData - validator *validator.Validate -} - -// Structs defining the required input data for command validation purposes. // - -type addUserData struct { - Username string `validate:"required"` - UID uint64 `validate:"required"` - GID uint64 `validate:"required"` - Comment string `validate:"required"` - HomeDirectory string `validate:"required"` - HomeDirectoryPermission string `validate:"omitempty"` // Use omitempty for backward compatibility - Shell string `validate:"required"` - Groupname string `validate:"required"` -} - -type addGroupData struct { - Groupname string `validate:"required"` - GID uint64 `validate:"required"` -} - -type deleteUserData struct { - Username string `validate:"required"` - PurgeHomeDirectory bool `validate:"omitempty"` -} - -type deleteGroupData struct { - Groupname string `validate:"required"` -} - -type modUserData struct { - Username string `validate:"required"` - Groupnames []string `validate:"required"` - Comment string `validate:"required"` -} - -type openPtyData struct { - SessionID string `validate:"required"` - URL string `validate:"required"` - Username string `validate:"required"` - Groupname string `validate:"required"` - HomeDirectory string `validate:"required"` - Rows uint16 `validate:"required"` - Cols uint16 `validate:"required"` -} - -type openFtpData struct { - SessionID string `validate:"required"` - URL string `validate:"required"` - Username string `validate:"required"` - Groupname string `validate:"required"` - HomeDirectory string `validate:"required"` -} - -type openTunnelData struct { - SessionID string `validate:"required"` - TargetPort int `validate:"required"` - URL string `validate:"required"` -} - -type closeTunnelData struct { - SessionID string `validate:"required"` -} - -type commandFin struct { - Success bool `json:"success"` - Result string `json:"result"` - ElapsedTime float64 `json:"elapsed_time"` -} - -type commandStat struct { - Success bool `json:"success"` - Message string `json:"message"` - Type transferType `json:"type"` -} - -type transferType string - -const ( - DOWNLOAD transferType = "download" - UPLOAD transferType = "upload" -) - -var nonZipExt = map[string]bool{ - ".jar": true, - ".war": true, - ".ear": true, - ".apk": true, - ".xpi": true, - ".vsix": true, - ".crx": true, - ".egg": true, - ".whl": true, - ".appx": true, - ".msix": true, - ".ipk": true, - ".nupkg": true, - ".kmz": true, -} diff --git a/pkg/runner/commit.go b/pkg/runner/commit.go index 405a87f..eb3b81d 100644 --- a/pkg/runner/commit.go +++ b/pkg/runner/commit.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/scheduler" "github.com/alpacax/alpamon/pkg/utils" "github.com/alpacax/alpamon/pkg/version" @@ -41,18 +42,31 @@ const ( var syncMutex sync.Mutex -func CommitAsync(session *scheduler.Session, commissioned bool) { +// CommitAsync commits system information asynchronously +// Uses ContextManager for coordinated lifecycle management +func CommitAsync(session *scheduler.Session, commissioned bool, ctxManager *agent.ContextManager) { if commissioned { + // Use a goroutine with delayed execution for commissioned systems go func() { - time.Sleep(5 * time.Second) - syncSystemInfo(session, nil) + // Get application-level context for shutdown coordination + ctx := ctxManager.Root() + + // Wait for either timeout or shutdown signal + select { + case <-time.After(5 * time.Second): + // Timeout occurred, proceed with sync + SyncSystemInfo(session, nil) + case <-ctx.Done(): + // Shutdown occurred before timeout, skip sync + log.Debug().Msg("Skipping syncSystemInfo due to shutdown") + } }() } else { - go commitSystemInfo() + go CommitSystemInfo() } } -func commitSystemInfo() { +func CommitSystemInfo() { log.Debug().Msg("Start committing system information.") data := collectData() @@ -64,24 +78,18 @@ func commitSystemInfo() { "description": "Committed system information. version: %s"}`, version.Version)), 80, time.Time{}) // Sync firewall rules after committing system info - // Skip if firewall functionality is disabled or high-level firewall tools are detected + // Skip if firewall functionality is disabled + // Note: Full firewall sync is handled by FirewallHandler in executor package if utils.IsFirewallDisabled() { log.Info().Msg("Skipping firewall sync - firewall functionality is temporarily disabled") - } else if detected, toolName := utils.DetectHighLevelFirewall(); detected { - log.Info().Msgf("Skipping firewall sync - %s is active", toolName) } else { - firewallData, err := utils.CollectFirewallRules() - if err != nil { - log.Debug().Err(err).Msg("Failed to collect firewall rules during commit.") - } else { - scheduler.Rqueue.Post(firewallSyncURL, firewallData, 80, time.Time{}) - } + log.Debug().Msg("Firewall sync delegated to executor FirewallHandler") } log.Info().Msg("Completed committing system information.") } -func syncSystemInfo(session *scheduler.Session, keys []string) { +func SyncSystemInfo(session *scheduler.Session, keys []string) { log.Debug().Msg("Start system information synchronization.") syncMutex.Lock() @@ -161,21 +169,13 @@ func syncSystemInfo(session *scheduler.Session, keys []string) { remoteData = &[]Partition{} case "firewall": // Firewall sync only posts current rules without comparison - // Skip if firewall functionality is disabled or high-level firewall tools are detected + // Skip if firewall functionality is disabled + // Note: Full firewall sync is handled by FirewallHandler in executor package if utils.IsFirewallDisabled() { log.Info().Msg("Skipping firewall sync - firewall functionality is temporarily disabled") continue } - if detected, toolName := utils.DetectHighLevelFirewall(); detected { - log.Info().Msgf("Skipping firewall sync - %s is active", toolName) - continue - } - firewallData, err := utils.CollectFirewallRules() - if err != nil { - log.Debug().Err(err).Msg("Failed to collect firewall rules.") - continue - } - scheduler.Rqueue.Post(utils.JoinPath(entry.URL, entry.URLSuffix), firewallData, 80, time.Time{}) + log.Debug().Msg("Firewall sync delegated to executor FirewallHandler") continue default: log.Warn().Msgf("Unknown key: %s", key) diff --git a/pkg/runner/firewall.go b/pkg/runner/firewall.go deleted file mode 100644 index e2274d4..0000000 --- a/pkg/runner/firewall.go +++ /dev/null @@ -1,128 +0,0 @@ -package runner - -import ( - "fmt" - - "github.com/alpacax/alpamon/pkg/utils" - "github.com/rs/zerolog/log" -) - -// firewallReorderRules handles the firewall-reorder-rules command -// Flushes a chain and reapplies all rules in the specified order -// TODO: Implement zero-downtime reordering to prevent complete firewall shutdown: -// 1. Create temporary table/chain with new rules in desired order -// 2. Atomically swap with existing table/chain -// 3. Clean up old table/chain -// This prevents firewall from being completely down during flush+reorder -func (cr *CommandRunner) firewallReorderRules() (exitCode int, result string) { - log.Info().Msgf("Firewall reorder rules command received for chain: %s", cr.data.ChainName) - - // Validate required fields - if cr.data.ChainName == "" { - return 1, "firewall-reorder-rules: chain_name is required" - } - if len(cr.data.Rules) == 0 { - return 1, "firewall-reorder-rules: rules array is required" - } - - log.Debug().Msgf("Reordering %d rules in chain %s", len(cr.data.Rules), cr.data.ChainName) - - // Detect firewall backend - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall-reorder-rules: Failed to check firewall installation: %v", err) - } - - // Flush the chain/table before applying new rules - var flushExitCode int - var flushOutput string - - if nftablesInstalled { - // For nftables, chain_name is actually the table name (security group) - flushExitCode, flushOutput = runCmdWithOutput( - []string{"nft", "flush", "table", "inet", cr.data.ChainName}, - "root", "", nil, 10, - ) - } else if iptablesInstalled { - // For iptables, chain_name is the actual chain name - flushExitCode, flushOutput = runCmdWithOutput( - []string{"iptables", "-F", cr.data.ChainName}, - "root", "", nil, 10, - ) - } else { - return 1, "firewall-reorder-rules: No firewall backend available" - } - - if flushExitCode != 0 { - log.Error().Msgf("Failed to flush %s: %s", cr.data.ChainName, flushOutput) - return 1, fmt.Sprintf("firewall-reorder-rules: Failed to flush %s: %s", cr.data.ChainName, flushOutput) - } - - log.Info().Msgf("Successfully flushed %s", cr.data.ChainName) - - // Use the common batch apply logic with flush (same as batch operation) - appliedRules, failedRules, rolledBack, rollbackReason := cr.applyRulesBatchWithFlush() - - // Prepare response - if rolledBack { - return 1, fmt.Sprintf(`{"success": false, "error": "Failed to reorder rules", "applied_rules": %d, "failed_rules": %d, "rolled_back": true, "rollback_reason": "%s"}`, - appliedRules, len(failedRules), rollbackReason) - } - - log.Info().Msgf("Successfully reordered %d rules in chain %s", appliedRules, cr.data.ChainName) - return 0, fmt.Sprintf(`{"success": true, "message": "Rules reordered successfully", "chain": "%s", "applied_rules": %d, "failed_rules": [], "rolled_back": false}`, - cr.data.ChainName, appliedRules) -} - -// firewallReorderChains handles the firewall-reorder-chains command -// Reorders INPUT chain jump rules for security groups -// TODO: Implement zero-downtime chain reordering to prevent firewall shutdown: -// 1. Create temporary chain with jump rules in new order -// 2. Atomically swap INPUT chain reference to temporary chain -// 3. Clean up old chain -// This prevents firewall from being completely down during chain reordering -func (cr *CommandRunner) firewallReorderChains() (exitCode int, result string) { - log.Info().Msg("Firewall reorder chains command received") - - // Get chain_names from data - chainNames := cr.data.ChainNames - if len(chainNames) == 0 { - return 1, "firewall-reorder-chains: No chain_names provided" - } - - log.Debug().Msgf("Reordering chains: %v", chainNames) - - // Detect firewall backend - nftablesInstalled, iptablesInstalled, err := utils.CheckFirewallTool() - if err != nil { - return 1, fmt.Sprintf("firewall-reorder-chains: Failed to check firewall installation: %v", err) - } - - var deletedRules int - - // Execute reordering based on backend - if nftablesInstalled { - resultData, err := utils.ReorderNftablesChains(chainNames) - if err != nil { - log.Error().Err(err).Msg("Failed to reorder firewall chains") - return 1, fmt.Sprintf("firewall-reorder-chains: %v", err) - } - if count, ok := resultData["deleted_rules"].(int); ok { - deletedRules = count - } - } else if iptablesInstalled { - resultData, err := utils.ReorderIptablesChains(chainNames) - if err != nil { - log.Error().Err(err).Msg("Failed to reorder firewall chains") - return 1, fmt.Sprintf("firewall-reorder-chains: %v", err) - } - if count, ok := resultData["deleted_rules"].(int); ok { - deletedRules = count - } - } else { - return 1, "firewall-reorder-chains: No firewall backend available" - } - - log.Info().Msgf("Successfully reordered %d chains", len(chainNames)) - return 0, fmt.Sprintf(`{"success": true, "message": "Chains reordered successfully", "reordered_chains": %d, "deleted_rules": %d}`, len(chainNames), deletedRules) -} diff --git a/pkg/runner/pty.go b/pkg/runner/pty.go index db5484c..94bfdd1 100644 --- a/pkg/runner/pty.go +++ b/pkg/runner/pty.go @@ -10,9 +10,11 @@ import ( "os/exec" "strconv" "strings" + "sync" "sync/atomic" "time" + "github.com/alpacax/alpamon/internal/protocol" "github.com/alpacax/alpamon/pkg/config" "github.com/alpacax/alpamon/pkg/scheduler" "github.com/alpacax/alpamon/pkg/utils" @@ -48,13 +50,16 @@ const ( sessionCloseCode = 4000 ) -var terminals map[string]*PtyClient +var ( + terminals map[string]*PtyClient + terminalsMu sync.RWMutex +) func init() { terminals = make(map[string]*PtyClient) } -func NewPtyClient(data CommandData, apiSession *scheduler.Session) *PtyClient { +func NewPtyClient(data protocol.CommandData, apiSession *scheduler.Session) *PtyClient { headers := http.Header{ "Authorization": {fmt.Sprintf(`id="%s", key="%s"`, config.GlobalSettings.ID, config.GlobalSettings.Key)}, "Origin": {config.GlobalSettings.ServerURL}, @@ -100,8 +105,11 @@ func (pc *PtyClient) initializePtySession() error { return fmt.Errorf("failed to start pty: %w", err) } + terminalsMu.Lock() terminals[pc.sessionID] = pc + terminalsMu.Unlock() + // Add PID-to-session mapping for auth manager pid := pc.cmd.Process.Pid sessionInfo := &SessionInfo{ SessionID: pc.sessionID, @@ -266,6 +274,18 @@ func (pc *PtyClient) resize(rows, cols uint16) error { return nil } +// Resize is the exported version of resize for external packages +func (pc *PtyClient) Resize(rows, cols uint16) error { + return pc.resize(rows, cols) +} + +// GetTerminal returns the PTY client for the given session ID +func GetTerminal(sessionID string) *PtyClient { + terminalsMu.RLock() + defer terminalsMu.RUnlock() + return terminals[sessionID] +} + // close terminates the PTY session and cleans up resources. // It ensures that the PTY, command, and WebSocket connection are properly closed. func (pc *PtyClient) close() { @@ -283,9 +303,11 @@ func (pc *PtyClient) close() { _ = pc.cmd.Wait() } + terminalsMu.Lock() if terminals[pc.sessionID] != nil { delete(terminals, pc.sessionID) } + terminalsMu.Unlock() if pc.conn != nil { err := pc.conn.WriteControl( diff --git a/pkg/runner/shell.go b/pkg/runner/shell.go deleted file mode 100644 index f18d9eb..0000000 --- a/pkg/runner/shell.go +++ /dev/null @@ -1,216 +0,0 @@ -package runner - -import ( - "context" - "fmt" - "os" - "os/exec" - "os/user" - "strconv" - "strings" - "syscall" - "time" - - "github.com/alpacax/alpamon/pkg/utils" - "github.com/rs/zerolog/log" -) - -func demote(username, groupname string) (*syscall.SysProcAttr, error) { - currentUid := os.Getuid() - - if username == "" || groupname == "" { - log.Debug().Msg("No username or groupname provided, running as the current user.") - return nil, nil - } - - if currentUid != 0 { - log.Warn().Msg("Alpamon is not running as root. Falling back to the current user.") - return nil, nil - } - - usr, err := user.Lookup(username) - if err != nil { - return nil, fmt.Errorf("there is no corresponding %s username in this server", username) - } - - group, err := user.LookupGroup(groupname) - if err != nil { - return nil, fmt.Errorf("there is no corresponding %s groupname in this server", groupname) - } - - uid, err := strconv.ParseUint(usr.Uid, 10, 32) - if err != nil { - return nil, err - } - - gid, err := strconv.ParseUint(group.Gid, 10, 32) - if err != nil { - return nil, err - } - - groupIds, err := usr.GroupIds() - if err != nil { - return nil, err - } - - groups := make([]uint32, 0, len(groupIds)) - groupInList := false - for _, gidStr := range groupIds { - gidUint, err := strconv.ParseUint(gidStr, 10, 32) - if err != nil { - return nil, err - } - if gidUint == gid { - groupInList = true - } - groups = append(groups, uint32(gidUint)) - } - if !groupInList { - return nil, fmt.Errorf("groupname %s is not in user %s's group list", groupname, username) - } - - log.Debug().Msgf("Demote permission to match user: %s, group: %s.", username, groupname) - - return &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - Groups: groups, - }, - }, nil -} - -func demoteFtp(username, groupname string) (*syscall.SysProcAttr, string, error) { - currentUid := os.Getuid() - - if username == "" || groupname == "" { - log.Debug().Msg("No username or groupname provided, running as the current user.") - return nil, "", nil - } - - if currentUid != 0 { - log.Warn().Msg("Alpamon is not running as root. Falling back to the current user.") - return nil, "", nil - } - - usr, err := user.Lookup(username) - if err != nil { - return nil, "", fmt.Errorf("there is no corresponding %s username in this server", username) - } - - group, err := user.LookupGroup(groupname) - if err != nil { - return nil, "", fmt.Errorf("there is no corresponding %s groupname in this server", groupname) - } - - uid, err := strconv.ParseUint(usr.Uid, 10, 32) - if err != nil { - return nil, "", err - } - - gid, err := strconv.ParseUint(group.Gid, 10, 32) - if err != nil { - return nil, "", err - } - - groupIds, err := usr.GroupIds() - if err != nil { - return nil, "", err - } - - groups := make([]uint32, 0, len(groupIds)) - for _, gidStr := range groupIds { - gidInt, err := strconv.Atoi(gidStr) - if err != nil { - return nil, "", err - } - groups = append(groups, uint32(gidInt)) - } - - log.Debug().Msgf("Demote permission to match user: %s, group: %s.", username, groupname) - - return &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - Groups: groups, - }, - }, usr.HomeDir, nil -} - -func runCmdWithOutput(args []string, username, groupname string, env map[string]string, timeout int) (exitCode int, result string) { - if env != nil { - defaultEnv := getDefaultEnv() - for key, value := range defaultEnv { - if _, exists := env[key]; !exists { - env[key] = value - } - } - for i := range args { - if strings.HasPrefix(args[i], "${") && strings.HasSuffix(args[i], "}") { - varName := args[i][2 : len(args[i])-1] - if val, ok := env[varName]; ok { - args[i] = val - } - } else if strings.HasPrefix(args[i], "$") { - varName := args[i][1:] - if val, ok := env[varName]; ok { - args[i] = val - } - } - } - } - - var ctx context.Context - var cancel context.CancelFunc - - if timeout > 0 { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) - } else { - ctx, cancel = context.WithCancel(context.Background()) - } - defer cancel() - - // Check if args is empty - if len(args) == 0 { - return 1, "no command provided" - } - - // Get user info first to determine working directory for glob expansion - usr, err := utils.GetSystemUser(username) - if err != nil { - return 1, err.Error() - } - - // Expand glob patterns in arguments using the user's home directory as base - expandedArgs := utils.ExpandGlobArgs(args[1:], usr.HomeDir) - cmd := exec.CommandContext(ctx, args[0], expandedArgs...) - - if username != "root" { - sysProcAttr, err := demote(username, groupname) - if err != nil { - log.Error().Err(err).Msg("Failed to demote user.") - return -1, err.Error() - } - if sysProcAttr != nil { - cmd.SysProcAttr = sysProcAttr - } - } - - for key, value := range env { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value)) - } - - cmd.Dir = usr.HomeDir - - log.Debug().Msgf("Executing command as user '%s' (group: '%s') -> '%s'", username, groupname, strings.Join(args, " ")) - output, err := cmd.CombinedOutput() - if err != nil { - if exitError, ok := err.(*exec.ExitError); ok { - return exitError.ExitCode(), string(output) - } - return -1, err.Error() - } - - return 0, string(output) -} diff --git a/pkg/scheduler/reporter.go b/pkg/scheduler/reporter.go index c31bc72..6e15a8d 100644 --- a/pkg/scheduler/reporter.go +++ b/pkg/scheduler/reporter.go @@ -1,6 +1,7 @@ package scheduler import ( + "context" "encoding/json" "fmt" "math" @@ -8,6 +9,7 @@ import ( "sync" "time" + "github.com/alpacax/alpamon/pkg/agent" "github.com/alpacax/alpamon/pkg/config" "github.com/alpacax/alpamon/pkg/utils" "github.com/alpacax/alpamon/pkg/version" @@ -32,20 +34,58 @@ func NewReporter(index int, session *Session) *Reporter { } } -func StartReporters(session *Session) { +func StartReporters(session *Session, ctxManager *agent.ContextManager) *ReporterManager { newRequestQueue() // init RequestQueue - wg := sync.WaitGroup{} + wg := &sync.WaitGroup{} + cancels := make([]func(), 0, config.GlobalSettings.HTTPThreads) + for i := 0; i < config.GlobalSettings.HTTPThreads; i++ { wg.Add(1) reporter := NewReporter(i, session) - go func() { + // Create context for each reporter with no timeout + ctx, cancel := ctxManager.NewContext(0) + cancels = append(cancels, cancel) + go func(ctx context.Context) { defer wg.Done() - reporter.Run() - }() + reporter.Run(ctx) + }(ctx) } reportStartupEvent() + + return &ReporterManager{ + wg: wg, + cancels: cancels, + } +} + +// Shutdown gracefully stops all reporters with a timeout +func (rm *ReporterManager) Shutdown(timeout time.Duration) error { + log.Info().Msg("Shutting down reporters...") + + // Cancel all reporter contexts + for _, cancel := range rm.cancels { + cancel() + } + + // Wake up all waiting reporters + Rqueue.cond.Broadcast() + + // Wait for all reporters to finish with timeout + done := make(chan struct{}) + go func() { + rm.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Info().Msg("All reporters shut down gracefully") + return nil + case <-time.After(timeout): + return fmt.Errorf("reporter shutdown timeout after %v", timeout) + } } func reportStartupEvent() { @@ -99,28 +139,70 @@ func (r *Reporter) query(entry PriorityEntry) { } } -func (r *Reporter) Run() { +func (r *Reporter) Run(ctx context.Context) { for { - Rqueue.cond.L.Lock() - for Rqueue.queue.Size() == 0 { - Rqueue.cond.Wait() + // Check for shutdown signal + select { + case <-ctx.Done(): + log.Debug().Msgf("Reporter %s shutting down", r.name) + return + default: } - entry, err := Rqueue.queue.Get() - Rqueue.cond.L.Unlock() - if err != nil { - continue + + // Wait for queue entry + entry, ok := r.waitForEntry(ctx) + if !ok { + return // Context cancelled during wait + } + + // Process the entry + r.processEntry(entry) + } +} + +// waitForEntry waits for an entry from the queue or context cancellation +func (r *Reporter) waitForEntry(ctx context.Context) (PriorityEntry, bool) { + Rqueue.cond.L.Lock() + defer Rqueue.cond.L.Unlock() + + for Rqueue.queue.Size() == 0 { + // Check context before waiting + select { + case <-ctx.Done(): + log.Debug().Msgf("Reporter %s shutting down", r.name) + return PriorityEntry{}, false + default: } + Rqueue.cond.Wait() + } - if !entry.expiry.IsZero() && entry.expiry.Before(time.Now()) { + entry, err := Rqueue.queue.Get() + if err != nil { + // Return empty entry to continue loop + return PriorityEntry{}, true + } + + return entry, true +} + +// processEntry handles the business logic for a queue entry +func (r *Reporter) processEntry(entry PriorityEntry) { + // Handle expired entries + if !entry.expiry.IsZero() && entry.expiry.Before(time.Now()) { + r.counters.ignored++ + return + } + + // Handle entries that are not yet due + if !entry.due.IsZero() && entry.due.After(time.Now()) { + err := Rqueue.queue.Offer(entry) + if err != nil { r.counters.ignored++ - } else if !entry.due.IsZero() && entry.due.After(time.Now()) { - err = Rqueue.queue.Offer(entry) - if err != nil { - r.counters.ignored++ - } - time.Sleep(1 * time.Second) - } else { - r.query(entry) } + time.Sleep(1 * time.Second) + return } + + // Process the entry + r.query(entry) } diff --git a/pkg/scheduler/types.go b/pkg/scheduler/types.go index 5f1e54e..702f57e 100644 --- a/pkg/scheduler/types.go +++ b/pkg/scheduler/types.go @@ -43,3 +43,9 @@ type counters struct { delay float64 latency float64 } + +// ReporterManager manages the lifecycle of multiple reporter goroutines +type ReporterManager struct { + wg *sync.WaitGroup + cancels []func() +} diff --git a/pkg/utils/firewall.go b/pkg/utils/firewall.go index d972777..cf38dc6 100644 --- a/pkg/utils/firewall.go +++ b/pkg/utils/firewall.go @@ -1,40 +1,10 @@ package utils import ( - "encoding/json" "fmt" - "os" - "regexp" - "strconv" "strings" - "sync" - "time" "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -// Firewall tool check state (caching) -var ( - firewallCheckMutex sync.Mutex - firewallCheckAttempted bool - firewallNftablesInstalled bool - firewallIptablesInstalled bool - firewallCheckError error - - // Feature flag to disable automatic rule recreation - // Set to true to prevent conflicts with ufw/firewalld - disableRuleRecreation = true - - // Temporary flag to disable all firewall functionality - // Set to true to completely disable alpacon firewall management - firewallFunctionalityDisabled = true - - // High-level firewall detection cache - highLevelFirewallCheckMutex sync.Mutex - highLevelFirewallCheckAttempted bool - highLevelFirewallDetected bool - highLevelFirewallToolName string ) // Default values matching alpacon-server FirewallRuleSyncSerializer @@ -50,6 +20,10 @@ const ( RuleTypeAlpacon = "alpacon" // Alpacon-created rules (remove last) ) +// Temporary flag to disable all firewall functionality +// Set to true to completely disable alpacon firewall management +var FirewallFunctionalityDisabled = true + // FirewallChainSync represents a firewall chain for sync payload type FirewallChainSync struct { Name string `json:"name"` @@ -78,190 +52,9 @@ type FirewallSyncPayload struct { Chains []FirewallChainSync `json:"chains"` } -// Note: Firewall rules caching removed to ensure real-time rule synchronization -// and prevent stale data issues during rapid rule updates - -// FirewallCommandExecutor is a function type for executing firewall commands -// This allows the runner package to inject its runCmdWithOutput function -type FirewallCommandExecutor func(args []string, user string, dir string, env map[string]string, timeout int) (exitCode int, output string) - -var commandExecutor FirewallCommandExecutor - -// SetFirewallCommandExecutor sets the command executor function -// This should be called from the runner package to inject its runCmdWithOutput -func SetFirewallCommandExecutor(executor FirewallCommandExecutor) { - commandExecutor = executor -} - -// runFirewallCommand executes a firewall command using the injected executor -func runFirewallCommand(args []string, timeout int) (exitCode int, output string) { - if commandExecutor == nil { - return 1, "firewall command executor not initialized" - } - return commandExecutor(args, "root", "", nil, timeout) -} - // IsFirewallDisabled checks if firewall functionality is disabled func IsFirewallDisabled() bool { - return firewallFunctionalityDisabled -} - -// DetectHighLevelFirewall detects if high-level firewall management tools are active -// Returns (detected, toolName) where toolName is "ufw" or "firewalld" -func DetectHighLevelFirewall() (detected bool, toolName string) { - // Use mutex to prevent concurrent checks - highLevelFirewallCheckMutex.Lock() - defer highLevelFirewallCheckMutex.Unlock() - - // Return cached result if we've already checked - if highLevelFirewallCheckAttempted { - return highLevelFirewallDetected, highLevelFirewallToolName - } - - // 1. Check ufw via systemctl (most reliable) - exitCode, output := runFirewallCommand([]string{"systemctl", "is-active", "ufw"}, 5) - if exitCode == 0 && strings.TrimSpace(output) == "active" { - highLevelFirewallCheckAttempted = true - highLevelFirewallDetected = true - highLevelFirewallToolName = "ufw" - log.Info().Msg("Detected active ufw firewall - alpacon firewall management will be disabled") - return true, "ufw" - } - - // 2. Fallback: Check ufw via direct command - exitCode, output = runFirewallCommand([]string{"ufw", "status"}, 5) - if exitCode == 0 && strings.Contains(strings.ToLower(output), "status: active") { - highLevelFirewallCheckAttempted = true - highLevelFirewallDetected = true - highLevelFirewallToolName = "ufw" - log.Info().Msg("Detected active ufw firewall - alpacon firewall management will be disabled") - return true, "ufw" - } - - // 3. Check firewalld via systemctl - exitCode, output = runFirewallCommand([]string{"systemctl", "is-active", "firewalld"}, 5) - if exitCode == 0 && strings.TrimSpace(output) == "active" { - highLevelFirewallCheckAttempted = true - highLevelFirewallDetected = true - highLevelFirewallToolName = "firewalld" - log.Info().Msg("Detected active firewalld - alpacon firewall management will be disabled") - return true, "firewalld" - } - - // 4. Fallback: Check firewalld via firewall-cmd - exitCode, output = runFirewallCommand([]string{"firewall-cmd", "--state"}, 5) - if exitCode == 0 && strings.Contains(strings.ToLower(output), "running") { - highLevelFirewallCheckAttempted = true - highLevelFirewallDetected = true - highLevelFirewallToolName = "firewalld" - log.Info().Msg("Detected active firewalld - alpacon firewall management will be disabled") - return true, "firewalld" - } - - // No high-level firewall detected - highLevelFirewallCheckAttempted = true - highLevelFirewallDetected = false - highLevelFirewallToolName = "" - log.Debug().Msg("No high-level firewall detected - alpacon firewall management enabled") - return false, "" -} - -// CheckFirewallTool checks if firewall tools (nftables or iptables) are installed -// and detects which backend to use based on existing rules -// Returns (nftablesInstalled, iptablesInstalled, error) -func CheckFirewallTool() (nftablesInstalled bool, iptablesInstalled bool, err error) { - // Use mutex to prevent concurrent checks - firewallCheckMutex.Lock() - defer firewallCheckMutex.Unlock() - - // Return cached result if we've already checked - if firewallCheckAttempted { - return firewallNftablesInstalled, firewallIptablesInstalled, firewallCheckError - } - - // Detect backend based on existing rules - backend := detectFirewallBackend() - - if backend == "iptables" { - nftablesInstalled = false - iptablesInstalled = true - log.Info().Msg("Using iptables backend (existing iptables rules detected)") - } else if backend == "nftables" { - nftablesInstalled = true - iptablesInstalled = false - log.Info().Msg("Using nftables backend (no iptables rules found)") - } else { - // Neither backend available - firewallCheckError = fmt.Errorf("firewall tool not installed: neither nftables nor iptables is available") - firewallCheckAttempted = true - return false, false, firewallCheckError - } - - // Cache the result - firewallCheckAttempted = true - firewallNftablesInstalled = nftablesInstalled - firewallIptablesInstalled = iptablesInstalled - firewallCheckError = nil - - return nftablesInstalled, iptablesInstalled, nil -} - -// detectFirewallBackend detects which firewall backend to use based on existing rules -// Returns "iptables", "nftables", or "none" -func detectFirewallBackend() string { - // 1. Try iptables-save to check for existing iptables rules - exitCode, output := runFirewallCommand([]string{"iptables-save"}, 10) - - if exitCode == 0 { - // Count actual rules (lines starting with -A or -I) - ruleCount := 0 - for _, line := range strings.Split(output, "\n") { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "-A ") || strings.HasPrefix(line, "-I ") { - ruleCount++ - } - } - - if ruleCount > 0 { - log.Debug().Msgf("Found %d iptables rules", ruleCount) - return "iptables" - } - - // iptables-save succeeded but no rules - check if nft is available - exitCode, _ := runFirewallCommand([]string{"which", "nft"}, 5) - if exitCode == 0 { - log.Debug().Msg("No iptables rules, nft available") - return "nftables" - } - - // Only iptables available, no nft - log.Debug().Msg("No iptables rules, nft not available, defaulting to iptables") - return "iptables" - } - - // 2. iptables-save failed, try fallback with iptables -S - exitCode, output = runFirewallCommand([]string{"iptables", "-S"}, 10) - if exitCode == 0 { - // Check for rules (iptables -S output starts with -P, -A, -I, etc) - for _, line := range strings.Split(output, "\n") { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "-A ") || strings.HasPrefix(line, "-I ") { - log.Debug().Msg("Found iptables rules via iptables -S") - return "iptables" - } - } - } - - // 3. No iptables rules found, check if nft is available - exitCode, _ = runFirewallCommand([]string{"which", "nft"}, 5) - if exitCode == 0 { - log.Debug().Msg("No iptables rules, using nftables") - return "nftables" - } - - // 4. Neither iptables nor nft available - log.Warn().Msg("No firewall backend available") - return "none" + return FirewallFunctionalityDisabled } // ParseFirewallComment parses firewall rule comment to extract rule_id and type @@ -341,529 +134,8 @@ func ParseCommentOrGenerate(comment string) (ruleID, ruleType string) { return ruleID, ruleType } -// RecreateNftablesRuleWithComment re-creates an nftables rule with updated comment -// Returns true if re-creation was successful -func RecreateNftablesRuleWithComment(tableName string, rule *FirewallRuleSync, newComment string) bool { - // Build nft add rule command with new comment - args := []string{"nft", "add", "rule", tableName, rule.Chain} - - // Add protocol match - if rule.Protocol != "" && rule.Protocol != DefaultProtocol { - args = append(args, "meta", "l4proto", rule.Protocol) - } - - // Add source CIDR match - if rule.SourceCIDR != "" && rule.SourceCIDR != DefaultCIDR { - args = append(args, "ip", "saddr", rule.SourceCIDR) - } - - // Add destination CIDR match - if rule.DestinationCIDR != "" && rule.DestinationCIDR != DefaultCIDR { - args = append(args, "ip", "daddr", rule.DestinationCIDR) - } - - // Add port matches - if rule.Protocol == "tcp" || rule.Protocol == "udp" { - if rule.Dports != "" { - // Multiple ports - ports := strings.Split(rule.Dports, ",") - args = append(args, rule.Protocol, "dport", "{") - args = append(args, ports...) - args = append(args, "}") - } else if rule.PortStart != nil { - if rule.PortEnd != nil && *rule.PortEnd != *rule.PortStart { - // Port range - args = append(args, rule.Protocol, "dport", fmt.Sprintf("%d-%d", *rule.PortStart, *rule.PortEnd)) - } else { - // Single port - args = append(args, rule.Protocol, "dport", fmt.Sprintf("%d", *rule.PortStart)) - } - } - } - - // Add ICMP type match - if rule.Protocol == "icmp" && rule.ICMPType != nil { - args = append(args, "icmp", "type", fmt.Sprintf("%d", *rule.ICMPType)) - } - - // Add target/verdict (must come before comment) - args = append(args, strings.ToLower(rule.Target)) - - // Add comment with updated metadata (must come after verdict) - if newComment != "" { - args = append(args, "comment", fmt.Sprintf("\"%s\"", newComment)) - } - - // Execute the command - exitCode, _ := runFirewallCommand(args, 10) - return exitCode == 0 -} - -// RecreateIptablesRuleWithComment re-creates an iptables rule with updated comment -// Returns true if re-creation was successful -func RecreateIptablesRuleWithComment(chainName string, rule *FirewallRuleSync, newComment string) bool { - // Build iptables insert command (insert at beginning to maintain priority) - args := []string{"iptables", "-I", chainName} - - // Protocol - if rule.Protocol != "" && rule.Protocol != DefaultProtocol { - args = append(args, "-p", rule.Protocol) - } - - // Source CIDR - if rule.SourceCIDR != "" && rule.SourceCIDR != DefaultCIDR { - args = append(args, "-s", rule.SourceCIDR) - } - - // Destination CIDR - if rule.DestinationCIDR != "" && rule.DestinationCIDR != DefaultCIDR { - args = append(args, "-d", rule.DestinationCIDR) - } - - // Handle ports - if rule.Protocol == "tcp" || rule.Protocol == "udp" { - if rule.Dports != "" { - args = append(args, "-m", "multiport", "--dports", rule.Dports) - } else if rule.PortStart != nil { - if rule.PortEnd != nil && *rule.PortEnd != *rule.PortStart { - args = append(args, "--dport", fmt.Sprintf("%d:%d", *rule.PortStart, *rule.PortEnd)) - } else { - args = append(args, "--dport", fmt.Sprintf("%d", *rule.PortStart)) - } - } - } - - // ICMP type - if rule.Protocol == "icmp" && rule.ICMPType != nil { - args = append(args, "--icmp-type", fmt.Sprintf("%d", *rule.ICMPType)) - } - - // Target - if rule.Target != "" { - args = append(args, "-j", rule.Target) - } - - // Comment with updated metadata - if newComment != "" { - args = append(args, "-m", "comment", "--comment", newComment) - } - - // Execute the command - exitCode, _ := runFirewallCommand(args, 10) - return exitCode == 0 -} - -// CollectFirewallRules collects current firewall rules from the system -// This is the reverse operation of command.go firewall application logic -func CollectFirewallRules() (*FirewallSyncPayload, error) { - // Check which firewall tool is available - nftablesInstalled, iptablesInstalled, err := CheckFirewallTool() - if err != nil { - return nil, fmt.Errorf("failed to check firewall installation: %w", err) - } - - var chains map[string][]FirewallRuleSync - - if nftablesInstalled { - payload, err := collectNftablesRules() - if err != nil { - return nil, err - } - chains = make(map[string][]FirewallRuleSync) - for _, chain := range payload.Chains { - chains[chain.Name] = chain.Rules - } - } else if iptablesInstalled { - payload, err := collectIptablesRules() - if err != nil { - return nil, err - } - chains = make(map[string][]FirewallRuleSync) - for _, chain := range payload.Chains { - chains[chain.Name] = chain.Rules - } - } else { - // This should never happen as CheckFirewallTool() already returns error - // if neither tool is installed, but keep as safety fallback - return nil, fmt.Errorf("no firewall tools available") - } - - return buildSyncPayload(chains), nil -} - -// collectNftablesRules extracts rules from nftables -func collectNftablesRules() (*FirewallSyncPayload, error) { - exitCode, output := runFirewallCommand([]string{"nft", "-j", "list", "ruleset"}, 30) - if exitCode != 0 { - return nil, fmt.Errorf("failed to list nftables ruleset: exit code %d", exitCode) - } - - var nftData struct { - Nftables []map[string]interface{} `json:"nftables"` - } - - if err := json.Unmarshal([]byte(output), &nftData); err != nil { - return nil, fmt.Errorf("failed to parse nftables output: %w", err) - } - - chains := make(map[string][]FirewallRuleSync) - tableNames := make(map[string]bool) - - // First pass: collect all table names - for _, item := range nftData.Nftables { - if table, ok := item["table"]; ok { - if tableMap, ok := table.(map[string]interface{}); ok { - if name, ok := tableMap["name"].(string); ok { - tableNames[name] = true - } - } - } - } - - // Second pass: collect rules grouped by table - currentTable := "" - for _, item := range nftData.Nftables { - // Track current table - if table, ok := item["table"]; ok { - if tableMap, ok := table.(map[string]interface{}); ok { - if name, ok := tableMap["name"].(string); ok { - currentTable = name - } - } - } - - // Parse rules - if rule, ok := item["rule"]; ok { - ruleMap := rule.(map[string]interface{}) - tableName, _ := ruleMap["table"].(string) - if tableName == "" { - tableName = currentTable - } - - // Only process tables we've seen in first pass - if !tableNames[tableName] { - continue - } - - if parsedRule, err := parseNftablesRuleToSync(ruleMap); err == nil { - chains[tableName] = append(chains[tableName], *parsedRule) - } - } - } - - return buildSyncPayload(chains), nil -} - -// collectIptablesRules extracts rules from iptables -func collectIptablesRules() (*FirewallSyncPayload, error) { - exitCode, output := runFirewallCommand([]string{"iptables-save"}, 30) - if exitCode != 0 { - log.Debug().Msgf("Failed to run iptables-save: exit code %d", exitCode) - return &FirewallSyncPayload{Chains: []FirewallChainSync{}}, nil - } - - chains := parseIptablesSaveOutput(output) - return buildSyncPayload(chains), nil -} - -// parseNftablesRuleToSync converts nftables rule map to FirewallRuleSync -func parseNftablesRuleToSync(ruleMap map[string]interface{}) (*FirewallRuleSync, error) { - rule := &FirewallRuleSync{ - SourceCIDR: DefaultCIDR, - Priority: DefaultPriority, - Protocol: DefaultProtocol, - Target: DefaultTarget, - } - - rule.Chain = ruleMap["chain"].(string) - tableName, _ := ruleMap["table"].(string) - - // Extract comment if present - var fullComment string - if comment, ok := ruleMap["comment"].(string); ok { - fullComment = comment - } - - // Parse expressions for protocol, ports, source, target, comment, etc. - if expr, ok := ruleMap["expr"].([]interface{}); ok { - for _, e := range expr { - exprMap, ok := e.(map[string]interface{}) - if !ok { - continue - } - - // Extract comment from expression - if commentExpr, ok := exprMap["comment"].(string); ok { - fullComment = commentExpr - } - - // Match protocol - if match, ok := exprMap["match"].(map[string]interface{}); ok { - if left, ok := match["left"].(map[string]interface{}); ok { - if protocol, ok := left["meta"].(map[string]interface{}); ok { - if key, ok := protocol["key"].(string); ok && key == "l4proto" { - if right, ok := match["right"].(string); ok { - rule.Protocol = right - } - } - } - } - - // Match ports - if right, ok := match["right"].(float64); ok { - port := int(right) - if rule.PortStart == nil { - rule.PortStart = &port - } - } else if right, ok := match["right"].(map[string]interface{}); ok { - if set, ok := right["set"].([]interface{}); ok { - var ports []string - for _, portVal := range set { - if p, ok := portVal.(float64); ok { - ports = append(ports, fmt.Sprintf("%d", int(p))) - } - } - if len(ports) > 0 { - rule.Dports = strings.Join(ports, ",") - } - } - } - } - - // Match source/destination - if match, ok := exprMap["match"].(map[string]interface{}); ok { - if left, ok := match["left"].(map[string]interface{}); ok { - if payload, ok := left["payload"].(map[string]interface{}); ok { - if field, ok := payload["field"].(string); ok { - if right, ok := match["right"].(string); ok { - if field == "saddr" { - rule.SourceCIDR = right - } else if field == "daddr" { - rule.DestinationCIDR = right - } - } - } - } - } - } - - // Match target/verdict - if accept, ok := exprMap["accept"]; ok && accept != nil { - rule.Target = "ACCEPT" - } else if drop, ok := exprMap["drop"]; ok && drop != nil { - rule.Target = "DROP" - } else if reject, ok := exprMap["reject"]; ok && reject != nil { - rule.Target = "REJECT" - } - } - } - - // Check if original comment has rule_id and type - originalRuleID, originalRuleType, existingComment := ParseFirewallComment(fullComment) - - // Parse comment to extract rule_id and type, or generate new ones - rule.RuleID, rule.RuleType = ParseCommentOrGenerate(fullComment) - - // If rule_id or type was missing, re-create the rule with proper metadata - // DISABLED: This can conflict with ufw/firewalld - if !disableRuleRecreation && (originalRuleID == "" || originalRuleType == "") { - newComment := BuildFirewallComment(existingComment, rule.RuleID, rule.RuleType) - - // Get rule handle from the ruleMap - var ruleHandle string - if handle, ok := ruleMap["handle"].(float64); ok { - ruleHandle = fmt.Sprintf("%.0f", handle) - } - - // Create the rule with updated comment first - if !RecreateNftablesRuleWithComment(tableName, rule, newComment) { - log.Warn().Msgf("Failed to re-create nftables rule %s with proper metadata", rule.RuleID) - return rule, nil - } - - // Delete the old rule using saved handle - if ruleHandle != "" { - deleteArgs := []string{"nft", "delete", "rule", tableName, rule.Chain, "handle", ruleHandle} - if exitCode, _ := runFirewallCommand(deleteArgs, 10); exitCode != 0 { - log.Warn().Msgf("Failed to delete old nftables rule with handle %s after re-creation", ruleHandle) - } else { - log.Debug().Msgf("Re-created nftables rule %s with proper metadata (table: %s, chain: %s)", rule.RuleID, tableName, rule.Chain) - } - } - } - - return rule, nil -} - -// parseIptablesSaveOutput parses iptables-save output to extract all chain rules -func parseIptablesSaveOutput(output string) map[string][]FirewallRuleSync { - chains := make(map[string][]FirewallRuleSync) - chainNames := make(map[string]bool) - lines := strings.Split(output, "\n") - - // First pass: extract all chain names - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, ":") { - parts := strings.Fields(line) - if len(parts) > 0 { - chainName := strings.TrimPrefix(parts[0], ":") - chainNames[chainName] = true - } - } - } - - // Second pass: extract rules for all chains - for _, line := range lines { - line = strings.TrimSpace(line) - - // Skip empty lines, comments, chain definitions, table markers - if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ":") || strings.HasPrefix(line, "*") || line == "COMMIT" { - continue - } - - // Process both -A (append) and -I (insert) rules - if !strings.HasPrefix(line, "-A ") && !strings.HasPrefix(line, "-I ") { - continue - } - - // Parse rule line - rule := parseIptablesSaveRuleLine(line) - if rule != nil && chainNames[rule.Chain] { - chains[rule.Chain] = append(chains[rule.Chain], *rule) - } - } - - return chains -} - -// parseIptablesSaveRuleLine parses a single iptables-save rule line -func parseIptablesSaveRuleLine(line string) *FirewallRuleSync { - // Remove -A or -I prefix - if strings.HasPrefix(line, "-A ") { - line = strings.TrimPrefix(line, "-A ") - } else if strings.HasPrefix(line, "-I ") { - line = strings.TrimPrefix(line, "-I ") - } - - parts := strings.Fields(line) - if len(parts) < 2 { - return nil - } - - // First part is chain name - chainName := parts[0] - - rule := &FirewallRuleSync{ - Chain: chainName, - SourceCIDR: DefaultCIDR, - Priority: DefaultPriority, - Protocol: DefaultProtocol, - Target: DefaultTarget, - } - - // Parse arguments - var fullComment string - for i := 1; i < len(parts); i++ { - switch parts[i] { - case "-p", "--protocol": - if i+1 < len(parts) { - rule.Protocol = parts[i+1] - i++ - } - case "-s", "--source": - if i+1 < len(parts) { - rule.SourceCIDR = parts[i+1] - i++ - } - case "-d", "--destination": - if i+1 < len(parts) { - rule.DestinationCIDR = parts[i+1] - i++ - } - case "-j", "--jump": - if i+1 < len(parts) { - rule.Target = strings.ToUpper(parts[i+1]) - i++ - } - case "--dport": - if i+1 < len(parts) { - portStr := parts[i+1] - if strings.Contains(portStr, ":") { - // Port range - portRange := strings.Split(portStr, ":") - if len(portRange) == 2 { - if start, err := strconv.Atoi(portRange[0]); err == nil { - rule.PortStart = &start - } - if end, err := strconv.Atoi(portRange[1]); err == nil { - rule.PortEnd = &end - } - } - } else { - // Single port - if port, err := strconv.Atoi(portStr); err == nil { - rule.PortStart = &port - } - } - i++ - } - case "--dports": - if i+1 < len(parts) { - rule.Dports = parts[i+1] - i++ - } - case "--icmp-type": - if i+1 < len(parts) { - if icmpType, err := strconv.Atoi(parts[i+1]); err == nil { - rule.ICMPType = &icmpType - } - i++ - } - case "--comment": - if i+1 < len(parts) { - comment := parts[i+1] - comment = strings.Trim(comment, "\"") - fullComment = comment - i++ - } - } - } - - // Check if original comment has rule_id and type - originalRuleID, originalRuleType, existingComment := ParseFirewallComment(fullComment) - - // Parse comment to extract rule_id and type, or generate new ones - rule.RuleID, rule.RuleType = ParseCommentOrGenerate(fullComment) - - // If rule_id or type was missing, re-create the rule with proper metadata - // DISABLED: This can conflict with ufw/firewalld - if !disableRuleRecreation && (originalRuleID == "" || originalRuleType == "") { - newComment := BuildFirewallComment(existingComment, rule.RuleID, rule.RuleType) - - // Create the rule with updated comment first - if !RecreateIptablesRuleWithComment(chainName, rule, newComment) { - log.Warn().Msgf("Failed to re-create iptables rule %s with proper metadata", rule.RuleID) - return rule - } - - // Delete the old rule - oldRule := *rule - oldRule.RuleID = originalRuleID - oldRule.RuleType = originalRuleType - - if err := RemoveIptablesRule(chainName, oldRule); err != nil { - log.Warn().Err(err).Msgf("Failed to delete old iptables rule after re-creation") - } else { - log.Debug().Msgf("Re-created iptables rule %s with proper metadata (chain: %s)", rule.RuleID, chainName) - } - } - - return rule -} - -// buildSyncPayload creates sync payload from parsed rules -func buildSyncPayload(chains map[string][]FirewallRuleSync) *FirewallSyncPayload { +// BuildSyncPayload creates sync payload from parsed rules +func BuildSyncPayload(chains map[string][]FirewallRuleSync) *FirewallSyncPayload { chainsList := make([]FirewallChainSync, 0) for name, rules := range chains { @@ -879,439 +151,3 @@ func buildSyncPayload(chains map[string][]FirewallRuleSync) *FirewallSyncPayload Chains: chainsList, } } - -// RemoveFirewallRulesByType removes all firewall rules of a specific type -func RemoveFirewallRulesByType(ruleType string) (int, error) { - // Create backup before removing rules - backup, err := BackupFirewallRules() - if err != nil { - return 0, fmt.Errorf("failed to backup firewall rules: %w", err) - } - - payload, err := CollectFirewallRules() - if err != nil { - return 0, fmt.Errorf("failed to collect firewall rules: %w", err) - } - - if len(payload.Chains) == 0 { - log.Debug().Msgf("No firewall chains found for rule type: %s", ruleType) - return 0, nil - } - - nftablesInstalled, iptablesInstalled, err := CheckFirewallTool() - if err != nil { - return 0, fmt.Errorf("failed to check firewall availability: %w", err) - } - - removedCount := 0 - - for _, chain := range payload.Chains { - for _, rule := range chain.Rules { - if rule.RuleType != ruleType { - continue - } - - var removeErr error - if nftablesInstalled { - removeErr = RemoveNftablesRule(chain.Name, rule) - } else if iptablesInstalled { - removeErr = RemoveIptablesRule(chain.Name, rule) - } - - if removeErr != nil { - log.Error().Err(removeErr).Msgf("Failed to remove rule %s from chain %s, restoring backup", rule.RuleID, chain.Name) - if restoreErr := RestoreFirewallRules(backup); restoreErr != nil { - log.Error().Err(restoreErr).Msg("Failed to restore backup after removal failure") - return removedCount, fmt.Errorf("failed to remove rule and restore failed: %w", restoreErr) - } - return removedCount, fmt.Errorf("failed to remove rule %s, backup restored: %w", rule.RuleID, removeErr) - } - - removedCount++ - log.Debug().Msgf("Removed rule %s (type: %s) from chain %s", rule.RuleID, ruleType, chain.Name) - } - } - - log.Info().Msgf("Removed %d firewall rules of type: %s", removedCount, ruleType) - - return removedCount, nil -} - -// RemoveNftablesRule removes a specific rule from nftables -func RemoveNftablesRule(tableName string, rule FirewallRuleSync) error { - exitCode, output := runFirewallCommand([]string{"nft", "-a", "list", "table", tableName}, 10) - if exitCode != 0 { - return fmt.Errorf("failed to list nftables table %s", tableName) - } - - lines := strings.Split(output, "\n") - for _, line := range lines { - if rule.RuleID != "" && strings.Contains(line, rule.RuleID) { - if idx := strings.Index(line, "# handle "); idx != -1 { - handleStr := strings.TrimSpace(line[idx+9:]) - handleParts := strings.Fields(handleStr) - if len(handleParts) > 0 { - handle := handleParts[0] - deleteArgs := []string{"nft", "delete", "rule", tableName, rule.Chain, "handle", handle} - exitCode, _ := runFirewallCommand(deleteArgs, 10) - if exitCode == 0 { - return nil - } - return fmt.Errorf("failed to delete nftables rule handle %s", handle) - } - } - } - } - - return fmt.Errorf("rule not found in nftables table %s", tableName) -} - -// RemoveIptablesRule removes a specific rule from iptables -func RemoveIptablesRule(chainName string, rule FirewallRuleSync) error { - args := []string{"iptables", "-D", chainName} - - if rule.Protocol != "" && rule.Protocol != DefaultProtocol { - args = append(args, "-p", rule.Protocol) - } - - if rule.SourceCIDR != "" && rule.SourceCIDR != DefaultCIDR { - args = append(args, "-s", rule.SourceCIDR) - } - - if rule.DestinationCIDR != "" && rule.DestinationCIDR != DefaultCIDR { - args = append(args, "-d", rule.DestinationCIDR) - } - - if rule.Protocol == "tcp" || rule.Protocol == "udp" { - if rule.Dports != "" { - args = append(args, "-m", "multiport", "--dports", rule.Dports) - } else if rule.PortStart != nil { - if rule.PortEnd != nil && *rule.PortEnd != *rule.PortStart { - args = append(args, "--dport", fmt.Sprintf("%d:%d", *rule.PortStart, *rule.PortEnd)) - } else { - args = append(args, "--dport", fmt.Sprintf("%d", *rule.PortStart)) - } - } - } - - if rule.Protocol == "icmp" && rule.ICMPType != nil { - args = append(args, "--icmp-type", fmt.Sprintf("%d", *rule.ICMPType)) - } - - if rule.Target != "" { - args = append(args, "-j", rule.Target) - } - - if rule.RuleID != "" { - args = append(args, "-m", "comment", "--comment") - comment := BuildFirewallComment("", rule.RuleID, rule.RuleType) - args = append(args, comment) - } - - exitCode, output := runFirewallCommand(args, 10) - if exitCode != 0 { - return fmt.Errorf("failed to delete iptables rule from chain %s: %s", chainName, output) - } - - return nil -} - -// ReorderNftablesChains reorders nftables INPUT chain jump rules -func ReorderNftablesChains(chainNames []string) (map[string]interface{}, error) { - log.Debug().Msg("Starting nftables chain reordering") - - // Backup current ruleset - backup, err := BackupFirewallRules() - if err != nil { - return nil, fmt.Errorf("failed to backup nftables ruleset: %w", err) - } - - // Get current INPUT chain rules with handles - exitCode, output := runFirewallCommand([]string{"nft", "-a", "list", "chain", "inet", "filter", "INPUT"}, 30) - if exitCode != 0 { - return nil, fmt.Errorf("failed to list INPUT chain rules") - } - - // Parse and find alpacon jump rule handles - jumpHandles := []string{} - lines := strings.Split(output, "\n") - - for _, line := range lines { - isJumpRule := false - for _, chainName := range chainNames { - if strings.Contains(line, fmt.Sprintf("jump %s", chainName)) { - isJumpRule = true - break - } - } - - if !isJumpRule { - continue - } - - handleRegex := regexp.MustCompile(`# handle (\d+)`) - matches := handleRegex.FindStringSubmatch(line) - if len(matches) > 1 { - jumpHandles = append(jumpHandles, matches[1]) - log.Debug().Msgf("Found jump rule handle: %s", matches[1]) - } - } - - if len(jumpHandles) == 0 { - log.Warn().Msg("No jump rules found to reorder") - return map[string]interface{}{ - "reordered_chains": chainNames, - "deleted_rules": 0, - }, nil - } - - // Delete old jump rules - for _, handle := range jumpHandles { - exitCode, errOutput := runFirewallCommand( - []string{"nft", "delete", "rule", "inet", "filter", "INPUT", "handle", handle}, - 10, - ) - if exitCode != 0 { - log.Error().Msgf("Failed to delete rule handle %s: %s", handle, errOutput) - if err := RestoreFirewallRules(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore backup after delete failure") - } - return nil, fmt.Errorf("failed to delete rule handle %s", handle) - } - log.Debug().Msgf("Deleted rule handle: %s", handle) - } - - // Add jump rules in new order - for _, chainName := range chainNames { - exitCode, errOutput := runFirewallCommand( - []string{"nft", "add", "rule", "inet", "filter", "INPUT", "jump", chainName}, - 10, - ) - if exitCode != 0 { - log.Error().Msgf("Failed to add jump rule for chain %s: %s", chainName, errOutput) - if err := RestoreFirewallRules(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore backup after add failure") - } - return nil, fmt.Errorf("failed to add jump rule for chain %s", chainName) - } - log.Debug().Msgf("Added jump rule for chain: %s", chainName) - } - - return map[string]interface{}{ - "reordered_chains": chainNames, - "deleted_rules": len(jumpHandles), - }, nil -} - -// ReorderIptablesChains reorders iptables INPUT chain jump rules -func ReorderIptablesChains(chainNames []string) (map[string]interface{}, error) { - log.Debug().Msg("Starting iptables chain reordering") - - // Backup current rules - backup, err := BackupFirewallRules() - if err != nil { - return nil, fmt.Errorf("failed to backup iptables rules: %w", err) - } - - // Get current INPUT chain rules - exitCode, output := runFirewallCommand([]string{"iptables", "-L", "INPUT", "--line-numbers", "-n"}, 30) - if exitCode != 0 { - return nil, fmt.Errorf("failed to list INPUT chain rules") - } - - // Find alpacon jump rule line numbers - jumpLines := []int{} - lines := strings.Split(output, "\n") - - for _, line := range lines { - parts := strings.Fields(line) - if len(parts) < 3 { - continue - } - - for _, chainName := range chainNames { - if parts[1] == chainName || (len(parts) > 2 && parts[2] == chainName) { - lineNum := 0 - _, err := fmt.Sscanf(parts[0], "%d", &lineNum) - if err == nil && lineNum > 0 { - jumpLines = append(jumpLines, lineNum) - log.Debug().Msgf("Found jump rule at line: %d for chain: %s", lineNum, chainName) - break - } - } - } - } - - if len(jumpLines) == 0 { - log.Warn().Msg("No jump rules found to reorder") - return map[string]interface{}{ - "reordered_chains": chainNames, - "deleted_rules": 0, - }, nil - } - - // Sort in reverse order - for i := 0; i < len(jumpLines); i++ { - for j := i + 1; j < len(jumpLines); j++ { - if jumpLines[i] < jumpLines[j] { - jumpLines[i], jumpLines[j] = jumpLines[j], jumpLines[i] - } - } - } - - // Delete old jump rules - for _, lineNum := range jumpLines { - exitCode, errOutput := runFirewallCommand( - []string{"iptables", "-D", "INPUT", fmt.Sprintf("%d", lineNum)}, - 10, - ) - if exitCode != 0 { - log.Error().Msgf("Failed to delete rule at line %d: %s", lineNum, errOutput) - if err := RestoreFirewallRules(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore backup after delete failure") - } - return nil, fmt.Errorf("failed to delete rule at line %d", lineNum) - } - log.Debug().Msgf("Deleted rule at line: %d", lineNum) - } - - // Add jump rules in new order - for _, chainName := range chainNames { - exitCode, errOutput := runFirewallCommand( - []string{"iptables", "-A", "INPUT", "-j", chainName}, - 10, - ) - if exitCode != 0 { - log.Error().Msgf("Failed to add jump rule for chain %s: %s", chainName, errOutput) - if err := RestoreFirewallRules(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore backup after add failure") - } - return nil, fmt.Errorf("failed to add jump rule for chain %s", chainName) - } - log.Debug().Msgf("Added jump rule for chain: %s", chainName) - } - - return map[string]interface{}{ - "reordered_chains": chainNames, - "deleted_rules": len(jumpLines), - }, nil -} - -// BackupFirewallRules creates a backup of current firewall rules -// Returns the backup string and error -func BackupFirewallRules() (string, error) { - nftablesInstalled, iptablesInstalled, err := CheckFirewallTool() - if err != nil { - return "", fmt.Errorf("failed to check firewall installation: %w", err) - } - - if nftablesInstalled { - exitCode, output := runFirewallCommand([]string{"nft", "list", "ruleset"}, 30) - if exitCode != 0 { - return "", fmt.Errorf("failed to backup nftables ruleset: exit code %d", exitCode) - } - log.Debug().Msg("Created nftables backup") - return output, nil - } else if iptablesInstalled { - exitCode, output := runFirewallCommand([]string{"iptables-save"}, 30) - if exitCode != 0 { - return "", fmt.Errorf("failed to backup iptables rules: exit code %d", exitCode) - } - log.Debug().Msg("Created iptables backup") - return output, nil - } - - return "", fmt.Errorf("no firewall tools available for backup") -} - -// RestoreFirewallRules restores firewall rules from backup string -// Automatically detects the firewall type and uses appropriate restore method -func RestoreFirewallRules(backup string) error { - if backup == "" { - return fmt.Errorf("empty backup provided") - } - - nftablesInstalled, iptablesInstalled, err := CheckFirewallTool() - if err != nil { - return fmt.Errorf("failed to check firewall installation: %w", err) - } - - if nftablesInstalled { - return restoreNftablesBackup(backup) - } else if iptablesInstalled { - return restoreIptablesBackup(backup) - } - - return fmt.Errorf("no firewall tools available for restore") -} - -// restoreNftablesBackup restores nftables ruleset from backup string -func restoreNftablesBackup(backup string) error { - log.Warn().Msg("Restoring nftables backup") - - tmpFile := fmt.Sprintf("/tmp/nft-backup-%d-%d.nft", os.Getpid(), time.Now().UnixNano()) - f, err := os.OpenFile(tmpFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) - if err != nil { - return fmt.Errorf("failed to create nftables backup temp file: %w", err) - } - defer os.Remove(tmpFile) - - if _, err := f.WriteString(backup); err != nil { - f.Close() - return fmt.Errorf("failed to write nftables backup: %w", err) - } - f.Close() - - runFirewallCommand([]string{"nft", "flush", "ruleset"}, 10) - exitCode, output := runFirewallCommand([]string{"nft", "-f", tmpFile}, 10) - - if exitCode != 0 { - return fmt.Errorf("failed to restore nftables backup: %s", output) - } - - log.Info().Msg("Successfully restored nftables backup") - return nil -} - -// restoreIptablesBackup restores iptables rules from backup string -func restoreIptablesBackup(backup string) error { - log.Warn().Msg("Restoring iptables backup") - - tmpFile := fmt.Sprintf("/tmp/iptables-backup-%d-%d.rules", os.Getpid(), time.Now().UnixNano()) - f, err := os.OpenFile(tmpFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) - if err != nil { - return fmt.Errorf("failed to create iptables backup temp file: %w", err) - } - defer os.Remove(tmpFile) - - if _, err := f.WriteString(backup); err != nil { - f.Close() - return fmt.Errorf("failed to write iptables backup: %w", err) - } - f.Close() - - exitCode, output := runFirewallCommand([]string{"iptables-restore", tmpFile}, 10) - - if exitCode != 0 { - return fmt.Errorf("failed to restore iptables backup: %s", output) - } - - log.Info().Msg("Successfully restored iptables backup") - return nil -} - -// RestoreNftablesBackup restores nftables ruleset from backup -// Deprecated: Use RestoreFirewallRules instead -func RestoreNftablesBackup(backup string) { - if err := restoreNftablesBackup(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore nftables backup") - } -} - -// RestoreIptablesBackup restores iptables rules from backup -// Deprecated: Use RestoreFirewallRules instead -func RestoreIptablesBackup(backup string) { - if err := restoreIptablesBackup(backup); err != nil { - log.Error().Err(err).Msg("Failed to restore iptables backup") - } -} diff --git a/pkg/utils/fs.go b/pkg/utils/fs.go index d39f43b..24aeda5 100644 --- a/pkg/utils/fs.go +++ b/pkg/utils/fs.go @@ -1,6 +1,8 @@ package utils import ( + "archive/zip" + "bytes" "fmt" "io" "os" @@ -11,6 +13,24 @@ import ( "syscall" ) +// nonZipExt contains file extensions that are zip-like but shouldn't be auto-unzipped +var nonZipExt = map[string]bool{ + ".jar": true, + ".war": true, + ".ear": true, + ".apk": true, + ".xpi": true, + ".vsix": true, + ".crx": true, + ".egg": true, + ".whl": true, + ".appx": true, + ".msix": true, + ".ipk": true, + ".nupkg": true, + ".kmz": true, +} + func CopyFile(src, dst string, allowOverwrite bool) error { srcFile, err := os.Open(src) if err != nil { @@ -223,3 +243,20 @@ func GetCopyPath(src, dst string) string { } } } + +// FileExists checks if the file exists at the given path +// codeql[go/path-injection]: Intentional - Admin-specified file path check +func FileExists(path string) bool { + _, err := os.Stat(path) // lgtm[go/path-injection] + return !os.IsNotExist(err) +} + +// IsZipFile checks if the content is a valid zip file +func IsZipFile(content []byte, ext string) bool { + if _, found := nonZipExt[ext]; found { + return false + } + + _, err := zip.NewReader(bytes.NewReader(content), int64(len(content))) + return err == nil +} diff --git a/pkg/utils/http_client.go b/pkg/utils/http_client.go index 98f6c44..379a865 100644 --- a/pkg/utils/http_client.go +++ b/pkg/utils/http_client.go @@ -13,29 +13,36 @@ import ( "github.com/rs/zerolog/log" ) -func Put(url string, body bytes.Buffer, timeout time.Duration) ([]byte, int, error) { - req, err := http.NewRequest(http.MethodPut, url, &body) - if err != nil { - return nil, 0, err +// NewHTTPClient creates an HTTP client with TLS configuration from global settings +func NewHTTPClient() *http.Client { + tlsConfig := &tls.Config{ + InsecureSkipVerify: !config.GlobalSettings.SSLVerify, } - client := &http.Client{Timeout: timeout} - - tlsConfig := &tls.Config{} if config.GlobalSettings.CaCert != "" { caCertPool := x509.NewCertPool() - caCert, err := os.ReadFile(config.GlobalSettings.CaCert) - if err != nil { + if caCert, err := os.ReadFile(config.GlobalSettings.CaCert); err == nil { + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool + } else { log.Error().Err(err).Msg("Failed to read CA certificate.") } - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig.RootCAs = caCertPool } - tlsConfig.InsecureSkipVerify = !config.GlobalSettings.SSLVerify - client.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, + return &http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsConfig}, } +} + +// codeql[go/request-forgery]: Intentional - HTTP client for admin-specified URLs +func Put(url string, body bytes.Buffer, timeout time.Duration) ([]byte, int, error) { + req, err := http.NewRequest(http.MethodPut, url, &body) // lgtm[go/request-forgery] + if err != nil { + return nil, 0, err + } + + client := NewHTTPClient() + client.Timeout = timeout resp, err := client.Do(req) if err != nil { diff --git a/pkg/utils/privilege.go b/pkg/utils/privilege.go new file mode 100644 index 0000000..6d43de8 --- /dev/null +++ b/pkg/utils/privilege.go @@ -0,0 +1,95 @@ +package utils + +import ( + "fmt" + "os" + "os/user" + "strconv" + "syscall" + + "github.com/rs/zerolog/log" +) + +// DemoteOptions configures the behavior of privilege demotion +type DemoteOptions struct { + // ValidateGroup checks if the specified group is in the user's group list + ValidateGroup bool +} + +// DemoteResult contains the result of privilege demotion +type DemoteResult struct { + // SysProcAttr contains the credentials for privilege demotion + SysProcAttr *syscall.SysProcAttr + // User contains the looked up user information + User *user.User +} + +// Demote creates syscall attributes for privilege demotion to the specified user/group. +// If username or groupname is empty, or if not running as root, returns nil without error. +// When ValidateGroup is true, returns an error if the group is not in the user's group list. +func Demote(username, groupname string, opts DemoteOptions) (*DemoteResult, error) { + if username == "" || groupname == "" { + log.Debug().Msg("No username or groupname provided, running as the current user.") + return nil, nil + } + + if os.Getuid() != 0 { + log.Warn().Msg("Not running as root. Falling back to the current user.") + return nil, nil + } + + usr, err := user.Lookup(username) + if err != nil { + return nil, fmt.Errorf("there is no corresponding %s username in this server", username) + } + + grp, err := user.LookupGroup(groupname) + if err != nil { + return nil, fmt.Errorf("there is no corresponding %s groupname in this server", groupname) + } + + uid, err := strconv.ParseUint(usr.Uid, 10, 32) + if err != nil { + return nil, err + } + + gid, err := strconv.ParseUint(grp.Gid, 10, 32) + if err != nil { + return nil, err + } + + groupIds, err := usr.GroupIds() + if err != nil { + return nil, err + } + + groups := make([]uint32, 0, len(groupIds)) + groupInList := false + for _, gidStr := range groupIds { + gidUint, err := strconv.ParseUint(gidStr, 10, 32) + if err != nil { + return nil, err + } + if gidUint == gid { + groupInList = true + } + groups = append(groups, uint32(gidUint)) + } + + if opts.ValidateGroup && !groupInList { + return nil, fmt.Errorf("groupname %s is not in user %s's group list", groupname, username) + } + + log.Debug().Msgf("Demote permission to match user: %s, group: %s.", username, groupname) + + return &DemoteResult{ + SysProcAttr: &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(gid), + Groups: groups, + }, + }, + User: usr, + }, nil +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index f4d660f..d564be6 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -53,6 +53,11 @@ func getPlatformLike() { } } +// SetPlatformLike allows setting PlatformLike for testing purposes. +func SetPlatformLike(platform string) { + PlatformLike = platform +} + func JoinPath(base string, paths ...string) string { fullURL, err := url.JoinPath(base, paths...) if err != nil {