diff --git a/.gitignore b/.gitignore index 2bd422eca7..8e391aa19c 100644 --- a/.gitignore +++ b/.gitignore @@ -223,3 +223,4 @@ gha-creds-*.json *.jsonl **/*.jsonl scr/* +.weaver/ diff --git a/AGENTS.md b/AGENTS.md index 47606d75bb..88c3991eb6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,7 +23,7 @@ - Assume Python >=3.11. - Always use `uv run` for Python entry points. If that fails, try `.venv/bin/python` directly. - Run `uv run python infra/pre-commit.py --all-files` before sending changes; formatting and linting are enforced with `ruff`. -- Keep type hints passing under `uv run mypy`; configuration lives in `pyproject.toml`. +- Keep type hints passing under `uv run pyrefly`; configuration lives in `pyproject.toml`. ### Communication & Commits @@ -88,6 +88,15 @@ You don't generate comments that merely restate the code, e.g. - Run the appropriate tests for your changes (for example, `uv run pytest` under the relevant directory); consult subproject guides for preferred markers. - Use pytest features like fixtures and parameterization to avoid duplication and write clean code. +PREFER: + +- Integration style tests which exercise behavior and test the output + +DO NOT: + +- Create tests which validate obvious features: if a type exists, a constant has a value, etc. + + ## Environment - Prefer to use `uv` when possible. If you can't (for instance, due to sandbox restrictions) you can use `.venv/bin/python` diff --git a/infra/pre-commit.py b/infra/pre-commit.py index 4c4478df0e..3acd397099 100755 --- a/infra/pre-commit.py +++ b/infra/pre-commit.py @@ -46,6 +46,9 @@ ".git/**", ".github/**", "tests/snapshots/**", + # grpc generated files + "**/*_connect.py", + "**/*_pb2.py", "**/*.gz", "**/*.pb", "**/*.index", diff --git a/lib/fluster/AGENTS.md b/lib/fluster/AGENTS.md new file mode 100644 index 0000000000..7294136756 --- /dev/null +++ b/lib/fluster/AGENTS.md @@ -0,0 +1,46 @@ +# Agent Tips + +* Use the connect/RPC abstractions to implement and perform RPC calls. DO NOT use httpx or raw HTTP. +* Use scripts/generate-protos.py to regenerate files after changing the `.proto` files. +* Prefer shallow, functional interfaces which return control to the user, vs callbacks or inheritance. + +e.g. + +class Scheduler: + def add_job() + def add_worker(): + def compute_schedule() -> ScheduledJobs: + +is preferable to: + +class Scheduler: + def __init__(self, job_creator: JobCreator): + self.job_creator = job_creator + def run(self): + ... self.job_creator.create_job() + +It's acceptable to have a top-level class which implements the main loop of +course, but prefer to keep other interfaces shallow and functional whenever +possible. + +* Tests should evaluate _behavior_, not implementation. Don't test things that are trivially caught by the type checker. Explicitly that means: + +- No tests for "constant = constant" +- No tests for "method exists" +- No tests for "create an object(x, y, z) and attributes are x, y, z" + +These tests have negative value - they make our code more brittle. Test +_behavior_ instead. You can use mocks as needed to isolate environments (e.g. +mock around a remote API). Prefer "fakes" -- e.g. create a real database but +with fake data -- when reasonable. + +## Protocols and Testing + +Non-trivial public classes should define a protocol which represents their +_important_ interface characteristics. Use this protocol in type hints for +when the class is used instead of the concrete class. + +Test to this protocol, not the concrete class: the protocol should describe the +interesting behavior of the class, but not betray the implementation details. + +(You may of course _instantiate_ the concrete class for testing.) \ No newline at end of file diff --git a/lib/fluster/README.md b/lib/fluster/README.md new file mode 100644 index 0000000000..ac2b466433 --- /dev/null +++ b/lib/fluster/README.md @@ -0,0 +1,130 @@ +# Fluster + +Fluster is a distributed job orchestration and RPC framework designed to replace Ray with simpler, more focused primitives. It provides job lifecycle management, actor-based RPC communication, and task dispatch capabilities for distributed Python workloads. + +## Architecture Overview + +Fluster consists of four main components: + +| Component | Description | +|-----------|-------------| +| **Controller** | Central coordinator managing job scheduling, worker registration, and service discovery | +| **Worker** | Execution agent that runs jobs in isolated containers with resource management | +| **Actor System** | RPC framework enabling Python object method invocation across processes | +| **WorkerPool** | High-level task dispatch abstraction for stateless parallel workloads | + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Controller │ +│ │ +│ Job Scheduling │ Worker Registry │ Endpoint Registry│ +└─────────────────────────────────────────────────────────────────┘ + │ │ ▲ + │ dispatch │ health │ register + ▼ ▼ │ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Worker │ │ Worker │ │ ActorServer │ +│ │ │ │ │ (in job) │ +│ runs jobs in │ │ runs jobs in │ │ │ +│ containers │ │ containers │ │ │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +## Directory Structure + +``` +src/fluster/ +├── actor/ # Actor RPC system +│ ├── client.py # Actor method invocation +│ ├── pool.py # Multi-endpoint management +│ ├── resolver.py # Endpoint discovery +│ ├── server.py # Actor hosting +│ └── types.py # Core types +├── cluster/ # Cluster orchestration +│ ├── controller/ # Controller service +│ ├── worker/ # Worker service +│ ├── client.py # Client interface +│ └── types.py # Shared types +├── proto/ # Protocol definitions +├── worker_pool.py # Task dispatch +└── *_pb2.py, *_connect.py # Generated RPC code +``` + +## Component Documentation + +- [Controller Overview](docs/controller.md) - Job scheduling and coordination +- [Worker Overview](docs/worker.md) - Job execution and container management +- [Actor System Overview](docs/actor.md) - RPC and service discovery + +## Quick Start + +### Submitting a Job + +```python +from fluster.cluster import RpcClusterClient, Entrypoint, create_environment +from fluster.cluster_pb2 import ResourceSpec + +def my_task(): + print("Hello from fluster!") + +client = RpcClusterClient("http://controller:8080") +job_id = client.submit( + name="my-job", + entrypoint=Entrypoint(callable=my_task), + resources=ResourceSpec(cpu=1, memory="2GB"), + environment=create_environment(), +) +client.wait(job_id) +``` + +### Running an Actor Server + +```python +from fluster.actor import ActorServer, ActorContext + +class InferenceActor: + def predict(self, ctx: ActorContext, data: list) -> list: + return [x * 2 for x in data] + +server = ActorServer(controller_address="http://controller:8080") +server.register("inference", InferenceActor()) +server.serve() +``` + +### Calling Actors + +```python +from fluster.actor import ActorPool, ClusterResolver + +resolver = ClusterResolver("http://controller:8080") +pool: ActorPool = resolver.lookup("inference") +pool.wait_for_size(1) + +result = pool.call().predict([1, 2, 3]) +``` + +### Using WorkerPool for Task Dispatch + +```python +from fluster.worker_pool import WorkerPool, WorkerPoolConfig +from fluster.cluster import RpcClusterClient + +client = RpcClusterClient("http://controller:8080") +config = WorkerPoolConfig(num_workers=10, resources=ResourceSpec(cpu=2)) +pool = WorkerPool(client, config) + +futures = [pool.submit(process_shard, shard) for shard in shards] +results = [f.result() for f in futures] +pool.shutdown() +``` + +## Design Principles + +1. **Shallow interfaces**: Components expose minimal APIs with clear responsibilities +2. **Explicit over implicit**: No magic discovery or hidden state synchronization +3. **Stateless workers**: Task retry and load balancing work because workers maintain no shared state +4. **Arbitrary callables**: Jobs and actor methods accept any picklable Python callable + +## Related Documentation + +- [Fray-Zero Design](docs/fray-zero.md) - Original design document and rationale diff --git a/lib/fluster/buf.gen.yaml b/lib/fluster/buf.gen.yaml new file mode 100644 index 0000000000..674facd570 --- /dev/null +++ b/lib/fluster/buf.gen.yaml @@ -0,0 +1,10 @@ +version: v2 +managed: + enabled: true +plugins: + - remote: buf.build/protocolbuffers/python + out: src/fluster + - remote: buf.build/protocolbuffers/pyi + out: src/fluster + - remote: buf.build/connectrpc/python + out: src/fluster diff --git a/lib/fluster/buf.yaml b/lib/fluster/buf.yaml new file mode 100644 index 0000000000..000d8e9f4d --- /dev/null +++ b/lib/fluster/buf.yaml @@ -0,0 +1,10 @@ +version: v2 +modules: + - path: src/fluster/proto + - path: src/fluster/actor/proto +lint: + use: + - STANDARD +breaking: + use: + - FILE diff --git a/lib/fluster/docs/actor.md b/lib/fluster/docs/actor.md new file mode 100644 index 0000000000..f2c020a728 --- /dev/null +++ b/lib/fluster/docs/actor.md @@ -0,0 +1,181 @@ +# Actor System Overview + +The Actor system provides RPC communication between Fluster jobs. It allows Python objects to be exposed as network services without writing protocol definitions. Actors register with the Controller's endpoint registry for discovery. Clients use resolvers to locate actors and call their methods with automatic load balancing and retries. Arguments and return values are serialized with cloudpickle, supporting arbitrary Python objects. + +## Components + +| Component | Description | +|-----------|-------------| +| `ActorServer` | Hosts actor instances, handles incoming RPC calls | +| `ActorClient` | Calls actor methods with automatic retry | +| `ActorPool` | Manages multiple endpoints for load balancing and broadcast | +| `Resolver` | Discovers actor endpoints by name | +| `ActorContext` | Injected into actor methods, enables actors to call other actors | + +## ActorServer + +Hosts one or more actor instances and serves RPC requests. Each job runs at most one ActorServer. + +```python +server = ActorServer(controller_address="http://controller:8080") +server.register("inference", InferenceModel()) +server.serve() # blocks, serving requests +``` + +| Method | Description | +|--------|-------------| +| `register(name, actor, metadata)` | Register an actor instance under a name | +| `serve()` | Start serving requests (blocking) | +| `serve_background()` | Start serving in background | +| `shutdown(grace_period)` | Stop the server | + +When an actor is registered, the server notifies the Controller's endpoint registry. Multiple actors (across different jobs) can register under the same name to form a pool. + +## ActorClient + +Calls methods on a specific actor endpoint. Method calls look like local invocations: + +```python +client = ActorClient(resolver, endpoint) +result = client.predict(data) # calls actor.predict(ctx, data) remotely +``` + +The client automatically retries failed calls with exponential backoff. Remote exceptions are propagated to the caller. + +## ActorPool + +Manages multiple endpoints registered under the same actor name. Provides two calling patterns: + +```python +pool = resolver.lookup("inference") +pool.wait_for_size(4) # wait for 4 actors to register + +# Round-robin: routes to one actor +result = pool.call().predict(data) + +# Broadcast: calls all actors in parallel +futures = pool.broadcast().shutdown() +results = [f.result() for f in futures] +``` + +| Method | Description | +|--------|-------------| +| `size` | Current number of endpoints | +| `endpoints` | List of current endpoints | +| `wait_for_size(n, timeout)` | Block until at least n actors available | +| `call()` | Get a client for round-robin calls | +| `broadcast()` | Get a handle for calling all actors | + +## Resolver + +Discovers actor endpoints by name. Three implementations: + +| Implementation | Use Case | +|----------------|----------| +| `ClusterResolver` | Production: queries Controller endpoint registry | +| `FixedResolver` | Testing: static endpoint configuration | +| `GcsResolver` | GCP: discovers from VM metadata | + +```python +resolver = ClusterResolver("http://controller:8080") +pool = resolver.lookup("inference") +``` + +The resolver returns an `ActorPool` that tracks endpoints for the given name. As actors register or unregister, the pool updates automatically. + +## ActorContext + +Passed as the first argument to actor methods. Enables actors to call other actors: + +```python +class CoordinatorActor: + def process(self, ctx: ActorContext, data): + workers = ctx.resolver.lookup("workers") + results = workers.broadcast().transform(data) + return aggregate([r.result() for r in results]) +``` + +| Field | Description | +|-------|-------------| +| `controller_address` | Controller URL | +| `job_id` | ID of the job hosting this actor | +| `namespace` | Namespace for endpoint isolation | +| `resolver` | Resolver for calling other actors | + +## Integration Points + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Controller │ +│ │ +│ Endpoint Registry │ +│ ▲ │ │ +│ RegisterEndpoint │ │ LookupEndpoint │ +└──────────────────────────┼────┼──────────────────────────────────┘ + │ │ + ┌────────────────┘ └────────────────┐ + │ │ + ▼ ▼ +┌──────────────────┐ ┌──────────────────┐ +│ ActorServer │ │ Resolver │ +│ │ │ │ +│ register(name) │ │ lookup(name) │ +│ serve() │ │ │ │ +└──────────────────┘ │ ▼ │ + ▲ │ ActorPool │ + │ │ │ │ + │ Call RPC │ ▼ │ + │ │ ActorClient │ + └───────────────────────────┴──────────────────┘ +``` + +1. **ActorServer** registers endpoints with the Controller +2. **Resolver** queries the Controller to discover endpoints +3. **ActorPool** tracks available endpoints and load balances calls +4. **ActorClient** makes RPC calls to individual endpoints + +## Usage Patterns + +### Single Actor +```python +# Server +server = ActorServer(controller_address) +server.register("model", MyModel()) +server.serve() + +# Client +pool = resolver.lookup("model") +result = pool.call().predict(x) +``` + +### Actor Pool (Load Balancing) +```python +# Multiple servers register under same name +for _ in range(4): + server = ActorServer(controller_address) + server.register("inference", InferenceModel()) + server.serve_background() + +# Client load balances across all +pool = resolver.lookup("inference") +pool.wait_for_size(4) +results = [pool.call().predict(batch) for batch in batches] +``` + +### Broadcast +```python +pool = resolver.lookup("workers") +futures = pool.broadcast().checkpoint() +for f in futures: + f.result() # wait for all to complete +``` + +## File Summary + +| File | Purpose | +|------|---------| +| `server.py` | `ActorServer` hosting and registration | +| `client.py` | `ActorClient` RPC calls with retry | +| `pool.py` | `ActorPool` load balancing and broadcast | +| `resolver.py` | `Resolver` protocol and implementations | +| `types.py` | `ActorContext`, `ActorEndpoint`, type definitions | diff --git a/lib/fluster/docs/cluster-tuneup.md b/lib/fluster/docs/cluster-tuneup.md new file mode 100644 index 0000000000..36ad4cce71 --- /dev/null +++ b/lib/fluster/docs/cluster-tuneup.md @@ -0,0 +1,309 @@ +# Cluster Resource Scheduling Tuneup + +This document describes the implementation plan for resource-aware scheduling in the fluster controller. + +## Goals + +1. Track worker resource availability and consumption (request-based, not actual utilization) +2. Implement FIFO queue that skips unfittable jobs (don't block smaller jobs behind large ones) +3. Add scheduling timeout with `JOB_STATE_UNSCHEDULABLE` state +4. Match jobs to workers based on CPU, memory, device type, and device variant + +## Current Problem + +The current scheduler (`scheduler.py:129-141`) has this behavior: +```python +while True: + job = self._state.pop_next_pending() + if not job: + break + worker = find_worker_for_job(self._state, job) + if not worker: + self._state.add_job(job) # re-queue + break # STOP TRYING - blocks everything! +``` + +This means a large job that can't fit blocks all smaller jobs behind it. + +## Implementation Plan + +### 1. Proto Changes (`lib/fluster/src/fluster/proto/cluster.proto`) + +Add new job state and scheduling timeout field: + +```protobuf +enum JobState { + // ... existing states ... + JOB_STATE_UNSCHEDULABLE = 8; // NEW: Couldn't be scheduled within timeout +} + +message LaunchJobRequest { + // ... existing fields ... + int32 scheduling_timeout_seconds = 8; // NEW: 0 = no timeout (wait forever) +} +``` + +Run `uv run buf generate` after changes. + +### 2. New Resource Utilities (`lib/fluster/src/fluster/cluster/controller/resources.py`) + +Create new module for resource parsing: + +```python +def parse_memory_string(memory_str: str) -> int: + """Parse '8g', '16gb', '512m' to bytes.""" + +def get_device_type(device: DeviceConfig) -> str: + """Return 'cpu', 'gpu', or 'tpu'.""" + +def get_device_variant(device: DeviceConfig) -> str | None: + """Return variant like 'A100', 'v5litepod-16', or None.""" + +def get_gpu_count(device: DeviceConfig) -> int: + """Return GPU count from device config.""" +``` + +### 3. State Changes (`lib/fluster/src/fluster/cluster/controller/state.py`) + +Add new methods to `ControllerState` for queue management: + +```python +def peek_pending_jobs(self) -> list[ControllerJob]: + """Return all PENDING jobs in queue order without removing them.""" + +def remove_from_queue(self, job_id: JobId) -> None: + """Remove a specific job from the queue.""" +``` + +**Note:** We do NOT track committed resources incrementally. Instead, we compute +available headroom dynamically by summing resources of jobs in `worker.running_jobs`. +This avoids sync issues and is simpler to reason about. + +### 4. Worker Matching (`lib/fluster/src/fluster/cluster/controller/workers.py`) + +Replace first-fit with resource-aware matching: + +```python +def get_committed_resources(state: ControllerState, worker: ControllerWorker) -> tuple[int, int, int]: + """Compute resources committed to running jobs on this worker. + + Dynamically sums resources from all jobs in worker.running_jobs. + Returns (cpu, memory_bytes, gpu_count). + """ + cpu, memory, gpu = 0, 0, 0 + for job_id in worker.running_jobs: + job = state.get_job(job_id) + if job: + cpu += job.request.resources.cpu + memory += parse_memory_string(job.request.resources.memory) + gpu += get_gpu_count(job.request.resources.device) + return cpu, memory, gpu + +def worker_can_fit_job(state: ControllerState, worker: ControllerWorker, job: ControllerJob) -> bool: + """Check if worker has sufficient capacity. + + Computes available headroom dynamically from running_jobs: + 1. CPU: job.cpu <= worker.total_cpu - committed_cpu + 2. Memory: job.memory <= worker.total_memory - committed_memory + 3. Device type: exact match (GPU job only on GPU worker) + 4. Device variant: if specified (not "auto"), must match worker + 5. GPU count: job.gpu_count <= available_gpus + """ + +def find_worker_for_job(state, job) -> ControllerWorker | None: + """Find first healthy worker that can fit the job.""" + for worker in state.get_available_workers(): + if worker_can_fit_job(state, worker, job): + return worker + return None +``` + +### 5. Scheduler Loop (`lib/fluster/src/fluster/cluster/controller/scheduler.py`) + +New scheduling algorithm: + +```python +def _schedule_pending_jobs(self) -> None: + """Schedule pending jobs with resource-aware matching. + + New algorithm: + 1. Peek all pending jobs (don't pop) + 2. For each job in FIFO order: + a. Check scheduling timeout - if expired, mark UNSCHEDULABLE + b. Find a worker that can fit the job (headroom computed dynamically) + c. If found: dispatch, remove from queue, add to worker.running_jobs + d. If not found: skip to next job (DON'T block queue) + """ + now_ms = int(time.time() * 1000) + pending_jobs = self._state.peek_pending_jobs() + + for job in pending_jobs: + if self._is_job_timed_out(job, now_ms): + self._mark_unschedulable(job, now_ms) + continue + + worker = find_worker_for_job(self._state, job) + if not worker: + continue # Skip, don't block! + + success = self._dispatch_fn(job, worker) + if success: + self._handle_successful_dispatch(job, worker, now_ms) + else: + self._handle_failed_dispatch(job, worker) +``` + +Key helper methods: +- `_is_job_timed_out(job, now_ms)` - check if scheduling timeout exceeded +- `_mark_unschedulable(job, now_ms)` - set state, set error, remove from queue +- `_handle_successful_dispatch(job, worker, now_ms)` - update state, add to running_jobs, remove from queue + +**Note:** No explicit resource commit/release needed - headroom is computed dynamically +from `worker.running_jobs` each time we check if a job fits. + +### 6. Heartbeat Updates (`lib/fluster/src/fluster/cluster/controller/heartbeat.py`) + +**No changes needed for resource tracking.** When jobs complete: +1. Heartbeat syncs terminal state from worker +2. Job is removed from `worker.running_jobs` (existing behavior) +3. Next scheduling pass automatically sees increased headroom + +The dynamic computation approach means resource release is automatic when +`running_jobs.discard(job_id)` is called. + +### 7. Types Update (`lib/fluster/src/fluster/cluster/types.py`) + +Add UNSCHEDULABLE to terminal states: + +```python +def is_job_finished(state: int) -> bool: + return state in ( + JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_KILLED, + JOB_STATE_WORKER_FAILED, JOB_STATE_UNSCHEDULABLE, # NEW + ) +``` + +## Test Plan + +### New Test File: `tests/cluster/controller/test_resources.py` + +- `test_parse_memory_string` - parameterized for '1g', '8g', '512m', etc. +- `test_parse_memory_string_invalid` - ValueError for bad input +- `test_get_device_type_cpu/gpu/tpu` - extract device type +- `test_get_device_variant` - extract variant or None +- `test_get_gpu_count` - extract count from DeviceConfig + +### Updates to `tests/cluster/controller/test_workers.py` + +Add resource matching tests: +- `test_worker_can_fit_job_cpu_constraint` - job.cpu > available +- `test_worker_can_fit_job_memory_constraint` - job.memory > available +- `test_worker_can_fit_job_device_type_mismatch` - GPU job on CPU worker +- `test_worker_can_fit_job_gpu_variant_match` - exact variant match +- `test_worker_can_fit_job_gpu_variant_auto` - "auto" matches any + +### Updates to `tests/cluster/controller/test_scheduler.py` + +- `test_scheduler_skips_jobs_that_dont_fit` - big job doesn't block small job +- `test_scheduler_marks_job_unschedulable_on_timeout` - timeout handling +- `test_scheduler_commits_resources_on_dispatch` - committed updated +- `test_scheduler_fifo_ordering_preserved` - jobs dispatched in order when possible + +### Updates to `tests/cluster/controller/test_heartbeat.py` + +- `test_heartbeat_releases_resources_on_completion` - committed decremented + +### Tests to Remove (Obvious/Redundant) + +From `test_state.py`: +- `test_controller_job_defaults` - validates default values +- `test_controller_worker_defaults` - validates default values + +From `test_workers.py`: +- `test_load_workers_from_config_empty_list` - trivial + +From `test_service.py`: +- `test_launch_job_returns_job_id` - obvious RPC behavior +- `test_list_jobs_empty` - trivial + +## Example Updates (`lib/fluster/examples/cluster_example.py`) + +### New Example: Resource Serialization + +```python +def example_resource_scheduling(cluster: ClusterContext): + """Demonstrate resource-aware scheduling with queuing.""" + def cpu_job(n): + time.sleep(2) + return n + + # Worker has 4 CPUs. Submit 4 jobs each requiring 2 CPUs. + # Only 2 can run at a time, so jobs serialize in pairs. + job_ids = [ + cluster.submit(cpu_job, i, name=f"job-{i}", resources={"cpu": 2}) + for i in range(4) + ] + + for jid in job_ids: + cluster.wait(jid) +``` + +### New Example: Scheduling Timeout + +```python +def example_scheduling_timeout(cluster: ClusterContext): + """Demonstrate scheduling timeout.""" + def impossible_job(): + pass + + job_id = cluster.submit( + impossible_job, + resources={"cpu": 100}, # More than any worker has + scheduling_timeout_seconds=5, + ) + status = cluster.wait(job_id) + assert status["state"] == "JOB_STATE_UNSCHEDULABLE" +``` + +### Update `submit()` Method + +Add `resources` and `scheduling_timeout_seconds` parameters to the `submit()` method. + +## Implementation Order + +1. Proto changes + regenerate bindings +2. `resources.py` - memory parsing, device helpers +3. `state.py` - CommittedResources, new methods +4. `workers.py` - worker_can_fit_job +5. `scheduler.py` - new scheduling loop +6. `heartbeat.py` - release resources on completion +7. `types.py` - add UNSCHEDULABLE to is_job_finished +8. Tests - add new, remove obvious +9. Examples - add resource scheduling demos + +## Files Modified + +- `lib/fluster/src/fluster/proto/cluster.proto` +- `lib/fluster/src/fluster/cluster/controller/state.py` +- `lib/fluster/src/fluster/cluster/controller/workers.py` +- `lib/fluster/src/fluster/cluster/controller/scheduler.py` +- `lib/fluster/src/fluster/cluster/controller/heartbeat.py` +- `lib/fluster/src/fluster/cluster/types.py` +- `lib/fluster/examples/cluster_example.py` +- `lib/fluster/tests/cluster/controller/test_workers.py` +- `lib/fluster/tests/cluster/controller/test_scheduler.py` +- `lib/fluster/tests/cluster/controller/test_heartbeat.py` +- `lib/fluster/tests/cluster/controller/test_state.py` (removals) +- `lib/fluster/tests/cluster/controller/test_service.py` (removals) + +## New Files + +- `lib/fluster/src/fluster/cluster/controller/resources.py` +- `lib/fluster/tests/cluster/controller/test_resources.py` + +## Verification + +1. Run existing tests: `uv run pytest lib/fluster/tests/cluster/controller/ -v` +2. Run the cluster example: `cd lib/fluster && uv run python examples/cluster_example.py` +3. Verify new examples demonstrate serialized scheduling +4. Verify scheduling timeout produces UNSCHEDULABLE state diff --git a/lib/fluster/docs/controller-class.md b/lib/fluster/docs/controller-class.md new file mode 100644 index 0000000000..5253492e6e --- /dev/null +++ b/lib/fluster/docs/controller-class.md @@ -0,0 +1,262 @@ +# Controller Class + +The `Controller` class provides a unified interface for managing all controller components and their lifecycle. + +## Overview + +Instead of manually initializing and managing 5+ controller components, the `Controller` class encapsulates: + +- **ControllerState**: Thread-safe job and worker state +- **Scheduler**: Background thread for job scheduling +- **HeartbeatMonitor**: Background thread for worker health checks +- **ControllerServiceImpl**: RPC service implementation +- **ControllerDashboard**: Web dashboard and HTTP server + +## Basic Usage + +```python +from pathlib import Path +from fluster.cluster.controller import Controller, ControllerConfig +from fluster.cluster.controller.state import ControllerJob, ControllerWorker +from fluster.cluster.types import JobId, WorkerId +from fluster import cluster_pb2 + +# Define callbacks +def dispatch_job(job: ControllerJob, worker: ControllerWorker) -> bool: + """Dispatch a job to a worker.""" + # Send RPC to worker + return True + +def send_heartbeat(address: str) -> cluster_pb2.HeartbeatResponse | None: + """Check worker health.""" + # Send heartbeat RPC + return response + +def on_worker_failed(worker_id: WorkerId, job_ids: list[JobId]) -> None: + """Handle worker failure.""" + # Retry jobs, log failure, etc. + pass + +# Configure controller +config = ControllerConfig( + host="127.0.0.1", + port=8080, + bundle_dir=Path("/tmp/bundles"), + scheduler_interval_seconds=0.5, + heartbeat_interval_seconds=2.0, +) + +# Create and start controller +controller = Controller( + config=config, + dispatch_fn=dispatch_job, + heartbeat_fn=send_heartbeat, + on_worker_failed=on_worker_failed, +) +controller.start() + +try: + # Submit jobs + response = controller.launch_job(job_request) + job_id = response.job_id + + # Query status + status = controller.get_job_status(job_id) + + # Register workers + controller.register_worker(worker_request) + +finally: + controller.stop() +``` + +## Configuration + +The `ControllerConfig` dataclass provides all configuration options: + +```python +@dataclass +class ControllerConfig: + host: str = "127.0.0.1" + port: int = 0 # 0 for auto-assign + bundle_dir: Path | None = None + scheduler_interval_seconds: float = 0.5 + heartbeat_interval_seconds: float = 2.0 +``` + +## Methods + +### Lifecycle Methods + +- **`start()`**: Start all background components (scheduler, heartbeat monitor, dashboard server) +- **`stop()`**: Stop all background components gracefully + +### Job Management + +- **`launch_job(request)`**: Submit a new job +- **`get_job_status(job_id)`**: Query job status +- **`terminate_job(job_id)`**: Terminate a running job + +### Worker Management + +- **`register_worker(request)`**: Register a worker with the controller + +## Properties + +- **`state`**: Access to the underlying `ControllerState` for advanced usage +- **`url`**: HTTP URL of the controller dashboard and RPC service + +## Callbacks + +The Controller requires three callback functions: + +### dispatch_fn + +```python +def dispatch_fn(job: ControllerJob, worker: ControllerWorker) -> bool: + """Dispatch a job to a worker. + + Args: + job: Job to dispatch + worker: Worker to dispatch to + + Returns: + True if dispatch succeeded, False otherwise + """ +``` + +Called by the scheduler when a job should be dispatched to a worker. Should send an RPC to the worker to start the job. + +### heartbeat_fn + +```python +def heartbeat_fn(address: str) -> cluster_pb2.HeartbeatResponse | None: + """Check worker health. + + Args: + address: Worker address (host:port) + + Returns: + HeartbeatResponse on success, None on failure + """ +``` + +Called by the heartbeat monitor to check worker health. Should send an RPC to the worker. + +### on_worker_failed + +```python +def on_worker_failed(worker_id: WorkerId, job_ids: list[JobId]) -> None: + """Handle worker failure. + + Args: + worker_id: Failed worker ID + job_ids: Jobs that were running on the worker + """ +``` + +Called when a worker exceeds the heartbeat failure threshold. Should handle job retry logic. + +## Benefits Over Manual Composition + +### Before (Manual) + +```python +# Create all components manually +state = ControllerState() +scheduler = Scheduler(state, dispatch_fn, interval_seconds=0.5) +heartbeat_monitor = HeartbeatMonitor(state, heartbeat_fn, on_worker_failed, interval_seconds=2.0) +service = ControllerServiceImpl(state, scheduler, bundle_dir=bundle_dir) +dashboard = ControllerDashboard(service, host="127.0.0.1", port=8080) + +# Start each component +scheduler.start() +heartbeat_monitor.start() +server_thread = threading.Thread(target=run_server, daemon=True) +server_thread.start() +time.sleep(1.0) + +# Use components +response = service.launch_job(request, None) + +# Stop everything +heartbeat_monitor.stop() +scheduler.stop() +``` + +### After (Controller) + +```python +# Create and configure +config = ControllerConfig(port=8080, bundle_dir=bundle_dir) +controller = Controller(config, dispatch_fn, heartbeat_fn, on_worker_failed) + +# Start +controller.start() + +# Use +response = controller.launch_job(request) + +# Stop +controller.stop() +``` + +## Dashboard Access + +When the controller is running, the dashboard is accessible at: + +- **`/`**: Web dashboard with auto-refresh +- **`/health`**: Health check endpoint +- **`/api/stats`**: Statistics JSON +- **`/api/jobs`**: Jobs list JSON +- **`/api/workers`**: Workers list JSON +- **`/api/actions`**: Recent actions log JSON +- **`/fluster.cluster.ControllerService/*`**: Connect RPC endpoints + +## Testing + +The Controller class is designed to be easily testable by accepting callbacks in the constructor: + +```python +def test_controller(): + dispatch_calls = [] + + def mock_dispatch(job, worker): + dispatch_calls.append((job, worker)) + return True + + config = ControllerConfig(port=0) + controller = Controller(config, mock_dispatch, mock_heartbeat, mock_failure) + + # Test functionality + controller.launch_job(request) + assert len(dispatch_calls) > 0 +``` + +## Advanced Usage + +For advanced use cases, access the underlying state: + +```python +# Access all jobs +all_jobs = controller.state.list_all_jobs() + +# Access all workers +all_workers = controller.state.list_all_workers() + +# Get recent actions +actions = controller.state.get_recent_actions() + +# Direct state manipulation (use with caution) +job = controller.state.get_job(JobId(job_id)) +if job: + job.state = cluster_pb2.JOB_STATE_FAILED +``` + +## See Also + +- [Controller State](../src/fluster/cluster/controller/state.py): In-memory state management +- [Scheduler](../src/fluster/cluster/controller/scheduler.py): Job scheduling algorithm +- [Heartbeat Monitor](../src/fluster/cluster/controller/heartbeat.py): Worker health checking +- [Controller Service](../src/fluster/cluster/controller/service.py): RPC implementation +- [Dashboard](../src/fluster/cluster/controller/dashboard.py): Web UI and HTTP server diff --git a/lib/fluster/docs/controller-v0.md b/lib/fluster/docs/controller-v0.md new file mode 100644 index 0000000000..1b1ce54790 --- /dev/null +++ b/lib/fluster/docs/controller-v0.md @@ -0,0 +1,1016 @@ +# Controller V0 + +We're working on the controller from `fray-zero.md`. This controller manages a +set of VMs, starts Fluster workers on them, and manages job dispatch and +monitoring of those workers. We are going to start implementation of this work. + +We'll start with the server side before moving over to the user-facing client. + +The server RPC protocol is defined in `cluster.proto`. It accepts new jobs, +registers workers, and allows users to monitor job status or terminate jobs as +needed. The controller based on a queueing system, where jobs are dispatched via +a priority queue. The controller has a notion of "users" and users can have +different numbers of credits. + +### Scheduling + +For our first implementation, we'll use a simple FIFO queueing system. The +controller will check if new jobs can be scheduled on the following basis: + +* On a one second timer +* Whenever a reservation is terminated +* Whenever a new worker registers + +The job scheduler runs in a separate thread and is woken in response to the above events. + +### Worker Registration + +The v0 cluster will accept a list of workers to use at startup. Future +iterations will allow workers to register themselves automatically via RPC. + +Workers have a set of properties that define their capabilities, e.g. the +number of TPUs, amount of RAM, etc. + +### Heartbeats and Health Checks + +The controller will periodically e.g. 1/second check the health of all workers. + +* The heartbeat will contain a "since" timestamp, and the worker will respond with +a list of all jobs currently running on that worker or which have been running +on that worker since the "since" timestamp. + +* If the worker fails to resspond to N consecutive heartbeats, the controller will mark the worker as unhealthy and remove it from the list of available workers. +Jobs on that worker will be terminated with JOB_STATE_WORKER_FAILED status, and retried if eligible. + +### Job Failure and Retries + +Job failures come in 2 types: external (worker failures) and internal (job +failures). A job may specify how many of either type of failure it should +tolerate. + +Jobs are retried at the cluster level. If a job fails for an internal reason, it +may be re-scheduled onto the same worker or a different worker. + +### Gang Scheduling + +TPU jobs are "gang-scheduled" onto a set of linked workers. If any of the +workers or jobs fails, _all_ jobs in the gang are terminated by the +controller.The job gang will be retried if eligible. + +### Dashboard + +The controller provides a web UI dashboard, which shows: + +* Recent actions log, e.g. job started, job terminated, etc. +* Job queue showing all jobs in priority order +* List of users, with links to their jobs and credits +* List of reservations, with links to their jobs and available resources +* List of workers, with overview of the worker and links to the worker status and recent jobs on that worker + +## Future Work + +This work is deferred to a future iteration: + +### Reservations + +Jobs are run inside of "reservations", which define the maximum number of +resources e.g. RAM, TPUs available for the job. The reservation ID, as with all +Fray information, is communicated to a job via the FRAY_RESERVATION_ID +environment variable. + +Reservations are tied to their initial job, and are cleaned up when that job completes. +Jobs are tied to their parent, and are cleaned up when that job completes. +Jobs run inside of a _namespace_, which is communicated to a job via the FRAY_NAMESPACE +environment variable. + +### User Credits + +Users have a set of credits that determine the maximum number of resources they +can use. Users credits are used automatically to determine the priority of +jobs. As a user runs jobs, it depletes their available credits, with future jobs +running at a lower priority. + +--- + +## Implementation Plan + +This section defines a tight, incrementally testable implementation path. Each +stage builds on the previous and has explicit test checkpoints. + +### Stage 1: Proto Updates for Controller-Worker Communication ✓ + +**Status**: Completed + +Added to `cluster.proto`: +- `JOB_STATE_WORKER_FAILED = 7` to JobState enum +- Worker registration: `WorkerInfo`, `RegisterWorkerRequest`, `RegisterWorkerResponse` +- Heartbeat: `HeartbeatRequest`, `HeartbeatResponse`, `WorkerHealthStatus` +- `ListWorkersRequest`, `ListWorkersResponse` +- New RPCs: `RegisterWorker`, `Heartbeat`, `ListWorkers` + +Cleaned up `types.py`: +- Removed redundant dataclasses (`ResourceConfig`, `JobRequest`, `VMInfo`, etc.) +- Kept: type aliases, TPU topology info, `Entrypoint`, `create_environment()` + +--- + +### Stage 2: Controller Core Data Structures + +**Goal**: In-memory state for jobs, workers, and the queue. + +```python +# lib/fluster/src/fluster/cluster/controller/state.py + +from dataclasses import dataclass, field +from collections import deque +from threading import RLock +from typing import NewType + +from fluster import cluster_pb2 +from fluster.cluster.types import JobId, WorkerId + +@dataclass +class ControllerJob: + """Controller's view of a job.""" + job_id: JobId + request: cluster_pb2.LaunchJobRequest + state: int = cluster_pb2.JOB_STATE_PENDING + worker_id: WorkerId | None = None + + # Retry tracking + failure_count: int = 0 + preemption_count: int = 0 + max_retries_failure: int = 0 + max_retries_preemption: int = 100 + + # Gang scheduling + gang_id: str | None = None + + # Timestamps + submitted_at_ms: int = 0 + started_at_ms: int | None = None + finished_at_ms: int | None = None + + error: str | None = None + exit_code: int | None = None + + +@dataclass +class ControllerWorker: + """Controller's view of a worker.""" + worker_id: WorkerId + address: str + resources: cluster_pb2.ResourceSpec + + # Health tracking + healthy: bool = True + consecutive_failures: int = 0 + last_heartbeat_ms: int = 0 + + # Current assignments + running_jobs: set[JobId] = field(default_factory=set) + + +class ControllerState: + """Thread-safe controller state. + + All mutations go through methods that acquire the lock. + """ + + def __init__(self): + self._lock = RLock() + self._jobs: dict[JobId, ControllerJob] = {} + self._workers: dict[WorkerId, ControllerWorker] = {} + self._queue: deque[JobId] = deque() # FIFO queue of PENDING jobs + self._gangs: dict[str, set[JobId]] = {} # gang_id -> job_ids + + def add_job(self, job: ControllerJob) -> None: + with self._lock: + self._jobs[job.job_id] = job + self._queue.append(job.job_id) + if job.gang_id: + self._gangs.setdefault(job.gang_id, set()).add(job.job_id) + + def get_job(self, job_id: JobId) -> ControllerJob | None: + with self._lock: + return self._jobs.get(job_id) + + def pop_next_pending(self) -> ControllerJob | None: + """Pop next job from queue if available.""" + with self._lock: + while self._queue: + job_id = self._queue.popleft() + job = self._jobs.get(job_id) + if job and job.state == cluster_pb2.JOB_STATE_PENDING: + return job + return None + + def add_worker(self, worker: ControllerWorker) -> None: + with self._lock: + self._workers[worker.worker_id] = worker + + def get_worker(self, worker_id: WorkerId) -> ControllerWorker | None: + with self._lock: + return self._workers.get(worker_id) + + def get_available_workers(self) -> list[ControllerWorker]: + """Return healthy workers with capacity.""" + with self._lock: + return [w for w in self._workers.values() if w.healthy] + + def get_gang_jobs(self, gang_id: str) -> list[ControllerJob]: + with self._lock: + job_ids = self._gangs.get(gang_id, set()) + return [self._jobs[jid] for jid in job_ids if jid in self._jobs] +``` + +**Deliverable**: `ControllerState` class with thread-safe operations. + +**Test checkpoint**: +```python +def test_controller_state_fifo_order(): + state = ControllerState() + job1 = ControllerJob(job_id=JobId("j1"), request=..., submitted_at_ms=100) + job2 = ControllerJob(job_id=JobId("j2"), request=..., submitted_at_ms=200) + state.add_job(job1) + state.add_job(job2) + + assert state.pop_next_pending().job_id == "j1" + assert state.pop_next_pending().job_id == "j2" + assert state.pop_next_pending() is None + +def test_controller_state_skip_non_pending(): + state = ControllerState() + job1 = ControllerJob(job_id=JobId("j1"), request=...) + job1.state = cluster_pb2.JOB_STATE_RUNNING # Already started + job2 = ControllerJob(job_id=JobId("j2"), request=...) + state.add_job(job1) + state.add_job(job2) + + # Should skip j1 since it's not PENDING + assert state.pop_next_pending().job_id == "j2" +``` + +--- + +### Stage 3: Worker Registry (Static Configuration) + +**Goal**: Load workers from config at startup, track health. + +```python +# lib/fluster/src/fluster/cluster/controller/workers.py + +from dataclasses import dataclass +import grpc +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerWorker +from fluster.cluster.types import WorkerId + +@dataclass +class WorkerConfig: + """Static worker configuration for v0.""" + worker_id: str + address: str + resources: cluster_pb2.ResourceSpec + + +def load_workers_from_config( + state: ControllerState, + workers: list[WorkerConfig], +) -> None: + """Register workers from static config.""" + import time + now_ms = int(time.time() * 1000) + + for cfg in workers: + worker = ControllerWorker( + worker_id=WorkerId(cfg.worker_id), + address=cfg.address, + resources=cfg.resources, + last_heartbeat_ms=now_ms, + ) + state.add_worker(worker) + + +def find_worker_for_job( + state: ControllerState, + job: "ControllerJob", +) -> ControllerWorker | None: + """Find a worker that can run the given job. + + For v0: simple first-fit on available workers. + Future: resource matching (TPU type, memory, etc). + """ + workers = state.get_available_workers() + for worker in workers: + # TODO: Check resource compatibility + # For now, any healthy worker works + return worker + return None +``` + +**Deliverable**: Worker loading from config, simple scheduling. + +**Test checkpoint**: +```python +def test_load_workers(): + state = ControllerState() + workers = [ + WorkerConfig("w1", "host1:8080", make_cpu_spec()), + WorkerConfig("w2", "host2:8080", make_cpu_spec()), + ] + load_workers_from_config(state, workers) + + assert len(state.get_available_workers()) == 2 + assert state.get_worker(WorkerId("w1")).address == "host1:8080" +``` + +--- + +### Stage 4: Job Scheduler Thread + +**Goal**: Background thread that matches pending jobs to workers. + +```python +# lib/fluster/src/fluster/cluster/controller/scheduler.py + +import threading +import time +import logging +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerJob +from fluster.cluster.controller.workers import find_worker_for_job + +logger = logging.getLogger(__name__) + +class Scheduler: + """Background scheduler that dispatches jobs to workers. + + Wakes on: + - 1 second timer + - wake() called (new worker, job finished, etc) + """ + + def __init__( + self, + state: ControllerState, + dispatch_fn: "Callable[[ControllerJob, ControllerWorker], bool]", + interval_seconds: float = 1.0, + ): + self._state = state + self._dispatch_fn = dispatch_fn + self._interval = interval_seconds + self._wake_event = threading.Event() + self._stop = False + self._thread: threading.Thread | None = None + + def start(self) -> None: + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop = True + self._wake_event.set() + if self._thread: + self._thread.join(timeout=5.0) + + def wake(self) -> None: + """Signal scheduler to run immediately.""" + self._wake_event.set() + + def _run(self) -> None: + while not self._stop: + self._wake_event.wait(timeout=self._interval) + self._wake_event.clear() + + if self._stop: + break + + self._schedule_pending_jobs() + + def _schedule_pending_jobs(self) -> None: + """Try to schedule all pending jobs.""" + while True: + job = self._state.pop_next_pending() + if not job: + break + + worker = find_worker_for_job(self._state, job) + if not worker: + # No worker available, re-queue + self._state.add_job(job) # Goes to back of queue + break + + # Dispatch to worker + success = self._dispatch_fn(job, worker) + if success: + job.state = cluster_pb2.JOB_STATE_RUNNING + job.worker_id = worker.worker_id + job.started_at_ms = int(time.time() * 1000) + worker.running_jobs.add(job.job_id) + logger.info(f"Dispatched job {job.job_id} to worker {worker.worker_id}") + else: + # Dispatch failed, mark worker unhealthy and retry + worker.healthy = False + self._state.add_job(job) # Re-queue + logger.warning(f"Failed to dispatch to {worker.worker_id}, re-queuing job") +``` + +**Deliverable**: Scheduler that runs dispatch loop. + +**Test checkpoint**: +```python +def test_scheduler_dispatches_jobs(): + state = ControllerState() + dispatched = [] + + def mock_dispatch(job, worker): + dispatched.append((job.job_id, worker.worker_id)) + return True + + scheduler = Scheduler(state, mock_dispatch, interval_seconds=0.1) + + # Add worker and job + state.add_worker(ControllerWorker(WorkerId("w1"), "addr", make_spec())) + state.add_job(ControllerJob(JobId("j1"), request=...)) + + scheduler.start() + scheduler.wake() + time.sleep(0.2) + scheduler.stop() + + assert dispatched == [("j1", "w1")] + assert state.get_job(JobId("j1")).state == cluster_pb2.JOB_STATE_RUNNING + +def test_scheduler_requeues_when_no_workers(): + state = ControllerState() + scheduler = Scheduler(state, lambda j, w: True, interval_seconds=0.1) + + # Add job but no workers + state.add_job(ControllerJob(JobId("j1"), request=...)) + + scheduler.start() + scheduler.wake() + time.sleep(0.2) + scheduler.stop() + + # Job should still be pending (re-queued) + assert state.get_job(JobId("j1")).state == cluster_pb2.JOB_STATE_PENDING +``` + +--- + +### Stage 5: Worker Heartbeat Monitor + +**Goal**: Periodically poll workers, detect failures, mark jobs. + +```python +# lib/fluster/src/fluster/cluster/controller/heartbeat.py + +import threading +import time +import logging +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState + +logger = logging.getLogger(__name__) + +class HeartbeatMonitor: + """Monitors worker health via periodic heartbeats. + + On N consecutive failures: + - Mark worker unhealthy + - Mark all running jobs as WORKER_FAILED + - Trigger retry logic + """ + + MAX_CONSECUTIVE_FAILURES = 3 + + def __init__( + self, + state: ControllerState, + heartbeat_fn: "Callable[[str], cluster_pb2.HeartbeatResponse | None]", + on_worker_failed: "Callable[[WorkerId, list[JobId]], None]", + interval_seconds: float = 1.0, + ): + self._state = state + self._heartbeat_fn = heartbeat_fn + self._on_worker_failed = on_worker_failed + self._interval = interval_seconds + self._stop = False + self._thread: threading.Thread | None = None + + def start(self) -> None: + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop = True + if self._thread: + self._thread.join(timeout=5.0) + + def _run(self) -> None: + while not self._stop: + time.sleep(self._interval) + self._check_all_workers() + + def _check_all_workers(self) -> None: + workers = list(self._state._workers.values()) # Snapshot + now_ms = int(time.time() * 1000) + + for worker in workers: + if not worker.healthy: + continue + + response = self._heartbeat_fn(worker.address) + + if response is None: + # Heartbeat failed + worker.consecutive_failures += 1 + logger.warning( + f"Heartbeat failed for {worker.worker_id} " + f"({worker.consecutive_failures}/{self.MAX_CONSECUTIVE_FAILURES})" + ) + + if worker.consecutive_failures >= self.MAX_CONSECUTIVE_FAILURES: + self._handle_worker_failure(worker) + else: + # Success - reset failure count, update job states + worker.consecutive_failures = 0 + worker.last_heartbeat_ms = now_ms + self._sync_job_states(worker, response) + + def _handle_worker_failure(self, worker) -> None: + """Mark worker dead, fail its jobs.""" + logger.error(f"Worker {worker.worker_id} declared dead") + worker.healthy = False + + failed_jobs = list(worker.running_jobs) + worker.running_jobs.clear() + + # Mark jobs as failed + for job_id in failed_jobs: + job = self._state.get_job(job_id) + if job: + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + job.finished_at_ms = int(time.time() * 1000) + job.error = f"Worker {worker.worker_id} failed" + + # Notify for retry handling + self._on_worker_failed(worker.worker_id, failed_jobs) + + def _sync_job_states(self, worker, response: cluster_pb2.HeartbeatResponse) -> None: + """Update controller state from worker's heartbeat response.""" + for status in response.jobs: + job = self._state.get_job(JobId(status.job_id)) + if not job: + continue + + # Update state from worker + if status.state in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + ): + job.state = status.state + job.finished_at_ms = status.finished_at_ms + job.error = status.error or None + job.exit_code = status.exit_code + worker.running_jobs.discard(job.job_id) +``` + +**Deliverable**: Heartbeat monitor with failure detection. + +**Test checkpoint**: +```python +def test_heartbeat_marks_worker_failed_after_n_failures(): + state = ControllerState() + worker = ControllerWorker(WorkerId("w1"), "addr", make_spec()) + job = ControllerJob(JobId("j1"), request=...) + job.state = cluster_pb2.JOB_STATE_RUNNING + job.worker_id = worker.worker_id + worker.running_jobs.add(job.job_id) + + state.add_worker(worker) + state.add_job(job) # Add directly, not via queue + + failed_workers = [] + def on_failed(wid, jobs): + failed_workers.append((wid, jobs)) + + # Heartbeat always fails + monitor = HeartbeatMonitor( + state, + heartbeat_fn=lambda addr: None, # Always fail + on_worker_failed=on_failed, + interval_seconds=0.05, + ) + + monitor.start() + time.sleep(0.3) # Wait for 3+ failures + monitor.stop() + + assert not worker.healthy + assert job.state == cluster_pb2.JOB_STATE_WORKER_FAILED + assert failed_workers == [("w1", ["j1"])] +``` + +--- + +### Stage 6: Job Failure and Retry Logic + +**Goal**: Re-queue eligible failed jobs, track retry counts. + +```python +# lib/fluster/src/fluster/cluster/controller/retry.py + +import logging +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerJob +from fluster.cluster.types import JobId + +logger = logging.getLogger(__name__) + +def handle_job_failure( + state: ControllerState, + job_id: JobId, + is_worker_failure: bool, +) -> bool: + """Handle a job failure, potentially retrying. + + Returns True if job was re-queued for retry. + """ + job = state.get_job(job_id) + if not job: + return False + + if is_worker_failure: + job.preemption_count += 1 + can_retry = job.preemption_count <= job.max_retries_preemption + else: + job.failure_count += 1 + can_retry = job.failure_count <= job.max_retries_failure + + if can_retry: + logger.info( + f"Retrying job {job_id} " + f"(failures={job.failure_count}, preemptions={job.preemption_count})" + ) + job.state = cluster_pb2.JOB_STATE_PENDING + job.worker_id = None + job.started_at_ms = None + job.finished_at_ms = None + job.error = None + state.add_job(job) # Re-queue + return True + else: + logger.warning(f"Job {job_id} exceeded retry limit, not retrying") + return False + + +def handle_gang_failure( + state: ControllerState, + gang_id: str, + is_worker_failure: bool, +) -> list[JobId]: + """Handle gang failure - terminate all jobs, optionally retry. + + Returns list of job IDs that were re-queued. + """ + jobs = state.get_gang_jobs(gang_id) + if not jobs: + return [] + + # First, terminate all jobs in gang + for job in jobs: + if job.state == cluster_pb2.JOB_STATE_RUNNING: + job.state = cluster_pb2.JOB_STATE_KILLED + job.error = f"Gang {gang_id} failed" + + # Check if gang can be retried (all jobs must have retries left) + if is_worker_failure: + can_retry = all( + job.preemption_count < job.max_retries_preemption + for job in jobs + ) + else: + can_retry = all( + job.failure_count < job.max_retries_failure + for job in jobs + ) + + if can_retry: + retried = [] + for job in jobs: + if is_worker_failure: + job.preemption_count += 1 + else: + job.failure_count += 1 + job.state = cluster_pb2.JOB_STATE_PENDING + job.worker_id = None + state.add_job(job) + retried.append(job.job_id) + return retried + + return [] +``` + +**Deliverable**: Retry logic respecting limits. + +**Test checkpoint**: +```python +def test_job_retry_on_worker_failure(): + state = ControllerState() + job = ControllerJob( + JobId("j1"), request=..., + max_retries_preemption=2, + ) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + state.add_job(job) + + # First failure - should retry + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.state == cluster_pb2.JOB_STATE_PENDING + assert job.preemption_count == 1 + + # Second failure - should retry + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == 2 + + # Third failure - should NOT retry (exceeded limit) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert not handle_job_failure(state, JobId("j1"), is_worker_failure=True) + +def test_gang_all_or_nothing_retry(): + state = ControllerState() + job1 = ControllerJob(JobId("j1"), request=..., gang_id="g1", max_retries_failure=1) + job2 = ControllerJob(JobId("j2"), request=..., gang_id="g1", max_retries_failure=0) + + state.add_job(job1) + state.add_job(job2) + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + + # Gang fails - j2 has 0 retries, so entire gang cannot retry + retried = handle_gang_failure(state, "g1", is_worker_failure=False) + assert retried == [] + assert job1.state == cluster_pb2.JOB_STATE_KILLED + assert job2.state == cluster_pb2.JOB_STATE_KILLED +``` + +--- + +### Stage 7: Controller Service Implementation + +**Goal**: Wire up RPC handlers using the state and scheduler. + +```python +# lib/fluster/src/fluster/cluster/controller/service.py + +import time +import uuid +from connectrpc.code import Code +from connectrpc.errors import ConnectError +from connectrpc.request import RequestContext + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerJob +from fluster.cluster.controller.scheduler import Scheduler +from fluster.cluster.types import JobId + + +class ControllerServiceImpl: + """ControllerService RPC implementation.""" + + def __init__(self, state: ControllerState, scheduler: Scheduler): + self._state = state + self._scheduler = scheduler + + def launch_job( + self, + request: cluster_pb2.LaunchJobRequest, + _ctx: RequestContext, + ) -> cluster_pb2.LaunchJobResponse: + job_id = str(uuid.uuid4()) + + job = ControllerJob( + job_id=JobId(job_id), + request=request, + submitted_at_ms=int(time.time() * 1000), + ) + + self._state.add_job(job) + self._scheduler.wake() # Try to schedule immediately + + return cluster_pb2.LaunchJobResponse(job_id=job_id) + + def get_job_status( + self, + request: cluster_pb2.GetJobStatusRequest, + _ctx: RequestContext, + ) -> cluster_pb2.GetJobStatusResponse: + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + return cluster_pb2.GetJobStatusResponse( + job=cluster_pb2.JobStatus( + job_id=job.job_id, + state=job.state, + error=job.error or "", + exit_code=job.exit_code or 0, + started_at_ms=job.started_at_ms or 0, + finished_at_ms=job.finished_at_ms or 0, + worker_id=job.worker_id or "", + ) + ) + + def terminate_job( + self, + request: cluster_pb2.TerminateJobRequest, + _ctx: RequestContext, + ) -> cluster_pb2.Empty: + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + # TODO: Send kill to worker + job.state = cluster_pb2.JOB_STATE_KILLED + job.finished_at_ms = int(time.time() * 1000) + + return cluster_pb2.Empty() + + def list_jobs( + self, + request: cluster_pb2.ListJobsRequest, + _ctx: RequestContext, + ) -> cluster_pb2.ListJobsResponse: + jobs = [ + cluster_pb2.JobStatus( + job_id=j.job_id, + state=j.state, + worker_id=j.worker_id or "", + ) + for j in self._state._jobs.values() + ] + return cluster_pb2.ListJobsResponse(jobs=jobs) +``` + +**Deliverable**: Working RPC handlers. + +**Test checkpoint**: +```python +def test_launch_job_adds_to_queue(): + state = ControllerState() + scheduler = Scheduler(state, lambda j, w: True) + service = ControllerServiceImpl(state, scheduler) + + response = service.launch_job( + cluster_pb2.LaunchJobRequest(name="test"), + None, + ) + + assert response.job_id + job = state.get_job(JobId(response.job_id)) + assert job.state == cluster_pb2.JOB_STATE_PENDING +``` + +--- + +### Stage 8: Integration Test - End to End + +**Goal**: Full flow from job submission to completion. + +```python +def test_full_job_lifecycle(worker_server): + """Integration test with real worker.""" + # Start controller with worker + state = ControllerState() + load_workers_from_config(state, [ + WorkerConfig("w1", worker_server.address, make_cpu_spec()) + ]) + + # Create dispatcher that calls real worker + def dispatch(job, worker): + client = create_worker_client(worker.address) + response = client.run_job(cluster_pb2.RunJobRequest( + job_id=job.job_id, + serialized_entrypoint=job.request.serialized_entrypoint, + # ... + )) + return response.state != cluster_pb2.JOB_STATE_FAILED + + scheduler = Scheduler(state, dispatch) + heartbeat_monitor = HeartbeatMonitor( + state, + heartbeat_fn=lambda addr: do_heartbeat(addr), + on_worker_failed=lambda wid, jobs: ..., + ) + + service = ControllerServiceImpl(state, scheduler) + + # Start background threads + scheduler.start() + heartbeat_monitor.start() + + try: + # Submit job + response = service.launch_job( + cluster_pb2.LaunchJobRequest( + name="test", + serialized_entrypoint=cloudpickle.dumps(lambda: print("hello")), + ), + None, + ) + + # Wait for completion + job_id = JobId(response.job_id) + for _ in range(100): + job = state.get_job(job_id) + if job.state in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + ): + break + time.sleep(0.1) + + assert job.state == cluster_pb2.JOB_STATE_SUCCEEDED + finally: + scheduler.stop() + heartbeat_monitor.stop() +``` + +--- + +### Stage 9: Dashboard (HTTP) + +**Goal**: Simple HTML dashboard for visibility. + +```python +# lib/fluster/src/fluster/cluster/controller/dashboard.py + +from aiohttp import web +from fluster.cluster.controller.state import ControllerState + +def create_dashboard_app(state: ControllerState) -> web.Application: + app = web.Application() + + async def index(request): + workers = state.get_available_workers() + jobs = list(state._jobs.values()) + + html = f""" + + Fluster Controller + +

Fluster Controller Dashboard

+ +

Workers ({len(workers)} healthy)

+ + + {"".join(f"" for w in state._workers.values())} +
IDAddressHealthyRunning
{w.worker_id}{w.address}{w.healthy}{len(w.running_jobs)}
+ +

Jobs ({len(jobs)} total)

+ + + {"".join(f"" for j in jobs)} +
IDStateWorkerError
{j.job_id}{j.state}{j.worker_id or '-'}{j.error or '-'}
+ + """ + return web.Response(text=html, content_type="text/html") + + app.router.add_get("/", index) + return app +``` + +**Deliverable**: Basic dashboard showing workers and jobs. + +--- + +## Testing Summary + +| Stage | Test Focus | Key Assertions | +|-------|------------|----------------| +| 2 | Queue ordering | FIFO order, skip non-pending | +| 3 | Worker loading | Config -> state | +| 4 | Scheduler | Dispatch loop, re-queue on no workers | +| 5 | Heartbeat | Failure detection, job marking | +| 6 | Retry | Limits respected, gang all-or-nothing | +| 7 | RPC | Basic request/response | +| 8 | Integration | Full lifecycle | + +## File Structure + +``` +lib/fluster/src/fluster/cluster/controller/ +├── __init__.py +├── state.py # Stage 2: ControllerState, ControllerJob, ControllerWorker +├── workers.py # Stage 3: Worker config loading +├── scheduler.py # Stage 4: Scheduler thread +├── heartbeat.py # Stage 5: HeartbeatMonitor +├── retry.py # Stage 6: Retry logic +├── service.py # Stage 7: RPC handlers +└── dashboard.py # Stage 9: HTTP dashboard +``` diff --git a/lib/fluster/docs/controller-v1.md b/lib/fluster/docs/controller-v1.md new file mode 100644 index 0000000000..ada5f0e2d2 --- /dev/null +++ b/lib/fluster/docs/controller-v1.md @@ -0,0 +1,33 @@ +# Controller V1 + +# Auth and User workflow + +First version, like Ray: + +* ssh into bridge machine +* proxy to controller port +* use localhost: + +No real auth for v0. + +# Auth workflow + +Controller opens a public port, accessible without SSH. +Users have either a GCP secret or auth via e.g. Google SSO to the controller, get a session token. +Now all controller RPCs accept session token to auth user. +Controller is always at a fixed address or DNS or something like that. + +`cluster.marin.community` + +GcpResolver -> tag="fluster.controller" -> host:port + +# User workflow + +run "train" on v5p:4x4 somewhere + +# Worker and controller serialization + +We should run e..g the dashboard off of the serialized state of the controller/worker +This would be good to use for post-mortem + +# Running under appengine would simplify SSO diff --git a/lib/fluster/docs/controller.md b/lib/fluster/docs/controller.md new file mode 100644 index 0000000000..51eddd1803 --- /dev/null +++ b/lib/fluster/docs/controller.md @@ -0,0 +1,115 @@ +# Controller Overview + +The Controller is the central coordination service in Fluster. It accepts job submissions from clients, assigns jobs to available workers, tracks job status through completion, and maintains an endpoint registry for actor discovery. All cluster-wide state flows through the Controller. + +## Responsibilities + +| Responsibility | Description | +|----------------|-------------| +| Job scheduling | Accepts job requests, queues them, assigns to workers with available capacity | +| Worker management | Tracks registered workers and their health status | +| Status tracking | Maintains authoritative job state, reports status to clients | +| Endpoint registry | Stores actor endpoints for service discovery | + +## RPC Interface + +The Controller exposes a single RPC service (`ControllerService`) with these methods: + +### Job Management + +| Method | Description | +|--------|-------------| +| `LaunchJob(JobRequest)` | Submit a job for execution, returns `JobId` | +| `GetJobStatus(JobId)` | Query current status of a job | +| `TerminateJob(JobId)` | Request termination of a running job | +| `ListJobs(filter)` | List jobs matching optional filter criteria | + +### Worker Management + +| Method | Description | +|--------|-------------| +| `RegisterWorker(WorkerInfo)` | Add a worker to the scheduling pool | +| `ListWorkers()` | List all registered workers | + +### Endpoint Registry + +| Method | Description | +|--------|-------------| +| `RegisterEndpoint(name, address, metadata)` | Register an actor endpoint | +| `UnregisterEndpoint(endpoint_id)` | Remove an actor endpoint | +| `LookupEndpoint(name)` | Find endpoints by actor name | +| `ListEndpoints(prefix)` | List all endpoints, optionally filtered by name prefix | + +## Job Lifecycle + +Jobs progress through these states: + +``` +PENDING ──► BUILDING ──► RUNNING ──► SUCCEEDED + │ │ + ▼ ▼ + FAILED FAILED +``` + +| State | Description | +|-------|-------------| +| `PENDING` | Job submitted, waiting for worker assignment | +| `BUILDING` | Assigned to worker, environment being prepared | +| `RUNNING` | Job entrypoint executing | +| `SUCCEEDED` | Completed with exit code 0 | +| `FAILED` | Completed with error or non-zero exit | +| `KILLED` | Terminated by user request | +| `WORKER_FAILED` | Worker became unresponsive | +| `UNSCHEDULABLE` | No worker can satisfy resource requirements | + +## Scheduling Behavior + +The Controller assigns pending jobs to workers using first-fit scheduling: + +1. Jobs are processed in submission order (FIFO) +2. Each job is assigned to the first worker with sufficient capacity +3. Jobs remain pending until a suitable worker is available +4. Failed jobs may be retried based on failure type and retry policy + +## Integration Points + +``` +┌──────────────┐ ┌──────────────┐ +│ Client │ │ Worker │ +│ │ │ │ +│ submit() ───┼────────►│ │ +│ status() ───┼────────►│◄─────────────┼── RegisterWorker +│ wait() ───┼────────►│ │ +│ │ │◄─────────────┼── job dispatch +└──────────────┘ │ │ + │◄─────────────┼── status updates +┌──────────────┐ │ │ +│ ActorServer │ └──────────────┘ +│ │ │ +│ register() ──┼────────────────┘ +└──────────────┘ + ▲ + │ +┌──────────────┐ +│ Resolver │ +│ │ +│ lookup() ────┼── LookupEndpoint +└──────────────┘ +``` + +- **Clients** submit jobs and query status via the Controller +- **Workers** register themselves and receive job assignments +- **ActorServers** register endpoints for discovery +- **Resolvers** query the endpoint registry to locate actors + +## File Summary + +| File | Purpose | +|------|---------| +| `controller.py` | Main `Controller` class and startup | +| `state.py` | State management and data types | +| `scheduler.py` | Job-to-worker assignment logic | +| `service.py` | RPC method implementations | +| `dashboard.py` | Web monitoring UI | +| `retry.py` | Retry policy for failed jobs | +| `workers.py` | Worker capacity evaluation | diff --git a/lib/fluster/docs/fray-zero-actor-and-resolver.md b/lib/fluster/docs/fray-zero-actor-and-resolver.md new file mode 100644 index 0000000000..4dc54d176f --- /dev/null +++ b/lib/fluster/docs/fray-zero-actor-and-resolver.md @@ -0,0 +1,1836 @@ +# fray-zero-actor-and-resolver + +Ref: [fray-zero.md](fray-zero.md) +Ref: [controller-v1.md](controller-v1.md) +Ref: [impl-recipe.md](impl-recipe.md) +Ref: [controller-v2.md](controller-v2.md) + +We're moving onto the _Actor_ and _Resolver_ components of the fray-zero system, +which will leverage our cluster controller and workers to run jobs. + +## Resolvers + +A resolver maps a _service name_ e.g. the name of an actor to a set of strings, +typically URLs which represent the location of the actor or gRPC endpoint. The +resolved URL indicates both the protocol to use (either grpc or actor), as well +as the host and port of the server(s) providing the service. + +We provide 3 types of resolvers: + +1. Controller Metadata Service +2. GCS VM Tags +3. Fixed Addresses + +Example usage: + +``` +resolver = GcsResolver() +resolver.resolve("fluster-controller") -> ["grpc://host:port/controller"] +``` + +## Namespaces + +In general we want actor names to be isolated across user jobs. To this end, by +default Fluster creates a new _namespace_ based on the initial job ID in a +Fluster run. A typical Fluster run involves a _leader_ job which requests +further resources. This leader job creates a new _namespace_ which is propagated +via the `FLUSTER_NAMESPACE` environment variable. + +When the fluster context is used to launch child jobs, the parent namespace +environment variable is automatically propagated by default. Users may override this to specify a shared global namespace to give children unique namespaces as needed. + +Thus a typical RL tree might look like: + +``` +: + trainer/0 + rollout/0 + rollout/1 + inference/0 + inference/1 +``` + +The actor namespace for these jobs is by default shared across all jobs, +therefore when they attempt to resolve e.g. a curriculum or training actor, they +will resolve to the same actor. + +From the perspective of the Actor/RPC system, it means we should: + +* Accept a namespace argument for the MetadataResolver +* Accept a namespace argument for the ActorServer +* Default these arguments to the FLUSTER_NAMESPACE environment variable or equivalent context var +* Default the namespace to FLUSTER_NAMESPACE if not specified and FLUSTER_JOB_ID is set + +The remaining work is handled by Flusters default injection of the +FLUSTER_JOB_ID variable and propagation of the namespace to child jobs via the +`fluster.launch` API. + +## Actors and Actor Servers + +Users define actors as a Python class with a set of methods which are registered +to an ActorServer. The ActorServer optionally registers itself with a cluster +controller to allow _discovery_ via the resolver pattern. + +``` +class ActorClass: + def method_a(self, arg1: int, arg2: str) -> int: + pass + +# e.g. this can be provided by the cluster controller +class NameMapping(Protocol): + def register(self, name: str, target_url: str): + pass + +server = ActorServer(host, port, NameMapping()) +server.register("actor_name", ActorClass()) +server.start() + +class ActorServer: + def register(self, name: str, klass): + self.mapping.register(name, f"actor://{self.host}:{self.port}/{name}") +``` + +### Actor server implementation + +The actor service is a _indirection_ over a generic "Actor" gRPC service. An +example implementation might look like: + +``` +message ActorRequest { + string actor_handle = 1; + string method = 2; + bytes args = 3; +} + +message ActorResponse { + bytes result = 1; +} + +service ActorService { + rpc Invoke(ActorRequest) returns (ActorResponse); + rpc ListMethods(ListMethodsRequest) returns (ListMethodsResponse); + rpc ListActors(ListActorsRequest) returns (ListActorsResponse); +} +``` + +Note that we use an actor _handle_ to identify the actor, which ensures we +can detect if an actor was terminated between invocations. + +## Actor Client and Name Resolution + +The actor client handles both name resolution via the resolver as well as +invocation of the actor methods. + +```python +from typing import Generic, TypeVar, ParamSpec, Callable, Awaitable, overload +from dataclasses import dataclass, replace + +T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") + +@dataclass(frozen=True) +class RpcOptions: + timeout: float | None = None + retries: int = 0 + +class RpcClient(Generic[T]): + def __init__(self, resolver: Resolver, cls: type[T], options: RpcOptions = RpcOptions()): + self._resolver = resolver + self._cls = cls + self._options = options + + def with_options(self, **kwargs) -> "RpcClient[T]": + return RpcClient(self._resolver, self._cls, replace(self._options, **kwargs)) + + def __getattr__(self, name: str) -> Callable[..., Awaitable]: + assert hasattr(self._cls, name), f"Method `{name}` not found on class `{self._cls.__name__}`" + return RpcMethod(self, name, self._options, self._cls) + +class RpcMethod(Generic[P, R]): + def __init__(self, client: RpcClient[T], name: str, options: RpcOptions, cls: type[T]): + self._client = client + self._name = name + self._options = options.copy() + self._cls = cls + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: + # pickle arguments + # resolve endpoint + # send request + # return result + # potentially cache resolved endpoint and actor handle + # re-resolve endpoint and actor if handle is not found, or host is unreachable + # retry up to retries times + # raise exception if all retries fail +``` + +--- + +# Implementation Plan + +This section provides a **spiral** implementation plan for the Actor and Resolver +system. Each stage delivers a working, testable slice of functionality that builds +on the previous stage. Unlike a ladder approach (proto → types → resolver → server +→ client), each spiral stage touches multiple components to create something useful. + +## Design Decisions + +Based on discussion, the following decisions guide the implementation: + +1. **ActorContext injection**: Via contextvar. ActorServer sets `_actor_context` + contextvar before calling user code. User code calls `current_ctx()` to access. + This is cleaner than signature inspection and works with any method signature. + +2. **Proto design**: Use `bytes serialized_callable` + `bytes serialized_args` + + `bytes serialized_kwargs`. Ignore the existing proto sketch - do the right thing. + +3. **Resolver implementations**: All 3 types (ClusterResolver, FixedResolver, + GcsResolver) live in `resolver.py`. Keep files simple. + +4. **Resolver return type**: Resolvers return a `ResolveResult` (list of URLs + + metadata), not ActorPool. ActorClient and ActorPool are layered on top. + +5. **Default namespace**: `""` when running without a cluster. + +6. **Namespace isolation**: Client-side convention, not enforced by controller. + +7. **Broadcast semantics**: Returns `BroadcastFuture` with `wait_all()`, + `wait_any()`, `as_completed()` methods. + +8. **Failure handling**: `pool.call()` propagates exceptions without auto-retry. + +9. **Endpoint lifecycle**: Controller monitors job status. When a job transitions + to a terminal state, the controller removes all endpoints registered by that job. + The controller always indirects through the jobs map when resolving - if the job + is not RUNNING, the endpoint is not returned. + +10. **Testing strategy**: Prefer real implementations with dummy backends (e.g., + tempdir) over mocks. Design APIs like GcsResolver to accept an injectable + `GcsApi` interface for testing. + +11. **RPC infrastructure**: Use Connect-RPC for everything (the generated code). + +## Spiral Stages Overview + +| Stage | Deliverable | Key Components | Test | +|-------|-------------|----------------|------| +| 1 | Minimal e2e actor call | proto, server, client (hardcoded URL) | call method, get result | +| 2 | Resolver integration | FixedResolver, update client | resolve then call | +| 3 | Controller endpoint registry | state.py, service.py, job lifecycle hooks | register, lookup, job cleanup | +| 4 | ClusterResolver | ClusterResolver, integrate with client | e2e with controller discovery | +| 5 | ActorPool | pool.py with round-robin, broadcast | load-balanced and broadcast calls | +| 6 | GcsResolver | GcsResolver with injectable GcsApi | mock-based testing | +| 7 | Introspection (optional) | ListMethods, ListActors RPCs | debugging helpers | +| 8 | Integration examples | cluster_example.py updates | full demo | + +--- + +## Stage 1: Minimal End-to-End Actor Call + +**Goal**: Get a working actor server + client with direct connection. This validates +the core RPC mechanism before adding resolution complexity. + +**Files to modify/create**: +- `src/fluster/actor/proto/actor.proto` - update to final design +- `src/fluster/actor/server.py` - new +- `src/fluster/actor/client.py` - new (hardcoded URL version) +- `src/fluster/actor/types.py` - add `current_ctx()` contextvar +- `tests/actor/test_actor_e2e.py` - new + +### Proto Changes + +Replace `actor.proto` with a clean design: + +```protobuf +syntax = "proto3"; +package fluster.actor; +option py_generic_services = true; + +message ActorCall { + string method_name = 1; + string actor_name = 2; // Which actor on this server + bytes serialized_args = 3; // cloudpickle((arg1, arg2, ...)) + bytes serialized_kwargs = 4; // cloudpickle({k1: v1, ...}) +} + +message ActorResponse { + oneof result { + bytes serialized_value = 1; // cloudpickle(return_value) + ActorError error = 2; + } +} + +message ActorError { + string error_type = 1; + string message = 2; + bytes serialized_exception = 3; // cloudpickle(exception) for re-raise +} + +message Empty {} + +message HealthResponse { + bool healthy = 1; +} + +service ActorService { + rpc Call(ActorCall) returns (ActorResponse); + rpc HealthCheck(Empty) returns (HealthResponse); +} +``` + +### ActorContext via contextvar + +Update `types.py`: + +```python +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fluster.actor.resolver import Resolver + from fluster.cluster.client import Cluster + +_actor_context: ContextVar["ActorContext | None"] = ContextVar("actor_context", default=None) + +def current_ctx() -> "ActorContext": + """Get the current ActorContext. Raises if not in an actor call.""" + ctx = _actor_context.get() + if ctx is None: + raise RuntimeError("current_ctx() called outside of actor method") + return ctx + +def _set_actor_context(ctx: "ActorContext | None") -> None: + """Internal: set the actor context for the current call.""" + _actor_context.set(ctx) + +@dataclass +class ActorContext: + """Context available to actor methods via current_ctx().""" + cluster: "Cluster | None" + resolver: "Resolver | None" + job_id: str + namespace: str +``` + +### Minimal ActorServer + +```python +# src/fluster/actor/server.py +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable + +import cloudpickle +from starlette.applications import Starlette +from starlette.responses import Response +from starlette.routing import Route + +from fluster.actor.types import ActorContext, _set_actor_context, ActorId + + +@dataclass +class RegisteredActor: + name: str + actor_id: ActorId + instance: Any + methods: dict[str, Callable] + registered_at_ms: int = field(default_factory=lambda: int(time.time() * 1000)) + + +class ActorServer: + def __init__(self, host: str = "0.0.0.0", port: int = 0): + self._host = host + self._port = port + self._actors: dict[str, RegisteredActor] = {} + self._context: ActorContext | None = None + self._app: Starlette | None = None + self._actual_port: int | None = None + + @property + def address(self) -> str: + port = self._actual_port or self._port + return f"{self._host}:{port}" + + def register(self, name: str, actor: Any) -> ActorId: + actor_id = ActorId(f"{name}-{uuid.uuid4().hex[:8]}") + methods = { + m: getattr(actor, m) + for m in dir(actor) + if not m.startswith("_") and callable(getattr(actor, m)) + } + self._actors[name] = RegisteredActor( + name=name, + actor_id=actor_id, + instance=actor, + methods=methods, + ) + return actor_id + + def _create_app(self) -> Starlette: + async def call_handler(request): + from fluster import actor_pb2 + + # Parse Connect-RPC request + body = await request.body() + call = actor_pb2.ActorCall() + call.ParseFromString(body) + + # Find actor + actor_name = call.actor_name or next(iter(self._actors), "") + actor = self._actors.get(actor_name) + if not actor: + error = actor_pb2.ActorError( + error_type="NotFound", + message=f"Actor '{actor_name}' not found", + ) + resp = actor_pb2.ActorResponse(error=error) + return Response(resp.SerializeToString(), media_type="application/proto") + + method = actor.methods.get(call.method_name) + if not method: + error = actor_pb2.ActorError( + error_type="NotFound", + message=f"Method '{call.method_name}' not found", + ) + resp = actor_pb2.ActorResponse(error=error) + return Response(resp.SerializeToString(), media_type="application/proto") + + try: + args = cloudpickle.loads(call.serialized_args) if call.serialized_args else () + kwargs = cloudpickle.loads(call.serialized_kwargs) if call.serialized_kwargs else {} + + # Set context for this call + _set_actor_context(self._context) + try: + result = method(*args, **kwargs) + finally: + _set_actor_context(None) + + resp = actor_pb2.ActorResponse( + serialized_value=cloudpickle.dumps(result) + ) + return Response(resp.SerializeToString(), media_type="application/proto") + + except Exception as e: + error = actor_pb2.ActorError( + error_type=type(e).__name__, + message=str(e), + serialized_exception=cloudpickle.dumps(e), + ) + resp = actor_pb2.ActorResponse(error=error) + return Response(resp.SerializeToString(), media_type="application/proto") + + async def health_handler(request): + from fluster import actor_pb2 + resp = actor_pb2.HealthResponse(healthy=True) + return Response(resp.SerializeToString(), media_type="application/proto") + + return Starlette(routes=[ + Route("/fluster.actor.ActorService/Call", call_handler, methods=["POST"]), + Route("/fluster.actor.ActorService/HealthCheck", health_handler, methods=["POST"]), + ]) + + def serve_background(self, context: ActorContext | None = None) -> int: + """Start server in background thread. Returns actual port.""" + import threading + import uvicorn + import socket + + self._context = context + self._app = self._create_app() + + # Find available port if port=0 + if self._port == 0: + with socket.socket() as s: + s.bind(("", 0)) + self._actual_port = s.getsockname()[1] + else: + self._actual_port = self._port + + config = uvicorn.Config( + self._app, + host=self._host, + port=self._actual_port, + log_level="error", + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + import time + for _ in range(50): + try: + import httpx + httpx.get(f"http://{self._host}:{self._actual_port}/", timeout=0.1) + except Exception: + pass + time.sleep(0.1) + if server.started: + break + + return self._actual_port +``` + +### Minimal ActorClient (hardcoded URL) + +```python +# src/fluster/actor/client.py +from typing import Any + +import cloudpickle + +from fluster import actor_pb2 + + +class ActorClient: + """Simple actor client with hardcoded URL (Stage 1).""" + + def __init__(self, url: str, actor_name: str = ""): + """ + Args: + url: Direct URL to actor server (e.g., "http://localhost:8080") + actor_name: Name of actor on the server + """ + self._url = url.rstrip("/") + self._actor_name = actor_name + self._timeout = 30.0 + + def __getattr__(self, method_name: str) -> "_RpcMethod": + return _RpcMethod(self, method_name) + + +class _RpcMethod: + def __init__(self, client: ActorClient, method_name: str): + self._client = client + self._method_name = method_name + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + call = actor_pb2.ActorCall( + method_name=self._method_name, + actor_name=self._client._actor_name, + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + ) + + ... + resp = actor_pb2.ActorResponse() + resp.ParseFromString(response.content) + + if resp.HasField("error"): + if resp.error.serialized_exception: + raise cloudpickle.loads(resp.error.serialized_exception) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) +``` + +### Test + +```python +# tests/actor/test_actor_e2e.py +import pytest +from fluster.actor.server import ActorServer +from fluster.actor.client import ActorClient +from fluster.actor.types import current_ctx, ActorContext + + +class Calculator: + def add(self, a: int, b: int) -> int: + return a + b + + def multiply(self, a: int, b: int) -> int: + return a * b + + def divide(self, a: int, b: int) -> float: + return a / b # May raise ZeroDivisionError + + +class ContextAwareActor: + def get_job_id(self) -> str: + return current_ctx().job_id + + +def test_basic_actor_call(): + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + port = server.serve_background() + + client = ActorClient(f"http://127.0.0.1:{port}", "calc") + assert client.add(2, 3) == 5 + assert client.multiply(4, 5) == 20 + + +def test_actor_exception_propagation(): + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + port = server.serve_background() + + client = ActorClient(f"http://127.0.0.1:{port}", "calc") + with pytest.raises(ZeroDivisionError): + client.divide(1, 0) + + +def test_actor_context_injection(): + server = ActorServer(host="127.0.0.1") + server.register("ctx_actor", ContextAwareActor()) + + ctx = ActorContext(cluster=None, resolver=None, job_id="test-job-123", namespace="") + port = server.serve_background(context=ctx) + + client = ActorClient(f"http://127.0.0.1:{port}", "ctx_actor") + assert client.get_job_id() == "test-job-123" +``` + +**Run**: `cd lib/fluster && buf generate && uv run pytest tests/actor/test_actor_e2e.py -v` + +--- + +## Stage 2: Resolver Integration + +**Goal**: Add FixedResolver, update ActorClient to use resolvers. + +**Files to modify/create**: +- `src/fluster/actor/resolver.py` - new (ResolveResult, Resolver protocol, FixedResolver) +- `src/fluster/actor/client.py` - update to accept Resolver +- `tests/actor/test_resolver.py` - new + +### Resolver Types and FixedResolver + +```python +# src/fluster/actor/resolver.py +from dataclasses import dataclass, field +from typing import Protocol + +from fluster.cluster.types import Namespace + + +@dataclass +class ResolvedEndpoint: + """A single resolved endpoint.""" + url: str # e.g., "http://host:port" + actor_id: str # Unique handle for staleness detection + metadata: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ResolveResult: + """Result of resolving an actor name.""" + name: str + namespace: Namespace + endpoints: list[ResolvedEndpoint] = field(default_factory=list) + + @property + def is_empty(self) -> bool: + return len(self.endpoints) == 0 + + def first(self) -> ResolvedEndpoint: + if not self.endpoints: + raise ValueError(f"No endpoints for '{self.name}' in namespace '{self.namespace}'") + return self.endpoints[0] + + +class Resolver(Protocol): + """Protocol for actor name resolution.""" + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + ... + + @property + def default_namespace(self) -> Namespace: + ... + + +class FixedResolver: + """Resolver with statically configured endpoints.""" + + def __init__( + self, + endpoints: dict[str, str | list[str]], + namespace: Namespace = Namespace(""), + ): + self._namespace = namespace + self._endpoints: dict[str, list[str]] = {} + for name, urls in endpoints.items(): + if isinstance(urls, str): + self._endpoints[name] = [urls] + else: + self._endpoints[name] = list(urls) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + ns = namespace or self._namespace + urls = self._endpoints.get(name, []) + endpoints = [ + ResolvedEndpoint(url=url, actor_id=f"fixed-{name}-{i}") + for i, url in enumerate(urls) + ] + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) +``` + +### Updated ActorClient + +```python +# src/fluster/actor/client.py - updated +from typing import Any + +import cloudpickle +import httpx + +from fluster import actor_pb2 +from fluster.actor.resolver import Resolver, ResolveResult + + +class ActorClient: + """Actor client with resolver-based discovery.""" + + def __init__( + self, + resolver: Resolver, + name: str, + timeout: float = 30.0, + ): + self._resolver = resolver + self._name = name + self._timeout = timeout + self._cached_result: ResolveResult | None = None + + def _resolve(self) -> ResolveResult: + if self._cached_result is None or self._cached_result.is_empty: + self._cached_result = self._resolver.resolve(self._name) + return self._cached_result + + def _invalidate_cache(self) -> None: + self._cached_result = None + + def __getattr__(self, method_name: str) -> "_RpcMethod": + return _RpcMethod(self, method_name) + + +class _RpcMethod: + def __init__(self, client: ActorClient, method_name: str): + self._client = client + self._method_name = method_name + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + result = self._client._resolve() + if result.is_empty: + raise RuntimeError(f"No endpoints found for actor '{self._client._name}'") + + endpoint = result.first() + + call = actor_pb2.ActorCall( + method_name=self._method_name, + actor_name=self._client._name, + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + ) + + try: + response = httpx.post( + f"{endpoint.url}/fluster.actor.ActorService/Call", + content=call.SerializeToString(), + headers={"Content-Type": "application/proto"}, + timeout=self._client._timeout, + ) + response.raise_for_status() + except httpx.RequestError: + self._client._invalidate_cache() + raise + + resp = actor_pb2.ActorResponse() + resp.ParseFromString(response.content) + + if resp.HasField("error"): + if resp.error.serialized_exception: + raise cloudpickle.loads(resp.error.serialized_exception) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) +``` + +### Test + +```python +# tests/actor/test_resolver.py +import pytest +from fluster.actor.resolver import FixedResolver, ResolveResult +from fluster.actor.server import ActorServer +from fluster.actor.client import ActorClient +from fluster.cluster.types import Namespace + + +class Echo: + def echo(self, msg: str) -> str: + return f"echo: {msg}" + + +def test_fixed_resolver_single(): + resolver = FixedResolver({"svc": "http://localhost:8080"}) + result = resolver.resolve("svc") + assert len(result.endpoints) == 1 + assert result.first().url == "http://localhost:8080" + + +def test_fixed_resolver_multiple(): + resolver = FixedResolver({"svc": ["http://h1:8080", "http://h2:8080"]}) + result = resolver.resolve("svc") + assert len(result.endpoints) == 2 + + +def test_fixed_resolver_missing(): + resolver = FixedResolver({}) + result = resolver.resolve("missing") + assert result.is_empty + + +def test_client_with_resolver(): + server = ActorServer(host="127.0.0.1") + server.register("echo", Echo()) + port = server.serve_background() + + resolver = FixedResolver({"echo": f"http://127.0.0.1:{port}"}) + client = ActorClient(resolver, "echo") + + assert client.echo("hello") == "echo: hello" +``` + +**Run**: `uv run pytest tests/actor/test_resolver.py -v` + +--- + +## Stage 3: Controller Endpoint Registry + +**Goal**: Implement endpoint registry in controller state and service. The controller +tracks endpoints by job, and automatically removes them when jobs terminate. + +**Key design point**: The controller always indirects through the jobs map when +returning endpoints. If a job is not RUNNING, its endpoints are filtered out. + +**Files to modify**: +- `src/fluster/cluster/controller/state.py` - add endpoint storage +- `src/fluster/cluster/controller/service.py` - implement RPC handlers +- `src/fluster/cluster/controller/heartbeat.py` - cleanup on job termination +- `tests/cluster/controller/test_endpoint_registry.py` - new + +### State Changes + +```python +# Add to state.py + +@dataclass +class ControllerEndpoint: + """An endpoint registered with the controller.""" + endpoint_id: str + name: str + address: str + job_id: JobId + namespace: str + metadata: dict[str, str] = field(default_factory=dict) + registered_at_ms: int = 0 + + +class ControllerState: + def __init__(self): + # ... existing ... + self._endpoints: dict[str, ControllerEndpoint] = {} + self._endpoints_by_job: dict[JobId, set[str]] = {} + + def add_endpoint(self, endpoint: ControllerEndpoint) -> None: + with self._lock: + self._endpoints[endpoint.endpoint_id] = endpoint + self._endpoints_by_job.setdefault(endpoint.job_id, set()).add(endpoint.endpoint_id) + + def remove_endpoint(self, endpoint_id: str) -> ControllerEndpoint | None: + with self._lock: + endpoint = self._endpoints.pop(endpoint_id, None) + if endpoint: + job_endpoints = self._endpoints_by_job.get(endpoint.job_id) + if job_endpoints: + job_endpoints.discard(endpoint_id) + return endpoint + + def lookup_endpoints(self, name: str, namespace: str) -> list[ControllerEndpoint]: + """Find endpoints by name, filtering to only RUNNING jobs.""" + with self._lock: + results = [] + for ep in self._endpoints.values(): + if ep.name != name or ep.namespace != namespace: + continue + # Only return endpoints for running jobs + job = self._jobs.get(ep.job_id) + if job and job.state == cluster_pb2.JOB_STATE_RUNNING: + results.append(ep) + return results + + def list_endpoints_by_prefix(self, prefix: str, namespace: str) -> list[ControllerEndpoint]: + """List endpoints matching prefix, filtering to only RUNNING jobs.""" + with self._lock: + results = [] + for ep in self._endpoints.values(): + if not ep.name.startswith(prefix) or ep.namespace != namespace: + continue + job = self._jobs.get(ep.job_id) + if job and job.state == cluster_pb2.JOB_STATE_RUNNING: + results.append(ep) + return results + + def remove_endpoints_for_job(self, job_id: JobId) -> list[ControllerEndpoint]: + """Remove all endpoints for a job. Called on job termination.""" + with self._lock: + endpoint_ids = list(self._endpoints_by_job.get(job_id, [])) + removed = [] + for eid in endpoint_ids: + ep = self.remove_endpoint(eid) + if ep: + removed.append(ep) + return removed +``` + +### Service Changes + +```python +# Update service.py - replace the stub implementations + +def register_endpoint( + self, + request: cluster_pb2.RegisterEndpointRequest, + ctx: Any, +) -> cluster_pb2.RegisterEndpointResponse: + endpoint_id = str(uuid.uuid4()) + + # Validate job exists and is running + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + if job.state != cluster_pb2.JOB_STATE_RUNNING: + raise ConnectError(Code.FAILED_PRECONDITION, f"Job {request.job_id} is not running") + + endpoint = ControllerEndpoint( + endpoint_id=endpoint_id, + name=request.name, + address=request.address, + job_id=JobId(request.job_id), + namespace=request.namespace or "", + metadata=dict(request.metadata), + registered_at_ms=int(time.time() * 1000), + ) + self._state.add_endpoint(endpoint) + self._state.log_action( + "endpoint_registered", + job_id=job.job_id, + details=f"{request.name} at {request.address}", + ) + return cluster_pb2.RegisterEndpointResponse(endpoint_id=endpoint_id) + + +def unregister_endpoint( + self, + request: cluster_pb2.UnregisterEndpointRequest, + ctx: Any, +) -> cluster_pb2.Empty: + endpoint = self._state.remove_endpoint(request.endpoint_id) + if endpoint: + self._state.log_action( + "endpoint_unregistered", + job_id=endpoint.job_id, + details=endpoint.name, + ) + return cluster_pb2.Empty() + + +def lookup_endpoint( + self, + request: cluster_pb2.LookupEndpointRequest, + ctx: Any, +) -> cluster_pb2.LookupEndpointResponse: + namespace = request.namespace or "" + endpoints = self._state.lookup_endpoints(request.name, namespace) + if not endpoints: + return cluster_pb2.LookupEndpointResponse() + + e = endpoints[0] + return cluster_pb2.LookupEndpointResponse( + endpoint=cluster_pb2.Endpoint( + endpoint_id=e.endpoint_id, + name=e.name, + address=e.address, + job_id=e.job_id, + namespace=e.namespace, + metadata=e.metadata, + ) + ) + + +def list_endpoints( + self, + request: cluster_pb2.ListEndpointsRequest, + ctx: Any, +) -> cluster_pb2.ListEndpointsResponse: + namespace = request.namespace or "" + endpoints = self._state.list_endpoints_by_prefix(request.prefix, namespace) + return cluster_pb2.ListEndpointsResponse( + endpoints=[ + cluster_pb2.Endpoint( + endpoint_id=e.endpoint_id, + name=e.name, + address=e.address, + job_id=e.job_id, + namespace=e.namespace, + metadata=e.metadata, + ) + for e in endpoints + ] + ) +``` + +### Job Termination Cleanup + +In `heartbeat.py` or wherever job state transitions are handled, add: + +```python +def _handle_job_termination(self, job_id: JobId) -> None: + """Clean up when a job transitions to terminal state.""" + removed = self._state.remove_endpoints_for_job(job_id) + for ep in removed: + self._state.log_action( + "endpoint_removed_job_terminated", + job_id=job_id, + details=ep.name, + ) +``` + +### Test + +```python +# tests/cluster/controller/test_endpoint_registry.py +import pytest +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerJob, ControllerEndpoint +from fluster.cluster.types import JobId + + +@pytest.fixture +def state() -> ControllerState: + return ControllerState() + + +def test_add_and_lookup_endpoint(state: ControllerState): + # Create a running job first + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + # Register endpoint + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Lookup + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 1 + assert results[0].address == "10.0.0.1:8080" + + +def test_endpoint_not_returned_for_non_running_job(state: ControllerState): + # Create a completed job + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_SUCCEEDED, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Should not return endpoint because job is not running + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 + + +def test_remove_endpoints_on_job_termination(state: ControllerState): + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Simulate job termination + removed = state.remove_endpoints_for_job(JobId("job-1")) + assert len(removed) == 1 + + # Endpoint should be gone + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 +``` + +**Run**: `uv run pytest tests/cluster/controller/test_endpoint_registry.py -v` + +--- + +## Stage 4: ClusterResolver + +**Goal**: Implement ClusterResolver that queries the controller for endpoints. + +**Files to modify/create**: +- `src/fluster/actor/resolver.py` - add ClusterResolver +- `tests/actor/test_cluster_resolver.py` - new + +### ClusterResolver + +```python +# Add to resolver.py + +import httpx +from fluster import cluster_pb2 + + +class ClusterResolver: + """Resolver backed by the cluster controller's endpoint registry.""" + + def __init__( + self, + controller_address: str, + namespace: Namespace | None = None, + timeout: float = 5.0, + ): + self._address = controller_address.rstrip("/") + self._timeout = timeout + + import os + self._namespace = namespace or Namespace( + os.environ.get("FLUSTER_NAMESPACE", "") + ) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + ns = namespace or self._namespace + + request = cluster_pb2.ListEndpointsRequest( + prefix=name, + namespace=ns, + ) + + response = httpx.post( + f"{self._address}/fluster.cluster.ControllerService/ListEndpoints", + content=request.SerializeToString(), + headers={"Content-Type": "application/proto"}, + timeout=self._timeout, + ) + response.raise_for_status() + + resp = cluster_pb2.ListEndpointsResponse() + resp.ParseFromString(response.content) + + # Filter to exact name matches + endpoints = [ + ResolvedEndpoint( + url=f"http://{ep.address}", + actor_id=ep.endpoint_id, + metadata=dict(ep.metadata), + ) + for ep in resp.endpoints + if ep.name == name + ] + + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) +``` + +### Test (with real controller) + +```python +# tests/actor/test_cluster_resolver.py +import threading +import pytest +import uvicorn +from starlette.applications import Starlette + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState, ControllerJob, ControllerEndpoint +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.controller.scheduler import Scheduler +from fluster.actor.resolver import ClusterResolver +from fluster.cluster.types import JobId, Namespace + + +def create_controller_app(state: ControllerState) -> Starlette: + """Create a minimal controller app for testing.""" + from starlette.responses import Response + from starlette.routing import Route + + scheduler = Scheduler(state, interval=1.0) + service = ControllerServiceImpl(state, scheduler) + + async def list_endpoints_handler(request): + body = await request.body() + req = cluster_pb2.ListEndpointsRequest() + req.ParseFromString(body) + resp = service.list_endpoints(req, None) + return Response(resp.SerializeToString(), media_type="application/proto") + + return Starlette(routes=[ + Route("/fluster.cluster.ControllerService/ListEndpoints", list_endpoints_handler, methods=["POST"]), + ]) + + +@pytest.fixture +def controller_with_endpoint(): + """Start a controller with a registered endpoint.""" + import socket + + state = ControllerState() + + # Add a running job + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + # Add an endpoint + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="inference", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Find free port + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + app = create_controller_app(state) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server + import time + for _ in range(50): + if server.started: + break + time.sleep(0.1) + + yield f"http://127.0.0.1:{port}", state + + +def test_cluster_resolver_finds_endpoint(controller_with_endpoint): + address, state = controller_with_endpoint + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert "10.0.0.1:8080" in result.first().url + + +def test_cluster_resolver_missing_endpoint(controller_with_endpoint): + address, state = controller_with_endpoint + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("nonexistent") + + assert result.is_empty +``` + +**Run**: `uv run pytest tests/actor/test_cluster_resolver.py -v` + +--- + +## Stage 5: ActorPool + +**Goal**: Implement ActorPool for load-balanced and broadcast calls. + +**Files to create**: +- `src/fluster/actor/pool.py` +- `tests/actor/test_actor_pool.py` + +### Implementation + +```python +# src/fluster/actor/pool.py +import itertools +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Callable, Generic, Iterator, TypeVar + +import cloudpickle +import httpx + +from fluster import actor_pb2 +from fluster.actor.resolver import ResolveResult, ResolvedEndpoint, Resolver + +T = TypeVar("T") + + +@dataclass +class CallResult: + """Result of a single call in a broadcast.""" + endpoint: ResolvedEndpoint + value: Any | None = None + exception: BaseException | None = None + + @property + def success(self) -> bool: + return self.exception is None + + +class BroadcastFuture(Generic[T]): + """Future for broadcast results.""" + + def __init__(self, futures: list[tuple[ResolvedEndpoint, Future]]): + self._futures = futures + + def wait_all(self, timeout: float | None = None) -> list[CallResult]: + results = [] + for endpoint, future in self._futures: + try: + value = future.result(timeout=timeout) + results.append(CallResult(endpoint=endpoint, value=value)) + except Exception as e: + results.append(CallResult(endpoint=endpoint, exception=e)) + return results + + def wait_any(self, timeout: float | None = None) -> CallResult: + for future in as_completed([f for _, f in self._futures], timeout=timeout): + idx = next(i for i, (_, f) in enumerate(self._futures) if f is future) + endpoint = self._futures[idx][0] + try: + value = future.result() + return CallResult(endpoint=endpoint, value=value) + except Exception as e: + return CallResult(endpoint=endpoint, exception=e) + raise TimeoutError("No results within timeout") + + def as_completed(self, timeout: float | None = None) -> Iterator[CallResult]: + endpoint_map = {id(f): ep for ep, f in self._futures} + for future in as_completed([f for _, f in self._futures], timeout=timeout): + endpoint = endpoint_map[id(future)] + try: + value = future.result() + yield CallResult(endpoint=endpoint, value=value) + except Exception as e: + yield CallResult(endpoint=endpoint, exception=e) + + +class ActorPool(Generic[T]): + """Pool of actors for load-balanced and broadcast calls.""" + + def __init__(self, resolver: Resolver, name: str, timeout: float = 30.0): + self._resolver = resolver + self._name = name + self._timeout = timeout + self._round_robin: itertools.cycle | None = None + self._cached_result: ResolveResult | None = None + self._executor = ThreadPoolExecutor(max_workers=32) + + def _resolve(self) -> ResolveResult: + result = self._resolver.resolve(self._name) + if self._cached_result is None or result.endpoints != self._cached_result.endpoints: + self._round_robin = itertools.cycle(result.endpoints) if result.endpoints else None + self._cached_result = result + return result + + @property + def size(self) -> int: + return len(self._resolve().endpoints) + + @property + def endpoints(self) -> list[ResolvedEndpoint]: + return list(self._resolve().endpoints) + + def _call_endpoint( + self, + endpoint: ResolvedEndpoint, + method_name: str, + args: tuple, + kwargs: dict, + ) -> Any: + call = actor_pb2.ActorCall( + method_name=method_name, + actor_name=self._name, + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + ) + + response = httpx.post( + f"{endpoint.url}/fluster.actor.ActorService/Call", + content=call.SerializeToString(), + headers={"Content-Type": "application/proto"}, + timeout=self._timeout, + ) + response.raise_for_status() + + resp = actor_pb2.ActorResponse() + resp.ParseFromString(response.content) + + if resp.HasField("error"): + if resp.error.serialized_exception: + raise cloudpickle.loads(resp.error.serialized_exception) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) + + def call(self) -> "_PoolCallProxy[T]": + return _PoolCallProxy(self) + + def broadcast(self) -> "_PoolBroadcastProxy[T]": + return _PoolBroadcastProxy(self) + + +class _PoolCallProxy(Generic[T]): + def __init__(self, pool: ActorPool[T]): + self._pool = pool + + def __getattr__(self, method_name: str) -> Callable[..., Any]: + def call(*args, **kwargs): + self._pool._resolve() + if self._pool._round_robin is None: + raise RuntimeError(f"No endpoints for '{self._pool._name}'") + endpoint = next(self._pool._round_robin) + return self._pool._call_endpoint(endpoint, method_name, args, kwargs) + return call + + +class _PoolBroadcastProxy(Generic[T]): + def __init__(self, pool: ActorPool[T]): + self._pool = pool + + def __getattr__(self, method_name: str) -> Callable[..., BroadcastFuture]: + def broadcast(*args, **kwargs) -> BroadcastFuture: + result = self._pool._resolve() + futures = [] + for endpoint in result.endpoints: + future = self._pool._executor.submit( + self._pool._call_endpoint, + endpoint, + method_name, + args, + kwargs, + ) + futures.append((endpoint, future)) + return BroadcastFuture(futures) + return broadcast +``` + +### Test + +```python +# tests/actor/test_actor_pool.py +import pytest +from fluster.actor.server import ActorServer +from fluster.actor.pool import ActorPool +from fluster.actor.resolver import FixedResolver + + +class Counter: + def __init__(self, start: int = 0): + self._value = start + + def get(self) -> int: + return self._value + + def increment(self) -> int: + self._value += 1 + return self._value + + +def test_pool_round_robin(): + servers = [] + urls = [] + + for i in range(3): + server = ActorServer(host="127.0.0.1") + server.register("counter", Counter(start=i * 100)) + port = server.serve_background() + servers.append(server) + urls.append(f"http://127.0.0.1:{port}") + + resolver = FixedResolver({"counter": urls}) + pool = ActorPool(resolver, "counter") + + assert pool.size == 3 + + # Round-robin should cycle through servers + results = [pool.call().get() for _ in range(6)] + # Should see values from all three servers (0, 100, 200, 0, 100, 200) + assert set(results) == {0, 100, 200} + + +def test_pool_broadcast(): + servers = [] + urls = [] + + for i in range(3): + server = ActorServer(host="127.0.0.1") + server.register("counter", Counter(start=i)) + port = server.serve_background() + servers.append(server) + urls.append(f"http://127.0.0.1:{port}") + + resolver = FixedResolver({"counter": urls}) + pool = ActorPool(resolver, "counter") + + broadcast = pool.broadcast().get() + results = broadcast.wait_all() + + assert len(results) == 3 + assert all(r.success for r in results) + assert {r.value for r in results} == {0, 1, 2} +``` + +**Run**: `uv run pytest tests/actor/test_actor_pool.py -v` + +--- + +## Stage 6: GcsResolver + +**Goal**: Implement GcsResolver with injectable GcsApi for testing. + +**Files to modify**: +- `src/fluster/actor/resolver.py` - add GcsResolver, GcsApi protocol +- `tests/actor/test_gcs_resolver.py` + +### Implementation + +```python +# Add to resolver.py + +from typing import Protocol as TypingProtocol + + +class GcsApi(TypingProtocol): + """Protocol for GCS Compute API operations.""" + + def list_instances(self, project: str, zone: str) -> list[dict]: + """List VM instances with metadata.""" + ... + + +class RealGcsApi: + """Real GCS API using google-cloud-compute.""" + + def list_instances(self, project: str, zone: str) -> list[dict]: + from google.cloud import compute_v1 + + client = compute_v1.InstancesClient() + instances = [] + for instance in client.list(project=project, zone=zone): + metadata = {} + if instance.metadata and instance.metadata.items: + for item in instance.metadata.items: + metadata[item.key] = item.value + + internal_ip = None + if instance.network_interfaces: + internal_ip = instance.network_interfaces[0].network_i_p + + instances.append({ + "name": instance.name, + "internal_ip": internal_ip, + "metadata": metadata, + "status": instance.status, + }) + return instances + + +class MockGcsApi: + """Mock GCS API for testing.""" + + def __init__(self, instances: list[dict] | None = None): + self._instances = instances or [] + + def set_instances(self, instances: list[dict]) -> None: + self._instances = instances + + def list_instances(self, project: str, zone: str) -> list[dict]: + return self._instances + + +class GcsResolver: + """Resolver using GCS VM instance metadata tags.""" + + ACTOR_PREFIX = "fluster_actor_" + NAMESPACE_KEY = "fluster_namespace" + + def __init__( + self, + project: str, + zone: str, + namespace: Namespace | None = None, + api: GcsApi | None = None, + ): + self._project = project + self._zone = zone + self._api = api or RealGcsApi() + + import os + self._namespace = namespace or Namespace( + os.environ.get("FLUSTER_NAMESPACE", "") + ) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + ns = namespace or self._namespace + endpoints = [] + + instances = self._api.list_instances(self._project, self._zone) + + for instance in instances: + if instance.get("status") != "RUNNING": + continue + + metadata = instance.get("metadata", {}) + instance_ns = metadata.get(self.NAMESPACE_KEY, "") + + if instance_ns != ns: + continue + + actor_key = f"{self.ACTOR_PREFIX}{name}" + if actor_key in metadata: + port = metadata[actor_key] + ip = instance.get("internal_ip") + if ip: + endpoints.append(ResolvedEndpoint( + url=f"http://{ip}:{port}", + actor_id=f"gcs-{instance['name']}-{name}", + metadata={"instance": instance["name"]}, + )) + + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) +``` + +### Test + +```python +# tests/actor/test_gcs_resolver.py +import pytest +from fluster.actor.resolver import GcsResolver, MockGcsApi +from fluster.cluster.types import Namespace + + +def test_gcs_resolver_finds_actors(): + api = MockGcsApi([ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ]) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert "10.0.0.1:8080" in result.first().url + + +def test_gcs_resolver_filters_namespace(): + api = MockGcsApi([ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "other-ns", + "fluster_actor_inference": "8080", + }, + }, + ]) + resolver = GcsResolver("project", "zone", namespace=Namespace(""), api=api) + result = resolver.resolve("inference") + + assert result.is_empty + + +def test_gcs_resolver_ignores_non_running(): + api = MockGcsApi([ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "TERMINATED", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ]) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert result.is_empty +``` + +**Run**: `uv run pytest tests/actor/test_gcs_resolver.py -v` + +--- + +## Stage 7: Introspection RPCs (Optional) + +**Goal**: Add ListMethods and ListActors for debugging. + +This is a polish stage. Add to `actor.proto`: + +```protobuf +message ListMethodsRequest { + string actor_name = 1; +} + +message MethodInfo { + string name = 1; + string signature = 2; + string docstring = 3; +} + +message ListMethodsResponse { + repeated MethodInfo methods = 1; +} + +message ListActorsRequest {} + +message ActorInfo { + string name = 1; + string actor_id = 2; + int64 registered_at_ms = 3; + map metadata = 4; +} + +message ListActorsResponse { + repeated ActorInfo actors = 1; +} + +service ActorService { + rpc Call(ActorCall) returns (ActorResponse); + rpc HealthCheck(Empty) returns (HealthResponse); + rpc ListMethods(ListMethodsRequest) returns (ListMethodsResponse); + rpc ListActors(ListActorsRequest) returns (ListActorsResponse); +} +``` + +Then add handlers in `ActorServer`. This is deferred until the core path works. + +--- + +## Stage 8: Integration Examples + +**Goal**: Add examples to `cluster_example.py` demonstrating actor patterns. + +**Implemented**: ✅ + +Added three comprehensive examples to `examples/cluster_example.py`: + +### 1. Basic Actor Pattern (`example_actor_basic`) +Demonstrates: +- Creating and registering an actor server +- Registering endpoints with the controller +- Using ActorClient with ClusterResolver for discovery +- Calling actor methods with arguments and return values + +Example actor: Calculator with add(), multiply(), and get_history() methods. + +### 2. Coordinator Pattern (`example_actor_coordinator`) +Demonstrates: +- Coordinator actor managing a task queue +- Worker actors fetching tasks from coordinator +- Context injection via `current_ctx()` for actor-to-actor communication +- Workers using ActorClient to communicate with coordinator + +Shows the pull-based task distribution pattern where workers fetch tasks, process them, and report results back. + +### 3. Actor Pool Pattern (`example_actor_pool`) +Demonstrates: +- ActorPool for load-balanced calls across multiple instances +- Round-robin distribution for inference requests +- Broadcast operations (update_weights) to all instances +- Collecting results from broadcast with `wait_all()` + +Example: Multiple inference servers that can be called via round-robin or broadcast. + +### CLI Updates +Added `--mode` option to cluster_example.py: +- `--mode actors`: Run only actor examples (no Docker required) +- `--mode jobs`: Run only cluster job examples (requires Docker) +- `--mode all`: Run all examples (default) + +**Run examples**: +```bash +cd lib/fluster +uv run python examples/cluster_example.py --mode actors +``` + +--- + +## Test Commands Summary + +```bash +cd lib/fluster + +# Stage 1: Minimal e2e +buf generate +uv run pytest tests/actor/test_actor_e2e.py -v + +# Stage 2: Resolver +uv run pytest tests/actor/test_resolver.py -v + +# Stage 3: Endpoint registry +uv run pytest tests/cluster/controller/test_endpoint_registry.py -v + +# Stage 4: ClusterResolver +uv run pytest tests/actor/test_cluster_resolver.py -v + +# Stage 5: Pool +uv run pytest tests/actor/test_actor_pool.py -v + +# Stage 6: GcsResolver +uv run pytest tests/actor/test_gcs_resolver.py -v + +# All actor tests +uv run pytest tests/actor/ -v +``` diff --git a/lib/fluster/docs/fray-zero.md b/lib/fluster/docs/fray-zero.md new file mode 100644 index 0000000000..123093f207 --- /dev/null +++ b/lib/fluster/docs/fray-zero.md @@ -0,0 +1,1277 @@ +# Fray-Zero Design + +As discussed in +[Fray Presentation](https://docs.google.com/presentation/d/1qgPGK7zYmiOSOn70W-rPIrcF7vuuHgg68Qr3e5UIuxw/edit) and +[Fray/RPC Design](https://docs.google.com/document/d/1UteX9nD9obY5ypV2o8KbHF72xwX1JezseRE5uLFBulk/edit?tab=t.0), +we think it's a good time to "grug-ify" our RPC and clustering system while +moving off of the complex and fragile Ray. + +## Original Design and Progress + +Our original Ray challenges doc +[Ray Infrastructure Challenges](https://docs.google.com/document/d/1gtCz3aN2q72ZF-BNK_nKHCjuS9CT88QSH5vmkTktwWQ/edit?tab=t.0#heading=h.9k9q6db2omrh) +and migration plan +[Ray Migration](https://docs.google.com/document/d/1r-YKwxMxD8dJPKFmQrdJdIsfVvrLnIPo8RnlANbwM5o/edit?tab=t.0) +outlined a mostly "drop-in" approach to replacing Ray with Fray. We would +introduce new wrapper APIs which hid Ray's usage (fray.cluster and fray.job), +move most direct usage into a data library (zephyr), and reduce the complexity +of our dependencies to remove the need for complex venv creation logic. + +We've executed most of this plan unaltered: + +* We've dramatically reduced the scope of "raw" Ray usage by using Zephyr for our data pipelines +* We've abstracted our Ray usage behind 2 APIs: fray.cluster and fray.job. +* We've dramatically pruned our "extra" Python dependencies, removing the need for on-demand Python environments + +We still use Ray for our cluster and job RPC management, but for local tests and +speedruns we now use a local implementation which completely avoids the use of +Ray. + +## Refining Our Vision + +In the process of implementing the final set of changes to implement our own Ray +compatible backend and clustering system, we started to ask "is this actually +what we want?". That is, while we feel our work up to this point has been +building useful abstractions around Ray, we're now confronted with the +complexity of re-implementing the bulk of Ray, and _only then_ proceeding to +revisit the API decisions and try to incrementally improve them. + +While acknowledging the value of incremental changes, we realized we're loathe +to create the same mess we started with. And since we've been so productive at +gradually refactoring our codebase to migrate off of Ray, we want to consider +whether we can push _further_ that direction: instead of recreating Ray, what +simpler primitives can we create that we can use in Marin _in place of Ray_? + +## Requirements + +We have a few job types in Marin, which in general run independently of each +other (that is, they are launched independently and don't talk to a different +job type). + +### 1. Training + +Our simplest case, training jobs run on TPU slices, and communicate entirely via +JAX collectives after startup. Moreover training tasks run for a long time, so +startup overhead is not a problem. + +### 2. Data Processing + +Data processing wants to flexibly be able to use CPU resources. With our latest +Zephyr refactoring, there is minimal communication between tasks, as all work is +staged to object storage. + +### 3. Inference/Evaluation + +Inference is a combination of training and data processing. We want to start N +inference servers (potentially on slices), and then dispatch work to that pool +via one or more CPU jobs. As we will typically be inference-bound, it's likely a +single CPU task will be sufficient for our initial work. + +### 4. RL + +RL has 2-3 different jobs which use accelerators: + +* Training worker +* Rollout workers (environments) +* Inference workers + +Internally these jobs use JAX as expected, but they also need to communicate +metadata about the progress and new checkpoints via actors which are shared +across all processes (CurriculumActor and TrainingActor). + +### 5. (Flex) Multi-slice Training + +For flex-multi-slice, we have multiple training jobs, each running on a separate +TPU slice. Slices can come and go over time, but we'd like to take advantage of +slices when they are available. The simplest way to express this is to move our +workers and leader into actors, and dispatch work from the leader: + +```python +# leader.py +def multislice_train(): + while True: + slice_workers = build_multi_slice() # build a cluster based on available slices + slice_workers.train_step() # or train_n_steps, or whatever + +# worker.py +class Worker: + def __init__(self): + self.weights = load_weights() + self.data = load_data(slice_id) + self.model = create_model([peer_slices]) + + def train_step(self): + self.model.step(next(self.data)) +``` + +## Design + +Our plan is to focus on doing a few simple things well, instead of many things +poorly. We'll explicitly break out goals into 2 parts, a _job management system_ +which excels at managing reservations and booting up UV environments on a set of +VMs, and a _RPC system_ which makes it easy for users to setup _their own_ +communication between tasks. + +### Job Management + +In Marin a typical workflow involves starting a controller job which runs +Executor, which then spawns one or more sub-jobs. The job management system is +responsible for managing the lifecycle of these jobs, including: + +* Reserving resources +* Booting up UV environments +* Managing auto-scaling +* Managing task failures/restarts/termination +* Providing access to task metrics/logs +* Fairly sharing resources across users + +The cluster manager manages workloads for all users and all regions. It manages +a set of VMs as well as auto-scaling VM requests in response to demand. To +request resources from the cluster users create _reservations_ which specify the +minimum and maximum set of resources required for a set of jobs. + +For the purposes of locality (we want RL inference workers to run in the same +cluster as the training workers), a user can also define a _job group_ which +specifies a co-located set of job requests. + +A job template defines: + * _entrypoint_ - a Python callable with optional arguments + * _environment_ - pip packages, extras, and environment variables + * _resources_ - accelerators, CPU, memory requirements + +A user may request multiple instantiations of a job template, and a friendly +name prefix for grouping (e.g. `{user}/rl/rollout/`). Every job receives +a globally unique ID. + +**Namespaces**: When a user starts a top-level job, the cluster creates a new +namespace derived from the job name. Sub-jobs launched from within that job +inherit the same namespace. Actor names are scoped to namespaces, so actors in +different namespaces don't collide. + +**Environment Variables**: The cluster injects standard environment variables +into every job process: + +* `FRAY_JOB_ID` - Unique identifier for this job instance +* `FRAY_JOB_NAME` - User-provided job name +* `FRAY_NAMESPACE` - Namespace for actor registration/lookup +* `FRAY_CLUSTER_ADDRESS` - Address of the cluster controller + +**Health Monitoring**: The cluster monitors job health via: + +* Process status monitoring (exit codes, crashes) +* Optional health check endpoint - jobs can expose a health route that the + cluster pings periodically + +**Lifecycle Management**: When a parent job terminates, the cluster +automatically terminates all child jobs spawned by that parent. This ensures +cleanup happens automatically without requiring explicit shutdown coordination. + +The cluster system provides access to job logs, and handles task failures and +pre-emption recovery. It attempts to minimize the startup time of tasks by +reusing warm environments and re-using warm workers when possible. The cluster +system _will not_ attempt to reuse warm processes, thus the target startup time +for tasks is expected to be <10s for a warm image. + +### Metadata Service + +The cluster provides a metadata service that maps actor names to endpoints +(job ID + address + port). Because the cluster owns both the jobs and the +metadata service, it can: + +* Automatically clean up mappings when jobs terminate +* Atomically update mappings when jobs restart after failure +* Garbage collect stale entries + +Multiple actors can register under the same name. The metadata service maintains +a list of all actors for each name. When a job terminates, all actors registered +by that job are automatically removed from any name mappings they participate in. + +The metadata service is internal to the cluster. Clients interact with it +through a **Resolver**. + +### Resolver and ActorPool + +A Resolver provides actor discovery and connection management. It maps actor +names to `ActorPool` instances that handle load balancing, broadcasting, and +failure recovery. + +```python +# Create resolver (explicitly, from cluster or environment) +resolver = ClusterResolver(cluster) # uses cluster's metadata service +resolver = FixedResolver({"inference": "localhost:8080"}) # for testing + +# Look up actors - always returns an ActorPool +pool = resolver.lookup("inference_pool") + +# Single call (round-robin to one actor) +result = pool.call().predict(x) + +# Broadcast to all actors +futures = pool.broadcast().predict(x) +results = [f.result() for f in futures] # Returns all results, including failures + +# Query pool state +print(f"Pool has {pool.size} actors") +pool.wait_for_size(4, timeout=60.0) # Wait until N actors available +``` + +The resolver handles: + +* **Worker resolution**: Maps actor names to current addresses +* **Fault tolerance**: Re-resolves addresses when workers fail; retries on transient failures +* **Load balancing**: Round-robin distribution via `pool.call()` +* **Fan-out**: Broadcast to all actors via `pool.broadcast()` + +### ActorServer + +ActorServer lets users expose Python classes as RPC services without writing +protos. Each job runs at most one ActorServer (since it binds to a port). +Serialization uses pickle for arguments and return values. + +Actor methods receive an `ActorContext` as their first argument, which provides +access to the cluster, resolver, and job information - enabling actors to call +other actors. + +```python +class InferenceActor: + def __init__(self): + self.model = load_model() + + def predict(self, ctx: ActorContext, x): + # ctx.resolver available if this actor needs to call other actors + return self.model(x) + +cluster = current_cluster() +server = ActorServer(cluster) +server.register("inference_pool", InferenceActor()) # Register under pool name +server.serve() # blocks, handling requests +``` + +When a job hosting an ActorServer fails and restarts, the cluster automatically +updates the metadata mappings. Clients holding pool references will transparently +reconnect to the new instance on their next call. + +### WorkerPool + +For task dispatch patterns (like Zephyr), we provide a WorkerPool abstraction +built on top of actors and jobs. WorkerPool manages a set of stateless workers +that can execute arbitrary callables. + +```python +# Create a pool of workers +pool = WorkerPool( + cluster=current_cluster(), + num_workers=10, + resources=ResourceConfig.with_cpu(cpu=2, memory="4GB"), +) + +# Submit tasks - returns futures +futures = [pool.submit(process_shard, shard) for shard in shards] + +# Gather results +results = [f.result() for f in futures] + +# Cleanup +pool.shutdown() +``` + +Internally, WorkerPool: +1. Launches N worker jobs, each running an actor that accepts callables +2. Distributes work via round-robin through the resolver +3. Handles retries on worker failure (stateless workers allow retry on any worker) + +### Use Case Implementations + +* **Training**: Single cluster job launch, no actors needed +* **Inference**: Launch N server jobs, each registers under same actor name, clients use `pool.call()` for load balancing +* **Zephyr**: Uses WorkerPool to dispatch shard processing tasks +* **RL**: Coordinator job hosts shared actors (curriculum, weights), workers connect via resolver + +```python +# rl_controller.py +cluster = current_cluster() +resolver = ClusterResolver(cluster) + +# Launch coordinator hosting shared actors +cluster.launch(coordinator_job) + +# Launch workers - they connect to coordinator via resolver +cluster.launch(train_worker_job) +for i in range(num_rollout_workers): + cluster.launch(rollout_worker_job(i)) + +# Workers internally do: +# resolver = ClusterResolver(cluster) +# curriculum = resolver.lookup("curriculum") +# curriculum.call().get_lesson() +``` + + +## Local Development + +For local development, all components have in-process implementations that +preserve the same interfaces. Code written for production works locally +without modification. + +* Jobs run as threads in the current process +* Actors are called directly (in-process) but serialization still occurs +* The same code paths execute, catching serialization bugs early + +```python +# Works identically in local and production +cluster = current_cluster() # Returns LocalCluster when FRAY_CLUSTER_SPEC unset +resolver = ClusterResolver(cluster) + +server = ActorServer(cluster) +server.register("my_actor", MyActor()) +server.serve_background() # Spawns a thread locally + +pool = resolver.lookup("my_actor") +result = pool.call().process(data) # Serializes args, calls in-process +``` + +## Detailed Design + +This section provides the complete Python interfaces for the Fray system. The +architecture consists of four main components: + +1. **Cluster** - Job lifecycle management (launch, monitor, terminate). Owns a + Metadata service for actor registration. +2. **Resolver** - Actor discovery and connection management. Maps actor names + to ActorPools and handles reconnection on failure. +3. **ActorServer** - Hosts actor instances, handles RPC calls, registers with + the cluster's metadata service. +4. **WorkerPool** - High-level task dispatch abstraction built on actors. + +### Core Types (should be defined via proto files) + +These are just for reference, the actual types should be defined via proto files for objects that will be serialized. +JobState etc types are shared between the cluster controller and worker. + +```python +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, Protocol, Sequence, TypeVar +from enum import StrEnum + + +# Type aliases for clarity +JobId = str +ActorId = str +ReservationId = str +Namespace = str + + +class JobState(StrEnum): + """Status of a job in the cluster.""" + PENDING = "pending" # Waiting for resources + RUNNING = "running" # Currently executing + SUCCEEDED = "succeeded" # Completed successfully + FAILED = "failed" # Terminated with error + STOPPED = "stopped" # Manually terminated + + +@dataclass +class JobInfo: + """Information about a job's current state.""" + job_id: JobId + name: str + state: JobState + error_message: str | None = None + start_time: float | None = None + end_time: float | None = None +``` + +### Resource Configuration + +```python +@dataclass +class CpuConfig: + """CPU-only resource configuration.""" + kind: str = "cpu" + variant: str = "default" + + +@dataclass +class TpuConfig: + """TPU accelerator configuration.""" + kind: str = "tpu" + variant: str # e.g., "v5litepod-4", "v5litepod-16" + count: int = 1 + + +@dataclass +class GpuConfig: + """GPU accelerator configuration.""" + kind: str = "gpu" + variant: str # e.g., "a100-40gb", "h100" + count: int = 1 + + +DeviceConfig = CpuConfig | TpuConfig | GpuConfig + + +@dataclass +class ResourceConfig: + """Resource requirements for a job.""" + device: DeviceConfig = field(default_factory=CpuConfig) + cpu: int = 1 + memory: str = "2GB" + replicas: int = 1 + + @classmethod + def with_cpu(cls, cpu: int = 1, memory: str = "2GB") -> "ResourceConfig": + """Create CPU-only resource config.""" + return cls(device=CpuConfig(), cpu=cpu, memory=memory) + + @classmethod + def with_tpu( + cls, + tpu_type: str, + slice_count: int = 1, + ) -> "ResourceConfig": + """Create TPU resource config. + + Args: + tpu_type: TPU variant (e.g., "v5litepod-4", "v5litepod-16") + slice_count: Number of TPU slices + """ + return cls( + device=TpuConfig(variant=tpu_type), + replicas=slice_count, + ) + + @classmethod + def with_gpu( + cls, + gpu_type: str, + count: int = 1, + cpu: int = 4, + memory: str = "16GB", + ) -> "ResourceConfig": + """Create GPU resource config.""" + return cls( + device=GpuConfig(variant=gpu_type, count=count), + cpu=cpu, + memory=memory, + ) +``` + +### Entrypoint and Environment + +```python +@dataclass +class Entrypoint: + """Job entrypoint specification. + + Jobs are started by invoking a Python callable with the provided arguments. + The callable and arguments must be picklable. + """ + callable: Callable + args: tuple = () + kwargs: dict = field(default_factory=dict) + + @classmethod + def from_callable( + cls, + fn: Callable, + args: tuple = (), + kwargs: dict | None = None, + ) -> "Entrypoint": + """Create entrypoint from a callable. + + Args: + fn: Python callable to execute + args: Positional arguments to pass + kwargs: Keyword arguments to pass + """ + return cls(callable=fn, args=args, kwargs=kwargs or {}) + + +@dataclass +class EnvironmentConfig: + """Environment configuration for job execution.""" + extras: list[str] = field(default_factory=list) + pip_packages: list[str] = field(default_factory=list) + env_vars: dict[str, str] = field(default_factory=dict) + + @classmethod + def create( + cls, + extras: list[str] | None = None, + pip_packages: list[str] | None = None, + env_vars: dict[str, str] | None = None, + ) -> "EnvironmentConfig": + """Create environment config with optional overrides.""" + return cls( + extras=extras or [], + pip_packages=pip_packages or [], + env_vars=env_vars or {}, + ) +``` + +### Job Request and Group Configuration + +```python +@dataclass +class JobRequest: + """Request to launch a job on the cluster.""" + name: str + resources: ResourceConfig + entrypoint: Entrypoint + environment: EnvironmentConfig = field(default_factory=EnvironmentConfig) + + +@dataclass +class JobGroupConfig: + """Configuration for a group of co-located jobs. + + Jobs within a group are scheduled together with network locality + guarantees. Use this for RL training where inference workers should + be close to training workers. + """ + name: str + namespace: Namespace | None = None + same_region: bool = True + same_zone: bool = False # Stricter locality for low-latency RPCs + + +@dataclass +class ReservationConfig: + """Resource reservation for elastic job scheduling.""" + min_resources: ResourceConfig + max_resources: ResourceConfig + priority: int = 0 + ttl_seconds: int = 3600 + preemptible: bool = True + + +@dataclass +class ReservationInfo: + """Status of an active reservation.""" + reservation_id: ReservationId + allocated_resources: ResourceConfig + pending_resources: ResourceConfig + jobs: list[JobId] + expires_at: float +``` + +### Cluster Interface + +The Cluster is responsible for job lifecycle management. It does not +provide task execution primitives - jobs are complete processes that run +independently. + +```python +class Cluster(Protocol): + """Abstract interface for cluster job scheduling. + + The Cluster manages job lifecycle: launching, monitoring, and terminating + jobs. Jobs are independent processes - there is no distributed task + execution through the Cluster interface. + + Implementations: FrayCluster, LocalCluster + """ + + def launch(self, request: JobRequest) -> JobId: + """Launch a job on the cluster. + + The job runs as an independent process with its own lifecycle. + Use the Resolver to discover actors within jobs. + """ + ... + + def monitor(self, job_id: JobId) -> JobInfo: + """Stream logs from a running job, blocking until completion.""" + ... + + def poll(self, job_id: JobId) -> JobInfo: + """Get current status of a job without blocking.""" + ... + + def terminate(self, job_id: JobId) -> None: + """Terminate a running job. + + Also terminates any child jobs spawned by this job. + """ + ... + + def list_jobs(self) -> list[JobInfo]: + """List all jobs managed by this cluster.""" + ... + + def wait( + self, + job_ids: JobId | Sequence[JobId], + raise_on_failure: bool = False, + ) -> JobInfo | list[JobInfo]: + """Block until job(s) complete. + + Args: + job_ids: Single job ID or sequence of job IDs + raise_on_failure: If True, raise exception if any job fails + """ + ... + + def create_reservation(self, config: ReservationConfig) -> ReservationId: + """Create a resource reservation for upcoming jobs.""" + ... + + def release_reservation(self, reservation_id: ReservationId) -> None: + """Release a reservation, freeing its resources.""" + ... + + def get_reservation(self, reservation_id: ReservationId) -> ReservationInfo: + """Get current status of a reservation.""" + ... + + def launch_group( + self, + requests: Sequence[JobRequest], + group: JobGroupConfig, + reservation_id: ReservationId | None = None, + ) -> list[JobId]: + """Launch a group of co-located jobs atomically. + + All jobs in the group are scheduled together with locality + guarantees. If any job cannot be scheduled, none are started. + """ + ... + + @property + def namespace(self) -> Namespace: + """The namespace for this cluster connection.""" + ... + + +def current_cluster() -> Cluster: + """Get the current cluster from environment. + + Reads FRAY_CLUSTER_SPEC environment variable: + - "local" or unset: LocalCluster (in-process execution) + - "fray://host:port": FrayCluster connecting to controller + + Jobs inherit namespace from FRAY_NAMESPACE environment variable. + """ + ... +``` + +### Metadata Service (Internal) + +The Metadata service is internal to the Cluster. It maps actor names to +endpoints (job ID + address + port). The Cluster owns it and automatically: + +* Cleans up entries when jobs terminate +* Maintains list of all actors registered under each name +* Supports multiple actors per name for load balancing + +Clients do not interact with Metadata directly - they use a Resolver. + +```python +@dataclass +class ActorEndpoint: + """Endpoint information for a registered actor.""" + actor_id: ActorId + name: str + job_id: JobId + namespace: Namespace + metadata: dict[str, str] = field(default_factory=dict) + + +class Metadata(Protocol): + """Internal actor registration service owned by the Cluster.""" + + def register( + self, + name: str, + job_id: JobId, + metadata: dict[str, str] | None = None, + ) -> ActorId: + """Register an actor endpoint. + + Multiple actors can register under the same name. All registrations + are tracked and returned by lookup_all(). + """ + ... + + def unregister(self, actor_id: ActorId) -> None: + """Remove an actor registration.""" + ... + + def lookup(self, name: str) -> ActorEndpoint | None: + """Look up one actor by name (returns None if not found).""" + ... + + def lookup_all(self, name: str) -> list[ActorEndpoint]: + """Look up all actors registered under a name.""" + ... + + def list_actors(self, prefix: str = "") -> list[ActorEndpoint]: + """List all registered actors, optionally filtered by name prefix.""" + ... +``` + +### Resolver and ActorPool + +A Resolver provides actor discovery and returns ActorPool instances for +managing calls to one or more actors. + +```python +T = TypeVar("T") + + +class ActorPool(Generic[T]): + """Pool of actors registered under a common name. + + Provides load-balanced calls, broadcasting, and pool state queries. + All calls go through the resolver for automatic failure handling. + """ + + @property + def size(self) -> int: + """Current number of actors in the pool.""" + ... + + @property + def endpoints(self) -> list[ActorEndpoint]: + """Current list of actor endpoints (snapshot).""" + ... + + def wait_for_size( + self, + min_size: int, + timeout: float = 60.0, + ) -> None: + """Block until pool has at least min_size actors. + + Useful during startup when waiting for workers to register. + + Raises: + TimeoutError: If timeout expires before min_size reached + """ + ... + + def call(self) -> T: + """Get a handle for single-actor calls (round-robin). + + Returns a proxy that routes method calls to one actor in the pool, + cycling through actors on successive calls. + + Example: + pool = resolver.lookup("inference") + result = pool.call().predict(x) # Routes to one actor + """ + ... + + def broadcast(self) -> "BroadcastHandle[T]": + """Get a handle for broadcasting to all actors. + + Returns a proxy that calls all actors in parallel and collects + results. Failed calls return exceptions in the results list. + + Example: + pool = resolver.lookup("workers") + futures = pool.broadcast().shutdown() + results = [f.result() for f in futures] # May contain exceptions + """ + ... + + +class BroadcastHandle(Generic[T]): + """Handle for broadcasting method calls to all actors in a pool.""" + + def __getattr__(self, method_name: str) -> Callable[..., list["ActorFuture"]]: + """Broadcast a method call to all actors. + + Returns a list of futures, one per actor. Each future resolves to + the result or exception from that actor's call. + """ + ... + + +class Resolver(Protocol): + """Actor discovery and connection management. + + Resolvers map actor names to ActorPools and manage reconnection on failure. + """ + + def lookup(self, name: str) -> ActorPool: + """Look up actors by name and return a pool. + + Always returns a pool, even if empty. Use pool.wait_for_size() + to block until actors are available. + """ + ... + + +class ClusterResolver(Resolver): + """Resolver backed by a Cluster's Metadata service. + + This is the standard resolver for production use. It queries the + Cluster's metadata service and handles reconnection when actors + restart. + """ + + def __init__(self, cluster: Cluster): + ... + + +class FixedResolver(Resolver): + """Resolver with fixed actor addresses. + + Useful for testing or when connecting to known endpoints. + """ + + def __init__(self, addresses: dict[str, str | list[str]]): + """Create resolver with fixed addresses. + + Args: + addresses: Mapping of actor names to addresses. + Values can be a single address or list of addresses. + + Example: + resolver = FixedResolver({ + "inference": "localhost:8080", + "workers": ["localhost:8081", "localhost:8082"], + }) + """ + ... +``` + +### Actor System + +The Actor system provides RPC-based communication between jobs: + +- **ActorServer**: Hosts actor instances, registers with the cluster's Metadata service +- **ActorContext**: Passed to actor methods, enables actors to call other actors +- **ActorFuture**: Represents an in-flight async call + +**Scope notes**: Streaming responses are not supported. Cancellation is not +supported. RPC tracing is not included in the initial implementation. + +```python +@dataclass +class ActorContext: + """Context passed to actor methods as first argument. + + Enables actors to call other actors and access cluster services. + """ + cluster: Cluster + resolver: Resolver + job_id: JobId + namespace: Namespace + + @classmethod + def from_environment(cls) -> "ActorContext": + """Create context from FRAY_* environment variables.""" + ... + + +class ActorFuture(Protocol[T]): + """Future representing an in-flight actor method call.""" + + def result(self, timeout: float | None = None) -> T: + """Block until result is available. + + Raises the remote exception if the call failed. + """ + ... + + def done(self) -> bool: + """Check if the call has completed.""" + ... + + def exception(self) -> BaseException | None: + """Get the exception if the call failed, None if succeeded or pending.""" + ... + + +class ActorServer: + """Server for hosting actors and handling RPC calls. + + Each job should run at most one ActorServer since it binds to a port. + The server reads FRAY_JOB_ID from environment to associate registrations + with the current job. + + Usage: + cluster = current_cluster() + server = ActorServer(cluster) + server.register("my_actor", MyActor(config)) + server.serve() # Blocks, serving requests + """ + + def __init__( + self, + cluster: Cluster, + host: str = "0.0.0.0", + port: int = 0, # 0 = auto-assign + ): + """Initialize actor server. + + Args: + cluster: Cluster for actor registration + host: Host address to bind + port: Port to bind (0 for auto-assignment) + """ + ... + + @property + def address(self) -> str: + """The server's bound address (host:port).""" + ... + + def register( + self, + name: str, + actor: Any, + metadata: dict[str, str] | None = None, + ) -> ActorId: + """Register an actor instance with the server. + + The actor is registered with the cluster's Metadata service. + Multiple actors (across multiple jobs) can register under the + same name for load balancing. + + Actor methods should accept ActorContext as their first argument: + + class MyActor: + def process(self, ctx: ActorContext, data: dict) -> dict: + # ctx.resolver allows calling other actors + other = ctx.resolver.lookup("other_actor") + return other.call().transform(data) + + Args: + name: Name for lookup (scoped to namespace) + actor: Actor instance (any object with callable methods) + metadata: Optional key-value metadata for discovery + + Returns: + Unique actor ID + """ + ... + + def unregister(self, name: str) -> None: + """Unregister an actor.""" + ... + + def serve(self) -> None: + """Start serving requests (blocks indefinitely).""" + ... + + def serve_background(self) -> None: + """Start serving in a background thread.""" + ... + + def shutdown(self, grace_period: float = 5.0) -> None: + """Gracefully shutdown the server.""" + ... +``` + +### WorkerPool + +WorkerPool provides Ray-like task dispatch for stateless workloads. It manages +a pool of worker jobs that can execute arbitrary callables. + +```python +class WorkerPool: + """Pool of stateless workers for task dispatch. + + Creates worker jobs that can execute arbitrary callables. Workers are + stateless - if a worker fails, tasks can be retried on any other worker. + + Usage: + pool = WorkerPool( + cluster=current_cluster(), + num_workers=10, + resources=ResourceConfig.with_cpu(cpu=2, memory="4GB"), + ) + + # Submit tasks + futures = [pool.submit(process_shard, shard) for shard in shards] + + # Wait for results + results = [f.result() for f in futures] + + pool.shutdown() + """ + + def __init__( + self, + cluster: Cluster, + num_workers: int, + resources: ResourceConfig, + environment: EnvironmentConfig | None = None, + name_prefix: str = "worker", + ): + """Create a worker pool. + + Args: + cluster: Cluster for launching worker jobs + num_workers: Number of worker jobs to launch + resources: Resource requirements per worker + environment: Optional environment config for workers + name_prefix: Prefix for worker job names + """ + ... + + @property + def size(self) -> int: + """Number of workers currently available.""" + ... + + def wait_for_workers( + self, + min_workers: int | None = None, + timeout: float = 60.0, + ) -> None: + """Wait for workers to become available. + + Args: + min_workers: Minimum workers required (default: all workers) + timeout: Maximum time to wait + """ + ... + + def submit( + self, + fn: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> ActorFuture[T]: + """Submit a task for execution. + + The callable and arguments must be picklable. Tasks are distributed + round-robin across available workers. + + Args: + fn: Callable to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Future that resolves to the function's return value + """ + ... + + def map( + self, + fn: Callable[[Any], T], + items: Sequence[Any], + ) -> list[ActorFuture[T]]: + """Map a function over items in parallel. + + Args: + fn: Function to apply to each item + items: Items to process + + Returns: + List of futures, one per item + """ + ... + + def shutdown(self, wait: bool = True) -> None: + """Shutdown the worker pool. + + Args: + wait: If True, wait for pending tasks to complete + """ + ... +``` + +### Local Development + +For local development, all components have in-process implementations that +preserve the same interfaces. Jobs run as threads, actors are called directly +(but with serialization), and the code paths are identical to production. + +```python +class LocalCluster(Cluster): + """Local cluster for development and testing. + + Runs jobs as threads in the current process. Includes an embedded + Metadata service. Child jobs are automatically terminated when + parent jobs exit. + """ + + def __init__(self): + ... + + +class LocalActorServer(ActorServer): + """In-process actor server for local development. + + Actors are called directly (no network) but arguments and return + values are still serialized/deserialized to catch serialization bugs. + """ + ... + + +# Example: local development workflow +cluster = LocalCluster() +resolver = ClusterResolver(cluster) + +# Server side (same code as production) +server = ActorServer(cluster) +server.register("my_actor", MyActor()) +server.serve_background() + +# Client side (same code as production) +pool = resolver.lookup("my_actor") +result = pool.call().process(data) +``` + +### Zephyr Integration + +Zephyr uses WorkerPool internally for distributed execution. The FrayBackend +launches worker jobs and dispatches shard processing tasks to them. + +```python +class FrayBackend(Backend): + """Zephyr backend using Fray for distributed execution. + + Launches a WorkerPool and dispatches shard processing tasks. + Workers are stateless - all coordination happens through object storage. + """ + + def __init__( + self, + cluster: Cluster | None = None, # None = current_cluster() + max_parallelism: int = 100, + memory_per_worker: str = "2GB", + ): + ... + + def execute(self, dataset: Dataset) -> list[Any]: + """Execute a dataset pipeline on Fray workers. + + 1. Plans the dataset into shards + 2. Launches a WorkerPool with max_parallelism workers + 3. Submits each shard for processing + 4. Collects and returns results + """ + ... +``` + +## Example Usage + + + +### 1. Training (Simple Job) + +```python +def train_main(): + # ... training logic ... + pass + +# Launch a single job +cluster = current_cluster() +job_id = cluster.launch( + JobRequest( + name="llama-training", + resources=ResourceConfig.with_tpu("v5litepod-16"), + entrypoint=Entrypoint.from_callable(train_main), + environment=EnvironmentConfig.create(extras=["tpu"]) + ) +) + +# Wait for completion +cluster.wait(job_id, raise_on_failure=True) +``` + +### 2. Inference (Client/Server) + +```python +# Server: Host an actor +def server_main(): + server = ActorServer(current_cluster()) + server.register("inference", InferenceModel()) # Registers with cluster metadata + server.serve() + +# Client: Dispatch requests +def run_client(): + resolver = ClusterResolver(current_cluster()) + pool = resolver.lookup("inference") # Discovers all actors named "inference" + + # Load balance requests across the pool + results = pool.call().predict(batch_of_prompts) +``` + +### 3. Zephyr (Task Dispatch) + +```python +# Create a worker pool for map/reduce style tasks +pool = WorkerPool(current_cluster(), num_workers=100, resources=ResourceConfig.with_cpu()) + +# Dispatch tasks - pool handles distribution +futures = [pool.submit(process_shard, shard) for shard in shards] +results = [f.result() for f in futures] +pool.shutdown() +``` + +### 4. RL (Coordinated Workers) + +```python +# Coordinator: Hosts shared state actors +def coordinator_main(): + server = ActorServer(current_cluster()) + server.register("curriculum", CurriculumActor()) + server.register("weights", WeightStore()) + server.serve() + +# Worker: Connects to shared actors +def worker_main(): + resolver = ClusterResolver(current_cluster()) + curriculum = resolver.lookup("curriculum") + weights = resolver.lookup("weights") + + while True: + # Fetch latest state and data + w = weights.call().get_weights() + task = curriculum.call().get_next_task() + + # ... perform rollout ... + curriculum.call().report_result(result) + +# Launch as a co-located group to ensure low latency +cluster.launch_group( + [ + JobRequest(name="coordinator", entrypoint=Entrypoint.from_callable(coordinator_main), ...), + JobRequest(name="worker-1", entrypoint=Entrypoint.from_callable(worker_main), ...), + # ... + ], + group=JobGroupConfig(same_region=True) +) +``` + +### 5. Multi-Slice (Elastic Coordination) + +```python +class BarrierActor: + def wait_for_barrier(self, ctx: ActorContext, slice_id, step): + # ... verify all slices have reached step ... + pass + +def slice_worker(slice_id): + coordinator = ClusterResolver(current_cluster()).lookup("coordinator") + while True: + # Sync then train + coordinator.call().wait_for_barrier(slice_id, step) + train_step() + +# Elastic Scheduling +reservation = cluster.create_reservation( + ReservationConfig(min_resources=..., max_resources=...) +) + +# Launch coordinator and initial slices... +# Cluster auto-scales within reservation as resources become available +``` + +### 6. Local Development + +Code behaves identically locally. `current_cluster()` returns `LocalCluster` when no cluster address is configured. + +```python +# No code changes needed. +cluster = current_cluster() # -> LocalCluster +server = ActorServer(cluster) # -> Spawns threads +``` diff --git a/lib/fluster/docs/impl-recipe.md b/lib/fluster/docs/impl-recipe.md new file mode 100644 index 0000000000..6274b6f747 --- /dev/null +++ b/lib/fluster/docs/impl-recipe.md @@ -0,0 +1,125 @@ +# Implementation Recipe Template + +This document defines a structured workflow for implementing a feature or +component in the fluster-zero system. Each feature follows this recipe to ensure +consistent, high-quality delivery. + +### Phase 1: Research + +**Goal**: Understand the context and existing code before making changes. +**Deliverable**: List of relevant files and understanding of integration points. + +Tasks: +- [ ] Read the feature description from the implementation doc provided +- [ ] Use the explore agent to find and review related code or protocol files +- [ ] Check for test patterns in existing tests +- [ ] Note any dependencies on other features + +--- + +### Phase 2: Evaluation + +**Goal**: Identify gaps between the spec and what's needed for implementation. +**Deliverable**: List of gaps, questions, and required adjustments. + +Tasks: +- [ ] Compare spec against existing code—what's already done? +- [ ] Identify missing types, imports, or proto messages +- [ ] Flag any ambiguities in the spec that need clarification +- [ ] Determine if the spec's code samples need adjustment for project conventions +- [ ] Verify dependencies are available (proto generated, types exist, etc.) + +--- + +### Phase 3: Sub-task Breakdown +**Goal**: Create a concrete list of atomic implementation tasks. +**Deliverable**: Numbered list of sub-tasks with clear acceptance criteria. + +Sub tasks should be independently testable or verifiable. + +Prefer "spiral" sub-tasks over "linear" ones: changes that touch multiple files but test a single concept, which is then expanded. + +Example sub-tasks for a new RPC method: + +* Add new proto message and stub implementation in the server, write a test which wires up client and server +* Implement initial logic in the server, update test +* Add client-level functionality e.g. tracing, update tests +* etc + +--- + +### Phase 4: Incremental Execution + +**Goal**: Implement sub-tasks one at a time, testing after each. + +Workflow: +1. Pick the next sub-task +2. Implement the change +3. Run relevant tests immediately +4. Fix any failures before proceeding +5. Repeat until all sub-tasks complete + +Rules: +- Never proceed with failing tests +- Keep changes minimal—don't refactor unrelated code + +**Deliverable**: Working implementation with all tests passing. + +--- + +### Phase 5: Validation + +**Goal**: Verify the stage is complete and correct. + +Tasks: +- [ ] All new tests pass +- [ ] Existing tests still pass +- [ ] Type checking passes: `uv run pyrefly` +- [ ] Linting passes: `uv run python infra/pre-commit.py --all-files` + +**Deliverable**: Clean test run, passing type checks, lint-clean code. + +--- + +### Phase 6: Code Review + +**Goal**: Get feedback from senior-code-reviewer agent before finalizing. + +Provide to reviewer: +- The feature and source implementation document +- Summary of changes made +- List of files modified/created +- Key design decisions and rationale +- Any deviations from the spec + +Address all feedback before proceeding. + +**Deliverable**: Reviewer-approved changes. + +Tasks: + +- [ ] Reviewer signed off on the changes + +--- + +### Phase 7: Commit + +**Goal**: Create a clean commit with descriptive message. + +Commit message format: +``` +[fluster] Implement {Stage Name} + +- {Brief description of what was added} +- {Key components implemented} +- {Tests added} + +Part of controller-v0 implementation. +``` + +Final checks: +- [ ] `git status` shows only relevant changes +- [ ] No debug code or temporary files included +- [ ] Commit message accurately describes changes + +YOU MUST COMMIT WITH GIT BEFORE CONSIDERING THIS PHASE COMPLETE diff --git a/lib/fluster/docs/offline-dashboard-design.md b/lib/fluster/docs/offline-dashboard-design.md new file mode 100644 index 0000000000..aaa7946aa0 --- /dev/null +++ b/lib/fluster/docs/offline-dashboard-design.md @@ -0,0 +1,444 @@ +# Offline Dashboard Mode Design + +> **Status**: Proposal (not yet implemented) + +This document outlines a design for adding "offline" mode to fluster's worker and controller dashboards, enabling them to render from serialized state snapshots rather than live APIs. + +## Motivation + +The codebase has a TODO for this (`cluster_example.py:26-27`): +```python +# TODO, consider having like a post-mortem view of the cluster state +# means cluster state should be serializable, cluster dashboard would always be a mapping over the state +``` + +**Use cases:** +- Post-mortem debugging (capture state at failure time) +- Testing dashboard rendering without spinning up infrastructure +- Demos with realistic data +- Reproducible bug reports + +## Current State + +### Worker Dashboard (`cluster/worker/dashboard.py`) +- REST endpoints call `WorkerServiceImpl` RPC methods +- `JobManager._jobs: dict[str, Job]` where `Job.to_proto()` → `cluster_pb2.JobStatus` +- Already has proto-serializable job state + +### Controller Dashboard (`cluster/controller/dashboard.py`) +- REST endpoints call `ControllerServiceImpl` and access `ControllerState` directly +- `ControllerState` contains: + - `_jobs: dict[JobId, ControllerJob]` + - `_workers: dict[WorkerId, ControllerWorker]` + - `_endpoints: dict[str, ControllerEndpoint]` + - `_actions: deque[ActionLogEntry]` + +### Gap: Internal Types Not in Protos +- `ControllerJob` (has retry tracking, gang_id not in `JobStatus`) +- `ControllerWorker` (has `running_jobs` set) +- `ActionLogEntry` (dashboard action log) +- No `ClusterSnapshot` message exists + +--- + +## Proposed Design + +### 1. Proto Schema: Separate Snapshots + +Add to `cluster.proto`: + +```protobuf +// Controller job with full tracking info +message ControllerJobSnapshot { + JobStatus status = 1; + int32 failure_count = 2; + int32 preemption_count = 3; + int32 max_retries_failure = 4; + int32 max_retries_preemption = 5; + string gang_id = 6; + int64 submitted_at_ms = 7; + LaunchJobRequest request = 8; +} + +message ControllerWorkerSnapshot { + string worker_id = 1; + string address = 2; + ResourceSpec resources = 3; + bool healthy = 4; + int32 consecutive_failures = 5; + int64 last_heartbeat_ms = 6; + repeated string running_job_ids = 7; +} + +message ActionLogEntrySnapshot { + int64 timestamp_ms = 1; + string action = 2; + string job_id = 3; + string worker_id = 4; + string details = 5; +} + +message ControllerSnapshot { + int64 captured_at_ms = 1; + repeated ControllerJobSnapshot jobs = 2; + repeated ControllerWorkerSnapshot workers = 3; + repeated Endpoint endpoints = 4; + repeated ActionLogEntrySnapshot actions = 5; + repeated string queue_order = 6; +} + +message WorkerSnapshot { + int64 captured_at_ms = 1; + string worker_id = 2; + repeated JobStatus jobs = 3; + map job_logs = 4; // job_id -> last N log lines +} + +message LogTail { + repeated LogEntry lines = 1; + int32 total_lines = 2; // Original count (for "showing X of Y") +} +``` + +### 2. Data Source Abstraction + +New file: `cluster/dashboard_data.py` + +```python +from typing import Protocol + +class DashboardDataSource(Protocol): + """Protocol for dashboard data access. + + Abstracts whether data comes from live APIs or serialized snapshots. + """ + + def get_stats(self) -> dict: ... + def get_jobs(self) -> list[dict]: ... + def get_workers(self) -> list[dict]: ... + def get_endpoints(self) -> list[dict]: ... + def get_actions(self, limit: int = 50) -> list[dict]: ... + def get_job_detail(self, job_id: str) -> dict | None: ... + + +class LiveControllerDataSource: + """Wraps ControllerServiceImpl + ControllerState.""" + + def __init__(self, service: ControllerServiceImpl, state: ControllerState): + self._service = service + self._state = state + + # Implements DashboardDataSource using existing code + + +class SnapshotControllerDataSource: + """Backed by ControllerSnapshot proto.""" + + def __init__(self, snapshot: cluster_pb2.ControllerSnapshot): + self._snapshot = snapshot + self._jobs = {j.status.job_id: j for j in snapshot.jobs} + self._workers = {w.worker_id: w for w in snapshot.workers} + + # Implements DashboardDataSource from snapshot data +``` + +### 3. Serialization Methods + +Add `to_snapshot()` to internal types in `state.py`: + +```python +# ControllerJob +def to_snapshot(self) -> cluster_pb2.ControllerJobSnapshot: + return cluster_pb2.ControllerJobSnapshot( + status=cluster_pb2.JobStatus( + job_id=self.job_id, + state=self.state, + # ... other fields + ), + failure_count=self.failure_count, + preemption_count=self.preemption_count, + gang_id=self.gang_id or "", + request=self.request, + ) + +# ControllerState +def to_snapshot(self) -> cluster_pb2.ControllerSnapshot: + with self._lock: + return cluster_pb2.ControllerSnapshot( + captured_at_ms=int(time.time() * 1000), + jobs=[j.to_snapshot() for j in self._jobs.values()], + workers=[w.to_snapshot() for w in self._workers.values()], + # ... + ) +``` + +### 4. Dashboard Changes + +Minimal changes to accept either live or snapshot data: + +```python +class ControllerDashboard: + def __init__( + self, + service: ControllerServiceImpl | None = None, + data_source: DashboardDataSource | None = None, + host: str = "0.0.0.0", + port: int = 8080, + ): + if data_source: + self._data_source = data_source + elif service: + self._data_source = LiveControllerDataSource(service, service._state) + else: + raise ValueError("Either service or data_source required") +``` + +--- + +## Style Alignment + +Worker and controller dashboards have different primary colors: +- Controller: Blue (`#3498db`) +- Worker: Green (`#4CAF50`) + +**Recommendation**: Keep distinct colors (helps users know which dashboard they're viewing), but align structure using CSS variables: + +```css +:root { + /* Shared status colors */ + --status-pending: #f39c12; + --status-running: #3498db; + --status-succeeded: #27ae60; + --status-failed: #e74c3c; + --status-killed: #95a5a6; + --status-building: #9b59b6; + + /* Shared layout */ + --max-width: 1400px; + --spacing-md: 20px; + --shadow-card: 0 2px 4px rgba(0,0,0,0.1); +} +``` + +Each dashboard defines its own `--primary-color` while sharing structural styles. No shared code dependencies needed. + +--- + +## Implementation Phases + +1. **Proto schema** - Add snapshot messages, regenerate +2. **Serialization** - Add `to_snapshot()` to internal types +3. **Data source abstraction** - Create protocol and implementations +4. **Dashboard integration** - Refactor to use data source +5. **Worker dashboard** - Same pattern +6. **Style alignment** - CSS variable extraction + +--- + +## Design Decisions + +- **Snapshot format**: Binary protobuf (compact, type-safe, matches RPC layer) +- **Log inclusion**: Last N lines (configurable, default 1000) with total count + +--- + +## Maintainability Concerns + +### 1. Schema Drift +**Risk**: Internal types (`ControllerJob`, etc.) evolve but snapshot protos don't get updated. + +**Mitigation**: +- Round-trip tests that fail if fields are missing +- Consider code generation or a single source of truth + +### 2. Dual Code Paths +**Risk**: Live and snapshot data sources diverge in behavior. + +**Mitigation**: +- Shared test fixtures that run against both implementations +- Property: `live.get_stats() == snapshot_from(live).get_stats()` + +### 3. Dashboard Abstraction Overhead +**Risk**: `DashboardDataSource` protocol adds indirection, making debugging harder. + +**Mitigation**: +- Keep the protocol simple (6 methods) +- Live implementation is a thin wrapper over existing code + +### 4. Proto Message Bloat +**Risk**: Snapshot messages duplicate information already in other protos. + +**Mitigation**: +- `ControllerJobSnapshot` embeds `JobStatus` rather than duplicating fields +- Composition over duplication + +### 5. CSS Duplication +**Risk**: Style alignment via copy-paste leads to drift. + +**Mitigation**: +- CSS variables for shared values +- No shared code = no coordination overhead +- Accept that minor drift is OK (dashboards are internal tools) + +--- + +## Files to Modify + +| File | Changes | +|------|---------| +| `proto/cluster.proto` | Add snapshot messages | +| `cluster/controller/state.py` | Add `to_snapshot()` methods | +| `cluster/dashboard_data.py` | New: data source protocol | +| `cluster/controller/dashboard.py` | Use data source | +| `cluster/worker/dashboard.py` | Same pattern | + +--- + +## Alternative: Make Internal Types Protos + +Rather than maintaining separate Python dataclasses and proto messages, we could make the internal types protos themselves. This section evaluates that approach. + +### Field-by-Field Proto Compatibility + +| Type | Proto-able | Notes | +|------|------------|-------| +| **ControllerJob** | 100% | All fields serializable | +| **ControllerWorker** | 100% | `set[JobId]` → `repeated string` | +| **ControllerEndpoint** | 100% | `Endpoint` proto already exists, just add `registered_at_ms` | +| **ActionLogEntry** | 100% | All fields serializable | +| **Job (worker)** | ~80% | `workdir: Path`, `thread: Thread` are local execution state | + +### Option A: Keep Dataclasses, Add `to_proto()` (Recommended) + +```python +@dataclass +class ControllerJob: + job_id: JobId + request: cluster_pb2.LaunchJobRequest + state: cluster_pb2.JobState = cluster_pb2.JOB_STATE_PENDING + # ... + + def to_snapshot(self) -> cluster_pb2.ControllerJobSnapshot: + return cluster_pb2.ControllerJobSnapshot(...) +``` + +**Pros:** +- Least invasive change +- Best Python ergonomics (methods, default factories, type hints) +- Familiar dataclass patterns + +**Cons:** +- Schema drift risk (field added to dataclass but not proto) +- Requires discipline to keep in sync + +**Mitigation:** Round-trip tests that fail if fields are missing. + +### Option B: Hybrid - Embed Proto in Dataclass + +```python +@dataclass +class ControllerJob: + proto: cluster_pb2.ControllerJobState # All persistent fields live here + + @property + def job_id(self) -> JobId: + return JobId(self.proto.job_id) + + @property + def state(self) -> cluster_pb2.JobState: + return self.proto.state + + def to_snapshot(self) -> cluster_pb2.ControllerJobSnapshot: + return self.proto # Already a proto! +``` + +**Pros:** +- Proto is source of truth - no schema drift +- Still get dataclass methods and ergonomics +- Adding a field to proto automatically makes it serializable + +**Cons:** +- More boilerplate (property accessors) +- Nested proto assignment is awkward: `job.proto.request.CopyFrom(req)` +- Two layers of indirection + +### Option C: Full Proto Types + +Make `ControllerJob`, `ControllerWorker`, etc. proto messages directly. + +```protobuf +message ControllerJobState { + string job_id = 1; + LaunchJobRequest request = 2; + JobState state = 3; + string worker_id = 4; + int32 failure_count = 5; + int32 preemption_count = 6; + // ... +} +``` + +```python +# Usage becomes: +job = cluster_pb2.ControllerJobState() +job.job_id = "abc" +job.request.CopyFrom(req) # Can't assign nested protos directly +job.failure_count += 1 +``` + +**Pros:** +- Single source of truth +- Automatic serialization +- No conversion code needed + +**Cons:** +- **Worse ergonomics:** + - No default factories (`running_jobs: set = field(default_factory=set)` impossible) + - Optional fields awkward (`if job.HasField('worker_id')`) + - Can't assign nested protos directly (must use `CopyFrom`) + - No methods on proto messages (`transition_to()` becomes standalone function) +- Weaker type hints (generated stubs are stringly-typed) +- Worker `Job` still needs wrapper for `thread`, `workdir` + +### Option Comparison + +| Aspect | A: Dataclass + to_proto | B: Hybrid | C: Full Proto | +|--------|------------------------|-----------|---------------| +| Schema drift risk | Medium | Low | None | +| Python ergonomics | Best | Good | Poor | +| Boilerplate | Low | Medium | Low | +| Refactoring cost | Low | Medium | High | + +### Recommendation + +**Option A** for a system this size. The schema drift risk is real but manageable with tests. + +**Option B** is worth considering if you find yourself frequently adding fields and forgetting serialization. The embedded proto acts as a forcing function. + +**Option C** is overkill - the ergonomic cost outweighs the benefit. + +### Note on Existing Protos + +The `Endpoint` proto already exists and is nearly identical to `ControllerEndpoint`: + +```protobuf +message Endpoint { + string endpoint_id = 1; + string name = 2; + string address = 3; + string job_id = 4; + string namespace = 5; + map metadata = 6; + // Missing: registered_at_ms +} +``` + +This is a candidate for consolidation - add `registered_at_ms` to `Endpoint` and use it directly instead of `ControllerEndpoint`. + +--- + +## Open Questions + +1. Should controller auto-capture snapshots on job failure? (opt-in config) +2. CLI tool for snapshot capture/viewing? (`fluster-snapshot capture`, `fluster-dashboard --snapshot`) +3. Snapshot retention policy? (N most recent, or time-based) diff --git a/lib/fluster/docs/worker.md b/lib/fluster/docs/worker.md new file mode 100644 index 0000000000..b894a46a7c --- /dev/null +++ b/lib/fluster/docs/worker.md @@ -0,0 +1,111 @@ +# Worker Overview + +The Worker is the execution agent in Fluster. It registers with the Controller, receives job assignments, prepares execution environments, runs jobs in isolated containers, and reports status back to the Controller. Workers are stateless—they can be added or removed from the cluster without affecting other workers or requiring coordination. + +## Responsibilities + +| Responsibility | Description | +|----------------|-------------| +| Job execution | Runs job entrypoints in isolated Docker containers | +| Environment preparation | Downloads bundles, builds images, sets up dependencies | +| Status reporting | Reports job progress and completion to the Controller | +| Log collection | Captures and serves job stdout/stderr | +| Port allocation | Assigns ephemeral ports for actor servers within jobs | + +## RPC Interface + +The Worker exposes a single RPC service (`WorkerService`) with these methods: + +| Method | Description | +|--------|-------------| +| `RunJob(JobRequest)` | Execute a job on this worker | +| `GetJobStatus(JobId)` | Query current status of a job | +| `ListJobs()` | List all jobs on this worker | +| `FetchLogs(JobId, options)` | Retrieve job logs with optional filtering | +| `KillJob(JobId)` | Terminate a running job | +| `HealthCheck()` | Liveness probe | + +## Job Execution Flow + +When the Controller dispatches a job to a worker: + +``` +1. Download bundle ──► 2. Build image ──► 3. Start container ──► 4. Monitor + │ │ │ │ + ▼ ▼ ▼ ▼ + Cache lookup Cache lookup Port allocation Log collection + or fetch URL or docker build Env var injection Status updates +``` + +1. **Download bundle**: Fetch the workspace archive containing code and dependencies +2. **Build image**: Create a Docker image with the required Python environment +3. **Start container**: Launch the container with allocated ports and environment variables +4. **Monitor**: Track container status, collect logs, report completion + +## Environment Variables + +The Worker injects these environment variables into every job container: + +| Variable | Description | +|----------|-------------| +| `FLUSTER_JOB_ID` | Unique job identifier | +| `FLUSTER_WORKER_ID` | ID of the executing worker | +| `FLUSTER_CONTROLLER_ADDRESS` | Controller URL for actor registration | +| `FLUSTER_NAMESPACE` | Namespace for actor isolation | +| `FLUSTER_PORT_` | Allocated ports (e.g., `FLUSTER_PORT_ACTOR`) | + +Jobs that run actor servers use `FLUSTER_PORT_ACTOR` to bind their server and `FLUSTER_CONTROLLER_ADDRESS` to register with the endpoint registry. + +## Integration Points + +``` +┌──────────────────────────────────────────────────────────┐ +│ Controller │ +│ │ +│ ◄───────────────────┬────────────────────────────────► │ +│ RegisterWorker │ RunJob, GetJobStatus, KillJob │ +└──────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────┐ +│ Worker │ +│ │ +│ Job State ──► Docker ──► Container │ +│ │ │ +│ ▼ │ +│ Job Entrypoint │ +│ (user code) │ +│ │ │ +│ ▼ │ +│ ActorServer │ +│ (optional) │ +└──────────────────────────────────────────────────────────┘ +``` + +- The **Controller** dispatches jobs and polls for status updates +- The **Worker** orchestrates the full job lifecycle directly +- **Docker** provides container isolation +- **Job Entrypoint** is user code that may optionally start an **ActorServer** + +## Job States on Worker + +| State | Description | +|-------|-------------| +| `BUILDING` | Downloading bundle, building image | +| `RUNNING` | Container executing | +| `SUCCEEDED` | Exited with code 0 | +| `FAILED` | Exited with non-zero code or error | +| `KILLED` | Terminated by request | + +## File Summary + +| File | Purpose | +|------|---------| +| `worker.py` | Main `Worker` class with job lifecycle management, port allocation, configuration | +| `service.py` | RPC method implementations | +| `docker.py` | Container runtime interface | +| `builder.py` | Image building | +| `bundle.py` | Bundle download | +| `worker_types.py` | Internal job tracking types | +| `dashboard.py` | Web monitoring UI | +| `main.py` | CLI entry point | diff --git a/lib/fluster/examples/cluster_example.py b/lib/fluster/examples/cluster_example.py new file mode 100644 index 0000000000..7d61cd50a0 --- /dev/null +++ b/lib/fluster/examples/cluster_example.py @@ -0,0 +1,1021 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example demonstrating full cluster operation with controller and worker. + +This example runs a complete mini-cluster locally: +- Controller: schedules jobs, tracks workers, serves dashboard +- Worker: executes jobs in Docker containers + +Usage: + cd lib/fluster + uv run python examples/cluster_example.py +""" + +# TODO, consider having like a post-mortem view of the cluster state +# means cluster state should be serializable, cluster dashboard would always be a mapping over the state + +import socket +import tempfile +import threading +import time +import uuid +import zipfile +from pathlib import Path + +import click +import cloudpickle +from fluster import cluster_pb2 +from fluster.cluster_connect import ControllerServiceClientSync, WorkerServiceClientSync +from fluster.cluster.client import RpcClusterClient +from fluster.cluster.controller.controller import Controller, ControllerConfig, DefaultWorkerStubFactory +from fluster.cluster.types import Entrypoint +from fluster.cluster.worker.worker import Worker, WorkerConfig + + +def find_free_port() -> int: + """Find an available port.""" + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +MINIMAL_PYPROJECT = """\ +[project] +name = "fluster-example" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [ + "cloudpickle", + "fluster", +] + +[tool.uv.sources] +fluster = { path = "./fluster" } +""" + + +def create_minimal_workspace(temp_dir: Path) -> Path: + """Create a minimal workspace with pyproject.toml, uv.lock, and fluster source.""" + workspace = temp_dir / "workspace" + workspace.mkdir(exist_ok=True) + + # Write minimal pyproject.toml with fluster dependency + (workspace / "pyproject.toml").write_text(MINIMAL_PYPROJECT) + + # Copy fluster project into workspace for bundling + # This makes the bundle self-contained + # __file__ = lib/fluster/examples/cluster_example.py + # parent = lib/fluster/examples/ + # parent.parent = lib/fluster/ (the fluster project root) + fluster_project_root = Path(__file__).parent.parent + fluster_dest = workspace / "fluster" + + import shutil + import subprocess + + # Copy only the essential files: pyproject.toml and src/ + fluster_dest.mkdir(exist_ok=True) + shutil.copy2(fluster_project_root / "pyproject.toml", fluster_dest / "pyproject.toml") + shutil.copytree( + fluster_project_root / "src", + fluster_dest / "src", + ignore=shutil.ignore_patterns("__pycache__", "*.pyc", "*.egg-info"), + ) + + # Generate uv.lock with the bundled fluster source + subprocess.run( + ["uv", "lock"], + cwd=workspace, + check=True, + capture_output=True, + ) + + return workspace + + +def create_workspace_bundle(workspace: Path, output_path: Path) -> None: + """Create a zip bundle from workspace directory.""" + with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf: + for file in workspace.rglob("*"): + if file.is_file(): + zf.write(file, file.relative_to(workspace)) + + +class LogPoller: + """Background thread that polls for job logs and prints them.""" + + def __init__( + self, + job_id: str, + worker_address: str, + poll_interval: float = 1.0, + prefix: str = "", + ): + """Initialize log poller. + + Args: + job_id: Job ID to poll logs for + worker_address: Worker RPC address (e.g., "http://127.0.0.1:8080") + poll_interval: How often to poll for new logs (in seconds) + prefix: Optional prefix to add to log lines (e.g., "[calculator] ") + """ + self._job_id = job_id + self._worker_address = worker_address + self._poll_interval = poll_interval + self._prefix = prefix + self._last_timestamp_ms = 0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self): + """Start polling for logs in background thread.""" + if self._thread is not None: + return # Already started + + def _poll(): + client = WorkerServiceClientSync(address=self._worker_address, timeout_ms=5000) + + while not self._stop_event.is_set(): + try: + # Build filter with timestamp for incremental fetching + filter_proto = cluster_pb2.Worker.FetchLogsFilter() + if self._last_timestamp_ms > 0: + filter_proto.start_ms = self._last_timestamp_ms + + request = cluster_pb2.Worker.FetchLogsRequest( + job_id=self._job_id, + filter=filter_proto, + ) + response = client.fetch_logs(request) + + for entry in response.logs: + # Update last seen timestamp + if entry.timestamp_ms > self._last_timestamp_ms: + self._last_timestamp_ms = entry.timestamp_ms + + # Print log with prefix + print(f"{self._prefix}{entry.data}", flush=True) + + except Exception: + # Ignore errors (job may not exist yet, worker starting, etc.) + pass + + # Wait for next poll interval + self._stop_event.wait(self._poll_interval) + + self._thread = threading.Thread(target=_poll, daemon=True) + self._thread.start() + + def stop(self): + """Stop polling thread.""" + if self._thread is None: + return + + self._stop_event.set() + self._thread.join(timeout=2.0) + self._thread = None + + +class ClusterContext: + """Synchronous context manager running a controller + worker cluster. + + Provides a simple API for submitting jobs through the controller, + which schedules them to workers. + + Example: + with ClusterContext() as cluster: + job_id = cluster.submit(my_function, arg1, arg2) + status = cluster.wait(job_id) + logs = cluster.logs(job_id) + """ + + def __init__( + self, + controller_port: int = 0, + worker_port: int = 0, + max_concurrent_jobs: int = 3, + registry: str = "localhost:5000", + ): + self._controller_port = controller_port or find_free_port() + self._worker_port = worker_port or find_free_port() + self._max_concurrent_jobs = max_concurrent_jobs + self._registry = registry + + # Will be initialized in __enter__ + self._temp_dir: tempfile.TemporaryDirectory | None = None + self._bundle_dir: Path | None = None + self._workspace: Path | None = None + self._worker_id: str | None = None + + # Controller and Worker + self._controller: Controller | None = None + self._worker: Worker | None = None + + # RPC client for controller calls + self._controller_client: ControllerServiceClientSync | None = None + + # Cached bundle + self._bundle_blob: bytes | None = None + + # Log polling + self._log_pollers: dict[str, LogPoller] = {} + + def __enter__(self): + """Start controller and worker.""" + self._temp_dir = tempfile.TemporaryDirectory(prefix="cluster_") + temp_path = Path(self._temp_dir.name) + self._bundle_dir = temp_path / "bundles" + self._bundle_dir.mkdir() + cache_path = temp_path / "cache" + cache_path.mkdir() + + # Create a minimal workspace with pyproject.toml and uv.lock + print("Creating minimal workspace...", flush=True) + self._workspace = create_minimal_workspace(temp_path) + + # --- Start Worker First (so it's ready when controller dispatches) --- + print("Starting worker components...") + self._worker_id = f"worker-{uuid.uuid4().hex[:8]}" + worker_config = WorkerConfig( + host="127.0.0.1", + port=self._worker_port, + cache_dir=cache_path, + registry=self._registry, + max_concurrent_jobs=self._max_concurrent_jobs, + controller_address=f"http://127.0.0.1:{self._controller_port}", + worker_id=self._worker_id, + ) + self._worker = Worker(worker_config, cache_dir=cache_path) + self._worker.start() + print(f"Worker server should be at http://127.0.0.1:{self._worker_port}") + + # --- Start Controller --- + print("Starting controller components...") + controller_config = ControllerConfig( + host="127.0.0.1", + port=self._controller_port, + bundle_dir=self._bundle_dir, + ) + self._controller = Controller( + config=controller_config, + worker_stub_factory=DefaultWorkerStubFactory(), + ) + self._controller.start() + print(f"Controller server should be at http://127.0.0.1:{self._controller_port}", flush=True) + + # Create RPC client + print("Creating RPC client...", flush=True) + self._controller_client = ControllerServiceClientSync( + address=f"http://127.0.0.1:{self._controller_port}", + timeout_ms=30000, + ) + print("RPC client created", flush=True) + + # Register worker with controller + print(f"Registering worker {self._worker_id}...", flush=True) + self._register_worker() + print("Worker registered", flush=True) + + print(f"Controller: http://127.0.0.1:{self._controller_port}", flush=True) + print(f"Worker: http://127.0.0.1:{self._worker_port}", flush=True) + + print("Cluster startup complete!", flush=True) + return self + + def start_log_polling(self, job_id: str, poll_interval: float = 1.0, prefix: str = ""): + """Start a background thread that polls for job logs. + + Args: + job_id: Job ID to poll logs for + poll_interval: How often to poll for new logs (in seconds) + prefix: Optional prefix to add to log lines (e.g., "[calculator] ") + """ + if job_id in self._log_pollers: + return # Already polling + + poller = LogPoller(job_id, self.worker_url, poll_interval, prefix) + poller.start() + self._log_pollers[job_id] = poller + + def stop_log_polling(self, job_id: str): + """Stop log polling for a job. + + Args: + job_id: Job ID to stop polling for + """ + poller = self._log_pollers.pop(job_id, None) + if poller: + poller.stop() + + def __exit__(self, *args): + """Stop cluster and cleanup.""" + # Stop all log polling threads + for job_id in list(self._log_pollers.keys()): + self.stop_log_polling(job_id) + + if self._controller_client: + self._controller_client.close() + + if self._controller: + self._controller.stop() + + if self._worker: + self._worker.stop() + + if self._temp_dir: + self._temp_dir.cleanup() + + def _register_worker(self): + """Register worker with controller.""" + request = cluster_pb2.Controller.RegisterWorkerRequest( + worker_id=self._worker_id, + address=f"127.0.0.1:{self._worker_port}", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="16g", + ), + ) + self._controller_client.register_worker(request) + + def _get_bundle_blob(self) -> bytes: + """Get workspace bundle (cached).""" + if self._bundle_blob is not None: + return self._bundle_blob + + bundle_path = Path(self._temp_dir.name) / "workspace.zip" + create_workspace_bundle(self._workspace, bundle_path) + self._bundle_blob = bundle_path.read_bytes() + return self._bundle_blob + + def submit( + self, + fn, + *args, + name: str | None = None, + timeout_seconds: int = 0, + env_vars: dict[str, str] | None = None, + cpu: int = 1, + memory: str = "1g", + scheduling_timeout_seconds: int = 0, + namespace: str | None = None, + ports: list[str] | None = None, + **kwargs, + ) -> str: + """Submit a job to the cluster. + + Args: + fn: Callable to execute + *args: Positional arguments for fn + name: Job name (defaults to function name) + timeout_seconds: Job timeout + env_vars: Environment variables + cpu: Number of CPUs to request + memory: Memory to request (e.g., "1g", "512m") + scheduling_timeout_seconds: How long to wait for scheduling before marking UNSCHEDULABLE + namespace: Namespace for actor isolation (defaults to "") + ports: List of port names to allocate (e.g., ["actor", "metrics"]) + **kwargs: Keyword arguments for fn + + Returns: + Job ID + """ + entrypoint = Entrypoint(callable=fn, args=args, kwargs=kwargs) + serialized = cloudpickle.dumps(entrypoint) + + # Build environment with user-provided vars + # Worker will auto-inject system FLUSTER_* variables + env = env_vars or {} + + # Add namespace as environment variable (actor-level concern, not cluster-level) + if namespace: + env["FLUSTER_NAMESPACE"] = namespace + + request = cluster_pb2.Controller.LaunchJobRequest( + name=name or fn.__name__, + serialized_entrypoint=serialized, + resources=cluster_pb2.ResourceSpec( + cpu=cpu, + memory=memory, + ), + environment=cluster_pb2.EnvironmentConfig( + workspace="/app", + env_vars=env, + ), + bundle_blob=self._get_bundle_blob(), + scheduling_timeout_seconds=scheduling_timeout_seconds, + ports=ports or [], + ) + response = self._controller_client.launch_job(request) + return response.job_id + + def status(self, job_id: str) -> dict: + """Get job status from controller.""" + request = cluster_pb2.Controller.GetJobStatusRequest(job_id=job_id) + response = self._controller_client.get_job_status(request) + # Convert protobuf to dict for compatibility with existing code + return { + "jobId": response.job.job_id, + "state": cluster_pb2.JobState.Name(response.job.state), + "exitCode": response.job.exit_code, + "error": response.job.error, + "workerId": response.job.worker_id, + } + + def wait(self, job_id: str, timeout: float = 300.0, poll_interval: float = 0.5) -> dict: + """Wait for job to complete.""" + start = time.time() + terminal_states = { + "JOB_STATE_SUCCEEDED", + "JOB_STATE_FAILED", + "JOB_STATE_KILLED", + "JOB_STATE_UNSCHEDULABLE", + } + + while time.time() - start < timeout: + status = self.status(job_id) + if status["state"] in terminal_states: + return status + time.sleep(poll_interval) + + raise TimeoutError(f"Job {job_id} did not complete in {timeout}s") + + def logs(self, job_id: str, since_ms: int | None = None) -> list[cluster_pb2.Worker.LogEntry]: + """Get job logs from worker. + + Args: + job_id: Job ID + since_ms: Only return logs after this timestamp (milliseconds since epoch). + If None, returns all logs. + + Returns: + List of LogEntry protos with timestamp_ms, source, and data. + """ + # Find the worker that has this job + status = self.status(job_id) + worker_id = status.get("workerId") + + if not worker_id: + return [] + + # For now, query worker directly (in a real cluster, would route through controller) + worker_client = WorkerServiceClientSync( + address=f"http://127.0.0.1:{self._worker_port}", + timeout_ms=10000, + ) + + # Build filter with optional timestamp + filter_proto = cluster_pb2.Worker.FetchLogsFilter() + if since_ms is not None: + filter_proto.start_ms = since_ms + + request = cluster_pb2.Worker.FetchLogsRequest(job_id=job_id, filter=filter_proto) + response = worker_client.fetch_logs(request) + return list(response.logs) + + def kill(self, job_id: str) -> None: + """Kill a job via controller.""" + request = cluster_pb2.Controller.TerminateJobRequest(job_id=job_id) + self._controller_client.terminate_job(request) + + @property + def controller_url(self) -> str: + return f"http://127.0.0.1:{self._controller_port}" + + @property + def worker_url(self) -> str: + return f"http://127.0.0.1:{self._worker_port}" + + def get_client(self) -> RpcClusterClient: + """Get an RpcClusterClient for this cluster.""" + return RpcClusterClient( + controller_address=self.controller_url, + bundle_blob=self._get_bundle_blob(), + ) + + +# ============================================================================= +# ACTOR SYSTEM EXAMPLES +# ============================================================================= + + +def example_actor_job_workflow(cluster: ClusterContext): + """Demonstrate real actor job workflow with cluster integration. + + This example shows the complete end-to-end workflow: + 1. Submit a job that runs an ActorServer + 2. The job registers its actor endpoint with the controller + 3. A client uses ClusterResolver to discover and call the actor + 4. The actor can access cluster context via current_ctx() + + This is the recommended pattern for production actor deployments. + """ + print("\n=== Example: Real Actor Job Workflow ===\n") + + # Step 1: Define an actor job entrypoint + # This function will run inside a cluster job and start an ActorServer + def actor_job_entrypoint(): + """Job entrypoint that starts an ActorServer and registers with controller.""" + import os + import time + + from fluster import cluster_pb2 + from fluster.actor import ActorServer + from fluster.cluster_connect import ControllerServiceClientSync + + # Get environment variables injected by the cluster + job_id = os.environ["FLUSTER_JOB_ID"] + namespace = os.environ["FLUSTER_NAMESPACE"] + controller_url = os.environ["FLUSTER_CONTROLLER_ADDRESS"] + # Port allocated by the cluster and mapped to host via Docker -p flag + allocated_port = int(os.environ["FLUSTER_PORT_ACTOR"]) + + print(f"Actor job starting: job_id={job_id}, namespace={namespace}") + print(f"Using allocated port: {allocated_port}") + + # Define our actor class inline (could also be imported) + class Calculator: + def __init__(self): + self._history = [] + + def add(self, a: int, b: int) -> int: + result = a + b + self._history.append(f"add({a}, {b}) = {result}") + print(f"Calculator.add({a}, {b}) = {result}") + return result + + def multiply(self, a: int, b: int) -> int: + result = a * b + self._history.append(f"multiply({a}, {b}) = {result}") + print(f"Calculator.multiply({a}, {b}) = {result}") + return result + + def get_history(self) -> list[str]: + return self._history + + # Start the ActorServer on the allocated port + # Use 0.0.0.0 to bind to all interfaces (necessary inside Docker) + # The port is mapped to the host via Docker -p flag + server = ActorServer(host="0.0.0.0", port=allocated_port) + server.register("calculator", Calculator()) + port = server.serve_background() + print(f"ActorServer started on port {port}") + + # Register the endpoint with the controller using Connect RPC + # Use localhost since the port is mapped from host to container via Docker -p + # Clients on the host will connect to localhost: + endpoint_address = f"localhost:{port}" + + print(f"Registering endpoint: calculator at {endpoint_address}") + try: + controller_client = ControllerServiceClientSync(address=controller_url) + request = cluster_pb2.Controller.RegisterEndpointRequest( + name="calculator", + address=endpoint_address, + job_id=job_id, + namespace=namespace, + metadata={"version": "1.0"}, + ) + response = controller_client.register_endpoint(request) + print(f"Endpoint registered successfully: {response.endpoint_id}") + except Exception as e: + print(f"Error registering endpoint: {e}") + import traceback + + traceback.print_exc() + + # Keep the job running to serve requests + print("Actor server ready, waiting for requests...") + while True: + time.sleep(1) + + # Step 2: Submit the actor job to the cluster + # Request a port named "actor" which will be allocated and mapped by Docker + print("Submitting actor job to cluster...") + job_id = cluster.submit( + actor_job_entrypoint, + name="calculator-actor", + cpu=1, + memory="512m", + namespace="", + ports=["actor"], + ) + print(f"Job submitted: {job_id}") + + # Start log polling in background thread + print("Starting log polling...") + cluster.start_log_polling(job_id, poll_interval=1.0, prefix="[calculator] ") + + # Step 3: Wait for the job to start and endpoint to be registered + print("Waiting for job to start...") + max_wait = 30 + start_time = time.time() + job_running = False + + while time.time() - start_time < max_wait: + status = cluster.status(job_id) + state = status.get("state", "") + print(f" Job state: {state}") + + if state == "JOB_STATE_RUNNING": + job_running = True + print("Job is running!") + break + + if state in ["JOB_STATE_FAILED", "JOB_STATE_KILLED"]: + print(f"Job failed: {status.get('error', 'Unknown error')}") + logs = cluster.logs(job_id) + if logs: + print("Job logs:") + for log in logs: + print(f" {log}") + return + + time.sleep(1) + + if not job_running: + print("Job did not start in time") + return + + # Give the actor server time to register + print("Waiting for endpoint registration...") + time.sleep(3) + + # Step 4: Use ClusterResolver to discover the actor + from fluster.actor import ActorClient, ClusterResolver + + print("\nResolving actor via ClusterResolver...") + resolver = ClusterResolver(cluster.controller_url, namespace="") + client = ActorClient(resolver, "calculator") + + # Step 5: Call the actor methods + print("\nCalling actor methods...") + result1 = client.add(10, 20) + print(f"Client received: add(10, 20) = {result1}") + + result2 = client.multiply(5, 7) + print(f"Client received: multiply(5, 7) = {result2}") + + history = client.get_history() + print(f"Operation history: {history}") + + print("\nActor job workflow complete!") + print("Note: The actor job will continue running until the cluster shuts down.") + + +def example_worker_pool(cluster: ClusterContext): + """Demonstrate WorkerPool for task dispatch.""" + from fluster.worker_pool import WorkerPool, WorkerPoolConfig + + print("\n=== Example: Worker Pool ===\n") + + client = cluster.get_client() + config = WorkerPoolConfig( + num_workers=2, + resources=cluster_pb2.ResourceSpec(cpu=1, memory="512m"), + ) + + def square(x): + return x * x + + with WorkerPool(client, config) as pool: + print(f"WorkerPool started with {pool.size} workers") + futures = pool.map(square, [1, 2, 3, 4, 5]) + results = [f.result() for f in futures] + print(f"Results: {results}") + + print("\nWorkerPool example complete!") + + +# ============================================================================= +# DEPRECATED ACTOR EXAMPLES - REMOVED +# ============================================================================= +# The old actor examples (example_actor_basic, example_actor_coordinator, +# example_actor_pool) used a standalone actor pattern that required a system +# job hack (_system_job_id). This pattern is no longer supported and the +# examples have been removed. +# +# Use example_actor_job_workflow() instead for the recommended pattern where +# actors run as cluster jobs and register with the controller properly. +# ============================================================================= + + +# ============================================================================= +# CLUSTER JOB EXAMPLES +# ============================================================================= + + +def example_basic(cluster: ClusterContext): + """Basic job submission through cluster.""" + print("\n=== Example: Basic Job Submission ===\n", flush=True) + + def hello(): + print("Hello from the cluster!") + return 42 + + job_id = cluster.submit(hello, name="hello-job") + print(f"Submitted: {job_id}") + + status = cluster.wait(job_id) + print(f"Completed with state: {status['state']}") + if status.get("error"): + print(f"Error: {status['error']}") + if status.get("exitCode"): + print(f"Exit code: {status['exitCode']}") + + logs = cluster.logs(job_id) + if logs: + print("Logs:") + for log in logs: + print(f" {log}") + + +def example_with_args(cluster: ClusterContext): + """Job with arguments.""" + print("\n=== Example: Job With Arguments ===\n") + + def add_numbers(a, b): + result = a + b + print(f"{a} + {b} = {result}") + return result + + job_id = cluster.submit(add_numbers, 10, 32, name="add-job") + print(f"Submitted: {job_id}") + + status = cluster.wait(job_id) + print(f"Completed: {status['state']}") + + logs = cluster.logs(job_id) + if logs: + print("Output:") + for log in logs: + print(f" {log}") + + +def example_concurrent(cluster: ClusterContext): + """Multiple concurrent jobs.""" + print("\n=== Example: Concurrent Jobs ===\n") + + def slow_job(n): + import time as t + + for i in range(3): + print(f"Job {n}: iteration {i}") + t.sleep(1) + return n + + # Submit 3 jobs + job_ids = [] + for i in range(3): + job_id = cluster.submit(slow_job, i, name=f"slow-{i}") + job_ids.append(job_id) + print(f"Submitted job {i}: {job_id[:8]}...") + + # Wait for all + print("\nWaiting for jobs to complete...") + for i, job_id in enumerate(job_ids): + status = cluster.wait(job_id) + print(f"Job {i} ({job_id[:8]}...): {status['state']}") + + +def example_kill(cluster: ClusterContext): + """Kill a running job.""" + print("\n=== Example: Kill Job ===\n") + + def long_job(): + import time as t + + for i in range(60): + print(f"Tick {i}") + t.sleep(1) + + job_id = cluster.submit(long_job, name="long-job") + print(f"Started: {job_id[:8]}...") + + # Wait for it to start running + print("Waiting for job to start...") + for _ in range(60): + status = cluster.status(job_id) + if status["state"] == "JOB_STATE_RUNNING": + print("Job is running!") + break + time.sleep(0.5) + else: + print("Job did not start in time") + return + + # Give it a moment to produce some output + time.sleep(2) + + print("Killing job...") + cluster.kill(job_id) + + status = cluster.status(job_id) + print(f"Final state: {status['state']}") + + +def example_resource_serialization(cluster: ClusterContext): + """Demonstrate job serialization based on resource constraints. + + The worker has 4 CPUs. We submit jobs requiring 2 CPUs each, + so only 2 can run concurrently. The rest must wait in queue. + """ + print("\n=== Example: Resource Serialization ===\n") + + def cpu_bound_job(n): + import time as t + + print(f"Job {n}: starting (needs 2 CPUs)") + t.sleep(3) + print(f"Job {n}: completed") + return n + + # Submit 4 jobs, each requiring 2 CPUs (worker has 4 CPUs total) + # Only 2 should run at a time + job_ids = [] + for i in range(4): + job_id = cluster.submit(cpu_bound_job, i, name=f"cpu-job-{i}", cpu=2, memory="1g") + job_ids.append(job_id) + print(f"Submitted job {i}: {job_id[:8]}... (requires 2 CPUs)") + + # Check initial states - first 2 should be running, rest pending + time.sleep(2) + print("\nChecking job states after 2 seconds:") + for i, job_id in enumerate(job_ids): + status = cluster.status(job_id) + print(f" Job {i}: {status['state']}") + + # Wait for all jobs to complete + print("\nWaiting for all jobs to complete...") + for i, job_id in enumerate(job_ids): + status = cluster.wait(job_id) + print(f"Job {i} ({job_id[:8]}...): {status['state']}") + + +def example_scheduling_timeout(cluster: ClusterContext): + """Demonstrate scheduling timeout for jobs that can't be scheduled. + + Submit a job requiring more resources than available. With a short + scheduling timeout, it should become UNSCHEDULABLE. + """ + print("\n=== Example: Scheduling Timeout ===\n") + + def impossible_job(): + print("This should never run!") + return 0 + + # Submit a job requiring 100 CPUs (worker only has 4) + # With a 2 second scheduling timeout, it should fail quickly + print("Submitting job requiring 100 CPUs (worker has 4)...") + print("Setting 2 second scheduling timeout...") + job_id = cluster.submit( + impossible_job, + name="impossible-job", + cpu=100, + memory="1g", + scheduling_timeout_seconds=2, + ) + print(f"Submitted: {job_id[:8]}...") + + # Wait for it to timeout + status = cluster.wait(job_id, timeout=10.0) + print(f"Final state: {status['state']}") + if status.get("error"): + print(f"Error: {status['error']}") + + +def example_small_job_skips_queue(cluster: ClusterContext): + """Demonstrate that smaller jobs can skip ahead of larger jobs. + + Submit a large job that won't fit, then a small job. The small + job should be scheduled even though the large job is ahead in queue. + """ + print("\n=== Example: Small Jobs Skip Large Jobs ===\n") + + def big_job(): + import time as t + + print("Big job running (this shouldn't happen immediately)") + t.sleep(5) + return "big" + + def small_job(): + import time as t + + print("Small job running!") + t.sleep(1) + return "small" + + # First submit a job that's too big to fit + print("Submitting big job (8 CPUs, won't fit on 4-CPU worker)...") + big_job_id = cluster.submit(big_job, name="big-job", cpu=8, memory="1g", scheduling_timeout_seconds=0) + print(f"Big job: {big_job_id[:8]}...") + + # Then submit a small job that can run + print("Submitting small job (1 CPU)...") + small_job_id = cluster.submit(small_job, name="small-job", cpu=1, memory="1g") + print(f"Small job: {small_job_id[:8]}...") + + # Small job should run even though big job is first in queue + time.sleep(2) + big_status = cluster.status(big_job_id) + small_status = cluster.status(small_job_id) + print("\nAfter 2 seconds:") + print(f" Big job: {big_status['state']}") + print(f" Small job: {small_status['state']}") + + # Wait for small job to complete + small_result = cluster.wait(small_job_id, timeout=30.0) + print(f"\nSmall job completed: {small_result['state']}") + + # Big job should still be pending (never scheduled) + big_status = cluster.status(big_job_id) + print(f"Big job still: {big_status['state']}") + + +@click.command() +@click.option( + "--wait/--no-wait", default=False, help="Wait for Ctrl+C after examples complete (for dashboard exploration)" +) +@click.option( + "--mode", + type=click.Choice(["all", "actors", "jobs"], case_sensitive=False), + default="all", + help="Which examples to run: all (default), actors (actor system only), or jobs (cluster jobs only)", +) +def main(wait: bool, mode: str): + """Run cluster and actor examples. + + This example demonstrates the full Fluster system including: + - Cluster controller and worker for job scheduling + - Actor system for distributed RPC between services + - Various patterns: coordinator, pool, load-balancing, broadcast + """ + print("=" * 60) + print("Fluster Cluster & Actor System Example") + print("=" * 60) + + if mode in ["all", "jobs"]: + print("\nNote: Job examples require Docker to be running.") + + try: + with ClusterContext(max_concurrent_jobs=3) as cluster: + print(f"\nController dashboard: {cluster.controller_url}", flush=True) + print(f"Worker dashboard: {cluster.worker_url}", flush=True) + if wait: + print("\nPress Ctrl+C to stop.\n", flush=True) + + # Run actor examples + if mode in ["all", "actors"]: + print("\n" + "=" * 60) + print("ACTOR SYSTEM EXAMPLES") + print("=" * 60) + example_actor_job_workflow(cluster) + example_worker_pool(cluster) + + # Run cluster job examples + if mode in ["all", "jobs"]: + print("\n" + "=" * 60) + print("CLUSTER JOB EXAMPLES") + print("=" * 60) + print("About to run job examples...", flush=True) + example_basic(cluster) + example_with_args(cluster) + example_concurrent(cluster) + example_kill(cluster) + example_resource_serialization(cluster) + example_scheduling_timeout(cluster) + example_small_job_skips_queue(cluster) + + print("\n" + "=" * 60) + print("All examples completed!") + print("=" * 60) + + if wait: + print("Dashboards still available for exploration.") + print("Press Ctrl+C to stop.") + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\nShutting down...") + except Exception as e: + print(f"\nError: {e}") + if mode in ["all", "jobs"]: + print("\nMake sure Docker is running and try again.") + raise + + +if __name__ == "__main__": + main() diff --git a/lib/fluster/pyproject.toml b/lib/fluster/pyproject.toml new file mode 100644 index 0000000000..d3215d9bcd --- /dev/null +++ b/lib/fluster/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fluster" +version = "0.1.0" +requires-python = ">=3.11,<3.13" +dependencies = [ + "cloudpickle>=3.1.2", + "connect-python @ git+https://github.com/connectrpc/connect-python.git@5342eacecef85e52717604ee5ac7e03a1e16c7ac", + "docker>=7.0.0", + "fsspec>=2024.0.0", + "httpx>=0.28.1", + "humanfriendly>=10.0", + "pydantic>=2.12.5", + "starlette>=0.50.0", + "typing-extensions>=4.0", + "uvicorn[standard]>=0.23.0", +] + +[dependency-groups] +dev = [ + "pytest>=8.3.2", + "pytest-asyncio", + "pytest-timeout", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/fluster"] + +[tool.pytest.ini_options] +timeout = 60 +addopts = "--durations=10" +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] +filterwarnings = ["ignore::DeprecationWarning"] +log_level = "INFO" +log_format = "%(asctime)s %(levelname)s %(message)s" +log_date_format = "%Y-%m-%d %H:%M:%S" +log_cli_level = "INFO" diff --git a/lib/fluster/scripts/generate_protos.py b/lib/fluster/scripts/generate_protos.py new file mode 100755 index 0000000000..799166ce4e --- /dev/null +++ b/lib/fluster/scripts/generate_protos.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate protobuf and Connect files, then fix imports.""" + +import re +import subprocess +import sys +from pathlib import Path + + +def fix_imports(file_path: Path) -> None: + """Fix imports in a generated _connect.py file.""" + content = file_path.read_text() + + # Pattern: import _pb2 as __pb2 + # Replace with: from . import _pb2 as __pb2 + pattern = r"^import (\w+_pb2) as (\w+__pb2)$" + replacement = r"from . import \1 as \2" + + new_content = re.sub(pattern, replacement, content, flags=re.MULTILINE) + + if new_content != content: + file_path.write_text(new_content) + print(f"✓ Fixed imports in {file_path.name}") + else: + print(f"✓ No changes needed in {file_path.name}") + + +def run_buf_generate(root_dir: Path) -> None: + """Run buf generate.""" + print("Running buf generate...") + result = subprocess.run( + ["buf", "generate"], + cwd=root_dir, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + print(f"Error running buf generate:\n{result.stderr}", file=sys.stderr) + sys.exit(1) + + print("✓ buf generate completed successfully") + + +def main(): + """Generate protobuf files and fix imports.""" + root_dir = Path(__file__).parent.parent + src_dir = root_dir / "src" / "fluster" + + # Run buf generate + run_buf_generate(root_dir) + + # Fix imports in all generated Connect files + print("\nFixing imports in generated files...") + for connect_file in src_dir.glob("*_connect.py"): + fix_imports(connect_file) + + print("\n✓ Generation complete!") + + +if __name__ == "__main__": + main() diff --git a/lib/fluster/src/fluster/__init__.py b/lib/fluster/src/fluster/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/src/fluster/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/src/fluster/actor/__init__.py b/lib/fluster/src/fluster/actor/__init__.py new file mode 100644 index 0000000000..a1feb146ec --- /dev/null +++ b/lib/fluster/src/fluster/actor/__init__.py @@ -0,0 +1,49 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Actor system for distributed RPC.""" + +from fluster.actor.client import ActorClient +from fluster.actor.pool import ActorPool, BroadcastFuture, CallResult +from fluster.actor.resolver import ( + ClusterResolver, + FixedResolver, + GcsApi, + GcsResolver, + MockGcsApi, + ResolveResult, + ResolvedEndpoint, + Resolver, +) +from fluster.actor.server import ActorServer +from fluster.actor.types import ActorContext, ActorId, current_ctx + +__all__ = [ + "ActorClient", + "ActorContext", + "ActorId", + "ActorPool", + "ActorServer", + "BroadcastFuture", + "CallResult", + "ClusterResolver", + "FixedResolver", + "GcsApi", + "GcsResolver", + "MockGcsApi", + "ResolveResult", + "ResolvedEndpoint", + "Resolver", + "current_ctx", +] diff --git a/lib/fluster/src/fluster/actor/client.py b/lib/fluster/src/fluster/actor/client.py new file mode 100644 index 0000000000..b39ccdc57c --- /dev/null +++ b/lib/fluster/src/fluster/actor/client.py @@ -0,0 +1,244 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Actor client for making RPC calls to actor servers. + +The ActorClient provides transparent actor discovery and invocation with +automatic retry logic. When an actor name cannot be resolved immediately +(e.g., actor server still starting), the client retries with exponential +backoff until the timeout is reached. + +Typical actor startup: 2-8 seconds (Docker container + server initialization). +Default timeout (30s) and retry config handle this gracefully. + +Example: + resolver = ClusterResolver("http://controller:8080") + client = ActorClient(resolver, "my-actor") + result = client.some_method(arg1, arg2) # Retries until actor found + +Custom retry behavior: + retry_config = RetryConfig(initial_delay=0.2, max_delay=5.0) + client = ActorClient(resolver, "my-actor", retry_config=retry_config) +""" + +import logging +import random +import time +from dataclasses import dataclass +from typing import Any + +import cloudpickle + +from fluster import actor_pb2 +from fluster.actor.resolver import Resolver, ResolveResult +from fluster.actor_connect import ActorServiceClientSync + +logger = logging.getLogger(__name__) + + +@dataclass +class RetryConfig: + """Configuration for exponential backoff retry when resolving actors. + + When an actor name cannot be resolved, the client retries with + exponentially increasing delays until the timeout is reached. + + Attributes: + initial_delay: Initial retry delay in seconds + max_delay: Maximum delay between retries in seconds + backoff_factor: Multiplier for exponential backoff + jitter_factor: Random jitter as fraction of delay (e.g., 0.25 = ±25%) + """ + + initial_delay: float = 0.1 + max_delay: float = 2.0 + backoff_factor: float = 2.0 + jitter_factor: float = 0.25 + + def __post_init__(self): + if self.initial_delay <= 0: + raise ValueError("initial_delay must be positive") + if self.max_delay < self.initial_delay: + raise ValueError("max_delay must be >= initial_delay") + if self.backoff_factor < 1.0: + raise ValueError("backoff_factor must be >= 1.0") + if not 0 <= self.jitter_factor < 1.0: + raise ValueError("jitter_factor must be in [0, 1)") + + +def calculate_next_delay(attempt: int, config: RetryConfig) -> float: + """Calculate next retry delay with exponential backoff and jitter. + + Args: + attempt: Retry attempt number (0-indexed) + config: Retry configuration + + Returns: + Delay in seconds before next retry + """ + # Exponential: initial * (backoff^attempt) + delay = config.initial_delay * (config.backoff_factor**attempt) + + # Cap at max_delay + delay = min(delay, config.max_delay) + + # Add jitter: random in [delay*(1-jitter), delay*(1+jitter)] + if config.jitter_factor > 0: + jitter_range = delay * config.jitter_factor + delay = delay + random.uniform(-jitter_range, jitter_range) + delay = max(0.001, delay) # Keep positive + + return delay + + +class ActorClient: + """Actor client with resolver-based discovery.""" + + def __init__( + self, + resolver: Resolver, + name: str, + timeout: float = 30.0, + retry_config: RetryConfig | None = None, + ): + """Initialize the actor client. + + Args: + resolver: Resolver instance for endpoint discovery + name: Name of the actor to invoke + timeout: Total timeout in seconds for resolution + RPC calls. + When resolving, retries continue until this timeout is reached. + retry_config: Retry configuration. If None, uses default RetryConfig(). + """ + self._resolver = resolver + self._name = name + self._timeout = timeout + self._retry_config = retry_config or RetryConfig() + self._cached_result: ResolveResult | None = None + self._client: ActorServiceClientSync | None = None + self._client_url: str | None = None + + def _resolve(self) -> ResolveResult: + """Resolve actor name with exponential backoff retry. + + Retries resolution with exponential backoff until either: + - Endpoints are found (success) + - Timeout is reached (raises TimeoutError) + + Returns: + ResolveResult with at least one endpoint + + Raises: + TimeoutError: If no endpoints found within timeout + """ + # Check cache first + if self._cached_result is not None and not self._cached_result.is_empty: + return self._cached_result + + # Retry loop with exponential backoff + start_time = time.monotonic() + attempt = 0 + + while True: + # Try to resolve + result = self._resolver.resolve(self._name) + + if not result.is_empty: + # Success! Cache and return + self._cached_result = result + if attempt > 0: + logger.debug( + f"Resolved actor '{self._name}' to {len(result.endpoints)} endpoint(s) " + f"after {attempt} retries in {time.monotonic() - start_time:.2f}s" + ) + return result + + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= self._timeout: + raise TimeoutError( + f"Failed to resolve actor '{self._name}' after {self._timeout}s " f"({attempt} retries)" + ) + + # Calculate next delay with exponential backoff + jitter + delay = calculate_next_delay(attempt, self._retry_config) + + # Adjust delay to not exceed timeout + remaining = self._timeout - elapsed + if delay > remaining: + delay = remaining + + if delay > 0: + logger.debug( + f"Actor '{self._name}' not found, retrying in {delay:.3f}s " + f"(attempt {attempt + 1}, elapsed {elapsed:.2f}s/{self._timeout}s)" + ) + time.sleep(delay) + + attempt += 1 + + def _invalidate_cache(self) -> None: + self._cached_result = None + self._client = None + self._client_url = None + + def _get_client(self, url: str) -> ActorServiceClientSync: + """Get or create a client for the given URL.""" + if self._client is None or self._client_url != url: + self._client = ActorServiceClientSync( + address=url, + timeout_ms=int(self._timeout * 1000), + ) + self._client_url = url + return self._client + + def __getattr__(self, method_name: str) -> "_RpcMethod": + return _RpcMethod(self, method_name) + + +class _RpcMethod: + """Represents a single RPC method call.""" + + def __init__(self, client: ActorClient, method_name: str): + self._client = client + self._method_name = method_name + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the RPC call.""" + result = self._client._resolve() + if result.is_empty: + raise RuntimeError(f"No endpoints found for actor '{self._client._name}'") + + endpoint = result.first() + + call = actor_pb2.ActorCall( + method_name=self._method_name, + actor_name=self._client._name, + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + ) + + try: + client = self._client._get_client(endpoint.url) + resp = client.call(call) + except Exception: + self._client._invalidate_cache() + raise + + if resp.HasField("error"): + if resp.error.serialized_exception: + raise cloudpickle.loads(resp.error.serialized_exception) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) diff --git a/lib/fluster/src/fluster/actor/pool.py b/lib/fluster/src/fluster/actor/pool.py new file mode 100644 index 0000000000..10ea70ea57 --- /dev/null +++ b/lib/fluster/src/fluster/actor/pool.py @@ -0,0 +1,289 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Actor pool for load-balanced and broadcast RPC calls.""" + +import threading +from collections.abc import Callable, Iterator +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +import cloudpickle + +from fluster import actor_pb2 +from fluster.actor.resolver import ResolveResult, ResolvedEndpoint, Resolver +from fluster.actor_connect import ActorServiceClientSync + +T = TypeVar("T") + + +@dataclass +class CallResult: + """Result of a single call in a broadcast. + + Attributes: + endpoint: The endpoint that was called + value: The return value (None if exception occurred) + exception: The exception raised (None if successful) + """ + + endpoint: ResolvedEndpoint + value: Any | None = None + exception: BaseException | None = None + + @property + def success(self) -> bool: + """Returns True if the call succeeded without exception.""" + return self.exception is None + + +class BroadcastFuture(Generic[T]): + """Future representing results from a broadcast call to multiple endpoints. + + Provides methods to wait for all results, wait for any result, or iterate + results as they complete. + """ + + def __init__(self, futures: list[tuple[ResolvedEndpoint, Future]]): + """Initialize with list of (endpoint, future) pairs.""" + self._futures = futures + + def wait_all(self, timeout: float | None = None) -> list[CallResult]: + """Wait for all calls to complete and return all results. + + Args: + timeout: Total timeout in seconds for all calls + + Returns: + List of CallResult, one per endpoint + """ + results = [] + for endpoint, future in self._futures: + try: + value = future.result(timeout=timeout) + results.append(CallResult(endpoint=endpoint, value=value)) + except Exception as e: + results.append(CallResult(endpoint=endpoint, exception=e)) + return results + + def wait_any(self, timeout: float | None = None) -> CallResult: + """Wait for the first call to complete and return its result. + + Args: + timeout: Timeout in seconds + + Returns: + CallResult from the first completed call + + Raises: + TimeoutError: If no results are ready within timeout + """ + for future in as_completed([f for _, f in self._futures], timeout=timeout): + idx = next(i for i, (_, f) in enumerate(self._futures) if f is future) + endpoint = self._futures[idx][0] + try: + value = future.result() + return CallResult(endpoint=endpoint, value=value) + except Exception as e: + return CallResult(endpoint=endpoint, exception=e) + raise TimeoutError("No results within timeout") + + def as_completed(self, timeout: float | None = None) -> Iterator[CallResult]: + """Iterate over results as they complete. + + Args: + timeout: Total timeout in seconds for all calls + + Yields: + CallResult for each completed call + """ + endpoint_map = {id(f): ep for ep, f in self._futures} + for future in as_completed([f for _, f in self._futures], timeout=timeout): + endpoint = endpoint_map[id(future)] + try: + value = future.result() + yield CallResult(endpoint=endpoint, value=value) + except Exception as e: + yield CallResult(endpoint=endpoint, exception=e) + + +class ActorPool(Generic[T]): + """Pool of actors for load-balanced and broadcast calls. + + Resolves a pool of endpoints for an actor name and provides methods to + distribute calls across them (round-robin) or broadcast to all endpoints. + + Example: + >>> pool = ActorPool(resolver, "inference") + >>> result = pool.call().predict(data) # Round-robin to one endpoint + >>> broadcast = pool.broadcast().reload_model() # Send to all endpoints + >>> results = broadcast.wait_all() + """ + + def __init__(self, resolver: Resolver, name: str, timeout: float = 30.0): + """Initialize actor pool. + + Args: + resolver: Resolver to discover endpoints + name: Actor name to resolve + timeout: RPC timeout in seconds + """ + self._resolver = resolver + self._name = name + self._timeout = timeout + self._endpoint_index = 0 + self._cached_result: ResolveResult | None = None + self._lock = threading.Lock() + self._executor = ThreadPoolExecutor(max_workers=32) + + def _resolve(self) -> ResolveResult: + """Resolve endpoints, caching result.""" + result = self._resolver.resolve(self._name) + with self._lock: + self._cached_result = result + return result + + def _get_next_endpoint(self) -> ResolvedEndpoint: + """Get the next endpoint in round-robin order. + + Thread-safe: uses a lock to protect the endpoint index. + """ + endpoints = self._resolve().endpoints + with self._lock: + if not endpoints: + raise RuntimeError(f"No endpoints for '{self._name}'") + endpoint = endpoints[self._endpoint_index % len(endpoints)] + self._endpoint_index += 1 + return endpoint + + def shutdown(self) -> None: + """Shutdown the thread pool executor.""" + self._executor.shutdown(wait=True) + + def __enter__(self) -> "ActorPool[T]": + return self + + def __exit__(self, *args) -> None: + self.shutdown() + + @property + def size(self) -> int: + """Number of endpoints in the pool.""" + return len(self._resolve().endpoints) + + @property + def endpoints(self) -> list[ResolvedEndpoint]: + """List of resolved endpoints.""" + return list(self._resolve().endpoints) + + def _call_endpoint( + self, + endpoint: ResolvedEndpoint, + method_name: str, + args: tuple, + kwargs: dict, + ) -> Any: + """Make an RPC call to a specific endpoint. + + Args: + endpoint: Target endpoint + method_name: Method to call + args: Positional arguments + kwargs: Keyword arguments + + Returns: + Deserialized return value + + Raises: + Exception from the remote actor method + """ + client = ActorServiceClientSync( + address=endpoint.url, + timeout_ms=int(self._timeout * 1000), + ) + + call = actor_pb2.ActorCall( + method_name=method_name, + actor_name=self._name, + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + ) + + resp = client.call(call) + + if resp.HasField("error"): + if resp.error.serialized_exception: + raise cloudpickle.loads(resp.error.serialized_exception) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) + + def call(self) -> "_PoolCallProxy[T]": + """Create a proxy for round-robin calls. + + Returns: + Proxy that distributes method calls across endpoints + """ + return _PoolCallProxy(self) + + def broadcast(self) -> "_PoolBroadcastProxy[T]": + """Create a proxy for broadcast calls. + + Returns: + Proxy that sends method calls to all endpoints + """ + return _PoolBroadcastProxy(self) + + +class _PoolCallProxy(Generic[T]): + """Proxy for round-robin calls to a pool.""" + + def __init__(self, pool: ActorPool[T]): + self._pool = pool + + def __getattr__(self, method_name: str) -> Callable[..., Any]: + """Create a callable that invokes the method on the next endpoint in round-robin.""" + + def call(*args, **kwargs): + endpoint = self._pool._get_next_endpoint() + return self._pool._call_endpoint(endpoint, method_name, args, kwargs) + + return call + + +class _PoolBroadcastProxy(Generic[T]): + """Proxy for broadcast calls to all endpoints in a pool.""" + + def __init__(self, pool: ActorPool[T]): + self._pool = pool + + def __getattr__(self, method_name: str) -> Callable[..., BroadcastFuture]: + """Create a callable that invokes the method on all endpoints in parallel.""" + + def broadcast(*args, **kwargs) -> BroadcastFuture: + result = self._pool._resolve() + futures = [] + for endpoint in result.endpoints: + future = self._pool._executor.submit( + self._pool._call_endpoint, + endpoint, + method_name, + args, + kwargs, + ) + futures.append((endpoint, future)) + return BroadcastFuture(futures) + + return broadcast diff --git a/lib/fluster/src/fluster/actor/proto/actor.proto b/lib/fluster/src/fluster/actor/proto/actor.proto new file mode 100644 index 0000000000..e8b9d90bce --- /dev/null +++ b/lib/fluster/src/fluster/actor/proto/actor.proto @@ -0,0 +1,89 @@ +// Copyright 2025 The Marin Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package fluster.actor; + +option py_generic_services = true; + +// ============ Actor RPC ============ +// This is the protocol for calling actor methods. +// Actors register via the cluster's endpoint registry. + +message ActorCall { + string method_name = 1; + string actor_name = 2; // Which actor on this server + bytes serialized_args = 3; // cloudpickle(args) + bytes serialized_kwargs = 4; // cloudpickle(kwargs) +} + +message ActorResponse { + oneof result { + bytes serialized_value = 1; // cloudpickle(return_value) + ActorError error = 2; + } +} + +message ActorError { + string error_type = 1; + string message = 2; + bytes serialized_exception = 3; // cloudpickle(exception) for re-raise +} + +message Empty {} + +message HealthResponse { + bool healthy = 1; +} + +// ============ Introspection ============ +// Methods for debugging and discovering actors + +message ListMethodsRequest { + string actor_name = 1; +} + +message MethodInfo { + string name = 1; + string signature = 2; + string docstring = 3; +} + +message ListMethodsResponse { + repeated MethodInfo methods = 1; +} + +message ListActorsRequest {} + +message ActorInfo { + string name = 1; + string actor_id = 2; + int64 registered_at_ms = 3; + map metadata = 4; +} + +message ListActorsResponse { + repeated ActorInfo actors = 1; +} + +// ============ Actor Service ============ +// Each ActorServer exposes this service + +service ActorService { + rpc Call(ActorCall) returns (ActorResponse); + rpc HealthCheck(Empty) returns (HealthResponse); + rpc ListMethods(ListMethodsRequest) returns (ListMethodsResponse); + rpc ListActors(ListActorsRequest) returns (ListActorsResponse); +} diff --git a/lib/fluster/src/fluster/actor/resolver.py b/lib/fluster/src/fluster/actor/resolver.py new file mode 100644 index 0000000000..4245459ec5 --- /dev/null +++ b/lib/fluster/src/fluster/actor/resolver.py @@ -0,0 +1,276 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resolver types and implementations for actor discovery.""" + +import os +from dataclasses import dataclass, field +from typing import Protocol as TypingProtocol + +from fluster import cluster_pb2 +from fluster.cluster.types import Namespace +from fluster.cluster_connect import ControllerServiceClientSync + + +@dataclass +class ResolvedEndpoint: + """A single resolved endpoint.""" + + url: str # e.g., "http://host:port" + actor_id: str # Unique handle for staleness detection + metadata: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ResolveResult: + """Result of resolving an actor name.""" + + name: str + namespace: Namespace + endpoints: list[ResolvedEndpoint] = field(default_factory=list) + + @property + def is_empty(self) -> bool: + return len(self.endpoints) == 0 + + def first(self) -> ResolvedEndpoint: + if not self.endpoints: + raise ValueError(f"No endpoints for '{self.name}' in namespace '{self.namespace}'") + return self.endpoints[0] + + +class Resolver(TypingProtocol): + """Protocol for actor name resolution.""" + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: ... + + @property + def default_namespace(self) -> Namespace: ... + + +class FixedResolver: + """Resolver with statically configured endpoints.""" + + def __init__( + self, + endpoints: dict[str, str | list[str]], + namespace: Namespace = Namespace(""), + ): + self._namespace = namespace + self._endpoints: dict[str, list[str]] = {} + for name, urls in endpoints.items(): + if isinstance(urls, str): + self._endpoints[name] = [urls] + else: + self._endpoints[name] = list(urls) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + ns = namespace or self._namespace + urls = self._endpoints.get(name, []) + endpoints = [ResolvedEndpoint(url=url, actor_id=f"fixed-{name}-{i}") for i, url in enumerate(urls)] + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) + + +class ClusterResolver: + """Resolver backed by the cluster controller's endpoint registry. + + Queries the controller's ListEndpoints RPC to discover actor endpoints + registered by running jobs. Respects namespace boundaries for isolation. + + Args: + controller_address: Controller URL (e.g., "http://localhost:8080") + namespace: Namespace for actor isolation (defaults to FLUSTER_NAMESPACE env var) + timeout: HTTP request timeout in seconds + """ + + def __init__( + self, + controller_address: str, + namespace: Namespace | None = None, + timeout: float = 5.0, + ): + self._address = controller_address.rstrip("/") + self._timeout = timeout + self._namespace = namespace or Namespace(os.environ.get("FLUSTER_NAMESPACE", "")) + self._client = ControllerServiceClientSync( + address=self._address, + timeout_ms=int(timeout * 1000), + ) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + """Resolve actor name to endpoints via controller. + + Args: + name: Actor name to resolve + namespace: Override default namespace + + Returns: + ResolveResult with matching endpoints + """ + ns = namespace or self._namespace + + request = cluster_pb2.Controller.ListEndpointsRequest( + prefix=name, + namespace=str(ns), + ) + + resp = self._client.list_endpoints(request) + + # Filter to exact name matches (controller uses prefix matching) + endpoints = [ + ResolvedEndpoint( + url=f"http://{ep.address}", + actor_id=ep.endpoint_id, + metadata=dict(ep.metadata), + ) + for ep in resp.endpoints + if ep.name == name + ] + + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) + + +class GcsApi(TypingProtocol): + """Protocol for GCS Compute API operations.""" + + def list_instances(self, project: str, zone: str) -> list[dict]: + """List VM instances with metadata.""" + ... + + +class RealGcsApi: + """Real GCS API using google-cloud-compute.""" + + def list_instances(self, project: str, zone: str) -> list[dict]: + from google.cloud import compute_v1 + + client = compute_v1.InstancesClient() + instances = [] + for instance in client.list(project=project, zone=zone): + metadata = {} + if instance.metadata and instance.metadata.items: + for item in instance.metadata.items: + metadata[item.key] = item.value + + internal_ip = None + if instance.network_interfaces: + internal_ip = instance.network_interfaces[0].network_i_p + + instances.append( + { + "name": instance.name, + "internal_ip": internal_ip, + "metadata": metadata, + "status": instance.status, + } + ) + return instances + + +class MockGcsApi: + """Mock GCS API for testing.""" + + def __init__(self, instances: list[dict] | None = None): + self._instances = instances or [] + + def set_instances(self, instances: list[dict]) -> None: + self._instances = instances + + def list_instances(self, project: str, zone: str) -> list[dict]: + return self._instances + + +class GcsResolver: + """Resolver using GCS VM instance metadata tags. + + Discovers actor endpoints by querying GCP VM instance metadata. Instances must + have metadata tags in the format: + - `fluster_actor_`: port number for the actor + - `fluster_namespace`: namespace for isolation (defaults to "") + + Only RUNNING instances are considered for resolution. + + Args: + project: GCP project ID + zone: GCP zone (e.g., "us-central1-a") + namespace: Namespace for actor isolation (defaults to FLUSTER_NAMESPACE env var) + api: GcsApi implementation (defaults to RealGcsApi) + """ + + ACTOR_PREFIX = "fluster_actor_" + NAMESPACE_KEY = "fluster_namespace" + + def __init__( + self, + project: str, + zone: str, + namespace: Namespace | None = None, + api: GcsApi | None = None, + ): + self._project = project + self._zone = zone + self._api = api or RealGcsApi() + self._namespace = namespace or Namespace(os.environ.get("FLUSTER_NAMESPACE", "")) + + @property + def default_namespace(self) -> Namespace: + return self._namespace + + def resolve(self, name: str, namespace: Namespace | None = None) -> ResolveResult: + """Resolve actor name to endpoints via GCS instance metadata. + + Args: + name: Actor name to resolve + namespace: Override default namespace + + Returns: + ResolveResult with matching endpoints from RUNNING instances + """ + ns = namespace or self._namespace + endpoints = [] + + instances = self._api.list_instances(self._project, self._zone) + + for instance in instances: + if instance.get("status") != "RUNNING": + continue + + metadata = instance.get("metadata", {}) + instance_ns = metadata.get(self.NAMESPACE_KEY, "") + + if instance_ns != str(ns): + continue + + actor_key = f"{self.ACTOR_PREFIX}{name}" + if actor_key in metadata: + port = metadata[actor_key] + ip = instance.get("internal_ip") + if ip: + endpoints.append( + ResolvedEndpoint( + url=f"http://{ip}:{port}", + actor_id=f"gcs-{instance['name']}-{name}", + metadata={"instance": instance["name"]}, + ) + ) + + return ResolveResult(name=name, namespace=ns, endpoints=endpoints) diff --git a/lib/fluster/src/fluster/actor/server.py b/lib/fluster/src/fluster/actor/server.py new file mode 100644 index 0000000000..9331daa7f9 --- /dev/null +++ b/lib/fluster/src/fluster/actor/server.py @@ -0,0 +1,232 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Actor server implementation for hosting actor instances.""" + +import inspect +import socket +import threading +import time +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import cloudpickle +import uvicorn + +from fluster import actor_pb2 +from fluster.actor.types import ActorContext, ActorId, _set_actor_context +from fluster.actor_connect import ActorServiceASGIApplication +from connectrpc.request import RequestContext + + +@dataclass +class RegisteredActor: + """An actor registered with the server.""" + + name: str + actor_id: ActorId + instance: Any + methods: dict[str, Callable] + registered_at_ms: int = field(default_factory=lambda: int(time.time() * 1000)) + + +class ActorServer: + """Server for hosting actor instances and handling RPC calls.""" + + def __init__(self, host: str = "0.0.0.0", port: int = 0): + """Initialize the actor server. + + Args: + host: Host address to bind to + port: Port to bind to (0 for auto-assign) + """ + self._host = host + self._port = port + self._actors: dict[str, RegisteredActor] = {} + self._context: ActorContext | None = None + self._app: ActorServiceASGIApplication | None = None + self._actual_port: int | None = None + + @property + def address(self) -> str: + """Get the server address as host:port.""" + port = self._actual_port or self._port + return f"{self._host}:{port}" + + def register(self, name: str, actor: Any) -> ActorId: + """Register an actor instance with the server. + + Args: + name: Name for actor discovery + actor: Actor instance with public methods + + Returns: + Unique actor ID + """ + actor_id = ActorId(f"{name}-{uuid.uuid4().hex[:8]}") + methods = {m: getattr(actor, m) for m in dir(actor) if not m.startswith("_") and callable(getattr(actor, m))} + self._actors[name] = RegisteredActor( + name=name, + actor_id=actor_id, + instance=actor, + methods=methods, + ) + return actor_id + + async def call(self, request: actor_pb2.ActorCall, ctx: RequestContext) -> actor_pb2.ActorResponse: + """Handle actor RPC call.""" + # Find actor + actor_name = request.actor_name or next(iter(self._actors), "") + actor = self._actors.get(actor_name) + if not actor: + error = actor_pb2.ActorError( + error_type="NotFound", + message=f"Actor '{actor_name}' not found", + ) + return actor_pb2.ActorResponse(error=error) + + method = actor.methods.get(request.method_name) + if not method: + error = actor_pb2.ActorError( + error_type="NotFound", + message=f"Method '{request.method_name}' not found", + ) + return actor_pb2.ActorResponse(error=error) + + try: + args = cloudpickle.loads(request.serialized_args) if request.serialized_args else () + kwargs = cloudpickle.loads(request.serialized_kwargs) if request.serialized_kwargs else {} + + # Set context for this call + _set_actor_context(self._context) + try: + result = method(*args, **kwargs) + finally: + _set_actor_context(None) + + return actor_pb2.ActorResponse(serialized_value=cloudpickle.dumps(result)) + + except Exception as e: + error = actor_pb2.ActorError( + error_type=type(e).__name__, + message=str(e), + serialized_exception=cloudpickle.dumps(e), + ) + return actor_pb2.ActorResponse(error=error) + + async def health_check(self, request: actor_pb2.Empty, ctx: RequestContext) -> actor_pb2.HealthResponse: + """Handle health check.""" + return actor_pb2.HealthResponse(healthy=True) + + async def list_methods( + self, request: actor_pb2.ListMethodsRequest, ctx: RequestContext + ) -> actor_pb2.ListMethodsResponse: + """List all methods available on an actor. + + Returns method names, signatures, and docstrings for debugging. + """ + actor_name = request.actor_name or next(iter(self._actors), "") + actor = self._actors.get(actor_name) + if not actor: + return actor_pb2.ListMethodsResponse() + + methods = [] + for name, method in actor.methods.items(): + try: + sig = str(inspect.signature(method)) + except (ValueError, TypeError): + sig = "()" + + docstring = inspect.getdoc(method) or "" + + methods.append( + actor_pb2.MethodInfo( + name=name, + signature=sig, + docstring=docstring, + ) + ) + + return actor_pb2.ListMethodsResponse(methods=methods) + + async def list_actors( + self, request: actor_pb2.ListActorsRequest, ctx: RequestContext + ) -> actor_pb2.ListActorsResponse: + """List all actors registered with this server. + + Returns actor names, IDs, and registration timestamps for debugging. + """ + actors = [] + for actor in self._actors.values(): + actors.append( + actor_pb2.ActorInfo( + name=actor.name, + actor_id=actor.actor_id, + registered_at_ms=actor.registered_at_ms, + metadata={}, + ) + ) + + return actor_pb2.ListActorsResponse(actors=actors) + + def _create_app(self) -> ActorServiceASGIApplication: + """Create the Connect RPC ASGI application for the server.""" + return ActorServiceASGIApplication(service=self) + + def serve_background(self, context: ActorContext | None = None) -> int: + """Start server in background thread. + + Args: + context: ActorContext to inject into actor method calls + + Returns: + Actual port the server is listening on + """ + self._context = context + self._app = self._create_app() + + # Find available port if port=0 + if self._port == 0: + with socket.socket() as s: + s.bind(("", 0)) + self._actual_port = s.getsockname()[1] + else: + self._actual_port = self._port + + config = uvicorn.Config( + self._app, + host=self._host, + port=self._actual_port, + log_level="error", + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + for _ in range(50): + try: + import httpx + + httpx.get(f"http://{self._host}:{self._actual_port}/", timeout=0.1) + except Exception: + pass + time.sleep(0.1) + if server.started: + break + + return self._actual_port diff --git a/lib/fluster/src/fluster/actor/types.py b/lib/fluster/src/fluster/actor/types.py new file mode 100644 index 0000000000..ae33f9d0c6 --- /dev/null +++ b/lib/fluster/src/fluster/actor/types.py @@ -0,0 +1,90 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core types for the fluster actor layer. + +This module contains actor-specific types that depend on the cluster layer. +""" + +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NewType + +if TYPE_CHECKING: + from fluster.actor.resolver import Resolver + from fluster.cluster.types import JobId, Namespace + + +# Type aliases +ActorId = NewType("ActorId", str) + +# Context variable for actor context injection +_actor_context: ContextVar["ActorContext | None"] = ContextVar("actor_context", default=None) + + +def current_ctx() -> "ActorContext": + """Get the current ActorContext. Raises if not in an actor call.""" + ctx = _actor_context.get() + if ctx is None: + raise RuntimeError("current_ctx() called outside of actor method") + return ctx + + +def _set_actor_context(ctx: "ActorContext | None") -> None: + """Internal: set the actor context for the current call.""" + _actor_context.set(ctx) + + +@dataclass +class ActorEndpoint: + """Actor endpoint for discovery and RPC. + + Wraps a cluster Endpoint with actor-specific semantics. + + Args: + actor_id: Unique actor identifier + name: Actor name for discovery + address: Network address (host:port) + job_id: Job hosting this actor + namespace: Namespace for scoping + metadata: Optional key-value metadata + """ + + actor_id: ActorId + name: str + address: str + job_id: "JobId" + namespace: "Namespace" + metadata: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ActorContext: + """Context passed to actor methods as first argument. + + Enables actors to call other actors and access cluster services. + + Args: + cluster: Cluster client for job management (or None for Stage 1) + resolver: Resolver for actor discovery (or None for Stage 1) + job_id: Current job ID + namespace: Current namespace + """ + + cluster: Any + resolver: "Resolver | None" + job_id: str + namespace: str + + # TODO: Stage 2+: from_environment() will be implemented when ClusterResolver exists diff --git a/lib/fluster/src/fluster/actor_connect.py b/lib/fluster/src/fluster/actor_connect.py new file mode 100644 index 0000000000..f2f87af45f --- /dev/null +++ b/lib/fluster/src/fluster/actor_connect.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! +# source: actor.proto + +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping +from typing import Protocol + +from connectrpc.client import ConnectClient, ConnectClientSync +from connectrpc.code import Code +from connectrpc.errors import ConnectError +from connectrpc.interceptor import Interceptor, InterceptorSync +from connectrpc.method import IdempotencyLevel, MethodInfo +from connectrpc.request import Headers, RequestContext +from connectrpc.server import ConnectASGIApplication, ConnectWSGIApplication, Endpoint, EndpointSync +from . import actor_pb2 as actor__pb2 + + +class ActorService(Protocol): + async def call(self, request: actor__pb2.ActorCall, ctx: RequestContext) -> actor__pb2.ActorResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def health_check(self, request: actor__pb2.Empty, ctx: RequestContext) -> actor__pb2.HealthResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_methods(self, request: actor__pb2.ListMethodsRequest, ctx: RequestContext) -> actor__pb2.ListMethodsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_actors(self, request: actor__pb2.ListActorsRequest, ctx: RequestContext) -> actor__pb2.ListActorsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class ActorServiceASGIApplication(ConnectASGIApplication[ActorService]): + def __init__(self, service: ActorService | AsyncGenerator[ActorService], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + service=service, + endpoints=lambda svc: { + "/fluster.actor.ActorService/Call": Endpoint.unary( + method=MethodInfo( + name="Call", + service_name="fluster.actor.ActorService", + input=actor__pb2.ActorCall, + output=actor__pb2.ActorResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.call, + ), + "/fluster.actor.ActorService/HealthCheck": Endpoint.unary( + method=MethodInfo( + name="HealthCheck", + service_name="fluster.actor.ActorService", + input=actor__pb2.Empty, + output=actor__pb2.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.health_check, + ), + "/fluster.actor.ActorService/ListMethods": Endpoint.unary( + method=MethodInfo( + name="ListMethods", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListMethodsRequest, + output=actor__pb2.ListMethodsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_methods, + ), + "/fluster.actor.ActorService/ListActors": Endpoint.unary( + method=MethodInfo( + name="ListActors", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListActorsRequest, + output=actor__pb2.ListActorsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_actors, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.actor.ActorService" + + +class ActorServiceClient(ConnectClient): + async def call( + self, + request: actor__pb2.ActorCall, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ActorResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="Call", + service_name="fluster.actor.ActorService", + input=actor__pb2.ActorCall, + output=actor__pb2.ActorResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def health_check( + self, + request: actor__pb2.Empty, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.HealthResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="HealthCheck", + service_name="fluster.actor.ActorService", + input=actor__pb2.Empty, + output=actor__pb2.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_methods( + self, + request: actor__pb2.ListMethodsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ListMethodsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListMethods", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListMethodsRequest, + output=actor__pb2.ListMethodsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_actors( + self, + request: actor__pb2.ListActorsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ListActorsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListActors", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListActorsRequest, + output=actor__pb2.ListActorsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + +class ActorServiceSync(Protocol): + def call(self, request: actor__pb2.ActorCall, ctx: RequestContext) -> actor__pb2.ActorResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def health_check(self, request: actor__pb2.Empty, ctx: RequestContext) -> actor__pb2.HealthResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_methods(self, request: actor__pb2.ListMethodsRequest, ctx: RequestContext) -> actor__pb2.ListMethodsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_actors(self, request: actor__pb2.ListActorsRequest, ctx: RequestContext) -> actor__pb2.ListActorsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class ActorServiceWSGIApplication(ConnectWSGIApplication): + def __init__(self, service: ActorServiceSync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + endpoints={ + "/fluster.actor.ActorService/Call": EndpointSync.unary( + method=MethodInfo( + name="Call", + service_name="fluster.actor.ActorService", + input=actor__pb2.ActorCall, + output=actor__pb2.ActorResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.call, + ), + "/fluster.actor.ActorService/HealthCheck": EndpointSync.unary( + method=MethodInfo( + name="HealthCheck", + service_name="fluster.actor.ActorService", + input=actor__pb2.Empty, + output=actor__pb2.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.health_check, + ), + "/fluster.actor.ActorService/ListMethods": EndpointSync.unary( + method=MethodInfo( + name="ListMethods", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListMethodsRequest, + output=actor__pb2.ListMethodsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_methods, + ), + "/fluster.actor.ActorService/ListActors": EndpointSync.unary( + method=MethodInfo( + name="ListActors", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListActorsRequest, + output=actor__pb2.ListActorsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_actors, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.actor.ActorService" + + +class ActorServiceClientSync(ConnectClientSync): + def call( + self, + request: actor__pb2.ActorCall, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ActorResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="Call", + service_name="fluster.actor.ActorService", + input=actor__pb2.ActorCall, + output=actor__pb2.ActorResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def health_check( + self, + request: actor__pb2.Empty, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.HealthResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="HealthCheck", + service_name="fluster.actor.ActorService", + input=actor__pb2.Empty, + output=actor__pb2.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_methods( + self, + request: actor__pb2.ListMethodsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ListMethodsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListMethods", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListMethodsRequest, + output=actor__pb2.ListMethodsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_actors( + self, + request: actor__pb2.ListActorsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> actor__pb2.ListActorsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListActors", + service_name="fluster.actor.ActorService", + input=actor__pb2.ListActorsRequest, + output=actor__pb2.ListActorsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) diff --git a/lib/fluster/src/fluster/actor_pb2.py b/lib/fluster/src/fluster/actor_pb2.py new file mode 100644 index 0000000000..a17cd91d85 --- /dev/null +++ b/lib/fluster/src/fluster/actor_pb2.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: actor.proto +# Protobuf Python Version: 6.33.4 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 4, + '', + 'actor.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x61\x63tor.proto\x12\rfluster.actor\"\xa1\x01\n\tActorCall\x12\x1f\n\x0bmethod_name\x18\x01 \x01(\tR\nmethodName\x12\x1d\n\nactor_name\x18\x02 \x01(\tR\tactorName\x12\'\n\x0fserialized_args\x18\x03 \x01(\x0cR\x0eserializedArgs\x12+\n\x11serialized_kwargs\x18\x04 \x01(\x0cR\x10serializedKwargs\"y\n\rActorResponse\x12+\n\x10serialized_value\x18\x01 \x01(\x0cH\x00R\x0fserializedValue\x12\x31\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x19.fluster.actor.ActorErrorH\x00R\x05\x65rrorB\x08\n\x06result\"x\n\nActorError\x12\x1d\n\nerror_type\x18\x01 \x01(\tR\terrorType\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12\x31\n\x14serialized_exception\x18\x03 \x01(\x0cR\x13serializedException\"\x07\n\x05\x45mpty\"*\n\x0eHealthResponse\x12\x18\n\x07healthy\x18\x01 \x01(\x08R\x07healthy\"3\n\x12ListMethodsRequest\x12\x1d\n\nactor_name\x18\x01 \x01(\tR\tactorName\"\\\n\nMethodInfo\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\tsignature\x18\x02 \x01(\tR\tsignature\x12\x1c\n\tdocstring\x18\x03 \x01(\tR\tdocstring\"J\n\x13ListMethodsResponse\x12\x33\n\x07methods\x18\x01 \x03(\x0b\x32\x19.fluster.actor.MethodInfoR\x07methods\"\x13\n\x11ListActorsRequest\"\xe5\x01\n\tActorInfo\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x19\n\x08\x61\x63tor_id\x18\x02 \x01(\tR\x07\x61\x63torId\x12(\n\x10registered_at_ms\x18\x03 \x01(\x03R\x0eregisteredAtMs\x12\x42\n\x08metadata\x18\x04 \x03(\x0b\x32&.fluster.actor.ActorInfo.MetadataEntryR\x08metadata\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"F\n\x12ListActorsResponse\x12\x30\n\x06\x61\x63tors\x18\x01 \x03(\x0b\x32\x18.fluster.actor.ActorInfoR\x06\x61\x63tors2\xbb\x02\n\x0c\x41\x63torService\x12>\n\x04\x43\x61ll\x12\x18.fluster.actor.ActorCall\x1a\x1c.fluster.actor.ActorResponse\x12\x42\n\x0bHealthCheck\x12\x14.fluster.actor.Empty\x1a\x1d.fluster.actor.HealthResponse\x12T\n\x0bListMethods\x12!.fluster.actor.ListMethodsRequest\x1a\".fluster.actor.ListMethodsResponse\x12Q\n\nListActors\x12 .fluster.actor.ListActorsRequest\x1a!.fluster.actor.ListActorsResponseBw\n\x11\x63om.fluster.actorB\nActorProtoP\x01\x90\x01\x01\xa2\x02\x03\x46\x41X\xaa\x02\rFluster.Actor\xca\x02\rFluster\\Actor\xe2\x02\x19\x46luster\\Actor\\GPBMetadata\xea\x02\x0e\x46luster::Actorb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'actor_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\021com.fluster.actorB\nActorProtoP\001\220\001\001\242\002\003FAX\252\002\rFluster.Actor\312\002\rFluster\\Actor\342\002\031Fluster\\Actor\\GPBMetadata\352\002\016Fluster::Actor' + _globals['_ACTORINFO_METADATAENTRY']._loaded_options = None + _globals['_ACTORINFO_METADATAENTRY']._serialized_options = b'8\001' + _globals['_ACTORCALL']._serialized_start=31 + _globals['_ACTORCALL']._serialized_end=192 + _globals['_ACTORRESPONSE']._serialized_start=194 + _globals['_ACTORRESPONSE']._serialized_end=315 + _globals['_ACTORERROR']._serialized_start=317 + _globals['_ACTORERROR']._serialized_end=437 + _globals['_EMPTY']._serialized_start=439 + _globals['_EMPTY']._serialized_end=446 + _globals['_HEALTHRESPONSE']._serialized_start=448 + _globals['_HEALTHRESPONSE']._serialized_end=490 + _globals['_LISTMETHODSREQUEST']._serialized_start=492 + _globals['_LISTMETHODSREQUEST']._serialized_end=543 + _globals['_METHODINFO']._serialized_start=545 + _globals['_METHODINFO']._serialized_end=637 + _globals['_LISTMETHODSRESPONSE']._serialized_start=639 + _globals['_LISTMETHODSRESPONSE']._serialized_end=713 + _globals['_LISTACTORSREQUEST']._serialized_start=715 + _globals['_LISTACTORSREQUEST']._serialized_end=734 + _globals['_ACTORINFO']._serialized_start=737 + _globals['_ACTORINFO']._serialized_end=966 + _globals['_ACTORINFO_METADATAENTRY']._serialized_start=907 + _globals['_ACTORINFO_METADATAENTRY']._serialized_end=966 + _globals['_LISTACTORSRESPONSE']._serialized_start=968 + _globals['_LISTACTORSRESPONSE']._serialized_end=1038 + _globals['_ACTORSERVICE']._serialized_start=1041 + _globals['_ACTORSERVICE']._serialized_end=1356 +_builder.BuildServices(DESCRIPTOR, 'actor_pb2', _globals) +# @@protoc_insertion_point(module_scope) diff --git a/lib/fluster/src/fluster/actor_pb2.pyi b/lib/fluster/src/fluster/actor_pb2.pyi new file mode 100644 index 0000000000..29897d0b03 --- /dev/null +++ b/lib/fluster/src/fluster/actor_pb2.pyi @@ -0,0 +1,103 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import service as _service +from collections.abc import Iterable as _Iterable, Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class ActorCall(_message.Message): + __slots__ = ("method_name", "actor_name", "serialized_args", "serialized_kwargs") + METHOD_NAME_FIELD_NUMBER: _ClassVar[int] + ACTOR_NAME_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_ARGS_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_KWARGS_FIELD_NUMBER: _ClassVar[int] + method_name: str + actor_name: str + serialized_args: bytes + serialized_kwargs: bytes + def __init__(self, method_name: _Optional[str] = ..., actor_name: _Optional[str] = ..., serialized_args: _Optional[bytes] = ..., serialized_kwargs: _Optional[bytes] = ...) -> None: ... + +class ActorResponse(_message.Message): + __slots__ = ("serialized_value", "error") + SERIALIZED_VALUE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + serialized_value: bytes + error: ActorError + def __init__(self, serialized_value: _Optional[bytes] = ..., error: _Optional[_Union[ActorError, _Mapping]] = ...) -> None: ... + +class ActorError(_message.Message): + __slots__ = ("error_type", "message", "serialized_exception") + ERROR_TYPE_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_EXCEPTION_FIELD_NUMBER: _ClassVar[int] + error_type: str + message: str + serialized_exception: bytes + def __init__(self, error_type: _Optional[str] = ..., message: _Optional[str] = ..., serialized_exception: _Optional[bytes] = ...) -> None: ... + +class Empty(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class HealthResponse(_message.Message): + __slots__ = ("healthy",) + HEALTHY_FIELD_NUMBER: _ClassVar[int] + healthy: bool + def __init__(self, healthy: _Optional[bool] = ...) -> None: ... + +class ListMethodsRequest(_message.Message): + __slots__ = ("actor_name",) + ACTOR_NAME_FIELD_NUMBER: _ClassVar[int] + actor_name: str + def __init__(self, actor_name: _Optional[str] = ...) -> None: ... + +class MethodInfo(_message.Message): + __slots__ = ("name", "signature", "docstring") + NAME_FIELD_NUMBER: _ClassVar[int] + SIGNATURE_FIELD_NUMBER: _ClassVar[int] + DOCSTRING_FIELD_NUMBER: _ClassVar[int] + name: str + signature: str + docstring: str + def __init__(self, name: _Optional[str] = ..., signature: _Optional[str] = ..., docstring: _Optional[str] = ...) -> None: ... + +class ListMethodsResponse(_message.Message): + __slots__ = ("methods",) + METHODS_FIELD_NUMBER: _ClassVar[int] + methods: _containers.RepeatedCompositeFieldContainer[MethodInfo] + def __init__(self, methods: _Optional[_Iterable[_Union[MethodInfo, _Mapping]]] = ...) -> None: ... + +class ListActorsRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ActorInfo(_message.Message): + __slots__ = ("name", "actor_id", "registered_at_ms", "metadata") + class MetadataEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + ACTOR_ID_FIELD_NUMBER: _ClassVar[int] + REGISTERED_AT_MS_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + name: str + actor_id: str + registered_at_ms: int + metadata: _containers.ScalarMap[str, str] + def __init__(self, name: _Optional[str] = ..., actor_id: _Optional[str] = ..., registered_at_ms: _Optional[int] = ..., metadata: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ListActorsResponse(_message.Message): + __slots__ = ("actors",) + ACTORS_FIELD_NUMBER: _ClassVar[int] + actors: _containers.RepeatedCompositeFieldContainer[ActorInfo] + def __init__(self, actors: _Optional[_Iterable[_Union[ActorInfo, _Mapping]]] = ...) -> None: ... + +class ActorService(_service.service): ... + +class ActorService_Stub(ActorService): ... diff --git a/lib/fluster/src/fluster/cluster/__init__.py b/lib/fluster/src/fluster/cluster/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/src/fluster/cluster/backend/__init__.py b/lib/fluster/src/fluster/cluster/backend/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/backend/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/src/fluster/cluster/backend/base.py b/lib/fluster/src/fluster/cluster/backend/base.py new file mode 100644 index 0000000000..89c85d543f --- /dev/null +++ b/lib/fluster/src/fluster/cluster/backend/base.py @@ -0,0 +1,18 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""VM backend abstraction. + +TODO: Implement in Stage 3 +""" diff --git a/lib/fluster/src/fluster/cluster/client.py b/lib/fluster/src/fluster/cluster/client.py new file mode 100644 index 0000000000..77fb7c0804 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/client.py @@ -0,0 +1,293 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cluster client for job management. + +This module provides: +- ClusterClient: Protocol for cluster job operations +- RpcClusterClient: Default implementation using RPC to controller +- BundleCreator: Helper for creating workspace bundles +""" + +import shutil +import subprocess +import tempfile +import time +import zipfile +from pathlib import Path +from typing import Protocol + +import cloudpickle + +from fluster import cluster_pb2 +from fluster.cluster.types import Entrypoint, JobId, is_job_finished +from fluster.cluster_connect import ControllerServiceClientSync + + +class ClusterClient(Protocol): + """Protocol for cluster job operations. + + This is the interface WorkerPool and other clients use to interact + with a cluster. Default implementation is RpcClusterClient. + """ + + def submit( + self, + entrypoint: Entrypoint, + name: str, + resources: cluster_pb2.ResourceSpec, + environment: cluster_pb2.EnvironmentConfig | None = None, + namespace: str = "", + ports: list[str] | None = None, + ) -> JobId: + """Submit a job to the cluster. + + Args: + entrypoint: Job entrypoint (callable + args/kwargs) + name: Job name + resources: Resource requirements + environment: Environment configuration + namespace: Namespace for actor isolation + ports: Port names to allocate (e.g., ["actor", "metrics"]) + + Returns: + Job ID + """ + ... + + def status(self, job_id: JobId) -> cluster_pb2.JobStatus: + """Get job status. + + Args: + job_id: Job ID to query + + Returns: + JobInfo proto with current state + """ + ... + + def wait( + self, + job_id: JobId, + timeout: float = 300.0, + poll_interval: float = 0.5, + ) -> cluster_pb2.JobStatus: + """Wait for job to complete. + + Args: + job_id: Job ID to wait for + timeout: Maximum time to wait in seconds + poll_interval: Time between status checks + + Returns: + Final JobInfo + + Raises: + TimeoutError: If job doesn't complete within timeout + """ + ... + + def terminate(self, job_id: JobId) -> None: + """Terminate a running job. + + Args: + job_id: Job ID to terminate + """ + ... + + @property + def controller_address(self) -> str: + """Address of the cluster controller (for resolver).""" + ... + + +MINIMAL_PYPROJECT = """\ +[project] +name = "fluster-bundle" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [ + "cloudpickle", + "fluster", +] + +[tool.uv.sources] +fluster = { path = "./fluster" } +""" + + +class BundleCreator: + """Helper for creating workspace bundles. + + Creates minimal workspace bundles with pyproject.toml, uv.lock, + and fluster source for job execution. + """ + + def __init__(self, fluster_root: Path | None = None): + """Initialize bundle creator. + + Args: + fluster_root: Path to fluster project root. If None, auto-detects + from this file's location. + """ + if fluster_root is None: + # This file is at: lib/fluster/src/fluster/cluster/client.py + # Fluster root is: lib/fluster/ + fluster_root = Path(__file__).parent.parent.parent.parent + self._fluster_root = fluster_root + + def create_bundle(self, temp_dir: Path | None = None) -> bytes: + """Create a workspace bundle. + + Creates a zip file containing: + - pyproject.toml with fluster dependency + - uv.lock generated from the workspace + - fluster source code + + Args: + temp_dir: Optional temp directory for workspace. Creates one if None. + + Returns: + Bundle as bytes (zip file contents) + """ + if temp_dir is None: + with tempfile.TemporaryDirectory(prefix="bundle_") as td: + return self._create_bundle_in_dir(Path(td)) + return self._create_bundle_in_dir(temp_dir) + + def _create_bundle_in_dir(self, temp_dir: Path) -> bytes: + """Create bundle in the given temp directory.""" + workspace = temp_dir / "workspace" + workspace.mkdir(exist_ok=True) + + # Write minimal pyproject.toml + (workspace / "pyproject.toml").write_text(MINIMAL_PYPROJECT) + + # Copy fluster source + fluster_dest = workspace / "fluster" + fluster_dest.mkdir(exist_ok=True) + shutil.copy2(self._fluster_root / "pyproject.toml", fluster_dest / "pyproject.toml") + shutil.copytree( + self._fluster_root / "src", + fluster_dest / "src", + ignore=shutil.ignore_patterns("__pycache__", "*.pyc", "*.egg-info"), + ) + + # Generate uv.lock + subprocess.run( + ["uv", "lock"], + cwd=workspace, + check=True, + capture_output=True, + ) + + # Create zip bundle + bundle_path = temp_dir / "bundle.zip" + with zipfile.ZipFile(bundle_path, "w", zipfile.ZIP_DEFLATED) as zf: + for file in workspace.rglob("*"): + if file.is_file(): + zf.write(file, file.relative_to(workspace)) + + return bundle_path.read_bytes() + + +class RpcClusterClient: + """ClusterClient implementation using RPC to controller.""" + + def __init__( + self, + controller_address: str, + bundle_blob: bytes, + timeout_ms: int = 30000, + ): + """Initialize RPC cluster client. + + Args: + controller_address: Controller URL (e.g., "http://localhost:8080") + bundle_blob: Workspace bundle bytes (use BundleCreator to create) + timeout_ms: RPC timeout in milliseconds + """ + self._address = controller_address + self._bundle_blob = bundle_blob + self._timeout_ms = timeout_ms + self._client = ControllerServiceClientSync( + address=controller_address, + timeout_ms=timeout_ms, + ) + + @property + def controller_address(self) -> str: + return self._address + + def submit( + self, + entrypoint: Entrypoint, + name: str, + resources: cluster_pb2.ResourceSpec, + environment: cluster_pb2.EnvironmentConfig | None = None, + namespace: str = "", + ports: list[str] | None = None, + ) -> JobId: + """Submit a job to the cluster.""" + serialized = cloudpickle.dumps(entrypoint) + + # Build environment with namespace + env = dict(environment.env_vars) if environment else {} + env["FLUSTER_NAMESPACE"] = namespace + + env_config = cluster_pb2.EnvironmentConfig( + workspace=environment.workspace if environment else "/app", + pip_packages=list(environment.pip_packages) if environment else [], + env_vars=env, + extras=list(environment.extras) if environment else [], + ) + + request = cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=serialized, + resources=resources, + environment=env_config, + bundle_blob=self._bundle_blob, + ports=ports or [], + ) + response = self._client.launch_job(request) + return JobId(response.job_id) + + def status(self, job_id: JobId) -> cluster_pb2.JobStatus: + """Get job status.""" + request = cluster_pb2.Controller.GetJobStatusRequest(job_id=job_id) + response = self._client.get_job_status(request) + return response.job + + def wait( + self, + job_id: JobId, + timeout: float = 300.0, + poll_interval: float = 0.5, + ) -> cluster_pb2.JobStatus: + """Wait for job to complete.""" + start = time.time() + + while time.time() - start < timeout: + job_info = self.status(job_id) + if is_job_finished(job_info.state): + return job_info + time.sleep(poll_interval) + + raise TimeoutError(f"Job {job_id} did not complete in {timeout}s") + + def terminate(self, job_id: JobId) -> None: + """Terminate a running job.""" + request = cluster_pb2.Controller.TerminateJobRequest(job_id=job_id) + self._client.terminate_job(request) diff --git a/lib/fluster/src/fluster/cluster/controller/__init__.py b/lib/fluster/src/fluster/cluster/controller/__init__.py new file mode 100644 index 0000000000..dc30560db0 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fluster.cluster.controller.controller import Controller, ControllerConfig + +__all__ = ["Controller", "ControllerConfig"] diff --git a/lib/fluster/src/fluster/cluster/controller/controller.py b/lib/fluster/src/fluster/cluster/controller/controller.py new file mode 100644 index 0000000000..fac6c7d956 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/controller.py @@ -0,0 +1,630 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified Controller class for managing all controller components. + +Provides a single Controller class that encapsulates and manages the lifecycle +of all controller components: +- ControllerState: In-memory job and worker state +- Scheduler: Pure scheduling logic (shallow interface) +- WorkerHealthTracker: Worker health logic (shallow interface) +- ControllerServiceImpl: RPC service implementation +- ControllerDashboard: Web dashboard and HTTP server + +The Controller owns the background thread that drives scheduling and heartbeat +loops, calling the shallow Scheduler and WorkerHealthTracker as needed. + +This simplifies controller initialization and ensures consistent lifecycle +management across all components. +""" + +import logging +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Protocol + +import uvicorn + +from fluster import cluster_pb2 +from fluster.cluster.controller.dashboard import ControllerDashboard +from fluster.cluster.controller.retry import handle_job_failure +from fluster.cluster.controller.scheduler import ScheduleResult, Scheduler +from fluster.cluster.types import JobId +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster_connect import WorkerServiceClientSync + +logger = logging.getLogger(__name__) + + +class WorkerClient(Protocol): + """Protocol for worker RPC client. + + This matches the WorkerServiceClientSync signature (client-side). + The server Protocol (WorkerServiceSync) has different signatures. + """ + + def run_job( + self, + request: cluster_pb2.Worker.RunJobRequest, + ) -> cluster_pb2.Worker.RunJobResponse: ... + + def get_job_status( + self, + request: cluster_pb2.Worker.GetJobStatusRequest, + ) -> cluster_pb2.JobStatus: ... + + def list_jobs( + self, + request: cluster_pb2.Worker.ListJobsRequest, + ) -> cluster_pb2.Worker.ListJobsResponse: ... + + def health_check( + self, + request: cluster_pb2.Empty, + ) -> cluster_pb2.Worker.HealthResponse: ... + + +class WorkerStubFactory(Protocol): + """Factory for getting worker RPC stubs. + + This protocol allows injecting mock stubs for testing. In production, + use DefaultWorkerStubFactory which creates real RPC clients. + """ + + def get_stub(self, address: str) -> WorkerClient: + """Get a worker stub for the given address. + + Args: + address: Worker address in "host:port" format + + Returns: + A WorkerClient stub for making RPC calls + """ + ... + + +class DefaultWorkerStubFactory: + """Default factory that creates real RPC client stubs.""" + + def get_stub(self, address: str) -> WorkerClient: + """Create a real RPC client for the given address.""" + return WorkerServiceClientSync( + address=f"http://{address}", + timeout_ms=10000, + ) + + +MAX_CONSECUTIVE_HEARTBEAT_FAILURES = 3 + + +@dataclass +class JobStateUpdate: + """A job state change from heartbeat response.""" + + job_id: JobId + new_state: int # JobState enum value + exit_code: int = 0 + error: str | None = None + finished_at_ms: int = 0 + + +@dataclass +class HeartbeatResult: + """Result of processing a single heartbeat. + + Contains all information to update state after a heartbeat check. + """ + + worker_id: str + success: bool + consecutive_failures: int = 0 + worker_failed: bool = False + failed_job_ids: list[JobId] = field(default_factory=list) + job_updates: list[JobStateUpdate] = field(default_factory=list) + + +@dataclass +class ControllerConfig: + """Controller configuration. + + Args: + host: Host to bind the HTTP server to (default: "127.0.0.1") + port: Port to bind the HTTP server to (default: 0 for auto-assign) + bundle_dir: Directory for storing uploaded job bundles (optional) + scheduler_interval_seconds: How often the scheduler checks for pending jobs (default: 0.5) + heartbeat_interval_seconds: How often to check worker health (default: 2.0) + """ + + host: str = "127.0.0.1" + port: int = 0 + bundle_dir: Path | None = None + scheduler_interval_seconds: float = 0.5 + heartbeat_interval_seconds: float = 2.0 + + +class Controller: + """Unified controller managing all components and lifecycle. + + Encapsulates all controller components and provides a clean API for + job submission, status queries, and worker registration. The controller + handles all background threads and ensures proper cleanup on shutdown. + + Components managed: + - ControllerState: Thread-safe state for jobs, workers, and endpoints + - Scheduler: Pure job-to-worker matching (shallow interface) + - ControllerServiceImpl: RPC service implementation + - ControllerDashboard: Web dashboard and HTTP server + + The Controller owns a single background loop that periodically: + 1. Runs the scheduler to find job assignments + 2. Dispatches assigned jobs to workers + 3. Checks worker health via heartbeats + 4. Applies state changes from heartbeat results + + Example: + ```python + config = ControllerConfig(port=8080) + controller = Controller( + config=config, + worker_stub_factory=DefaultWorkerStubFactory(), + ) + controller.start() + try: + job_id = controller.launch_job(request) + status = controller.get_job_status(job_id) + finally: + controller.stop() + ``` + + Args: + config: Controller configuration + worker_stub_factory: Factory for creating worker RPC stubs. Use + DefaultWorkerStubFactory for production or inject a mock for testing. + """ + + def __init__( + self, + config: ControllerConfig, + worker_stub_factory: WorkerStubFactory, + ): + """Initialize controller components. + + Args: + config: Controller configuration + worker_stub_factory: Factory for creating worker RPC stubs + """ + self._config = config + self._stub_factory = worker_stub_factory + + # Initialize state first + self._state = ControllerState() + + # Scheduler: shallow interface, no threads, no callbacks + self._scheduler = Scheduler(self._state) + + # Service and dashboard + self._service = ControllerServiceImpl( + self._state, + self, # Controller implements the scheduling wake interface + bundle_dir=config.bundle_dir, + ) + self._dashboard = ControllerDashboard( + self._service, + host=config.host, + port=config.port, + ) + + # Background loop state + self._stop = False + self._wake_event = threading.Event() + self._loop_thread: threading.Thread | None = None + self._server_thread: threading.Thread | None = None + + # Track timing for scheduler and heartbeat + self._last_heartbeat_time = 0.0 + + def wake(self) -> None: + """Signal the controller loop to run immediately. + + Called when events occur that may make scheduling possible: + - New job submitted + - New worker registered + - Job finished (freeing capacity) + """ + self._wake_event.set() + + def start(self) -> None: + """Start all background components. + + Starts the main controller loop and dashboard server. + Both run in background daemon threads. + """ + self._stop = False + + # Start main controller loop + self._loop_thread = threading.Thread( + target=self._run_loop, + daemon=True, + ) + self._loop_thread.start() + + # Start dashboard server in background thread + self._server_thread = threading.Thread( + target=self._run_server, + daemon=True, + ) + self._server_thread.start() + + # Wait for server startup + time.sleep(1.0) + + def stop(self) -> None: + """Stop all background components gracefully. + + Signals the loop to stop, wakes it, and waits for termination. + The dashboard server stops automatically when the daemon thread exits. + """ + self._stop = True + self._wake_event.set() + if self._loop_thread: + self._loop_thread.join(timeout=5.0) + + def _run_loop(self) -> None: + """Main controller loop. + + Runs scheduling and heartbeat checks on configured intervals. + Uses an event for wake signaling to allow immediate scheduling + after job submissions or worker registrations. + """ + while not self._stop: + # Wait for wake signal or timeout (use scheduler interval) + self._wake_event.wait(timeout=self._config.scheduler_interval_seconds) + self._wake_event.clear() + + if self._stop: + break + + # Run scheduling + self._run_scheduling() + + # Check if heartbeats are due + now = time.time() + if now - self._last_heartbeat_time >= self._config.heartbeat_interval_seconds: + self._run_heartbeats() + self._last_heartbeat_time = now + + def _run_scheduling(self) -> None: + """Run one scheduling cycle. + + Gets pending jobs and available workers, calls the scheduler to + find assignments, then dispatches each assignment. + """ + now_ms = int(time.time() * 1000) + pending_jobs = self._state.peek_pending_jobs() + workers = self._state.get_available_workers() + + if not pending_jobs: + return + + result = self._scheduler.find_assignments(pending_jobs, workers, now_ms) + self._apply_schedule_result(result, now_ms) + + def _apply_schedule_result(self, result: ScheduleResult, now_ms: int) -> None: + """Apply scheduling results: dispatch jobs and handle timeouts. + + Args: + result: ScheduleResult from scheduler + now_ms: Current timestamp in milliseconds + """ + # Handle timed-out jobs + for job in result.timed_out_jobs: + self._mark_job_unschedulable(job, now_ms) + + # Dispatch assignments + for job, worker in result.assignments: + success = self._dispatch_job(job, worker) + if success: + self._handle_successful_dispatch(job, worker, now_ms) + else: + self._handle_failed_dispatch(job, worker) + + def _dispatch_job(self, job: ControllerJob, worker: ControllerWorker) -> bool: + """Dispatch a job to a worker via RPC. + + Args: + job: Job to dispatch + worker: Worker to dispatch to + + Returns: + True if dispatch succeeded, False on failure + """ + try: + stub = self._stub_factory.get_stub(worker.address) + request = cluster_pb2.Worker.RunJobRequest( + job_id=str(job.job_id), + serialized_entrypoint=job.request.serialized_entrypoint, + environment=cluster_pb2.EnvironmentConfig( + workspace=job.request.environment.workspace, + env_vars=dict(job.request.environment.env_vars), + ), + bundle_gcs_path=job.request.bundle_gcs_path, + resources=cluster_pb2.ResourceSpec( + cpu=job.request.resources.cpu, + memory=job.request.resources.memory, + ), + ports=list(job.request.ports), + ) + stub.run_job(request) + return True + except Exception: + return False + + def _handle_successful_dispatch(self, job: ControllerJob, worker: ControllerWorker, now_ms: int) -> None: + """Update state after successful dispatch.""" + job.state = cluster_pb2.JOB_STATE_RUNNING + job.worker_id = worker.worker_id + job.started_at_ms = now_ms + + worker.running_jobs.add(job.job_id) + self._state.remove_from_queue(job.job_id) + + logger.info(f"Dispatched job {job.job_id} to worker {worker.worker_id}") + self._state.log_action( + "job_dispatched", + job_id=job.job_id, + worker_id=worker.worker_id, + ) + + def _handle_failed_dispatch(self, job: ControllerJob, worker: ControllerWorker) -> None: + """Handle dispatch failure - mark worker unhealthy, keep job in queue.""" + worker.healthy = False + logger.warning(f"Failed to dispatch job {job.job_id} to {worker.worker_id}, " "marking worker unhealthy") + self._state.log_action( + "dispatch_failed", + job_id=job.job_id, + worker_id=worker.worker_id, + ) + + def _mark_job_unschedulable(self, job: ControllerJob, now_ms: int) -> None: + """Mark job as unschedulable and remove from queue.""" + logger.warning( + f"Job {job.job_id} exceeded scheduling timeout " + f"({job.request.scheduling_timeout_seconds}s), marking as UNSCHEDULABLE" + ) + job.state = cluster_pb2.JOB_STATE_UNSCHEDULABLE + job.finished_at_ms = now_ms + job.error = f"Scheduling timeout exceeded ({job.request.scheduling_timeout_seconds}s)" + self._state.remove_from_queue(job.job_id) + self._state.log_action( + "job_unschedulable", + job_id=job.job_id, + details=f"timeout={job.request.scheduling_timeout_seconds}s", + ) + + def _run_heartbeats(self) -> None: + """Run heartbeat checks for all workers.""" + workers = self._state.list_all_workers() + now_ms = int(time.time() * 1000) + + for worker in workers: + if not worker.healthy: + continue + + response = self._send_heartbeat(worker.address) + result = self._process_heartbeat(worker, response) + self._apply_heartbeat_result(worker, result, now_ms) + + def _send_heartbeat(self, address: str) -> cluster_pb2.Worker.ListJobsResponse | None: + """Send heartbeat to a worker and get job status. + + Args: + address: Worker address in "host:port" format + + Returns: + ListJobsResponse with job statuses, or None on failure + """ + try: + stub = self._stub_factory.get_stub(address) + stub.health_check(cluster_pb2.Empty()) + + jobs_response = stub.list_jobs(cluster_pb2.Worker.ListJobsRequest()) + return jobs_response + except Exception: + return None + + def _process_heartbeat( + self, + worker: ControllerWorker, + response: cluster_pb2.Worker.ListJobsResponse | None, + ) -> HeartbeatResult: + """Process heartbeat response and return what changed.""" + if response is None: + new_failure_count = worker.consecutive_failures + 1 + worker_failed = new_failure_count >= MAX_CONSECUTIVE_HEARTBEAT_FAILURES + + return HeartbeatResult( + worker_id=worker.worker_id, + success=False, + consecutive_failures=new_failure_count, + worker_failed=worker_failed, + failed_job_ids=list(worker.running_jobs) if worker_failed else [], + ) + + job_updates = [] + for status in response.jobs: + if status.state in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + ): + job_updates.append( + JobStateUpdate( + job_id=JobId(status.job_id), + new_state=status.state, + exit_code=status.exit_code, + error=status.error or None, + finished_at_ms=status.finished_at_ms, + ) + ) + + return HeartbeatResult( + worker_id=worker.worker_id, + success=True, + consecutive_failures=0, + worker_failed=False, + job_updates=job_updates, + ) + + def _apply_heartbeat_result(self, worker: ControllerWorker, result: HeartbeatResult, now_ms: int) -> None: + """Apply heartbeat result to worker and job state. + + Args: + worker: Worker that was checked + result: HeartbeatResult from health tracker + now_ms: Current timestamp in milliseconds + """ + worker.consecutive_failures = result.consecutive_failures + + if result.success: + worker.last_heartbeat_ms = now_ms + + if result.worker_failed: + worker.healthy = False + logger.warning(f"Worker {worker.worker_id} failed health check") + self._state.log_action("worker_failed", worker_id=worker.worker_id) + + # Retry jobs that were running on the failed worker + for job_id in result.failed_job_ids: + handle_job_failure(self._state, job_id, is_worker_failure=True) + + # Apply job state updates from heartbeat response + for update in result.job_updates: + job = self._state.get_job(update.job_id) + if job: + job.state = update.new_state + job.exit_code = update.exit_code + job.error = update.error + job.finished_at_ms = update.finished_at_ms + + # Remove from worker's running jobs + worker.running_jobs.discard(update.job_id) + + self._state.log_action( + "job_completed", + job_id=update.job_id, + details=f"state={update.new_state}, exit_code={update.exit_code}", + ) + + def _run_server(self) -> None: + """Run dashboard server (blocking, for thread).""" + try: + uvicorn.run( + self._dashboard._app, + host=self._config.host, + port=self._config.port, + log_level="error", + ) + except Exception as e: + print(f"Controller server error: {e}") + + # Delegate key service methods + + def launch_job( + self, + request: cluster_pb2.Controller.LaunchJobRequest, + ) -> cluster_pb2.Controller.LaunchJobResponse: + """Submit a job to the controller. + + Creates a new job, adds it to the queue, and wakes the scheduler + to attempt immediate dispatch. + + Args: + request: Job launch request with entrypoint and resources + + Returns: + LaunchJobResponse containing the assigned job_id + """ + return self._service.launch_job(request, None) + + def get_job_status( + self, + job_id: str, + ) -> cluster_pb2.Controller.GetJobStatusResponse: + """Get the status of a job. + + Args: + job_id: Job identifier + + Returns: + GetJobStatusResponse with current job status + """ + request = cluster_pb2.Controller.GetJobStatusRequest(job_id=job_id) + return self._service.get_job_status(request, None) + + def register_worker( + self, + request: cluster_pb2.Controller.RegisterWorkerRequest, + ) -> cluster_pb2.Controller.RegisterWorkerResponse: + """Register a worker with the controller. + + Adds the worker to the registry and wakes the scheduler to + potentially dispatch pending jobs. + + Args: + request: Worker registration request + + Returns: + RegisterWorkerResponse with acceptance status + """ + return self._service.register_worker(request, None) + + def terminate_job( + self, + job_id: str, + ) -> cluster_pb2.Empty: + """Terminate a running job. + + Marks the job as killed in the controller state. + + Args: + job_id: Job identifier + + Returns: + Empty response + """ + request = cluster_pb2.Controller.TerminateJobRequest(job_id=job_id) + return self._service.terminate_job(request, None) + + # Properties + + @property + def state(self) -> ControllerState: + """Access to controller state (for advanced usage). + + Returns: + The controller's internal state + """ + return self._state + + @property + def url(self) -> str: + """Controller URL. + + Returns: + HTTP URL for the controller dashboard and RPC service + """ + return f"http://{self._config.host}:{self._config.port}" diff --git a/lib/fluster/src/fluster/cluster/controller/dashboard.py b/lib/fluster/src/fluster/cluster/controller/dashboard.py new file mode 100644 index 0000000000..69411aada8 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/dashboard.py @@ -0,0 +1,893 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HTTP dashboard for controller visibility with Connect RPC mounting. + +Provides: +- Web dashboard at / with auto-refresh +- REST API at /api/* for dashboard consumption +- Health endpoint at /health +- Connect RPC at /fluster.cluster.ControllerService/* +""" + +# TODO: observability, gregate stats over jobs , log to stable storage + +from starlette.applications import Starlette +from starlette.middleware.wsgi import WSGIMiddleware +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse +from starlette.routing import Mount, Route + +from fluster import cluster_pb2 +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.types import JobId +from fluster.cluster_connect import ControllerServiceWSGIApplication + + +class FakeRequestContext: + """Minimal stub RequestContext for internal REST-to-RPC bridging. + + The ControllerDashboard translates REST API calls to RPC method calls, + which require a RequestContext parameter. Since the RPC methods never + actually access the context, this minimal stub satisfies the type signature. + """ + + pass + + +def _job_state_name(state: int) -> str: + """Convert job state integer to human-readable name.""" + state_map: dict[int, str] = { + cluster_pb2.JOB_STATE_PENDING: "pending", + cluster_pb2.JOB_STATE_BUILDING: "building", + cluster_pb2.JOB_STATE_RUNNING: "running", + cluster_pb2.JOB_STATE_SUCCEEDED: "succeeded", + cluster_pb2.JOB_STATE_FAILED: "failed", + cluster_pb2.JOB_STATE_KILLED: "killed", + cluster_pb2.JOB_STATE_WORKER_FAILED: "worker_failed", + } + return state_map.get(state, f"unknown({state})") + + +DASHBOARD_HTML = """ + + + Fluster Controller + + + +

Fluster Controller Dashboard

+ +
+
+
-
+
Jobs Pending
+
+
+
-
+
Jobs Running
+
+
+
-
+
Jobs Completed
+
+
+
-
+
Jobs Building
+
+
+
-
+
Workers Healthy
+
+
+
-
+
Workers Total
+
+
+
-
+
Endpoints
+
+
+ +

Recent Actions

+
+ +

Workers

+ + +
IDHealthyCPUMemoryRunning JobsLast Heartbeat
+ +

Job Queue

+ + +
IDNameStateResourcesWorkerError
+ +

Endpoints

+ + +
NameAddressJobNamespace
+ +

Users

+
Coming in future release
+ +

Reservations

+
Coming in future release
+ + + + +""" + + +JOB_DETAIL_HTML = """ + + + Job Detail - {{job_id}} + + + + ← Back to Dashboard +

Job: {{job_id}}

+ + + +
+ +
+
+

Status

+
+ State + - +
+
+ Exit Code + - +
+
+ Started + - +
+
+ Finished + - +
+
+ Duration + - +
+
+ +
+

Resources

+
+ Memory Used + - +
+
+ CPU Usage + - +
+
+ Disk Used + - +
+
+ +
+

Build Info

+
+ Image Tag + - +
+
+ Cache Status + - +
+
+
+ +

Logs

+
+ + + + +
+
Loading logs...
+ + + + +""" + + +class ControllerDashboard: + """HTTP dashboard with Connect RPC and web UI. + + Connect RPC is mounted at /fluster.cluster.ControllerService + Web dashboard at / + REST API for dashboard at /api/* + """ + + def __init__( + self, + service: ControllerServiceImpl, + host: str = "0.0.0.0", + port: int = 8080, + ): + self._service = service + self._state = service._state + self._host = host + self._port = port + self._app = self._create_app() + self._server = None + + @property + def port(self) -> int: + return self._port + + def _create_app(self) -> Starlette: + """Create Starlette application with all routes.""" + rpc_wsgi_app = ControllerServiceWSGIApplication(service=self._service) + rpc_app = WSGIMiddleware(rpc_wsgi_app) + + routes = [ + # Web dashboard + Route("/", self._dashboard), + Route("/job/{job_id}", self._job_detail_page), + # REST API (for dashboard) + Route("/api/stats", self._api_stats), + Route("/api/actions", self._api_actions), + Route("/api/workers", self._api_workers), + Route("/api/jobs", self._api_jobs), + Route("/api/endpoints", self._api_endpoints), + Route("/health", self._health), + # Connect RPC - mount WSGI app wrapped for ASGI + Mount(rpc_wsgi_app.path, app=rpc_app), + ] + return Starlette(routes=routes) + + def _dashboard(self, _request: Request) -> HTMLResponse: + """Serve web dashboard HTML.""" + return HTMLResponse(DASHBOARD_HTML) + + def _api_stats(self, _request: Request) -> JSONResponse: + """Return aggregated statistics for the dashboard.""" + ctx = FakeRequestContext() + jobs_response = self._service.list_jobs(cluster_pb2.Controller.ListJobsRequest(), ctx) + workers = self._state.list_all_workers() + + jobs_pending = sum(1 for j in jobs_response.jobs if j.state == cluster_pb2.JOB_STATE_PENDING) + jobs_running = sum(1 for j in jobs_response.jobs if j.state == cluster_pb2.JOB_STATE_RUNNING) + jobs_building = sum(1 for j in jobs_response.jobs if j.state == cluster_pb2.JOB_STATE_BUILDING) + jobs_completed = sum( + 1 + for j in jobs_response.jobs + if j.state + in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + cluster_pb2.JOB_STATE_WORKER_FAILED, + ) + ) + workers_healthy = sum(1 for w in workers if w.healthy) + + # Count endpoints for running jobs + endpoints_count = sum( + 1 + for ep in self._state._endpoints.values() + if (job := self._state.get_job(ep.job_id)) and job.state == cluster_pb2.JOB_STATE_RUNNING + ) + + return JSONResponse( + { + "jobs_pending": jobs_pending, + "jobs_running": jobs_running, + "jobs_building": jobs_building, + "jobs_completed": jobs_completed, + "workers_healthy": workers_healthy, + "workers_total": len(workers), + "endpoints_count": endpoints_count, + } + ) + + def _api_actions(self, _request: Request) -> JSONResponse: + """Return recent actions log.""" + actions = self._state.get_recent_actions(limit=50) + return JSONResponse( + [ + { + "timestamp_ms": a.timestamp_ms, + "action": a.action, + "job_id": str(a.job_id) if a.job_id else None, + "worker_id": str(a.worker_id) if a.worker_id else None, + "details": a.details, + } + for a in actions + ] + ) + + def _api_workers(self, _request: Request) -> JSONResponse: + """Return all workers with status.""" + workers = self._state.list_all_workers() + return JSONResponse( + [ + { + "worker_id": str(w.worker_id), + "address": w.address, + "healthy": w.healthy, + "running_jobs": len(w.running_jobs), + "consecutive_failures": w.consecutive_failures, + "last_heartbeat_ms": w.last_heartbeat_ms, + "resources": { + "cpu": w.resources.cpu if w.resources else 0, + "memory": w.resources.memory if w.resources else "", + }, + } + for w in workers + ] + ) + + def _api_jobs(self, _request: Request) -> JSONResponse: + """Return all jobs with status.""" + jobs = self._state.list_all_jobs() + return JSONResponse( + [ + { + "job_id": str(j.job_id), + "name": j.request.name, + "state": _job_state_name(j.state), + "worker_id": str(j.worker_id) if j.worker_id else None, + "error": j.error, + "submitted_at_ms": j.submitted_at_ms, + "started_at_ms": j.started_at_ms, + "finished_at_ms": j.finished_at_ms, + "resources": { + "cpu": j.request.resources.cpu if j.request.resources else 0, + "memory": j.request.resources.memory if j.request.resources else "", + }, + } + for j in jobs + ] + ) + + def _api_endpoints(self, _request: Request) -> JSONResponse: + """Return all active endpoints for RUNNING jobs.""" + endpoints = [] + for ep in self._state._endpoints.values(): + job = self._state.get_job(ep.job_id) + if job and job.state == cluster_pb2.JOB_STATE_RUNNING: + endpoints.append( + { + "endpoint_id": ep.endpoint_id, + "name": ep.name, + "address": ep.address, + "job_id": str(ep.job_id), + "namespace": ep.namespace, + "metadata": dict(ep.metadata), + } + ) + return JSONResponse(endpoints) + + def _job_detail_page(self, request: Request) -> HTMLResponse: + """Serve job detail page - fetches from worker via JS.""" + job_id = request.path_params["job_id"] + job = self._state.get_job(JobId(job_id)) + worker_address = "" + if job and job.worker_id: + worker = self._state.get_worker(job.worker_id) + if worker: + worker_address = worker.address + return HTMLResponse(JOB_DETAIL_HTML.replace("{{job_id}}", job_id).replace("{{worker_address}}", worker_address)) + + def _health(self, _request: Request) -> JSONResponse: + """Return health check status.""" + workers = self._state.list_all_workers() + jobs = self._state.list_all_jobs() + healthy_count = sum(1 for w in workers if w.healthy) + + return JSONResponse( + { + "status": "ok", + "workers": len(workers), + "healthy_workers": healthy_count, + "jobs": len(jobs), + } + ) + + def run(self) -> None: + """Run server (blocking).""" + import uvicorn + + uvicorn.run(self._app, host=self._host, port=self._port) + + async def run_async(self) -> None: + """Run server asynchronously (for use with asyncio.create_task).""" + import uvicorn + + config = uvicorn.Config(self._app, host=self._host, port=self._port) + self._server = uvicorn.Server(config) + await self._server.serve() + + async def shutdown(self) -> None: + """Shutdown the async server gracefully.""" + if self._server: + self._server.should_exit = True diff --git a/lib/fluster/src/fluster/cluster/controller/resources.py b/lib/fluster/src/fluster/cluster/controller/resources.py new file mode 100644 index 0000000000..85e9e3e0a4 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/resources.py @@ -0,0 +1,103 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resource parsing and comparison utilities. + +This module provides helpers for parsing memory strings (e.g., "8g", "512m"), +extracting device types and variants from DeviceConfig, and other resource-related +utilities used by the scheduler for resource-aware job matching. +""" + +import humanfriendly + +from fluster import cluster_pb2 + + +def parse_memory_string(memory_str: str) -> int: + """Parse memory string like '8g', '16gb', '512m' to bytes. + + Uses humanfriendly library for robust parsing. Supports various formats: + - "8G", "8GB", "8 GB", "8 gigabytes" + - "512M", "512MB", "512 megabytes" + - "1024K", "1024KB", "1024 kilobytes" + - Plain numbers treated as bytes + + Args: + memory_str: Memory string (e.g., "8g", "16gb", "512m", "1024mb") + + Returns: + Memory in bytes + + Raises: + ValueError: If format is invalid + """ + if not memory_str: + return 0 + + memory_str = memory_str.strip() + if not memory_str or memory_str == "0": + return 0 + + try: + return humanfriendly.parse_size(memory_str, binary=True) + except humanfriendly.InvalidSize as e: + raise ValueError(str(e)) from e + + +def get_device_type(device: cluster_pb2.DeviceConfig) -> str: + """Extract device type from DeviceConfig. + + Args: + device: DeviceConfig protobuf message + + Returns: + "cpu", "gpu", or "tpu" + """ + if device.HasField("cpu"): + return "cpu" + elif device.HasField("gpu"): + return "gpu" + elif device.HasField("tpu"): + return "tpu" + return "cpu" # Default to CPU if no device specified + + +def get_device_variant(device: cluster_pb2.DeviceConfig) -> str | None: + """Extract device variant from DeviceConfig. + + Args: + device: DeviceConfig protobuf message + + Returns: + Variant string (e.g., "A100", "v5litepod-16") or None if not specified + """ + if device.HasField("gpu"): + return device.gpu.variant if device.gpu.variant else None + elif device.HasField("tpu"): + return device.tpu.variant if device.tpu.variant else None + return None + + +def get_gpu_count(device: cluster_pb2.DeviceConfig) -> int: + """Get GPU count from DeviceConfig. + + Args: + device: DeviceConfig protobuf message + + Returns: + Number of GPUs (0 if not a GPU device) + """ + if device.HasField("gpu"): + return device.gpu.count or 1 + return 0 diff --git a/lib/fluster/src/fluster/cluster/controller/retry.py b/lib/fluster/src/fluster/cluster/controller/retry.py new file mode 100644 index 0000000000..ce137ba7b3 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/retry.py @@ -0,0 +1,143 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Job failure and retry logic. + +This module provides functions to handle job failures, distinguishing between: +- Worker failures (external): worker died or became unhealthy +- Job failures (internal): job exited with non-zero exit code + +Each failure type has separate retry limits (max_retries_preemption for worker +failures, max_retries_failure for job failures). Failed jobs are reset to +PENDING state and re-queued for another scheduling attempt. + +Gang scheduling requires all-or-nothing retry: the entire gang is only retried +if ALL jobs in the gang have retries remaining. This maintains gang consistency. +""" + +import logging + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerState +from fluster.cluster.types import JobId + +logger = logging.getLogger(__name__) + + +def handle_job_failure( + state: ControllerState, + job_id: JobId, + is_worker_failure: bool, +) -> bool: + """Handle a job failure, potentially retrying. + + Args: + state: Controller state + job_id: ID of the failed job + is_worker_failure: True if external failure (worker died), + False if internal (job exit code != 0) + + Returns: + True if job was re-queued for retry, False otherwise + """ + job = state.get_job(job_id) + if not job: + return False + + # Increment counter to track this failure + if is_worker_failure: + job.preemption_count += 1 + can_retry = job.preemption_count <= job.max_retries_preemption + else: + job.failure_count += 1 + can_retry = job.failure_count <= job.max_retries_failure + + if not can_retry: + logger.warning(f"Job {job_id} exceeded retry limit, not retrying") + return False + + logger.info(f"Retrying job {job_id} (failures={job.failure_count}, preemptions={job.preemption_count})") + + # Reset job state for retry + job.state = cluster_pb2.JOB_STATE_PENDING + job.worker_id = None + job.started_at_ms = None + job.finished_at_ms = None + job.error = None + + # Re-queue the job + state.add_job(job) + return True + + +def handle_gang_failure( + state: ControllerState, + gang_id: str, + is_worker_failure: bool, +) -> list[JobId]: + """Handle gang failure - terminate all jobs, optionally retry. + + All-or-nothing retry: only retries if ALL jobs in gang have retries left. + + Args: + state: Controller state + gang_id: Gang identifier + is_worker_failure: True if external failure (worker died), + False if internal (job failure) + + Returns: + List of job IDs that were re-queued (empty if gang couldn't retry) + """ + jobs = state.get_gang_jobs(gang_id) + if not jobs: + return [] + + # Check if ALL jobs in gang have retries remaining (all-or-nothing) + # Check before modifying any state to avoid partial updates on failure + if is_worker_failure: + can_retry = all(job.preemption_count < job.max_retries_preemption for job in jobs) + else: + can_retry = all(job.failure_count < job.max_retries_failure for job in jobs) + + if not can_retry: + # Mark all running jobs as KILLED (no retry possible) + for job in jobs: + if job.state == cluster_pb2.JOB_STATE_RUNNING: + job.state = cluster_pb2.JOB_STATE_KILLED + job.error = f"Gang {gang_id} failed" + logger.warning(f"Gang {gang_id} exceeded retry limit, not retrying") + return [] + + # Retry all jobs in gang + retried = [] + for job in jobs: + # Increment appropriate counter + if is_worker_failure: + job.preemption_count += 1 + else: + job.failure_count += 1 + + # Reset job state for retry + job.state = cluster_pb2.JOB_STATE_PENDING + job.worker_id = None + job.started_at_ms = None + job.finished_at_ms = None + job.error = None + + # Re-queue the job + state.add_job(job) + retried.append(job.job_id) + + logger.info(f"Retrying gang {gang_id} with {len(retried)} jobs") + return retried diff --git a/lib/fluster/src/fluster/cluster/controller/scheduler.py b/lib/fluster/src/fluster/cluster/controller/scheduler.py new file mode 100644 index 0000000000..0d01e60448 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/scheduler.py @@ -0,0 +1,148 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure job scheduler with shallow interface. + +This module provides the Scheduler class, which implements job-to-worker matching +logic without any threading, dispatch, or state mutation. The scheduler takes +inputs (pending jobs, workers, current time) and returns outputs (assignments, +timed-out jobs). All side effects are the caller's responsibility. + +This design follows the "shallow interface" pattern: the scheduler is a pure +function-like object that can be easily tested and composed without mocking +threads or callbacks. +""" + +from dataclasses import dataclass, field + +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.controller.workers import worker_can_fit_job + + +@dataclass +class ScheduleResult: + """Result of a scheduling attempt. + + Contains the job-to-worker assignments that the caller should dispatch, + and jobs that have exceeded their scheduling timeout. + + Args: + assignments: List of (job, worker) pairs to dispatch + timed_out_jobs: Jobs that exceeded their scheduling_timeout_seconds + """ + + assignments: list[tuple[ControllerJob, ControllerWorker]] = field(default_factory=list) + timed_out_jobs: list[ControllerJob] = field(default_factory=list) + + +class Scheduler: + """Pure job-to-worker matching logic. + + The scheduler matches pending jobs to available workers based on resource + requirements and availability. It does NOT: + - Dispatch jobs (caller does this) + - Modify state (caller does this) + - Run any threads (Controller owns threading) + + This is a stateless utility class - all inputs are passed to find_assignments() + and all outputs are returned in ScheduleResult. + """ + + def __init__(self, state: ControllerState): + """Initialize scheduler with controller state for resource lookups. + + Args: + state: Controller state used for looking up job resources when + computing committed resources on workers. + """ + self._state = state + + def find_assignments( + self, + pending_jobs: list[ControllerJob], + workers: list[ControllerWorker], + now_ms: int, + ) -> ScheduleResult: + """Match pending jobs to available workers. + + Uses first-fit algorithm, skipping jobs that don't fit any worker. + Also identifies jobs that have exceeded their scheduling timeout. + + The algorithm prevents head-of-line blocking: if a large job at the + front of the queue doesn't fit, smaller jobs behind it can still be + scheduled. + + Args: + pending_jobs: Jobs waiting to be scheduled (in FIFO order) + workers: Available workers (only healthy ones should be passed) + now_ms: Current timestamp in milliseconds + + Returns: + ScheduleResult with assignments and timed-out jobs + """ + result = ScheduleResult() + + # Track which workers have been assigned jobs in this scheduling round + # so we account for their capacity correctly + assigned_jobs_by_worker: dict[str, list[ControllerJob]] = {} + + for job in pending_jobs: + if self._is_job_timed_out(job, now_ms): + result.timed_out_jobs.append(job) + continue + + worker = self._find_worker_for_job(job, workers, assigned_jobs_by_worker) + if worker: + result.assignments.append((job, worker)) + assigned_jobs_by_worker.setdefault(worker.worker_id, []).append(job) + + return result + + def _is_job_timed_out(self, job: ControllerJob, now_ms: int) -> bool: + """Check if job has exceeded its scheduling timeout.""" + timeout_seconds = job.request.scheduling_timeout_seconds + if timeout_seconds <= 0: + return False + + pending_duration_ms = now_ms - job.submitted_at_ms + timeout_ms = timeout_seconds * 1000 + return pending_duration_ms > timeout_ms + + def _find_worker_for_job( + self, + job: ControllerJob, + workers: list[ControllerWorker], + assigned_jobs_by_worker: dict[str, list[ControllerJob]], + ) -> ControllerWorker | None: + """Find first worker that can fit the job. + + Takes into account both jobs already running on each worker AND jobs + assigned earlier in this scheduling round (tracked in assigned_jobs_by_worker). + + Args: + job: Job to find a worker for + workers: Available workers to consider + assigned_jobs_by_worker: Jobs assigned in this round, by worker_id + + Returns: + First matching worker, or None if no worker can fit the job + """ + for worker in workers: + if not worker.healthy: + continue + # Check if worker can fit this job, considering jobs assigned this round + jobs_assigned_this_round = assigned_jobs_by_worker.get(worker.worker_id, []) + if worker_can_fit_job(self._state, worker, job, jobs_assigned_this_round): + return worker + return None diff --git a/lib/fluster/src/fluster/cluster/controller/service.py b/lib/fluster/src/fluster/cluster/controller/service.py new file mode 100644 index 0000000000..c788675735 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/service.py @@ -0,0 +1,442 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Controller RPC service implementation. + +This module provides the ControllerServiceImpl class, which implements the +RPC handlers for the ControllerService. It handles: +- Job submission (launch_job) +- Job status queries (get_job_status) +- Job termination (terminate_job) +- Job listing (list_jobs) +- Worker registration (register_worker) +- Worker listing (list_workers) +- Endpoint registry operations (register/unregister/lookup/list endpoints) + +The service layer is thin, delegating most logic to the ControllerState and +Scheduler. It focuses on proto message conversion and error handling. +""" + +import time +import uuid +from pathlib import Path +from typing import Any, Protocol + +from connectrpc.code import Code +from connectrpc.errors import ConnectError + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerEndpoint, ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.types import JobId, WorkerId, is_job_finished + + +class SchedulerProtocol(Protocol): + """Protocol for scheduler operations used by ControllerServiceImpl.""" + + def wake(self) -> None: + """Signal scheduler to run immediately.""" + ... + + +class ControllerServiceImpl: + """ControllerService RPC implementation. + + Provides HTTP handlers for job management operations. Each method accepts + a protobuf request message and returns a protobuf response message. + + Args: + state: Controller state containing jobs and workers + scheduler: Background scheduler for job dispatch (any object with wake() method) + bundle_dir: Directory for storing uploaded bundles (optional) + """ + + def __init__( + self, + state: ControllerState, + scheduler: SchedulerProtocol, + bundle_dir: str | Path | None = None, + ): + self._state = state + self._scheduler = scheduler + self._bundle_dir = Path(bundle_dir) if bundle_dir else None + + def launch_job( + self, + request: cluster_pb2.Controller.LaunchJobRequest, + ctx: Any, + ) -> cluster_pb2.Controller.LaunchJobResponse: + """Submit a new job to the controller. + + Creates a new job with a unique ID, adds it to the controller state, + and wakes the scheduler to attempt immediate dispatch. + + If bundle_blob is provided and bundle_dir is configured, writes the + bundle to disk and updates bundle_gcs_path to a file:// URL. + + Args: + request: Job launch request with entrypoint and resource spec + ctx: Request context (unused in v0) + + Returns: + LaunchJobResponse containing the assigned job_id + """ + job_id = str(uuid.uuid4()) + + # Handle bundle_blob: write to bundle_dir if provided + if request.bundle_blob and self._bundle_dir: + bundle_path = self._bundle_dir / job_id / "bundle.zip" + bundle_path.parent.mkdir(parents=True, exist_ok=True) + bundle_path.write_bytes(request.bundle_blob) + + # Update the request with file:// path + request = cluster_pb2.Controller.LaunchJobRequest( + name=request.name, + serialized_entrypoint=request.serialized_entrypoint, + resources=request.resources, + environment=request.environment, + bundle_gcs_path=f"file://{bundle_path}", + bundle_hash=request.bundle_hash, + ports=list(request.ports), + scheduling_timeout_seconds=request.scheduling_timeout_seconds, + ) + + job = ControllerJob( + job_id=JobId(job_id), + request=request, + submitted_at_ms=int(time.time() * 1000), + ) + + self._state.add_job(job) + self._state.log_action("job_submitted", job_id=job.job_id, details=request.name) + self._scheduler.wake() # Try to schedule immediately + + return cluster_pb2.Controller.LaunchJobResponse(job_id=job_id) + + def get_job_status( + self, + request: cluster_pb2.Controller.GetJobStatusRequest, + ctx: Any, + ) -> cluster_pb2.Controller.GetJobStatusResponse: + """Get status of a specific job. + + Args: + request: Request containing job_id + ctx: Request context (unused in v0) + + Returns: + GetJobStatusResponse with JobStatus proto + + Raises: + ConnectError: If job is not found (Code.NOT_FOUND) + """ + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + worker_address = "" + if job.worker_id: + worker = self._state.get_worker(job.worker_id) + if worker: + worker_address = worker.address + + return cluster_pb2.Controller.GetJobStatusResponse( + job=cluster_pb2.JobStatus( + job_id=job.job_id, + state=job.state, + error=job.error or "", + exit_code=job.exit_code or 0, + started_at_ms=job.started_at_ms or 0, + finished_at_ms=job.finished_at_ms or 0, + worker_id=job.worker_id or "", + worker_address=worker_address, + ) + ) + + def terminate_job( + self, + request: cluster_pb2.Controller.TerminateJobRequest, + ctx: Any, + ) -> cluster_pb2.Empty: + """Terminate a running job. + + Marks the job as KILLED in the controller state. Note that in v0, + this does not send an actual kill signal to the worker - that is + deferred to a future implementation. + + Args: + request: Request containing job_id + ctx: Request context (unused in v0) + + Returns: + Empty response + + Raises: + ConnectError: If job is not found (Code.NOT_FOUND) + """ + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + # Idempotent: if already in terminal state, do nothing + if is_job_finished(job.state): + return cluster_pb2.Empty() + + # TODO: Send kill to worker + job.state = cluster_pb2.JOB_STATE_KILLED + job.finished_at_ms = int(time.time() * 1000) + self._state.log_action("job_killed", job_id=job.job_id) + + return cluster_pb2.Empty() + + def list_jobs( + self, + request: cluster_pb2.Controller.ListJobsRequest, + ctx: Any, + ) -> cluster_pb2.Controller.ListJobsResponse: + """List all jobs. + + Returns a list of all jobs in the controller, regardless of state. + Note that the namespace field in the request is ignored in v0. + + Args: + request: List request (namespace field currently ignored) + ctx: Request context (unused in v0) + + Returns: + ListJobsResponse containing all jobs as JobStatus protos + """ + jobs = [] + for j in self._state.list_all_jobs(): + worker_address = "" + if j.worker_id: + worker = self._state.get_worker(j.worker_id) + if worker: + worker_address = worker.address + jobs.append( + cluster_pb2.JobStatus( + job_id=j.job_id, + state=j.state, + worker_id=j.worker_id or "", + worker_address=worker_address, + error=j.error or "", + exit_code=j.exit_code or 0, + started_at_ms=j.started_at_ms or 0, + finished_at_ms=j.finished_at_ms or 0, + ) + ) + return cluster_pb2.Controller.ListJobsResponse(jobs=jobs) + + def register_worker( + self, + request: cluster_pb2.Controller.RegisterWorkerRequest, + ctx: Any, + ) -> cluster_pb2.Controller.RegisterWorkerResponse: + """Register a new worker with the controller. + + Workers register themselves on startup and provide their address and + resource capabilities. The controller adds them to the worker registry + and wakes the scheduler to potentially dispatch pending jobs. + + Args: + request: Worker registration request + ctx: Request context (unused in v0) + + Returns: + RegisterWorkerResponse with acceptance status + """ + worker = ControllerWorker( + worker_id=WorkerId(request.worker_id), + address=request.address, + resources=request.resources, + last_heartbeat_ms=int(time.time() * 1000), + ) + self._state.add_worker(worker) + self._state.log_action( + "worker_registered", + worker_id=worker.worker_id, + details=f"address={request.address}", + ) + self._scheduler.wake() # Try to schedule jobs on new worker + + return cluster_pb2.Controller.RegisterWorkerResponse(accepted=True) + + def list_workers( + self, + request: cluster_pb2.Controller.ListWorkersRequest, + ctx: Any, + ) -> cluster_pb2.Controller.ListWorkersResponse: + """List all registered workers. + + Returns health status for all workers in the controller, including + healthy and unhealthy workers. + + Args: + request: List workers request (currently ignored) + ctx: Request context (unused in v0) + + Returns: + ListWorkersResponse with worker health statuses + """ + workers = [ + cluster_pb2.Controller.WorkerHealthStatus( + worker_id=w.worker_id, + healthy=w.healthy, + consecutive_failures=w.consecutive_failures, + last_heartbeat_ms=w.last_heartbeat_ms, + running_job_ids=list(w.running_jobs), + ) + for w in self._state.list_all_workers() + ] + return cluster_pb2.Controller.ListWorkersResponse(workers=workers) + + # Endpoint registry methods + + def register_endpoint( + self, + request: cluster_pb2.Controller.RegisterEndpointRequest, + ctx: Any, + ) -> cluster_pb2.Controller.RegisterEndpointResponse: + """Register a service endpoint. + + Validates that the job exists and is RUNNING before registering. + Endpoints are automatically removed when jobs terminate. + + Args: + request: Endpoint registration request + ctx: Request context (unused) + + Returns: + RegisterEndpointResponse with assigned endpoint_id + + Raises: + ConnectError: If job is not found or not running + """ + endpoint_id = str(uuid.uuid4()) + + # Validate job exists and is running + job = self._state.get_job(JobId(request.job_id)) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + if job.state != cluster_pb2.JOB_STATE_RUNNING: + raise ConnectError(Code.FAILED_PRECONDITION, f"Job {request.job_id} is not running") + + endpoint = ControllerEndpoint( + endpoint_id=endpoint_id, + name=request.name, + address=request.address, + job_id=JobId(request.job_id), + namespace=request.namespace or "", + metadata=dict(request.metadata), + registered_at_ms=int(time.time() * 1000), + ) + self._state.add_endpoint(endpoint) + self._state.log_action( + "endpoint_registered", + job_id=job.job_id, + details=f"{request.name} at {request.address}", + ) + return cluster_pb2.Controller.RegisterEndpointResponse(endpoint_id=endpoint_id) + + def unregister_endpoint( + self, + request: cluster_pb2.Controller.UnregisterEndpointRequest, + ctx: Any, + ) -> cluster_pb2.Empty: + """Unregister a service endpoint. + + Removes an endpoint from the registry. This is idempotent - no error + if the endpoint doesn't exist. + + Args: + request: Endpoint unregistration request + ctx: Request context (unused) + + Returns: + Empty response + """ + endpoint = self._state.remove_endpoint(request.endpoint_id) + if endpoint: + self._state.log_action( + "endpoint_unregistered", + job_id=endpoint.job_id, + details=endpoint.name, + ) + return cluster_pb2.Empty() + + def lookup_endpoint( + self, + request: cluster_pb2.Controller.LookupEndpointRequest, + ctx: Any, + ) -> cluster_pb2.Controller.LookupEndpointResponse: + """Look up a service endpoint by name. + + Returns the first endpoint matching the name in the given namespace. + Only endpoints for RUNNING jobs are returned. + + Args: + request: Lookup request with name and namespace + ctx: Request context (unused) + + Returns: + LookupEndpointResponse with first matching endpoint (empty if not found) + """ + namespace = request.namespace or "" + endpoints = self._state.lookup_endpoints(request.name, namespace) + if not endpoints: + return cluster_pb2.Controller.LookupEndpointResponse() + + e = endpoints[0] + return cluster_pb2.Controller.LookupEndpointResponse( + endpoint=cluster_pb2.Controller.Endpoint( + endpoint_id=e.endpoint_id, + name=e.name, + address=e.address, + job_id=e.job_id, + namespace=e.namespace, + metadata=e.metadata, + ) + ) + + def list_endpoints( + self, + request: cluster_pb2.Controller.ListEndpointsRequest, + ctx: Any, + ) -> cluster_pb2.Controller.ListEndpointsResponse: + """List endpoints by name prefix. + + Returns all endpoints matching the prefix in the given namespace. + Only endpoints for RUNNING jobs are returned. + + Args: + request: List request with prefix and namespace + ctx: Request context (unused) + + Returns: + ListEndpointsResponse with matching endpoints + """ + namespace = request.namespace or "" + endpoints = self._state.list_endpoints_by_prefix(request.prefix, namespace) + return cluster_pb2.Controller.ListEndpointsResponse( + endpoints=[ + cluster_pb2.Controller.Endpoint( + endpoint_id=e.endpoint_id, + name=e.name, + address=e.address, + job_id=e.job_id, + namespace=e.namespace, + metadata=e.metadata, + ) + for e in endpoints + ] + ) diff --git a/lib/fluster/src/fluster/cluster/controller/state.py b/lib/fluster/src/fluster/cluster/controller/state.py new file mode 100644 index 0000000000..69145e7ed1 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/state.py @@ -0,0 +1,478 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Controller core data structures. + +This module provides the in-memory state management for the controller, including: +- ControllerJob: Controller's view of a job with retry tracking and gang info +- ControllerWorker: Controller's view of a worker with health and capacity +- ControllerEndpoint: An endpoint registered with the controller for service discovery +- ActionLogEntry: Record of a controller action for the dashboard +- ControllerState: Thread-safe state container for jobs, workers, endpoints, and the queue + +All state mutations are protected by a reentrant lock (RLock) to support +concurrent access from the scheduler, heartbeat monitor, and RPC handlers. +""" + +import time +from collections import deque +from dataclasses import dataclass, field +from threading import RLock + +from fluster import cluster_pb2 +from fluster.cluster.types import JobId, WorkerId + + +@dataclass +class ControllerJob: + """Controller's view of a job. + + Tracks job state, retry counts, gang scheduling information, and timestamps. + Used by the scheduler to determine which jobs to dispatch and by the + heartbeat monitor to update job state based on worker reports. + + Args: + job_id: Unique job identifier + request: Original job launch request from the client + state: Current job state (defaults to PENDING) + worker_id: Worker assigned to this job (if any) + failure_count: Number of internal failures (job exit code != 0) + preemption_count: Number of external failures (worker died) + max_retries_failure: Maximum internal failures before giving up + max_retries_preemption: Maximum external failures before giving up + gang_id: Gang identifier for gang-scheduled jobs (None for solo jobs) + submitted_at_ms: Timestamp when job was submitted + started_at_ms: Timestamp when job started running + finished_at_ms: Timestamp when job reached terminal state + error: Error message if job failed + exit_code: Process exit code if job completed + """ + + job_id: JobId + request: cluster_pb2.Controller.LaunchJobRequest + state: cluster_pb2.JobState = cluster_pb2.JOB_STATE_PENDING + worker_id: WorkerId | None = None + + # Retry tracking + failure_count: int = 0 + preemption_count: int = 0 + max_retries_failure: int = 0 + max_retries_preemption: int = 100 + + # Gang scheduling + gang_id: str | None = None + + # Timestamps + submitted_at_ms: int = 0 + started_at_ms: int | None = None + finished_at_ms: int | None = None + + error: str | None = None + exit_code: int | None = None + + +@dataclass +class ControllerWorker: + """Controller's view of a worker. + + Tracks worker capabilities, health status, and current job assignments. + The heartbeat monitor uses this to detect worker failures and the scheduler + uses it to find available capacity. + + Args: + worker_id: Unique worker identifier + address: Worker RPC address (host:port) + resources: Worker's available resources + healthy: Whether worker is currently healthy + consecutive_failures: Number of consecutive heartbeat failures + last_heartbeat_ms: Timestamp of last successful heartbeat + running_jobs: Set of job IDs currently running on this worker + """ + + worker_id: WorkerId + address: str + resources: cluster_pb2.ResourceSpec + + # Health tracking + healthy: bool = True + consecutive_failures: int = 0 + last_heartbeat_ms: int = 0 + + # Current assignments + running_jobs: set[JobId] = field(default_factory=set) + + +@dataclass +class ControllerEndpoint: + """An endpoint registered with the controller. + + Endpoints are associated with jobs and used for service discovery. + When a job transitions to a terminal state, all its endpoints are + automatically removed. + + Args: + endpoint_id: Unique endpoint identifier + name: Service name for discovery + address: Network address (host:port) + job_id: Job that registered this endpoint + namespace: Namespace for isolation + metadata: Additional key-value metadata + registered_at_ms: Timestamp when endpoint was registered + """ + + endpoint_id: str + name: str + address: str + job_id: JobId + namespace: str + metadata: dict[str, str] = field(default_factory=dict) + registered_at_ms: int = 0 + + +@dataclass +class ActionLogEntry: + """Record of a controller action for the dashboard. + + Actions are logged when significant events occur (job submitted, job started, + worker registered, etc.) to provide visibility into controller activity. + + Args: + timestamp_ms: Unix timestamp in milliseconds when action occurred + action: Action type (e.g., "job_submitted", "job_started", "worker_failed") + job_id: Associated job ID, if any + worker_id: Associated worker ID, if any + details: Additional human-readable details + """ + + timestamp_ms: int + action: str + job_id: JobId | None = None + worker_id: WorkerId | None = None + details: str = "" + + +class ControllerState: + """Thread-safe controller state. + + Manages in-memory state for jobs, workers, endpoints, job queue, and gang tracking. + All mutations go through methods that acquire the lock to ensure consistency + during concurrent access from multiple threads (scheduler, heartbeat monitor, + RPC handlers). + + The job queue is FIFO by default, with pop_next_pending() automatically + skipping jobs that are no longer in PENDING state. + + Endpoints are associated with jobs and automatically removed when jobs + transition to terminal states. Lookup and list operations filter to only + return endpoints for jobs in RUNNING state. + """ + + def __init__(self): + self._lock = RLock() + self._jobs: dict[JobId, ControllerJob] = {} + self._workers: dict[WorkerId, ControllerWorker] = {} + self._queue: deque[JobId] = deque() # FIFO queue of job IDs + self._gangs: dict[str, set[JobId]] = {} # gang_id -> job_ids + self._actions: deque[ActionLogEntry] = deque(maxlen=100) # Recent actions log + self._endpoints: dict[str, ControllerEndpoint] = {} # endpoint_id -> endpoint + self._endpoints_by_job: dict[JobId, set[str]] = {} # job_id -> endpoint_ids + + def add_job(self, job: ControllerJob) -> None: + """Add a job to the controller state and queue. + + The job is added to the jobs dict, appended to the FIFO queue, and + registered with its gang (if it has one). + + Args: + job: Job to add + """ + with self._lock: + self._jobs[job.job_id] = job + self._queue.append(job.job_id) + if job.gang_id: + self._gangs.setdefault(job.gang_id, set()).add(job.job_id) + + def get_job(self, job_id: JobId) -> ControllerJob | None: + """Get a job by ID. + + Args: + job_id: Job identifier + + Returns: + Job if found, None otherwise + """ + with self._lock: + return self._jobs.get(job_id) + + def pop_next_pending(self) -> ControllerJob | None: + """Pop next PENDING job from the queue. + + Iterates through the queue until finding a job in PENDING state, + skipping jobs that have transitioned to other states. + + Returns: + Next pending job, or None if queue is empty or no pending jobs + """ + with self._lock: + while self._queue: + job_id = self._queue.popleft() + job = self._jobs.get(job_id) + if job and job.state == cluster_pb2.JOB_STATE_PENDING: + return job + return None + + def add_worker(self, worker: ControllerWorker) -> None: + """Add or update a worker in the registry. + + Args: + worker: Worker to add + """ + with self._lock: + self._workers[worker.worker_id] = worker + + def get_worker(self, worker_id: WorkerId) -> ControllerWorker | None: + """Get a worker by ID. + + Args: + worker_id: Worker identifier + + Returns: + Worker if found, None otherwise + """ + with self._lock: + return self._workers.get(worker_id) + + def remove_worker(self, worker_id: WorkerId) -> ControllerWorker | None: + """Remove a worker from the registry. + + Used when a worker is permanently gone (e.g., after exceeding + heartbeat failure threshold). + + Args: + worker_id: Worker identifier + + Returns: + Removed worker if found, None otherwise + """ + with self._lock: + return self._workers.pop(worker_id, None) + + def get_available_workers(self) -> list[ControllerWorker]: + """Return healthy workers with capacity. + + For v0, simply returns all healthy workers. Future versions will + filter by available resources and capacity constraints. + + Returns: + List of healthy workers + """ + with self._lock: + return [w for w in self._workers.values() if w.healthy] + + def list_all_jobs(self) -> list[ControllerJob]: + """Get all jobs in the controller. + + Returns: + List of all jobs (snapshot under lock) + """ + with self._lock: + return list(self._jobs.values()) + + def list_all_workers(self) -> list[ControllerWorker]: + """Get all workers in the controller. + + Returns: + List of all workers (snapshot under lock) + """ + with self._lock: + return list(self._workers.values()) + + def get_gang_jobs(self, gang_id: str) -> list[ControllerJob]: + """Get all jobs in a gang. + + Args: + gang_id: Gang identifier + + Returns: + List of jobs in the gang (may be empty) + """ + with self._lock: + job_ids = self._gangs.get(gang_id, set()) + return [self._jobs[jid] for jid in job_ids if jid in self._jobs] + + def log_action( + self, + action: str, + job_id: JobId | None = None, + worker_id: WorkerId | None = None, + details: str = "", + ) -> None: + """Record an action in the log. + + Actions are stored in a bounded deque (last 100 entries) for display + on the dashboard. + + Args: + action: Action type (e.g., "job_submitted", "job_started") + job_id: Associated job ID, if any + worker_id: Associated worker ID, if any + details: Additional human-readable details + """ + entry = ActionLogEntry( + timestamp_ms=int(time.time() * 1000), + action=action, + job_id=job_id, + worker_id=worker_id, + details=details, + ) + with self._lock: + self._actions.append(entry) + + def get_recent_actions(self, limit: int = 50) -> list[ActionLogEntry]: + """Get most recent actions. + + Args: + limit: Maximum number of actions to return + + Returns: + List of recent actions, most recent last + """ + with self._lock: + actions = list(self._actions) + return actions[-limit:] if limit < len(actions) else actions + + def peek_pending_jobs(self) -> list[ControllerJob]: + """Return all PENDING jobs in queue order without removing them. + + Used by the scheduler to iterate through the queue and find schedulable + jobs. Unlike pop_next_pending(), this returns all pending jobs so the + scheduler can skip jobs that don't fit and continue to the next. + + Returns: + List of pending jobs in FIFO order + """ + with self._lock: + pending = [] + for job_id in self._queue: + job = self._jobs.get(job_id) + if job and job.state == cluster_pb2.JOB_STATE_PENDING: + pending.append(job) + return pending + + def remove_from_queue(self, job_id: JobId) -> None: + """Remove a specific job from the queue. + Args: + job_id: Job ID to remove from the queue + """ + with self._lock: + self._queue = deque(jid for jid in self._queue if jid != job_id) + + def add_endpoint(self, endpoint: ControllerEndpoint) -> None: + """Add an endpoint to the controller registry. + + Endpoints are tracked both by ID and by job. When a job terminates, + all its endpoints can be quickly removed. + + Args: + endpoint: Endpoint to register + """ + with self._lock: + self._endpoints[endpoint.endpoint_id] = endpoint + self._endpoints_by_job.setdefault(endpoint.job_id, set()).add(endpoint.endpoint_id) + + def remove_endpoint(self, endpoint_id: str) -> ControllerEndpoint | None: + """Remove an endpoint from the registry. + + Args: + endpoint_id: Endpoint ID to remove + + Returns: + Removed endpoint if found, None otherwise + """ + with self._lock: + endpoint = self._endpoints.pop(endpoint_id, None) + if endpoint: + job_endpoints = self._endpoints_by_job.get(endpoint.job_id) + if job_endpoints: + job_endpoints.discard(endpoint_id) + return endpoint + + def lookup_endpoints(self, name: str, namespace: str) -> list[ControllerEndpoint]: + """Find endpoints by exact name match. + + Only returns endpoints for jobs in RUNNING state. Endpoints for + jobs that have terminated or are not yet running are filtered out. + + Args: + name: Service name to look up + namespace: Namespace to search in + + Returns: + List of matching endpoints (may be empty) + """ + with self._lock: + results = [] + for ep in self._endpoints.values(): + if ep.name != name or ep.namespace != namespace: + continue + # Only return endpoints for running jobs + job = self._jobs.get(ep.job_id) + if job and job.state == cluster_pb2.JOB_STATE_RUNNING: + results.append(ep) + return results + + def list_endpoints_by_prefix(self, prefix: str, namespace: str) -> list[ControllerEndpoint]: + """List endpoints matching a name prefix. + + Only returns endpoints for jobs in RUNNING state. + + Args: + prefix: Service name prefix to match + namespace: Namespace to search in + + Returns: + List of matching endpoints (may be empty) + """ + with self._lock: + results = [] + for ep in self._endpoints.values(): + if not ep.name.startswith(prefix) or ep.namespace != namespace: + continue + job = self._jobs.get(ep.job_id) + if job and job.state == cluster_pb2.JOB_STATE_RUNNING: + results.append(ep) + return results + + def remove_endpoints_for_job(self, job_id: JobId) -> list[ControllerEndpoint]: + """Remove all endpoints for a job. + + Called when a job transitions to a terminal state to clean up + all service discovery entries. + + Args: + job_id: Job ID whose endpoints should be removed + + Returns: + List of removed endpoints + """ + with self._lock: + endpoint_ids = list(self._endpoints_by_job.get(job_id, [])) + removed = [] + for eid in endpoint_ids: + ep = self.remove_endpoint(eid) + if ep: + removed.append(ep) + # Clean up the job mapping + self._endpoints_by_job.pop(job_id, None) + return removed diff --git a/lib/fluster/src/fluster/cluster/controller/workers.py b/lib/fluster/src/fluster/cluster/controller/workers.py new file mode 100644 index 0000000000..c389e3e563 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/controller/workers.py @@ -0,0 +1,212 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker registry and scheduling for the controller. + +This module provides: +- WorkerConfig: Static worker configuration for v0 +- load_workers_from_config(): Register workers from static config at startup +- get_committed_resources(): Compute resources committed to running jobs +- worker_can_fit_job(): Check if worker has capacity for a job +- find_worker_for_job(): First-fit worker selection with resource matching +""" + +import time +from dataclasses import dataclass + +from fluster import cluster_pb2 +from fluster.cluster.controller.resources import ( + get_device_type, + get_device_variant, + get_gpu_count, + parse_memory_string, +) +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.types import WorkerId + + +@dataclass +class WorkerConfig: + """Static worker configuration for v0. + + Args: + worker_id: Unique worker identifier + address: Worker RPC address (host:port) + resources: Worker's available resources + """ + + worker_id: str + address: str + resources: cluster_pb2.ResourceSpec + + +def load_workers_from_config( + state: ControllerState, + workers: list[WorkerConfig], +) -> None: + """Register workers from static config. + + Creates ControllerWorker instances from the provided config and adds them + to the controller state. Sets last_heartbeat_ms to the current time so + workers are considered healthy at startup. + + Args: + state: Controller state to update + workers: List of worker configurations to register + """ + now_ms = int(time.time() * 1000) + + for cfg in workers: + worker = ControllerWorker( + worker_id=WorkerId(cfg.worker_id), + address=cfg.address, + resources=cfg.resources, + last_heartbeat_ms=now_ms, + ) + state.add_worker(worker) + + +def get_committed_resources( + state: ControllerState, + worker: ControllerWorker, +) -> tuple[int, int, int]: + """Compute resources committed to running jobs on this worker. + + Dynamically sums resources from all jobs in worker.running_jobs. + This approach avoids tracking committed resources incrementally and + prevents sync issues. + + Args: + state: Controller state to look up jobs + worker: Worker to compute committed resources for + + Returns: + Tuple of (cpu_cores, memory_bytes, gpu_count) + """ + cpu = 0 + memory = 0 + gpu = 0 + + for job_id in worker.running_jobs: + job = state.get_job(job_id) + if job: + resources = job.request.resources + cpu += resources.cpu + memory += parse_memory_string(resources.memory) + gpu += get_gpu_count(resources.device) + + return cpu, memory, gpu + + +def worker_can_fit_job( + state: ControllerState, + worker: ControllerWorker, + job: ControllerJob, + additional_jobs: list[ControllerJob] | None = None, +) -> bool: + """Check if worker has sufficient available capacity for job. + + Computes available headroom dynamically from running_jobs plus any + additional jobs that have been assigned in the current scheduling round + but not yet dispatched. + + Checks: + 1. CPU: job.cpu <= worker.total_cpu - committed_cpu + 2. Memory: job.memory <= worker.total_memory - committed_memory + 3. Device type: exact match (GPU job only on GPU worker) + 4. Device variant: if job specifies variant (not "auto"), worker must match + 5. GPU count: job.gpu_count <= available_gpus + + Args: + state: Controller state for job lookups + worker: Worker to check capacity + job: Job with resource requirements + additional_jobs: Jobs assigned this scheduling round but not yet + reflected in worker.running_jobs + + Returns: + True if worker can fit the job + """ + job_resources = job.request.resources + worker_resources = worker.resources + + # Get committed resources dynamically from running jobs + committed_cpu, committed_memory, committed_gpu = get_committed_resources(state, worker) + + # Add resources from jobs assigned this round but not yet dispatched + if additional_jobs: + for additional_job in additional_jobs: + add_resources = additional_job.request.resources + committed_cpu += add_resources.cpu + committed_memory += parse_memory_string(add_resources.memory) + committed_gpu += get_gpu_count(add_resources.device) + + # CPU check + available_cpu = worker_resources.cpu - committed_cpu + if job_resources.cpu > available_cpu: + return False + + # Memory check + worker_memory = parse_memory_string(worker_resources.memory) + job_memory = parse_memory_string(job_resources.memory) + available_memory = worker_memory - committed_memory + if job_memory > available_memory: + return False + + # Device type check + job_device_type = get_device_type(job_resources.device) + worker_device_type = get_device_type(worker_resources.device) + + if job_device_type != worker_device_type: + return False + + # Device variant check (only if job specifies one that's not "auto") + job_variant = get_device_variant(job_resources.device) + if job_variant and job_variant != "auto": + worker_variant = get_device_variant(worker_resources.device) + if worker_variant != job_variant: + return False + + # GPU count check + if job_device_type == "gpu": + job_gpu_count = get_gpu_count(job_resources.device) + worker_gpu_count = get_gpu_count(worker_resources.device) + available_gpus = worker_gpu_count - committed_gpu + if job_gpu_count > available_gpus: + return False + + return True + + +def find_worker_for_job( + state: ControllerState, + job: ControllerJob, +) -> ControllerWorker | None: + """Find a worker that can run the given job. + + Returns the first healthy worker with sufficient capacity and matching + device type/variant. Uses first-fit strategy. + + Args: + state: Controller state containing worker registry + job: Job to find a worker for + + Returns: + First matching worker, or None if no worker can fit the job + """ + workers = state.get_available_workers() + for worker in workers: + if worker_can_fit_job(state, worker, job): + return worker + return None diff --git a/lib/fluster/src/fluster/cluster/registry.py b/lib/fluster/src/fluster/cluster/registry.py new file mode 100644 index 0000000000..8bcd1e325f --- /dev/null +++ b/lib/fluster/src/fluster/cluster/registry.py @@ -0,0 +1,18 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Endpoint registry for service discovery. + +TODO: Implement in Stage 5 +""" diff --git a/lib/fluster/src/fluster/cluster/types.py b/lib/fluster/src/fluster/cluster/types.py new file mode 100644 index 0000000000..89def5e4d5 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/types.py @@ -0,0 +1,205 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core types for the fluster cluster layer. + +This module contains: +- Type aliases for IDs (JobId, WorkerId, etc.) +- Helper functions for working with proto types +- TPU topology information for scheduling +- Entrypoint dataclass for job execution + +Wire-format types (ResourceSpec, JobStatus, etc.) are defined in cluster.proto. +""" + +import os +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any, NewType + +from fluster import cluster_pb2 + +# Type aliases for clarity +JobId = NewType("JobId", str) +Namespace = NewType("Namespace", str) +WorkerId = NewType("WorkerId", str) +EndpointId = NewType("EndpointId", str) + + +def is_job_finished(state: int) -> bool: + """Check if job has reached terminal state.""" + return state in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + cluster_pb2.JOB_STATE_WORKER_FAILED, + cluster_pb2.JOB_STATE_UNSCHEDULABLE, + ) + + +JobState = cluster_pb2.JobState + + +@dataclass(frozen=True) +class TpuTopologyInfo: + """TPU topology configuration. + + Args: + name: TPU type name (e.g., "v5litepod-16", "v4-8") + chip_count: Total number of TPU chips + host_count: Number of physical hosts + vm_count: Number of VMs in the pod + chips_per_vm: Number of chips per VM + """ + + name: str + chip_count: int + host_count: int + vm_count: int + chips_per_vm: int + + +TPU_TOPOLOGIES: list[TpuTopologyInfo] = [ + # https://cloud.google.com/tpu/docs/v4 + TpuTopologyInfo("v4-8", 4, 1, 1, 4), + TpuTopologyInfo("v4-16", 8, 2, 2, 4), + TpuTopologyInfo("v4-32", 16, 4, 4, 4), + TpuTopologyInfo("v4-64", 32, 8, 8, 4), + TpuTopologyInfo("v4-128", 64, 16, 16, 4), + TpuTopologyInfo("v4-256", 128, 32, 32, 4), + TpuTopologyInfo("v4-512", 256, 64, 64, 4), + TpuTopologyInfo("v4-1024", 512, 128, 128, 4), + TpuTopologyInfo("v4-2048", 1024, 256, 256, 4), + TpuTopologyInfo("v4-4096", 2048, 512, 512, 4), + # https://cloud.google.com/tpu/docs/v5e + TpuTopologyInfo("v5litepod-1", 1, 1, 1, 1), + TpuTopologyInfo("v5litepod-2", 2, 1, 1, 2), + TpuTopologyInfo("v5litepod-4", 4, 1, 1, 4), + TpuTopologyInfo("v5litepod-8", 8, 1, 1, 8), + TpuTopologyInfo("v5litepod-16", 16, 2, 4, 4), + TpuTopologyInfo("v5litepod-32", 32, 4, 8, 4), + TpuTopologyInfo("v5litepod-64", 64, 8, 16, 4), + TpuTopologyInfo("v5litepod-128", 128, 16, 32, 4), + TpuTopologyInfo("v5litepod-256", 256, 32, 64, 4), + # https://cloud.google.com/tpu/docs/v5p + TpuTopologyInfo("v5p-8", 4, 1, 1, 4), + TpuTopologyInfo("v5p-16", 8, 2, 2, 4), + TpuTopologyInfo("v5p-32", 16, 4, 4, 4), + TpuTopologyInfo("v5p-64", 32, 8, 8, 4), + TpuTopologyInfo("v5p-128", 64, 16, 16, 4), + TpuTopologyInfo("v5p-256", 128, 32, 32, 4), + TpuTopologyInfo("v5p-512", 256, 64, 64, 4), + TpuTopologyInfo("v5p-1024", 512, 128, 128, 4), + TpuTopologyInfo("v5p-2048", 1024, 256, 256, 4), + TpuTopologyInfo("v5p-4096", 2048, 512, 512, 4), + TpuTopologyInfo("v5p-8192", 4096, 1024, 1024, 4), + TpuTopologyInfo("v5p-12288", 6144, 1536, 1536, 4), + # https://cloud.google.com/tpu/docs/v6e + TpuTopologyInfo("v6e-1", 1, 1, 1, 1), + TpuTopologyInfo("v6e-4", 4, 1, 1, 4), + TpuTopologyInfo("v6e-8", 8, 1, 1, 8), + TpuTopologyInfo("v6e-16", 16, 4, 4, 4), + TpuTopologyInfo("v6e-32", 32, 8, 8, 4), + TpuTopologyInfo("v6e-64", 64, 16, 16, 4), + TpuTopologyInfo("v6e-128", 128, 32, 32, 4), + TpuTopologyInfo("v6e-256", 256, 64, 64, 4), +] + + +def get_tpu_topology(tpu_type: str) -> TpuTopologyInfo: + """Get TPU topology by type name. + + Args: + tpu_type: TPU type name (e.g., "v5litepod-16", "v4-8") + + Returns: + TpuTopologyInfo for the given type + + Raises: + ValueError: If TPU type is unknown + """ + for config in TPU_TOPOLOGIES: + if config.name == tpu_type: + return config + raise ValueError(f"Unknown TPU type: {tpu_type}") + + +# Job Entrypoint + + +@dataclass +class Entrypoint: + """Job entrypoint specification. + + A callable with args/kwargs that will be executed by the worker. + The callable must be picklable (via cloudpickle). + + Args: + callable: Python callable to execute + args: Positional arguments to pass + kwargs: Keyword arguments to pass + """ + + callable: Callable[..., Any] + args: tuple = () + kwargs: dict[str, Any] = field(default_factory=dict) + + +# Helper functions for creating proto messages + + +def create_environment( + workspace: str | None = None, + pip_packages: Sequence[str] | None = None, + env_vars: dict[str, str] | None = None, + extras: Sequence[str] | None = None, +) -> cluster_pb2.EnvironmentConfig: + """Create an EnvironmentConfig proto with sensible defaults. + + Default environment variables: + - HF_DATASETS_TRUST_REMOTE_CODE: "1" (allows custom dataset code) + - TOKENIZERS_PARALLELISM: "false" (avoids tokenizer deadlocks) + - HF_TOKEN: from os.environ (if set) + - WANDB_API_KEY: from os.environ (if set) + + Args: + workspace: Path to workspace root (default: current directory) + pip_packages: Additional pip packages to install + env_vars: Custom environment variables (merged with defaults) + extras: Extra dependency groups for uv + + Returns: + EnvironmentConfig proto message with defaults applied + """ + if workspace is None: + workspace = os.getcwd() + + default_env_vars = { + "HF_DATASETS_TRUST_REMOTE_CODE": "1", + "TOKENIZERS_PARALLELISM": "false", + "HF_TOKEN": os.getenv("HF_TOKEN"), + "WANDB_API_KEY": os.getenv("WANDB_API_KEY"), + } + + # Filter out None values and merge with user-provided vars + merged_env_vars = {k: v for k, v in {**default_env_vars, **(env_vars or {})}.items() if v is not None} + + config = cluster_pb2.EnvironmentConfig( + workspace=workspace, + pip_packages=list(pip_packages or []), + env_vars=merged_env_vars, + extras=list(extras or []), + ) + + return config diff --git a/lib/fluster/src/fluster/cluster/worker/__init__.py b/lib/fluster/src/fluster/cluster/worker/__init__.py new file mode 100644 index 0000000000..3aacafcc54 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fluster worker implementation.""" + +from fluster.cluster.worker.worker import Worker, WorkerConfig + +__all__ = ["Worker", "WorkerConfig"] diff --git a/lib/fluster/src/fluster/cluster/worker/builder.py b/lib/fluster/src/fluster/cluster/worker/builder.py new file mode 100644 index 0000000000..9e2445a2e8 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/builder.py @@ -0,0 +1,202 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Virtual environment and Docker image caching with UV support.""" + +import hashlib +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +from fluster.cluster.worker.docker import DockerImageBuilder +from fluster.cluster.worker.worker_types import JobLogs + + +@dataclass +class VenvCacheEntry: + """Cached venv metadata.""" + + deps_hash: str + created_at: float + size_bytes: int + + +class VenvCache: + """Utility for computing dependency hashes for cache invalidation. + + UV handles dependency caching natively via BuildKit cache mounts with + explicit global cache ID (fluster-uv-global). This ensures all workspaces + share the same BuildKit-managed cache for dependency reuse. + + This class provides utilities for computing dependency hashes from + pyproject.toml and uv.lock files to determine when Docker image layers + can be reused. + """ + + def compute_deps_hash(self, bundle_path: Path) -> str: + """Compute hash from pyproject.toml + uv.lock.""" + h = hashlib.sha256() + for fname in ["pyproject.toml", "uv.lock"]: + fpath = bundle_path / fname + if fpath.exists(): + h.update(fpath.read_bytes()) + return h.hexdigest() + + +@dataclass +class BuildResult: + """Result of Docker image build.""" + + image_tag: str + deps_hash: str + build_time_ms: int + from_cache: bool + + +class ImageProvider(Protocol): + """Protocol for Docker image management. Mock this for testing.""" + + def build( + self, + bundle_path: Path, + base_image: str, + extras: list[str], + job_id: str, + deps_hash: str, + job_logs: JobLogs | None = None, + ) -> BuildResult: + """Build Docker image for job. Returns cached image if deps_hash matches.""" + ... + + +DOCKERFILE_TEMPLATE = """FROM {base_image} + +# Install git (required for git-based dependencies) and UV +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +# TODO -- install Cargo here. +# How do we make Rust stuff build faster? +# We could pre-build something similar or at least fetch Rust deps to cache? + +# Configure UV +# TODO, is this wasting disk space maybe we don't care +ENV UV_CACHE_DIR=/opt/uv-cache +ENV UV_LINK_MODE=copy +ENV UV_PROJECT_ENVIRONMENT=/app/.venv +WORKDIR /app + +# TODO cache dependencies once across jobs just for usefulness + +# Copy workspace contents +# Path dependencies referenced in [tool.uv.sources] must be present before uv sync. +# We copy everything upfront to support workspaces with local path dependencies. +COPY . . + +# Install all dependencies and project +RUN --mount=type=cache,id=fluster-uv-global,sharing=locked,target=/opt/uv-cache \\ + uv sync --frozen {extras_flags} + +# Use the venv python +ENV PATH="/app/.venv/bin:$PATH" +""" + + +class ImageCache: + """Manages Docker image building with caching. + + Image tag: {registry}/fluster-job-{job_id}:{deps_hash[:8]} + Uses Docker BuildKit cache mounts with explicit global cache ID + (fluster-uv-global) to ensure all workspaces share the same UV cache. + + Cache behavior: + - All builds use id=fluster-uv-global for the UV cache mount + - sharing=locked prevents concurrent access issues during UV operations + - Different workspaces reuse cached dependencies automatically + - BuildKit manages cache storage in /var/lib/buildkit/ + + Delegates actual Docker operations to DockerImageBuilder, keeping + caching logic separate from container runtime specifics. + """ + + def __init__( + self, + cache_dir: Path, + registry: str, + max_images: int = 50, + ): + self._cache_dir = cache_dir / "images" + self._registry = registry + self._max_images = max_images + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._docker = DockerImageBuilder(registry) + + def build( + self, + bundle_path: Path, + base_image: str, + extras: list[str], + job_id: str, + deps_hash: str, + job_logs: JobLogs | None = None, + ) -> BuildResult: + """Build Docker image for job. + + Returns cached image if deps_hash matches. + """ + image_tag = f"{self._registry}/fluster-job-{job_id}:{deps_hash[:8]}" + + # Check if image exists locally + if self._docker.exists(image_tag): + if job_logs: + job_logs.add("build", f"Using cached image: {image_tag}") + return BuildResult( + image_tag=image_tag, + deps_hash=deps_hash, + build_time_ms=0, + from_cache=True, + ) + + # Build image + start = time.time() + extras_flags = " ".join(f"--extra {e}" for e in extras) if extras else "" + dockerfile = DOCKERFILE_TEMPLATE.format( + base_image=base_image, + extras_flags=extras_flags, + ) + self._docker.build(bundle_path, dockerfile, image_tag, job_logs) + build_time_ms = int((time.time() - start) * 1000) + + self._evict_old_images() + + return BuildResult( + image_tag=image_tag, + deps_hash=deps_hash, + build_time_ms=build_time_ms, + from_cache=False, + ) + + def _evict_old_images(self) -> None: + """Remove old fluster images when over limit.""" + pattern = f"{self._registry}/fluster-job-*" + images = self._docker.list_images(pattern) + + if len(images) <= self._max_images: + return + + # Sort by creation time and remove oldest + images.sort(key=lambda x: x.created_at) + for image in images[: len(images) - self._max_images]: + self._docker.remove(image.tag) diff --git a/lib/fluster/src/fluster/cluster/worker/bundle.py b/lib/fluster/src/fluster/cluster/worker/bundle.py new file mode 100644 index 0000000000..c7c07bbead --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/bundle.py @@ -0,0 +1,141 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bundle cache for workspace bundles from GCS.""" + +import hashlib +import threading +import zipfile +from collections import defaultdict +from pathlib import Path +from typing import Protocol + +import fsspec + + +class BundleProvider(Protocol): + """Protocol for bundle retrieval. Mock this for testing.""" + + def get_bundle(self, gcs_path: str, expected_hash: str | None = None) -> Path: + """Download/retrieve bundle and return path to extracted directory.""" + ... + + +class BundleCache: + """Cache for workspace bundles downloaded from GCS. + + Assumes GCS paths are unique - uses path as cache key. + Two-level caching: zip files + extracted directories. + + Supports both gs:// paths (requires GCS auth) and file:// paths + for local testing. + """ + + def __init__(self, cache_dir: Path, max_bundles: int = 100): + self._cache_dir = cache_dir + self._bundles_dir = cache_dir / "bundles" + self._extracts_dir = cache_dir / "extracts" + self._max_bundles = max_bundles + self._extract_locks: dict[str, threading.Lock] = defaultdict(threading.Lock) + + self._bundles_dir.mkdir(parents=True, exist_ok=True) + self._extracts_dir.mkdir(parents=True, exist_ok=True) + + def _path_to_key(self, gcs_path: str) -> str: + """Convert GCS path to cache key (hash).""" + return hashlib.sha256(gcs_path.encode()).hexdigest()[:16] + + def get_bundle(self, gcs_path: str, expected_hash: str | None = None) -> Path: + """Get bundle path, downloading if needed. + + Args: + gcs_path: gs://bucket/path/bundle.zip or file:///local/path.zip + expected_hash: Optional SHA256 hash for verification + + Returns: + Path to extracted bundle directory + """ + key = self._path_to_key(gcs_path) + extract_path = self._extracts_dir / key + + if extract_path.exists(): + # Update access time for LRU + extract_path.touch() + return extract_path + + # Use a lock per bundle to prevent concurrent extractions to the same path + with self._extract_locks[key]: + # Double-check after acquiring lock - another task may have extracted it + if extract_path.exists(): + extract_path.touch() + return extract_path + + # Download and extract + zip_path = self._bundles_dir / f"{key}.zip" + if not zip_path.exists(): + self._download(gcs_path, zip_path) + + if expected_hash: + actual_hash = self._compute_hash(zip_path) + if actual_hash != expected_hash: + raise ValueError(f"Bundle hash mismatch: {actual_hash} != {expected_hash}") + + self._extract(zip_path, extract_path) + self._evict_old_bundles() + + return extract_path + + def _download(self, gcs_path: str, local_path: Path) -> None: + """Synchronous download implementation.""" + # fsspec handles gs://, file://, and other protocols + with fsspec.open(gcs_path, "rb") as src: + with open(local_path, "wb") as dst: + dst.write(src.read()) + + def _extract(self, zip_path: Path, extract_path: Path) -> None: + """Synchronous extraction implementation with zip slip protection.""" + extract_path.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(zip_path, "r") as zf: + # Validate all paths to prevent zip slip attacks + for member in zf.namelist(): + member_path = (extract_path / member).resolve() + if not member_path.is_relative_to(extract_path.resolve()): + raise ValueError(f"Zip slip detected: {member} attempts to write outside extract path") + zf.extractall(extract_path) + + def _compute_hash(self, path: Path) -> str: + """Synchronous hash computation implementation.""" + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + def _evict_old_bundles(self) -> None: + """LRU eviction when over max_bundles.""" + extracts = list(self._extracts_dir.iterdir()) + if len(extracts) <= self._max_bundles: + return + + # Sort by mtime, remove oldest + extracts.sort(key=lambda p: p.stat().st_mtime) + for path in extracts[: len(extracts) - self._max_bundles]: + if path.is_dir(): + import shutil + + shutil.rmtree(path) + # Also remove corresponding zip + zip_path = self._bundles_dir / f"{path.name}.zip" + if zip_path.exists(): + zip_path.unlink() diff --git a/lib/fluster/src/fluster/cluster/worker/dashboard.py b/lib/fluster/src/fluster/cluster/worker/dashboard.py new file mode 100644 index 0000000000..cc60651025 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/dashboard.py @@ -0,0 +1,437 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HTTP dashboard with Connect RPC and web UI for worker monitoring. + +The WorkerDashboard provides: +- Connect RPC at /fluster.cluster.WorkerService +- Web dashboard at / with live job statistics +- REST API at /api/* for dashboard consumption + +REST endpoints are implemented by calling the canonical RPC methods and +converting proto responses to JSON for browser consumption. +""" + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse +from starlette.routing import Mount, Route + +from fluster import cluster_pb2 +from starlette.middleware.wsgi import WSGIMiddleware + +from fluster.cluster_connect import WorkerServiceWSGIApplication +from fluster.cluster.worker.service import WorkerServiceImpl + + +class FakeRequestContext: + """Minimal stub RequestContext for internal REST-to-RPC bridging. + + The WorkerDashboard translates REST API calls to RPC method calls, + which require a RequestContext parameter. Since the RPC methods never + actually access the context, this minimal stub satisfies the type signature. + """ + + pass + + +DASHBOARD_HTML = """ + + + + Fluster Worker + + + +

Fluster Worker Dashboard

+
+

Jobs

+ + +
IDStatusExitMemoryCPUStartedFinishedError
+ + + +""" + + +JOB_DETAIL_HTML = """ + + + + Job {{job_id}} - Fluster Worker + + + +

Job: {{job_id}}

+ ← Back to Dashboard + +
+

Status:

+
+
+ +
+

Resources

+
+
+ +
+

Build

+
+
+ +
+

Logs

+
+
ALL
+
STDOUT
+
STDERR
+
BUILD
+
+
+
+
+
+
+ + + + +""" + + +class WorkerDashboard: + """HTTP dashboard with Connect RPC and web UI. + + Connect RPC is mounted at /fluster.cluster.WorkerService + Web dashboard at / + REST API for dashboard at /api/* + """ + + def __init__( + self, + service: WorkerServiceImpl, + host: str = "0.0.0.0", + port: int = 8080, + ): + self._service = service + self._host = host + self._port = port + self._app = self._create_app() + + @property + def port(self) -> int: + return self._port + + def _create_app(self) -> Starlette: + """Create Starlette application with all routes.""" + # Use WSGI application for sync RPC handlers, wrapped for ASGI compatibility + rpc_wsgi_app = WorkerServiceWSGIApplication(service=self._service) + rpc_app = WSGIMiddleware(rpc_wsgi_app) + + routes = [ + # Web dashboard + Route("/", self._dashboard), + Route("/job/{job_id}", self._job_detail_page), + # REST API (for dashboard) + Route("/api/stats", self._stats), + Route("/api/jobs", self._list_jobs), + Route("/api/jobs/{job_id}", self._get_job), + Route("/api/jobs/{job_id}/logs", self._get_logs), + # Connect RPC - mount WSGI app wrapped for ASGI + Mount(rpc_wsgi_app.path, app=rpc_app), + ] + return Starlette(routes=routes) + + def _dashboard(self, _request: Request) -> HTMLResponse: + """Serve web dashboard HTML.""" + return HTMLResponse(DASHBOARD_HTML) + + def _job_detail_page(self, request: Request) -> HTMLResponse: + """Serve detailed job view page.""" + job_id = request.path_params["job_id"] + return HTMLResponse(JOB_DETAIL_HTML.replace("{{job_id}}", job_id)) + + def _stats(self, _request: Request) -> JSONResponse: + """Return job statistics by status.""" + # Call canonical RPC method + ctx = FakeRequestContext() + response = self._service.list_jobs(cluster_pb2.Worker.ListJobsRequest(), ctx) + jobs = response.jobs + + return JSONResponse( + { + "running": sum(1 for j in jobs if j.state == cluster_pb2.JOB_STATE_RUNNING), + "pending": sum(1 for j in jobs if j.state == cluster_pb2.JOB_STATE_PENDING), + "building": sum(1 for j in jobs if j.state == cluster_pb2.JOB_STATE_BUILDING), + "completed": sum( + 1 + for j in jobs + if j.state + in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + ) + ), + } + ) + + def _list_jobs(self, _request: Request) -> JSONResponse: + """List all jobs as JSON.""" + # Call canonical RPC method + ctx = FakeRequestContext() + response = self._service.list_jobs(cluster_pb2.Worker.ListJobsRequest(), ctx) + jobs = response.jobs + + return JSONResponse( + [ + { + "job_id": j.job_id, + "status": self._status_name(j.state), + "started_at": j.started_at_ms, + "finished_at": j.finished_at_ms, + "exit_code": j.exit_code, + "error": j.error, + # Add resource metrics + "memory_mb": j.resource_usage.memory_mb, + "memory_peak_mb": j.resource_usage.memory_peak_mb, + "cpu_percent": j.resource_usage.cpu_percent, + "process_count": j.resource_usage.process_count, + "disk_mb": j.resource_usage.disk_mb, + # Add build metrics + "build_from_cache": j.build_metrics.from_cache, + "image_tag": j.build_metrics.image_tag, + } + for j in jobs + ] + ) + + def _get_job(self, request: Request) -> JSONResponse: + """Get single job by ID.""" + job_id = request.path_params["job_id"] + + # Call canonical RPC method + ctx = FakeRequestContext() + try: + job = self._service.get_job_status(cluster_pb2.Worker.GetJobStatusRequest(job_id=job_id), ctx) + except Exception: + # RPC raises ConnectError with NOT_FOUND for missing jobs + return JSONResponse({"error": "Not found"}, status_code=404) + + return JSONResponse( + { + "job_id": job.job_id, + "status": self._status_name(job.state), + "started_at": job.started_at_ms, + "finished_at": job.finished_at_ms, + "exit_code": job.exit_code, + "error": job.error, + "ports": dict(job.ports), + "resources": { + "memory_mb": job.resource_usage.memory_mb, + "memory_peak_mb": job.resource_usage.memory_peak_mb, + "cpu_percent": job.resource_usage.cpu_percent, + "disk_mb": job.resource_usage.disk_mb, + "process_count": job.resource_usage.process_count, + }, + "build": { + "started_ms": job.build_metrics.build_started_ms, + "finished_ms": job.build_metrics.build_finished_ms, + "duration_ms": ( + (job.build_metrics.build_finished_ms - job.build_metrics.build_started_ms) + if job.build_metrics.build_started_ms + else 0 + ), + "from_cache": job.build_metrics.from_cache, + "image_tag": job.build_metrics.image_tag, + }, + } + ) + + def _get_logs(self, request: Request) -> JSONResponse: + """Get logs with optional tail and source parameters.""" + job_id = request.path_params["job_id"] + + # Support ?tail=N for last N lines + tail = request.query_params.get("tail") + start_line = -int(tail) if tail else 0 + + # Support ?source=stdout|stderr|build for filtering + source = request.query_params.get("source") + + # Call canonical RPC method + ctx = FakeRequestContext() + log_filter = cluster_pb2.Worker.FetchLogsFilter(start_line=start_line) + try: + response = self._service.fetch_logs( + cluster_pb2.Worker.FetchLogsRequest(job_id=job_id, filter=log_filter), ctx + ) + except Exception: + # RPC raises ConnectError with NOT_FOUND for missing jobs + return JSONResponse({"error": "Not found"}, status_code=404) + + logs = [ + { + "timestamp": entry.timestamp_ms, + "source": entry.source, + "data": entry.data, + } + for entry in response.logs + ] + + # Apply source filter if specified + if source: + logs = [log for log in logs if log["source"] == source] + + return JSONResponse(logs) + + def _status_name(self, status: cluster_pb2.JobState) -> str: + """Convert status enum to string name.""" + status_map = { + cluster_pb2.JOB_STATE_PENDING: "pending", + cluster_pb2.JOB_STATE_BUILDING: "building", + cluster_pb2.JOB_STATE_RUNNING: "running", + cluster_pb2.JOB_STATE_SUCCEEDED: "succeeded", + cluster_pb2.JOB_STATE_FAILED: "failed", + cluster_pb2.JOB_STATE_KILLED: "killed", + } + return status_map.get(status, "unknown") + + def run(self) -> None: + """Run server (blocking).""" + import uvicorn + + uvicorn.run(self._app, host=self._host, port=self._port) + + async def run_async(self) -> None: + """Run server asynchronously (for use with asyncio.create_task).""" + import uvicorn + + config = uvicorn.Config(self._app, host=self._host, port=self._port) + self._server = uvicorn.Server(config) + await self._server.serve() + + async def shutdown(self) -> None: + """Shutdown the async server gracefully.""" + if hasattr(self, "_server") and self._server: + self._server.should_exit = True diff --git a/lib/fluster/src/fluster/cluster/worker/docker.py b/lib/fluster/src/fluster/cluster/worker/docker.py new file mode 100644 index 0000000000..9d0f21d010 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/docker.py @@ -0,0 +1,520 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Docker container runtime and image builder implementations.""" + +# TODO - set things up some memray/pyspy/etc work as expected +# these need to be installed at least, and then need maybe some permissions + +import json +import os +import re +import subprocess +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Protocol + +import docker.errors + +import docker +from fluster import cluster_pb2 +from fluster.cluster.worker.worker_types import JobLogs, LogLine + + +@dataclass +class ContainerConfig: + """Configuration for container execution.""" + + image: str + command: list[str] + env: dict[str, str] + workdir: str = "/app" + resources: cluster_pb2.ResourceSpec | None = None + timeout_seconds: int | None = None + mounts: list[tuple[str, str, str]] = field(default_factory=list) # (host, container, mode) + ports: dict[str, int] = field(default_factory=dict) # name -> host_port + + def _parse_memory_mb(self, memory_str: str) -> int: + """Parse memory string like '8g', '128m' to MB.""" + # TODO humansomething parser + match = re.match(r"^(\d+)([gmk]?)$", memory_str.lower()) + if not match: + raise ValueError(f"Invalid memory format: {memory_str}") + + value, unit = match.groups() + value = int(value) + + if unit == "g": + return value * 1024 + elif unit == "k": + return value // 1024 + else: # 'm' or no unit (assume MB) + return value + + def get_cpu_millicores(self) -> int | None: + """Get CPU millicores from ResourceSpec.""" + if not self.resources or not self.resources.cpu: + return None + return self.resources.cpu * 1000 # Convert cores to millicores + + def get_memory_mb(self) -> int | None: + """Get memory in MB from ResourceSpec.""" + if not self.resources or not self.resources.memory: + return None + return self._parse_memory_mb(self.resources.memory) + + +@dataclass +class ContainerResult: + """Result of container execution.""" + + container_id: str + exit_code: int + started_at: float + finished_at: float + error: str | None = None + + +@dataclass +class ContainerStats: + """Parsed container statistics. + + Attributes: + memory_mb: Memory usage in megabytes + cpu_percent: CPU usage as percentage (0-100) + process_count: Number of processes running in container + available: False if container stopped or unavailable + """ + + memory_mb: int + cpu_percent: int + process_count: int + available: bool + + +@dataclass +class ContainerStatus: + """Container state from docker inspect. + + Attributes: + running: True if container is currently running + exit_code: Exit code if container has exited, None if still running + error: Error message if container failed to start + """ + + running: bool + exit_code: int | None = None + error: str | None = None + + +@dataclass +class ImageInfo: + """Information about a container image.""" + + tag: str + created_at: str + + +# ============================================================================= +# Protocols +# ============================================================================= + + +class ContainerRuntime(Protocol): + """Protocol for container runtimes (Docker, Firecracker, Podman, etc.).""" + + def create_container(self, config: ContainerConfig) -> str: + """Create container and return container_id.""" + ... + + def start_container(self, container_id: str) -> None: + """Start a created container (non-blocking).""" + ... + + def inspect(self, container_id: str) -> ContainerStatus: + """Check container status.""" + ... + + def kill(self, container_id: str, force: bool = False) -> None: + """Kill container (SIGTERM or SIGKILL).""" + ... + + def remove(self, container_id: str) -> None: + """Remove stopped container.""" + ... + + def get_logs(self, container_id: str) -> list[LogLine]: + """Fetch logs from container.""" + ... + + def get_stats(self, container_id: str) -> ContainerStats: + """Collect container resource statistics.""" + ... + + +class ImageBuilder(Protocol): + """Protocol for image building (Docker build, rootfs creation, etc.).""" + + def build( + self, + context: Path, + dockerfile_content: str, + tag: str, + job_logs: JobLogs | None = None, + ) -> None: + """Build image from context directory.""" + ... + + def exists(self, tag: str) -> bool: + """Check if image exists locally.""" + ... + + def remove(self, tag: str) -> None: + """Remove image.""" + ... + + def list_images(self, pattern: str) -> list[ImageInfo]: + """List images matching pattern.""" + ... + + +class DockerRuntime: + """Execute containers via Docker CLI with cgroups v2 resource limits. + + Security hardening: + - no-new-privileges + - cap-drop ALL + + Uses subprocess.run() for synchronous container lifecycle operations, and the Docker + Python library for stats/logs retrieval. + """ + + def create_container(self, config: ContainerConfig) -> str: + """Create container with cgroups v2 resource limits.""" + cmd = [ + "docker", + "create", + "--add-host=host.docker.internal:host-gateway", + "--security-opt", + "no-new-privileges", + "--cap-drop", + "ALL", + "-w", + config.workdir, + ] + + # Resource limits (cgroups v2) + cpu_millicores = config.get_cpu_millicores() + if cpu_millicores: + cpus = cpu_millicores / 1000 + cmd.extend(["--cpus", str(cpus)]) + memory_mb = config.get_memory_mb() + if memory_mb: + cmd.extend(["--memory", f"{memory_mb}m"]) + + # Environment variables + for k, v in config.env.items(): + cmd.extend(["-e", f"{k}={v}"]) + + # Mounts + for host, container, mode in config.mounts: + cmd.extend(["-v", f"{host}:{container}:{mode}"]) + + # Port mappings (name is for reference only, bind host->container) + for host_port in config.ports.values(): + cmd.extend(["-p", f"{host_port}:{host_port}"]) + + cmd.append(config.image) + cmd.extend(config.command) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to create container: {result.stderr}") + return result.stdout.strip() + + def start_container(self, container_id: str) -> None: + """Start a created container (non-blocking).""" + result = subprocess.run( + ["docker", "start", container_id], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to start container: {result.stderr}") + + def inspect(self, container_id: str) -> ContainerStatus: + """Check container status via docker inspect.""" + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{json .State}}", + ], + capture_output=True, + text=True, + check=False, + ) + + if result.returncode != 0: + return ContainerStatus(running=False, error="Container not found") + + try: + state = json.loads(result.stdout.strip()) + running = state.get("Running", False) + exit_code = state.get("ExitCode") + error_msg = state.get("Error", "") or None + + return ContainerStatus( + running=running, + exit_code=exit_code if not running else None, + error=error_msg, + ) + except (json.JSONDecodeError, KeyError) as e: + return ContainerStatus(running=False, error=f"Failed to parse inspect output: {e}") + + def kill(self, container_id: str, force: bool = False) -> None: + """Kill container. + + Args: + container_id: Container ID to kill + force: Use SIGKILL instead of SIGTERM + """ + signal = "SIGKILL" if force else "SIGTERM" + result = subprocess.run( + ["docker", "kill", f"--signal={signal}", container_id], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to kill container: {result.stderr}") + + def remove(self, container_id: str) -> None: + """Remove container.""" + result = subprocess.run( + ["docker", "rm", "-f", container_id], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to remove container: {result.stderr}") + + def get_logs(self, container_id: str) -> list[LogLine]: + """Fetch logs from container.""" + client = docker.from_env() # type: ignore[attr-defined] + try: + container = client.containers.get(container_id) + except docker.errors.NotFound: + return [] + + logs: list[LogLine] = [] + + # Fetch stdout with timestamps + stdout_logs = container.logs(stdout=True, stderr=False, timestamps=True) + for line in stdout_logs.decode().splitlines(): + if line: + timestamp, data = self._parse_docker_log_line(line) + logs.append(LogLine(timestamp=timestamp, source="stdout", data=data)) + + # Fetch stderr with timestamps + stderr_logs = container.logs(stdout=False, stderr=True, timestamps=True) + for line in stderr_logs.decode().splitlines(): + if line: + timestamp, data = self._parse_docker_log_line(line) + logs.append(LogLine(timestamp=timestamp, source="stderr", data=data)) + + return logs + + def _parse_docker_log_line(self, line: str) -> tuple[datetime, str]: + """Parse Docker log line with timestamp.""" + if len(line) > 30 and line[10] == "T": + z_idx = line.find("Z") + if 20 < z_idx < 35: + ts_str = line[: z_idx + 1] + # Truncate nanoseconds to microseconds for fromisoformat + if len(ts_str) > 27: + ts_str = ts_str[:26] + "Z" + try: + ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00")) + return ts, line[z_idx + 2 :] + except ValueError: + pass + return datetime.now(timezone.utc), line + + def get_stats(self, container_id: str) -> ContainerStats: + """Collect resource usage from a Docker container.""" + client = docker.from_env() # type: ignore[attr-defined] + try: + container = client.containers.get(container_id) + stats = container.stats(decode=True, stream=False) + + # Parse memory usage (bytes to MB) + memory_bytes = stats.get("memory_stats", {}).get("usage", 0) + memory_mb = int(memory_bytes / (1024 * 1024)) + + # Calculate CPU percentage from deltas + cpu_percent = _calculate_cpu_percent(stats) + + # Parse process count + process_count = stats.get("pids_stats", {}).get("current", 0) + + return ContainerStats( + memory_mb=memory_mb, + cpu_percent=cpu_percent, + process_count=process_count, + available=True, + ) + except (docker.errors.NotFound, docker.errors.APIError): + return ContainerStats( + memory_mb=0, + cpu_percent=0, + process_count=0, + available=False, + ) + + +def _calculate_cpu_percent(stats: dict) -> int: + """Calculate CPU percentage from stats deltas. + + Docker stats format provides cpu_stats and precpu_stats for delta calculation. + CPU percentage = (cpu_delta / system_delta) * num_cpus * 100 + """ + cpu_stats = stats.get("cpu_stats", {}) + precpu_stats = stats.get("precpu_stats", {}) + + cpu_delta = cpu_stats.get("cpu_usage", {}).get("total_usage", 0) - precpu_stats.get("cpu_usage", {}).get( + "total_usage", 0 + ) + system_delta = cpu_stats.get("system_cpu_usage", 0) - precpu_stats.get("system_cpu_usage", 0) + + if system_delta == 0 or cpu_delta == 0: + return 0 + + num_cpus = cpu_stats.get("online_cpus", 1) + cpu_percent = (cpu_delta / system_delta) * num_cpus * 100.0 + + return int(cpu_percent) + + +class DockerImageBuilder: + """Build Docker images using Docker CLI with BuildKit.""" + + def __init__(self, registry: str): + self._registry = registry + + def build( + self, + context: Path, + dockerfile_content: str, + tag: str, + job_logs: JobLogs | None = None, + ) -> None: + """Run docker build with BuildKit.""" + dockerfile_path = context / "Dockerfile.fluster" + dockerfile_path.write_text(dockerfile_content) + + try: + if job_logs: + job_logs.add("build", f"Starting build for image: {tag}") + + cmd = [ + "docker", + "build", + "-f", + str(dockerfile_path), + "-t", + tag, + "--build-arg", + "BUILDKIT_INLINE_CACHE=1", + str(context), + ] + + proc = subprocess.Popen( + cmd, + env={**os.environ, "DOCKER_BUILDKIT": "1"}, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # Stream output to job_logs + if proc.stdout: + for line in proc.stdout: + if job_logs: + job_logs.add("build", line.rstrip()) + + returncode = proc.wait() + + if job_logs: + if returncode == 0: + job_logs.add("build", "Build completed successfully") + else: + job_logs.add("build", f"Build failed with exit code {returncode}") + + if returncode != 0: + raise RuntimeError(f"Docker build failed with exit code {returncode}") + finally: + # Cleanup generated dockerfile + dockerfile_path.unlink(missing_ok=True) + + def exists(self, tag: str) -> bool: + """Check if image exists locally.""" + result = subprocess.run( + ["docker", "image", "inspect", tag], + capture_output=True, + check=False, + ) + return result.returncode == 0 + + def remove(self, tag: str) -> None: + """Remove image.""" + subprocess.run( + ["docker", "rmi", tag], + capture_output=True, + check=False, + ) + + def list_images(self, pattern: str) -> list[ImageInfo]: + """List images matching pattern.""" + result = subprocess.run( + [ + "docker", + "images", + "--format", + "{{.Repository}}:{{.Tag}}\t{{.CreatedAt}}", + "--filter", + f"reference={pattern}", + ], + capture_output=True, + text=True, + check=False, + ) + + images = [] + for line in result.stdout.strip().split("\n"): + if line and "\t" in line: + tag, created = line.split("\t", 1) + images.append(ImageInfo(tag=tag, created_at=created)) + + return images diff --git a/lib/fluster/src/fluster/cluster/worker/main.py b/lib/fluster/src/fluster/cluster/worker/main.py new file mode 100644 index 0000000000..15a6b20d0e --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/main.py @@ -0,0 +1,85 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Click-based CLI for the Fluster worker daemon. + +Provides two commands: +- serve: Start the worker service +- cleanup: Remove cached bundles, venvs, and images +""" + +import shutil +from pathlib import Path + +import click + +from fluster.cluster.worker.worker import Worker, WorkerConfig + + +@click.group() +def cli(): + """Fluster Worker - Job execution daemon.""" + pass + + +@cli.command() +@click.option("--host", default="0.0.0.0", help="Bind host") +@click.option("--port", default=8080, type=int, help="Bind port") +@click.option("--cache-dir", default="~/.cache/fluster-worker", help="Cache directory") +@click.option("--registry", required=True, help="Docker registry for built images") +@click.option("--max-concurrent-jobs", default=10, type=int, help="Max concurrent jobs") +@click.option("--port-range", default="30000-40000", help="Port range for job ports (start-end)") +def serve( + host: str, + port: int, + cache_dir: str, + registry: str, + max_concurrent_jobs: int, + port_range: str, +): + """Start the Fluster worker service.""" + port_start, port_end = map(int, port_range.split("-")) + + config = WorkerConfig( + host=host, + port=port, + cache_dir=Path(cache_dir).expanduser(), + registry=registry, + max_concurrent_jobs=max_concurrent_jobs, + port_range=(port_start, port_end), + ) + + worker = Worker(config) + + click.echo(f"Starting Fluster worker on {host}:{port}") + click.echo(f" Registry: {registry}") + click.echo(f" Cache dir: {config.cache_dir}") + click.echo(f" Max concurrent jobs: {max_concurrent_jobs}") + worker._run_server() + + +@cli.command() +@click.option("--cache-dir", default="~/.cache/fluster-worker", help="Cache directory") +def cleanup(cache_dir: str): + """Clean up cached bundles, venvs, and images.""" + cache_path = Path(cache_dir).expanduser() + if cache_path.exists(): + shutil.rmtree(cache_path) + click.echo(f"Removed cache directory: {cache_path}") + else: + click.echo(f"Cache directory does not exist: {cache_path}") + + +if __name__ == "__main__": + cli() diff --git a/lib/fluster/src/fluster/cluster/worker/service.py b/lib/fluster/src/fluster/cluster/worker/service.py new file mode 100644 index 0000000000..b68e2a098f --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/service.py @@ -0,0 +1,188 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WorkerService RPC implementation using Connect RPC. + +Implements the WorkerService protocol defined in cluster.proto. +Provides job execution, status, logs, and health monitoring endpoints. +""" + +import re +import time +from typing import Protocol + +from connectrpc.code import Code +from connectrpc.errors import ConnectError +from connectrpc.request import RequestContext + +from fluster import cluster_pb2 +from fluster.cluster.worker.worker_types import Job + + +class JobProvider(Protocol): + """Protocol for job management operations. + + Implemented by Worker to provide job lifecycle management. + """ + + def submit_job(self, request: cluster_pb2.Worker.RunJobRequest) -> str: ... + def get_job(self, job_id: str) -> Job | None: ... + def list_jobs(self) -> list[Job]: ... + def kill_job(self, job_id: str, term_timeout_ms: int = 5000) -> bool: ... + def get_logs(self, job_id: str, start_line: int = 0) -> list[cluster_pb2.Worker.LogEntry]: ... + + +class WorkerServiceImpl: + """Implementation of WorkerService RPC interface. + + Provides endpoints for: + - run_job: Submit job for execution + - get_job_status: Query job status + - list_jobs: List jobs (optionally filtered) + - fetch_logs: Get logs with filtering + - kill_job: Terminate job + - health_check: Worker health status + """ + + def __init__(self, provider: JobProvider): + self._provider = provider + self._start_time = time.time() + + def run_job( + self, + request: cluster_pb2.Worker.RunJobRequest, + _ctx: RequestContext, + ) -> cluster_pb2.Worker.RunJobResponse: + """Submit job for execution.""" + job_id = self._provider.submit_job(request) + job = self._provider.get_job(job_id) + + if not job: + raise ConnectError(Code.INTERNAL, f"Job {job_id} not found after submission") + + return cluster_pb2.Worker.RunJobResponse( + job_id=job_id, + state=job.to_proto().state, + ) + + def get_job_status( + self, + request: cluster_pb2.Worker.GetJobStatusRequest, + _ctx: RequestContext, + ) -> cluster_pb2.JobStatus: + """Get job status.""" + job = self._provider.get_job(request.job_id) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + status = job.to_proto() + if request.include_result and job.result: + status.serialized_result = job.result + return status + + def list_jobs( + self, + _request: cluster_pb2.Worker.ListJobsRequest, + _ctx: RequestContext, + ) -> cluster_pb2.Worker.ListJobsResponse: + """List jobs. + + Note: namespace filtering is not implemented in this stage as jobs + are stored without namespace information. Empty string is treated + as "list all jobs". + """ + jobs = self._provider.list_jobs() + return cluster_pb2.Worker.ListJobsResponse( + jobs=[job.to_proto() for job in jobs], + ) + + def fetch_logs( + self, + request: cluster_pb2.Worker.FetchLogsRequest, + _ctx: RequestContext, + ) -> cluster_pb2.Worker.FetchLogsResponse: + """Fetch job logs with filtering. + + Supports: + - start_line: Line offset. Negative values for tailing (e.g., -100 for last 100 lines) + - start_ms/end_ms: Time range filter + - regex: Content filter + - max_lines: Limit results + """ + # Get logs with start_line handling (negative = tail) + start_line = request.filter.start_line if request.filter.start_line else 0 + logs = self._provider.get_logs(request.job_id, start_line=start_line) + + # Apply additional filters + result = [] + for entry in logs: + # Time range filter (start_ms is exclusive for incremental polling) + if request.filter.start_ms and entry.timestamp_ms <= request.filter.start_ms: + continue + if request.filter.end_ms and entry.timestamp_ms > request.filter.end_ms: + continue + # TODO: Regex filter is vulnerable to DoS via catastrophic backtracking. + # Malicious regex like (a+)+ can cause minutes of CPU time. Consider using + # the re2 library or adding timeout/complexity limits. + # Regex filter + if request.filter.regex: + if not re.search(request.filter.regex, entry.data): + continue + + result.append(entry) + + # Max lines limit + if request.filter.max_lines and len(result) >= request.filter.max_lines: + break + + return cluster_pb2.Worker.FetchLogsResponse(logs=result) + + def kill_job( + self, + request: cluster_pb2.Worker.KillJobRequest, + _ctx: RequestContext, + ) -> cluster_pb2.Empty: + """Kill running job.""" + # Check if job exists first + job = self._provider.get_job(request.job_id) + if not job: + raise ConnectError(Code.NOT_FOUND, f"Job {request.job_id} not found") + + success = self._provider.kill_job( + request.job_id, + term_timeout_ms=request.term_timeout_ms or 5000, + ) + if not success: + # Job exists but is already in terminal state + state_name = cluster_pb2.JobState.Name(job.status) + raise ConnectError( + Code.FAILED_PRECONDITION, + f"Job {request.job_id} already completed with state {state_name}", + ) + return cluster_pb2.Empty() + + def health_check( + self, + _request: cluster_pb2.Empty, + _ctx: RequestContext, + ) -> cluster_pb2.Worker.HealthResponse: + """Worker health status.""" + jobs = self._provider.list_jobs() + running = sum(1 for j in jobs if j.status == cluster_pb2.JOB_STATE_RUNNING) + + return cluster_pb2.Worker.HealthResponse( + healthy=True, + uptime_ms=int((time.time() - self._start_time) * 1000), + running_jobs=running, + ) diff --git a/lib/fluster/src/fluster/cluster/worker/worker.py b/lib/fluster/src/fluster/cluster/worker/worker.py new file mode 100644 index 0000000000..43ed95320c --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/worker.py @@ -0,0 +1,617 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified worker managing all components and lifecycle. + +The Worker class encapsulates all worker components (caches, runtime, service, +dashboard) and provides a clean interface for lifecycle management. It directly +owns job state and execution logic. + +Example: + config = WorkerConfig(port=8081) + worker = Worker(config) + worker.start() + try: + # Use worker + job_id = worker.submit_job(request) + finally: + worker.stop() +""" + +import base64 +import shutil +import socket +import tempfile +import threading +import time +import uuid +from dataclasses import dataclass +from pathlib import Path + +import cloudpickle +import uvicorn + +from fluster import cluster_pb2 +from fluster.cluster.worker.builder import ImageCache, ImageProvider, VenvCache +from fluster.cluster.worker.bundle import BundleCache, BundleProvider +from fluster.cluster.worker.dashboard import WorkerDashboard +from fluster.cluster.worker.docker import ContainerConfig, ContainerRuntime, DockerRuntime +from fluster.cluster.worker.service import WorkerServiceImpl +from fluster.cluster.worker.worker_types import Job, collect_workdir_size_mb + + +def _rewrite_address_for_container(address: str) -> str: + """Rewrite localhost addresses to host.docker.internal for container access. + + Docker containers on Mac/Windows cannot reach host localhost directly. + Using host.docker.internal works cross-platform when combined with + --add-host=host.docker.internal:host-gateway on Linux. + """ + for localhost in ("127.0.0.1", "localhost", "0.0.0.0"): + if localhost in address: + return address.replace(localhost, "host.docker.internal") + return address + + +class PortAllocator: + """Allocate ephemeral ports for jobs. + + Tracks allocated ports to avoid conflicts. + Ports are released when jobs terminate. + """ + + def __init__(self, port_range: tuple[int, int] = (30000, 40000)): + self._range = port_range + self._allocated: set[int] = set() + self._lock = threading.Lock() + + def allocate(self, count: int = 1) -> list[int]: + """Allocate N unused ports.""" + with self._lock: + ports = [] + for _ in range(count): + port = self._find_free_port() + self._allocated.add(port) + ports.append(port) + return ports + + def release(self, ports: list[int]) -> None: + """Release allocated ports.""" + with self._lock: + for port in ports: + self._allocated.discard(port) + + def _find_free_port(self) -> int: + """Find an unused port in range.""" + for port in range(self._range[0], self._range[1]): + if port in self._allocated: + continue + if self._is_port_free(port): + return port + raise RuntimeError("No free ports available") + + def _is_port_free(self, port: int) -> bool: + """Check if port is free on host.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + return True + except OSError: + return False + + +@dataclass +class WorkerConfig: + """Worker configuration. + + Args: + host: Host to bind to (default: "127.0.0.1") + port: Port to bind to (default: 0 for ephemeral) + cache_dir: Cache directory for bundles and images (default: temp directory) + registry: Docker registry for images (default: "localhost:5000") + max_concurrent_jobs: Maximum concurrent jobs (default: 10) + port_range: Port range for job ports (default: (30000, 40000)) + controller_address: Controller URL for endpoint registration (default: None) + worker_id: Worker ID (default: None) + """ + + host: str = "127.0.0.1" + port: int = 0 + cache_dir: Path | None = None + registry: str = "localhost:5000" + max_concurrent_jobs: int = 10 + port_range: tuple[int, int] = (30000, 40000) + controller_address: str | None = None + worker_id: str | None = None + + +class Worker: + """Unified worker managing all components and lifecycle. + + Directly owns job state and execution logic, following the Controller pattern. + + Example: + config = WorkerConfig(port=8081) + worker = Worker(config) + worker.start() + try: + job_id = worker.submit_job(request) + finally: + worker.stop() + """ + + def __init__( + self, + config: WorkerConfig, + cache_dir: Path | None = None, + bundle_provider: BundleProvider | None = None, + image_provider: ImageProvider | None = None, + container_runtime: ContainerRuntime | None = None, + ): + """Initialize worker components. + + Args: + config: Worker configuration + cache_dir: Override cache directory from config + bundle_provider: Optional bundle provider for testing + image_provider: Optional image provider for testing + container_runtime: Optional container runtime for testing + """ + self._config = config + + # Setup cache directory + if cache_dir: + self._cache_dir = cache_dir + self._temp_dir = None + elif config.cache_dir: + self._cache_dir = config.cache_dir + self._temp_dir = None + else: + # Create temporary cache + self._temp_dir = tempfile.TemporaryDirectory(prefix="worker_cache_") + self._cache_dir = Path(self._temp_dir.name) + + self._cache_dir.mkdir(parents=True, exist_ok=True) + + # Use overrides if provided, otherwise create defaults + self._bundle_cache = bundle_provider or BundleCache(self._cache_dir, max_bundles=100) + self._venv_cache = VenvCache() + self._image_cache = image_provider or ImageCache( + self._cache_dir, + registry=config.registry, + max_images=50, + ) + self._runtime = container_runtime or DockerRuntime() + self._port_allocator = PortAllocator(config.port_range) + + # Job state + self._jobs: dict[str, Job] = {} + self._lock = threading.Lock() + self._semaphore = threading.Semaphore(config.max_concurrent_jobs) + + self._service = WorkerServiceImpl(self) + self._dashboard = WorkerDashboard( + self._service, + host=config.host, + port=config.port, + ) + + self._server_thread: threading.Thread | None = None + + def start(self) -> None: + """Start worker server.""" + self._server_thread = threading.Thread( + target=self._run_server, + daemon=True, + ) + self._server_thread.start() + time.sleep(1.0) # Wait for startup + + def stop(self) -> None: + """Stop worker server and cleanup. + + Note: Cleanup of temp directory is best-effort. If jobs have created + files in the cache, cleanup may fail. This is acceptable since the + temp directory will be cleaned up by the OS eventually. + """ + # Cleanup temp directory (best-effort) + if self._temp_dir: + try: + self._temp_dir.cleanup() + except OSError: + # Cleanup may fail if cache has files from running jobs + # This is acceptable - temp directories are cleaned by OS + pass + # Dashboard stops when thread exits (daemon) + + def _run_server(self) -> None: + """Run worker server (blocking, for thread).""" + try: + uvicorn.run( + self._dashboard._app, + host=self._config.host, + port=self._config.port, + log_level="error", + ) + except Exception as e: + print(f"Worker server error: {e}") + + # Job management methods + + def submit_job(self, request: cluster_pb2.Worker.RunJobRequest) -> str: + """Submit job for execution. + + Returns job_id immediately, execution happens in background. + + Fluster auto-injects system environment variables (these override user-provided values): + - FLUSTER_JOB_ID: The job's unique identifier + - FLUSTER_WORKER_ID: ID of the worker running this job + - FLUSTER_CONTROLLER_ADDRESS: Controller URL for endpoint registration + - FLUSTER_PORT_: Allocated port numbers (e.g., FLUSTER_PORT_HTTP) + - FRAY_PORT_MAPPING: All port mappings as "name:port,name:port" + + User-provided environment variables (including FLUSTER_NAMESPACE for actors) are + passed through from the RunJobRequest.environment.env_vars. + """ + job_id = request.job_id or str(uuid.uuid4()) + + # Allocate requested ports + port_names = list(request.ports) + allocated_ports = self._port_allocator.allocate(len(port_names)) if port_names else [] + ports = dict(zip(port_names, allocated_ports, strict=True)) + + # Create job working directory + workdir = Path(tempfile.gettempdir()) / "fluster-worker" / "jobs" / job_id + workdir.mkdir(parents=True, exist_ok=True) + + job = Job( + job_id=job_id, + request=request, + status=cluster_pb2.JOB_STATE_PENDING, + ports=ports, + workdir=workdir, + ) + + with self._lock: + self._jobs[job_id] = job + + # Start execution in background + job.thread = threading.Thread(target=self._execute_job, args=(job,), daemon=True) + job.thread.start() + + return job_id + + def _execute_job(self, job: Job) -> None: + """Execute job through all phases with integrated stats collection.""" + import sys + + try: + # Acquire semaphore to limit concurrent jobs + self._semaphore.acquire() + + # Phase 1: Download bundle + job.transition_to(cluster_pb2.JOB_STATE_BUILDING, message="downloading bundle") + job.started_at_ms = int(time.time() * 1000) + + bundle_path = self._bundle_cache.get_bundle( + job.request.bundle_gcs_path, + expected_hash=None, + ) + + # Phase 2: Build image + job.transition_to(cluster_pb2.JOB_STATE_BUILDING, message="building image") + job.build_started_ms = int(time.time() * 1000) + env_config = job.request.environment + extras = list(env_config.extras) + + # Compute deps_hash for caching + deps_hash = self._venv_cache.compute_deps_hash(bundle_path) + + job.transition_to(cluster_pb2.JOB_STATE_BUILDING, message="populating uv cache") + job.logs.add("build", "Building Docker image...") + + # Detect host Python version for container compatibility + # cloudpickle serializes bytecode which is version-specific + py_version = f"{sys.version_info.major}.{sys.version_info.minor}" + base_image = f"python:{py_version}-slim" + + build_result = self._image_cache.build( + bundle_path=bundle_path, + base_image=base_image, + extras=extras, + job_id=job.job_id, + deps_hash=deps_hash, + job_logs=job.logs, + ) + + job.build_finished_ms = int(time.time() * 1000) + job.build_from_cache = build_result.from_cache + job.image_tag = build_result.image_tag + + # Phase 3: Create and start container + job.transition_to(cluster_pb2.JOB_STATE_RUNNING) + + # Deserialize entrypoint + entrypoint = cloudpickle.loads(job.request.serialized_entrypoint) + command = self._build_command(entrypoint) + + # Build environment from user-provided vars + EnvironmentConfig + env = dict(env_config.env_vars) + + # Auto-inject Fluster system variables (these override user-provided values) + env["FLUSTER_JOB_ID"] = job.job_id + + if self._config.worker_id: + env["FLUSTER_WORKER_ID"] = self._config.worker_id + + if self._config.controller_address: + env["FLUSTER_CONTROLLER_ADDRESS"] = _rewrite_address_for_container(self._config.controller_address) + + # Inject allocated ports + for name, port in job.ports.items(): + env[f"FLUSTER_PORT_{name.upper()}"] = str(port) + + if job.ports: + port_mapping = ",".join(f"{name}:{port}" for name, port in job.ports.items()) + env["FRAY_PORT_MAPPING"] = port_mapping + + config = ContainerConfig( + image=build_result.image_tag, + command=command, + env=env, + resources=job.request.resources if job.request.HasField("resources") else None, + timeout_seconds=job.request.timeout_seconds or None, + ports=job.ports, + mounts=[(str(job.workdir), "/workdir", "rw")], + ) + + # Create and start container with retry on port binding failures + container_id = None + max_port_retries = 3 + for attempt in range(max_port_retries): + try: + container_id = self._runtime.create_container(config) + job.container_id = container_id + self._runtime.start_container(container_id) + break + except RuntimeError as e: + if "address already in use" in str(e) and attempt < max_port_retries - 1: + job.logs.add("build", f"Port conflict, retrying with new ports (attempt {attempt + 2})") + # Release current ports and allocate new ones + self._port_allocator.release(list(job.ports.values())) + port_names = list(job.ports.keys()) + new_ports = self._port_allocator.allocate(len(port_names)) + job.ports = dict(zip(port_names, new_ports, strict=True)) + + # Update config with new ports + config.ports = job.ports + for name, port in job.ports.items(): + config.env[f"FLUSTER_PORT_{name.upper()}"] = str(port) + if job.ports: + config.env["FRAY_PORT_MAPPING"] = ",".join(f"{n}:{p}" for n, p in job.ports.items()) + + # Try to remove failed container if it was created + if container_id: + try: + self._runtime.remove(container_id) + except RuntimeError: + pass + container_id = None + else: + raise + + # container_id is guaranteed to be set here (loop breaks on success, raises on failure) + assert container_id is not None + + # Phase 4: Poll loop - check status and collect stats + timeout = config.timeout_seconds + start_time = time.time() + + while True: + # Check if we should stop + if job.should_stop: + job.transition_to(cluster_pb2.JOB_STATE_KILLED) + break + + # Check timeout + if timeout and (time.time() - start_time) > timeout: + self._runtime.kill(container_id, force=True) + job.transition_to( + cluster_pb2.JOB_STATE_FAILED, + error="Timeout exceeded", + exit_code=-1, + ) + break + + # Check container status + status = self._runtime.inspect(container_id) + if not status.running: + # Read result file only if container succeeded + if status.exit_code == 0 and job.workdir: + result_path = job.workdir / "_result.pkl" + if result_path.exists(): + try: + job.result = result_path.read_bytes() + except Exception as e: + job.logs.add("error", f"Failed to read result file: {e}") + + # Container has stopped + if status.error: + job.transition_to( + cluster_pb2.JOB_STATE_FAILED, + error=status.error, + exit_code=status.exit_code or -1, + ) + elif status.exit_code == 0: + job.transition_to(cluster_pb2.JOB_STATE_SUCCEEDED, exit_code=0) + else: + job.transition_to( + cluster_pb2.JOB_STATE_FAILED, + error=f"Exit code: {status.exit_code}", + exit_code=status.exit_code or -1, + ) + break + + # Collect stats + try: + stats = self._runtime.get_stats(container_id) + if stats.available: + job.current_memory_mb = stats.memory_mb + job.current_cpu_percent = stats.cpu_percent + job.process_count = stats.process_count + if stats.memory_mb > job.peak_memory_mb: + job.peak_memory_mb = stats.memory_mb + + if job.workdir: + job.disk_mb = collect_workdir_size_mb(job.workdir) + except Exception: + pass # Don't fail job on stats collection errors + + # Sleep before next poll + time.sleep(5.0) + + except Exception as e: + job.logs.add("error", f"Job failed: {e!r}") + job.transition_to(cluster_pb2.JOB_STATE_FAILED, error=repr(e)) + finally: + # Release semaphore + self._semaphore.release() + + # Cleanup: release ports, remove workdir (keep container for logs) + if not job.cleanup_done: + job.cleanup_done = True + self._port_allocator.release(list(job.ports.values())) + # Keep container around for log retrieval via docker logs + # Remove working directory (no longer needed since logs come from Docker) + if job.workdir and job.workdir.exists(): + shutil.rmtree(job.workdir, ignore_errors=True) + + def _build_command(self, entrypoint) -> list[str]: + """Build command to run entrypoint. + + Serializes only the raw callable/args/kwargs tuple rather than the Entrypoint + dataclass to avoid requiring fluster.cluster.types in the container. + """ + data = (entrypoint.callable, entrypoint.args, entrypoint.kwargs) + serialized = cloudpickle.dumps(data) + encoded = base64.b64encode(serialized).decode() + + # Thunk that executes the function and writes result to file. + # Exceptions propagate naturally (container exits non-zero). + thunk = f""" +import cloudpickle +import base64 + +fn, args, kwargs = cloudpickle.loads(base64.b64decode('{encoded}')) +result = fn(*args, **kwargs) +with open('/workdir/_result.pkl', 'wb') as f: + f.write(cloudpickle.dumps(result)) +""" + return ["python", "-c", thunk] + + def get_job(self, job_id: str) -> Job | None: + """Get job by ID.""" + return self._jobs.get(job_id) + + def list_jobs(self) -> list[Job]: + """List all jobs.""" + return list(self._jobs.values()) + + def kill_job(self, job_id: str, term_timeout_ms: int = 5000) -> bool: + """Kill a running job by setting should_stop flag. + + The poll loop in _execute_job will handle the actual termination. + """ + job = self._jobs.get(job_id) + if not job: + return False + + # Check if already in terminal state + if job.status not in ( + cluster_pb2.JOB_STATE_RUNNING, + cluster_pb2.JOB_STATE_BUILDING, + cluster_pb2.JOB_STATE_PENDING, + ): + return False + + # Set flag to signal thread to stop + job.should_stop = True + + # If container exists, try to kill it + if job.container_id: + try: + # Send SIGTERM + self._runtime.kill(job.container_id, force=False) + + # Wait for graceful shutdown + timeout_sec = term_timeout_ms / 1000 + start_time = time.time() + while job.status in ( + cluster_pb2.JOB_STATE_RUNNING, + cluster_pb2.JOB_STATE_BUILDING, + ): + if (time.time() - start_time) > timeout_sec: + # Force kill + try: + self._runtime.kill(job.container_id, force=True) + except RuntimeError: + pass + break + time.sleep(0.1) + except RuntimeError: + # Container may have already been removed or stopped + pass + + return True + + def get_logs(self, job_id: str, start_line: int = 0) -> list[cluster_pb2.Worker.LogEntry]: + """Get logs for a job. + + Combines build logs (from job.logs) with container logs (from Docker). + + Args: + job_id: Job ID + start_line: Starting line number. If negative, returns last N lines + (e.g., start_line=-100 returns last 100 lines for tailing). + + Returns: + List of log entries sorted by timestamp + """ + job = self._jobs.get(job_id) + if not job: + return [] + + logs: list[cluster_pb2.Worker.LogEntry] = [] + + # Add build logs from job.logs (these have proper timestamps) + for log_line in job.logs.lines: + logs.append(log_line.to_proto()) + + # Fetch container stdout/stderr from Docker if container exists + if job.container_id: + container_logs = self._runtime.get_logs(job.container_id) + for log_line in container_logs: + logs.append(log_line.to_proto()) + + # Sort by timestamp + logs.sort(key=lambda x: x.timestamp_ms) + + return logs[start_line:] + + # Properties + + @property + def url(self) -> str: + """Worker URL.""" + return f"http://{self._config.host}:{self._config.port}" diff --git a/lib/fluster/src/fluster/cluster/worker/worker_types.py b/lib/fluster/src/fluster/cluster/worker/worker_types.py new file mode 100644 index 0000000000..ce7b48cf50 --- /dev/null +++ b/lib/fluster/src/fluster/cluster/worker/worker_types.py @@ -0,0 +1,176 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal worker types for job tracking and statistics collection.""" + +import subprocess +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path + +from pydantic import BaseModel + +from fluster import cluster_pb2 +from fluster.cluster.types import is_job_finished +from fluster.cluster_pb2 import JobState + + +def collect_workdir_size_mb(workdir: Path) -> int: + """Calculate workdir size in MB using du -sm command. + + Args: + workdir: Path to directory to measure + + Returns: + Directory size in megabytes, or 0 if directory doesn't exist + """ + if not workdir.exists(): + return 0 + + result = subprocess.run( + ["du", "-sm", str(workdir)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + return 0 + + # du -sm output format: "SIZE\tPATH" + output = result.stdout.strip() + size_str = output.split("\t")[0] + + return int(size_str) + + +class LogLine(BaseModel): + """A single log line with timestamp and source.""" + + timestamp: datetime + source: str # "build", "stdout", "stderr" + data: str + + @classmethod + def now(cls, source: str, data: str) -> "LogLine": + return cls(timestamp=datetime.now(timezone.utc), source=source, data=data) + + def to_proto(self) -> cluster_pb2.Worker.LogEntry: + return cluster_pb2.Worker.LogEntry( + timestamp_ms=int(self.timestamp.timestamp() * 1000), + source=self.source, + data=self.data, + ) + + +class JobLogs(BaseModel): + """All logs for a job, stored as structured data.""" + + lines: list[LogLine] = [] + + def add(self, source: str, data: str) -> None: + self.lines.append(LogLine.now(source, data)) + + +@dataclass(kw_only=True) +class Job: + """Internal job tracking state.""" + + job_id: str + request: cluster_pb2.Worker.RunJobRequest + status: JobState = cluster_pb2.JOB_STATE_PENDING + exit_code: int | None = None + error: str | None = None + started_at_ms: int | None = None + finished_at_ms: int | None = None + ports: dict[str, int] = field(default_factory=dict) + status_message: str = "" + + # Resource tracking + current_memory_mb: int = 0 + peak_memory_mb: int = 0 + current_cpu_percent: int = 0 + process_count: int = 0 + disk_mb: int = 0 + + # Build tracking + build_started_ms: int | None = None + build_finished_ms: int | None = None + build_from_cache: bool = False + image_tag: str = "" + + # Internals + container_id: str | None = None + workdir: Path | None = None # Job working directory with logs + thread: threading.Thread | None = None + cleanup_done: bool = False + should_stop: bool = False + + # Structured logs (build logs stored here, container logs fetched from Docker) + logs: JobLogs = field(default_factory=JobLogs) + + result: bytes | None = None # cloudpickle serialized return value from container + + def transition_to( + self, + state: JobState, + *, + message: str = "", + error: str | None = None, + exit_code: int | None = None, + ) -> None: + """Transition to a new state with appropriate side effects. + + Args: + state: Target state + message: Progress message (only retained in BUILDING state) + error: Error message (for FAILED state) + exit_code: Process exit code (for terminal states) + """ + self.status = state + self.status_message = message + if is_job_finished(state): + self.finished_at_ms = int(time.time() * 1000) + if error: + self.error = error + if exit_code is not None: + self.exit_code = exit_code + + def to_proto(self) -> cluster_pb2.JobStatus: + """Convert job to JobStatus proto.""" + return cluster_pb2.JobStatus( + job_id=self.job_id, + state=self.status, + exit_code=self.exit_code or 0, + error=self.error or "", + started_at_ms=self.started_at_ms or 0, + finished_at_ms=self.finished_at_ms or 0, + ports=self.ports, + status_message=self.status_message, + resource_usage=cluster_pb2.ResourceUsage( + memory_mb=self.current_memory_mb, + memory_peak_mb=self.peak_memory_mb, + disk_mb=self.disk_mb, + cpu_millicores=self.current_cpu_percent * 10, + cpu_percent=self.current_cpu_percent, + process_count=self.process_count, + ), + build_metrics=cluster_pb2.BuildMetrics( + build_started_ms=self.build_started_ms or 0, + build_finished_ms=self.build_finished_ms or 0, + from_cache=self.build_from_cache, + image_tag=self.image_tag, + ), + ) diff --git a/lib/fluster/src/fluster/cluster_connect.py b/lib/fluster/src/fluster/cluster_connect.py new file mode 100644 index 0000000000..569586c52f --- /dev/null +++ b/lib/fluster/src/fluster/cluster_connect.py @@ -0,0 +1,1135 @@ +# -*- coding: utf-8 -*- +# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! +# source: cluster.proto + +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping +from typing import Protocol + +from connectrpc.client import ConnectClient, ConnectClientSync +from connectrpc.code import Code +from connectrpc.errors import ConnectError +from connectrpc.interceptor import Interceptor, InterceptorSync +from connectrpc.method import IdempotencyLevel, MethodInfo +from connectrpc.request import Headers, RequestContext +from connectrpc.server import ConnectASGIApplication, ConnectWSGIApplication, Endpoint, EndpointSync +from . import cluster_pb2 as cluster__pb2 + + +class ControllerService(Protocol): + async def launch_job(self, request: cluster__pb2.Controller.LaunchJobRequest, ctx: RequestContext) -> cluster__pb2.Controller.LaunchJobResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def get_job_status(self, request: cluster__pb2.Controller.GetJobStatusRequest, ctx: RequestContext) -> cluster__pb2.Controller.GetJobStatusResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def terminate_job(self, request: cluster__pb2.Controller.TerminateJobRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_jobs(self, request: cluster__pb2.Controller.ListJobsRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListJobsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def register_worker(self, request: cluster__pb2.Controller.RegisterWorkerRequest, ctx: RequestContext) -> cluster__pb2.Controller.RegisterWorkerResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_workers(self, request: cluster__pb2.Controller.ListWorkersRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListWorkersResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def register_endpoint(self, request: cluster__pb2.Controller.RegisterEndpointRequest, ctx: RequestContext) -> cluster__pb2.Controller.RegisterEndpointResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def unregister_endpoint(self, request: cluster__pb2.Controller.UnregisterEndpointRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def lookup_endpoint(self, request: cluster__pb2.Controller.LookupEndpointRequest, ctx: RequestContext) -> cluster__pb2.Controller.LookupEndpointResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_endpoints(self, request: cluster__pb2.Controller.ListEndpointsRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListEndpointsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class ControllerServiceASGIApplication(ConnectASGIApplication[ControllerService]): + def __init__(self, service: ControllerService | AsyncGenerator[ControllerService], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + service=service, + endpoints=lambda svc: { + "/fluster.cluster.ControllerService/LaunchJob": Endpoint.unary( + method=MethodInfo( + name="LaunchJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LaunchJobRequest, + output=cluster__pb2.Controller.LaunchJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.launch_job, + ), + "/fluster.cluster.ControllerService/GetJobStatus": Endpoint.unary( + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.GetJobStatusRequest, + output=cluster__pb2.Controller.GetJobStatusResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.get_job_status, + ), + "/fluster.cluster.ControllerService/TerminateJob": Endpoint.unary( + method=MethodInfo( + name="TerminateJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.TerminateJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.terminate_job, + ), + "/fluster.cluster.ControllerService/ListJobs": Endpoint.unary( + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListJobsRequest, + output=cluster__pb2.Controller.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_jobs, + ), + "/fluster.cluster.ControllerService/RegisterWorker": Endpoint.unary( + method=MethodInfo( + name="RegisterWorker", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterWorkerRequest, + output=cluster__pb2.Controller.RegisterWorkerResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.register_worker, + ), + "/fluster.cluster.ControllerService/ListWorkers": Endpoint.unary( + method=MethodInfo( + name="ListWorkers", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListWorkersRequest, + output=cluster__pb2.Controller.ListWorkersResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_workers, + ), + "/fluster.cluster.ControllerService/RegisterEndpoint": Endpoint.unary( + method=MethodInfo( + name="RegisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterEndpointRequest, + output=cluster__pb2.Controller.RegisterEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.register_endpoint, + ), + "/fluster.cluster.ControllerService/UnregisterEndpoint": Endpoint.unary( + method=MethodInfo( + name="UnregisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.UnregisterEndpointRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.unregister_endpoint, + ), + "/fluster.cluster.ControllerService/LookupEndpoint": Endpoint.unary( + method=MethodInfo( + name="LookupEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LookupEndpointRequest, + output=cluster__pb2.Controller.LookupEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.lookup_endpoint, + ), + "/fluster.cluster.ControllerService/ListEndpoints": Endpoint.unary( + method=MethodInfo( + name="ListEndpoints", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListEndpointsRequest, + output=cluster__pb2.Controller.ListEndpointsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_endpoints, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.cluster.ControllerService" + + +class ControllerServiceClient(ConnectClient): + async def launch_job( + self, + request: cluster__pb2.Controller.LaunchJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.LaunchJobResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="LaunchJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LaunchJobRequest, + output=cluster__pb2.Controller.LaunchJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def get_job_status( + self, + request: cluster__pb2.Controller.GetJobStatusRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.GetJobStatusResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.GetJobStatusRequest, + output=cluster__pb2.Controller.GetJobStatusResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def terminate_job( + self, + request: cluster__pb2.Controller.TerminateJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="TerminateJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.TerminateJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_jobs( + self, + request: cluster__pb2.Controller.ListJobsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListJobsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListJobsRequest, + output=cluster__pb2.Controller.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def register_worker( + self, + request: cluster__pb2.Controller.RegisterWorkerRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.RegisterWorkerResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="RegisterWorker", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterWorkerRequest, + output=cluster__pb2.Controller.RegisterWorkerResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_workers( + self, + request: cluster__pb2.Controller.ListWorkersRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListWorkersResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListWorkers", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListWorkersRequest, + output=cluster__pb2.Controller.ListWorkersResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def register_endpoint( + self, + request: cluster__pb2.Controller.RegisterEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.RegisterEndpointResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="RegisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterEndpointRequest, + output=cluster__pb2.Controller.RegisterEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def unregister_endpoint( + self, + request: cluster__pb2.Controller.UnregisterEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="UnregisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.UnregisterEndpointRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def lookup_endpoint( + self, + request: cluster__pb2.Controller.LookupEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.LookupEndpointResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="LookupEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LookupEndpointRequest, + output=cluster__pb2.Controller.LookupEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_endpoints( + self, + request: cluster__pb2.Controller.ListEndpointsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListEndpointsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListEndpoints", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListEndpointsRequest, + output=cluster__pb2.Controller.ListEndpointsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + + +class WorkerService(Protocol): + async def run_job(self, request: cluster__pb2.Worker.RunJobRequest, ctx: RequestContext) -> cluster__pb2.Worker.RunJobResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def get_job_status(self, request: cluster__pb2.Worker.GetJobStatusRequest, ctx: RequestContext) -> cluster__pb2.JobStatus: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def list_jobs(self, request: cluster__pb2.Worker.ListJobsRequest, ctx: RequestContext) -> cluster__pb2.Worker.ListJobsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def fetch_logs(self, request: cluster__pb2.Worker.FetchLogsRequest, ctx: RequestContext) -> cluster__pb2.Worker.FetchLogsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def kill_job(self, request: cluster__pb2.Worker.KillJobRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + async def health_check(self, request: cluster__pb2.Empty, ctx: RequestContext) -> cluster__pb2.Worker.HealthResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class WorkerServiceASGIApplication(ConnectASGIApplication[WorkerService]): + def __init__(self, service: WorkerService | AsyncGenerator[WorkerService], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + service=service, + endpoints=lambda svc: { + "/fluster.cluster.WorkerService/RunJob": Endpoint.unary( + method=MethodInfo( + name="RunJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.RunJobRequest, + output=cluster__pb2.Worker.RunJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.run_job, + ), + "/fluster.cluster.WorkerService/GetJobStatus": Endpoint.unary( + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.GetJobStatusRequest, + output=cluster__pb2.JobStatus, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.get_job_status, + ), + "/fluster.cluster.WorkerService/ListJobs": Endpoint.unary( + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.ListJobsRequest, + output=cluster__pb2.Worker.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.list_jobs, + ), + "/fluster.cluster.WorkerService/FetchLogs": Endpoint.unary( + method=MethodInfo( + name="FetchLogs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.FetchLogsRequest, + output=cluster__pb2.Worker.FetchLogsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.fetch_logs, + ), + "/fluster.cluster.WorkerService/KillJob": Endpoint.unary( + method=MethodInfo( + name="KillJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.KillJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.kill_job, + ), + "/fluster.cluster.WorkerService/HealthCheck": Endpoint.unary( + method=MethodInfo( + name="HealthCheck", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Empty, + output=cluster__pb2.Worker.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=svc.health_check, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.cluster.WorkerService" + + +class WorkerServiceClient(ConnectClient): + async def run_job( + self, + request: cluster__pb2.Worker.RunJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.RunJobResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="RunJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.RunJobRequest, + output=cluster__pb2.Worker.RunJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def get_job_status( + self, + request: cluster__pb2.Worker.GetJobStatusRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.JobStatus: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.GetJobStatusRequest, + output=cluster__pb2.JobStatus, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def list_jobs( + self, + request: cluster__pb2.Worker.ListJobsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.ListJobsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.ListJobsRequest, + output=cluster__pb2.Worker.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def fetch_logs( + self, + request: cluster__pb2.Worker.FetchLogsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.FetchLogsResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="FetchLogs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.FetchLogsRequest, + output=cluster__pb2.Worker.FetchLogsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def kill_job( + self, + request: cluster__pb2.Worker.KillJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="KillJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.KillJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + async def health_check( + self, + request: cluster__pb2.Empty, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.HealthResponse: + return await self.execute_unary( + request=request, + method=MethodInfo( + name="HealthCheck", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Empty, + output=cluster__pb2.Worker.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + +class ControllerServiceSync(Protocol): + def launch_job(self, request: cluster__pb2.Controller.LaunchJobRequest, ctx: RequestContext) -> cluster__pb2.Controller.LaunchJobResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def get_job_status(self, request: cluster__pb2.Controller.GetJobStatusRequest, ctx: RequestContext) -> cluster__pb2.Controller.GetJobStatusResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def terminate_job(self, request: cluster__pb2.Controller.TerminateJobRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_jobs(self, request: cluster__pb2.Controller.ListJobsRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListJobsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def register_worker(self, request: cluster__pb2.Controller.RegisterWorkerRequest, ctx: RequestContext) -> cluster__pb2.Controller.RegisterWorkerResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_workers(self, request: cluster__pb2.Controller.ListWorkersRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListWorkersResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def register_endpoint(self, request: cluster__pb2.Controller.RegisterEndpointRequest, ctx: RequestContext) -> cluster__pb2.Controller.RegisterEndpointResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def unregister_endpoint(self, request: cluster__pb2.Controller.UnregisterEndpointRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def lookup_endpoint(self, request: cluster__pb2.Controller.LookupEndpointRequest, ctx: RequestContext) -> cluster__pb2.Controller.LookupEndpointResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_endpoints(self, request: cluster__pb2.Controller.ListEndpointsRequest, ctx: RequestContext) -> cluster__pb2.Controller.ListEndpointsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class ControllerServiceWSGIApplication(ConnectWSGIApplication): + def __init__(self, service: ControllerServiceSync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + endpoints={ + "/fluster.cluster.ControllerService/LaunchJob": EndpointSync.unary( + method=MethodInfo( + name="LaunchJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LaunchJobRequest, + output=cluster__pb2.Controller.LaunchJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.launch_job, + ), + "/fluster.cluster.ControllerService/GetJobStatus": EndpointSync.unary( + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.GetJobStatusRequest, + output=cluster__pb2.Controller.GetJobStatusResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.get_job_status, + ), + "/fluster.cluster.ControllerService/TerminateJob": EndpointSync.unary( + method=MethodInfo( + name="TerminateJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.TerminateJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.terminate_job, + ), + "/fluster.cluster.ControllerService/ListJobs": EndpointSync.unary( + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListJobsRequest, + output=cluster__pb2.Controller.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_jobs, + ), + "/fluster.cluster.ControllerService/RegisterWorker": EndpointSync.unary( + method=MethodInfo( + name="RegisterWorker", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterWorkerRequest, + output=cluster__pb2.Controller.RegisterWorkerResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.register_worker, + ), + "/fluster.cluster.ControllerService/ListWorkers": EndpointSync.unary( + method=MethodInfo( + name="ListWorkers", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListWorkersRequest, + output=cluster__pb2.Controller.ListWorkersResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_workers, + ), + "/fluster.cluster.ControllerService/RegisterEndpoint": EndpointSync.unary( + method=MethodInfo( + name="RegisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterEndpointRequest, + output=cluster__pb2.Controller.RegisterEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.register_endpoint, + ), + "/fluster.cluster.ControllerService/UnregisterEndpoint": EndpointSync.unary( + method=MethodInfo( + name="UnregisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.UnregisterEndpointRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.unregister_endpoint, + ), + "/fluster.cluster.ControllerService/LookupEndpoint": EndpointSync.unary( + method=MethodInfo( + name="LookupEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LookupEndpointRequest, + output=cluster__pb2.Controller.LookupEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.lookup_endpoint, + ), + "/fluster.cluster.ControllerService/ListEndpoints": EndpointSync.unary( + method=MethodInfo( + name="ListEndpoints", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListEndpointsRequest, + output=cluster__pb2.Controller.ListEndpointsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_endpoints, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.cluster.ControllerService" + + +class ControllerServiceClientSync(ConnectClientSync): + def launch_job( + self, + request: cluster__pb2.Controller.LaunchJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.LaunchJobResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="LaunchJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LaunchJobRequest, + output=cluster__pb2.Controller.LaunchJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def get_job_status( + self, + request: cluster__pb2.Controller.GetJobStatusRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.GetJobStatusResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.GetJobStatusRequest, + output=cluster__pb2.Controller.GetJobStatusResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def terminate_job( + self, + request: cluster__pb2.Controller.TerminateJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return self.execute_unary( + request=request, + method=MethodInfo( + name="TerminateJob", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.TerminateJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_jobs( + self, + request: cluster__pb2.Controller.ListJobsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListJobsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListJobsRequest, + output=cluster__pb2.Controller.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def register_worker( + self, + request: cluster__pb2.Controller.RegisterWorkerRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.RegisterWorkerResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="RegisterWorker", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterWorkerRequest, + output=cluster__pb2.Controller.RegisterWorkerResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_workers( + self, + request: cluster__pb2.Controller.ListWorkersRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListWorkersResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListWorkers", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListWorkersRequest, + output=cluster__pb2.Controller.ListWorkersResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def register_endpoint( + self, + request: cluster__pb2.Controller.RegisterEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.RegisterEndpointResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="RegisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.RegisterEndpointRequest, + output=cluster__pb2.Controller.RegisterEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def unregister_endpoint( + self, + request: cluster__pb2.Controller.UnregisterEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return self.execute_unary( + request=request, + method=MethodInfo( + name="UnregisterEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.UnregisterEndpointRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def lookup_endpoint( + self, + request: cluster__pb2.Controller.LookupEndpointRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.LookupEndpointResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="LookupEndpoint", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.LookupEndpointRequest, + output=cluster__pb2.Controller.LookupEndpointResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_endpoints( + self, + request: cluster__pb2.Controller.ListEndpointsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Controller.ListEndpointsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListEndpoints", + service_name="fluster.cluster.ControllerService", + input=cluster__pb2.Controller.ListEndpointsRequest, + output=cluster__pb2.Controller.ListEndpointsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + +class WorkerServiceSync(Protocol): + def run_job(self, request: cluster__pb2.Worker.RunJobRequest, ctx: RequestContext) -> cluster__pb2.Worker.RunJobResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def get_job_status(self, request: cluster__pb2.Worker.GetJobStatusRequest, ctx: RequestContext) -> cluster__pb2.JobStatus: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def list_jobs(self, request: cluster__pb2.Worker.ListJobsRequest, ctx: RequestContext) -> cluster__pb2.Worker.ListJobsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def fetch_logs(self, request: cluster__pb2.Worker.FetchLogsRequest, ctx: RequestContext) -> cluster__pb2.Worker.FetchLogsResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def kill_job(self, request: cluster__pb2.Worker.KillJobRequest, ctx: RequestContext) -> cluster__pb2.Empty: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + def health_check(self, request: cluster__pb2.Empty, ctx: RequestContext) -> cluster__pb2.Worker.HealthResponse: + raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") + + +class WorkerServiceWSGIApplication(ConnectWSGIApplication): + def __init__(self, service: WorkerServiceSync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None) -> None: + super().__init__( + endpoints={ + "/fluster.cluster.WorkerService/RunJob": EndpointSync.unary( + method=MethodInfo( + name="RunJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.RunJobRequest, + output=cluster__pb2.Worker.RunJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.run_job, + ), + "/fluster.cluster.WorkerService/GetJobStatus": EndpointSync.unary( + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.GetJobStatusRequest, + output=cluster__pb2.JobStatus, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.get_job_status, + ), + "/fluster.cluster.WorkerService/ListJobs": EndpointSync.unary( + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.ListJobsRequest, + output=cluster__pb2.Worker.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.list_jobs, + ), + "/fluster.cluster.WorkerService/FetchLogs": EndpointSync.unary( + method=MethodInfo( + name="FetchLogs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.FetchLogsRequest, + output=cluster__pb2.Worker.FetchLogsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.fetch_logs, + ), + "/fluster.cluster.WorkerService/KillJob": EndpointSync.unary( + method=MethodInfo( + name="KillJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.KillJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.kill_job, + ), + "/fluster.cluster.WorkerService/HealthCheck": EndpointSync.unary( + method=MethodInfo( + name="HealthCheck", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Empty, + output=cluster__pb2.Worker.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + function=service.health_check, + ), + }, + interceptors=interceptors, + read_max_bytes=read_max_bytes, + ) + + @property + def path(self) -> str: + """Returns the URL path to mount the application to when serving multiple applications.""" + return "/fluster.cluster.WorkerService" + + +class WorkerServiceClientSync(ConnectClientSync): + def run_job( + self, + request: cluster__pb2.Worker.RunJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.RunJobResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="RunJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.RunJobRequest, + output=cluster__pb2.Worker.RunJobResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def get_job_status( + self, + request: cluster__pb2.Worker.GetJobStatusRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.JobStatus: + return self.execute_unary( + request=request, + method=MethodInfo( + name="GetJobStatus", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.GetJobStatusRequest, + output=cluster__pb2.JobStatus, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def list_jobs( + self, + request: cluster__pb2.Worker.ListJobsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.ListJobsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="ListJobs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.ListJobsRequest, + output=cluster__pb2.Worker.ListJobsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def fetch_logs( + self, + request: cluster__pb2.Worker.FetchLogsRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.FetchLogsResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="FetchLogs", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.FetchLogsRequest, + output=cluster__pb2.Worker.FetchLogsResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def kill_job( + self, + request: cluster__pb2.Worker.KillJobRequest, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Empty: + return self.execute_unary( + request=request, + method=MethodInfo( + name="KillJob", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Worker.KillJobRequest, + output=cluster__pb2.Empty, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) + + def health_check( + self, + request: cluster__pb2.Empty, + *, + headers: Headers | Mapping[str, str] | None = None, + timeout_ms: int | None = None, + ) -> cluster__pb2.Worker.HealthResponse: + return self.execute_unary( + request=request, + method=MethodInfo( + name="HealthCheck", + service_name="fluster.cluster.WorkerService", + input=cluster__pb2.Empty, + output=cluster__pb2.Worker.HealthResponse, + idempotency_level=IdempotencyLevel.UNKNOWN, + ), + headers=headers, + timeout_ms=timeout_ms, + ) diff --git a/lib/fluster/src/fluster/cluster_pb2.py b/lib/fluster/src/fluster/cluster_pb2.py new file mode 100644 index 0000000000..0f60aab7d4 --- /dev/null +++ b/lib/fluster/src/fluster/cluster_pb2.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: cluster.proto +# Protobuf Python Version: 6.33.4 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 4, + '', + 'cluster.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rcluster.proto\x12\x0f\x66luster.cluster\"\x07\n\x05\x45mpty\"\xd8\x01\n\rResourceUsage\x12\x1b\n\tmemory_mb\x18\x01 \x01(\x03R\x08memoryMb\x12\x17\n\x07\x64isk_mb\x18\x02 \x01(\x03R\x06\x64iskMb\x12%\n\x0e\x63pu_millicores\x18\x03 \x01(\x05R\rcpuMillicores\x12$\n\x0ememory_peak_mb\x18\x04 \x01(\x03R\x0cmemoryPeakMb\x12\x1f\n\x0b\x63pu_percent\x18\x05 \x01(\x05R\ncpuPercent\x12#\n\rprocess_count\x18\x06 \x01(\x05R\x0cprocessCount\"\xa0\x01\n\x0c\x42uildMetrics\x12(\n\x10\x62uild_started_ms\x18\x01 \x01(\x03R\x0e\x62uildStartedMs\x12*\n\x11\x62uild_finished_ms\x18\x02 \x01(\x03R\x0f\x62uildFinishedMs\x12\x1d\n\nfrom_cache\x18\x03 \x01(\x08R\tfromCache\x12\x1b\n\timage_tag\x18\x04 \x01(\tR\x08imageTag\"\xea\x04\n\tJobStatus\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12/\n\x05state\x18\x02 \x01(\x0e\x32\x19.fluster.cluster.JobStateR\x05state\x12\x1b\n\texit_code\x18\x03 \x01(\x05R\x08\x65xitCode\x12\x14\n\x05\x65rror\x18\x04 \x01(\tR\x05\x65rror\x12\"\n\rstarted_at_ms\x18\x05 \x01(\x03R\x0bstartedAtMs\x12$\n\x0e\x66inished_at_ms\x18\x06 \x01(\x03R\x0c\x66inishedAtMs\x12;\n\x05ports\x18\x07 \x03(\x0b\x32%.fluster.cluster.JobStatus.PortsEntryR\x05ports\x12\x45\n\x0eresource_usage\x18\x08 \x01(\x0b\x32\x1e.fluster.cluster.ResourceUsageR\rresourceUsage\x12%\n\x0estatus_message\x18\t \x01(\tR\rstatusMessage\x12\x42\n\rbuild_metrics\x18\n \x01(\x0b\x32\x1d.fluster.cluster.BuildMetricsR\x0c\x62uildMetrics\x12\x1b\n\tworker_id\x18\x0b \x01(\tR\x08workerId\x12%\n\x0eworker_address\x18\x0c \x01(\tR\rworkerAddress\x12+\n\x11serialized_result\x18\r \x01(\x0cR\x10serializedResult\x1a\x38\n\nPortsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\x05R\x05value:\x02\x38\x01\"\xa8\x01\n\x0c\x44\x65viceConfig\x12.\n\x03\x63pu\x18\x01 \x01(\x0b\x32\x1a.fluster.cluster.CpuDeviceH\x00R\x03\x63pu\x12.\n\x03gpu\x18\x02 \x01(\x0b\x32\x1a.fluster.cluster.GpuDeviceH\x00R\x03gpu\x12.\n\x03tpu\x18\x03 \x01(\x0b\x32\x1a.fluster.cluster.TpuDeviceH\x00R\x03tpuB\x08\n\x06\x64\x65vice\"%\n\tCpuDevice\x12\x18\n\x07variant\x18\x01 \x01(\tR\x07variant\";\n\tGpuDevice\x12\x18\n\x07variant\x18\x01 \x01(\tR\x07variant\x12\x14\n\x05\x63ount\x18\x02 \x01(\x05R\x05\x63ount\"A\n\tTpuDevice\x12\x18\n\x07variant\x18\x01 \x01(\tR\x07variant\x12\x1a\n\x08topology\x18\x02 \x01(\tR\x08topology\"\xdb\x01\n\x0cResourceSpec\x12\x10\n\x03\x63pu\x18\x01 \x01(\x05R\x03\x63pu\x12\x16\n\x06memory\x18\x02 \x01(\tR\x06memory\x12\x12\n\x04\x64isk\x18\x03 \x01(\tR\x04\x64isk\x12\x35\n\x06\x64\x65vice\x18\x05 \x01(\x0b\x32\x1d.fluster.cluster.DeviceConfigR\x06\x64\x65vice\x12\x1a\n\x08replicas\x18\x06 \x01(\x05R\x08replicas\x12 \n\x0bpreemptible\x18\x07 \x01(\x08R\x0bpreemptible\x12\x18\n\x07regions\x18\x08 \x03(\tR\x07regions\"\x80\x02\n\x11\x45nvironmentConfig\x12\x1e\n\tworkspace\x18\x01 \x01(\tH\x00R\tworkspace\x12!\n\x0cpip_packages\x18\x02 \x03(\tR\x0bpipPackages\x12J\n\x08\x65nv_vars\x18\x03 \x03(\x0b\x32/.fluster.cluster.EnvironmentConfig.EnvVarsEntryR\x07\x65nvVars\x12\x16\n\x06\x65xtras\x18\x04 \x03(\tR\x06\x65xtras\x1a:\n\x0c\x45nvVarsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06source\"\x87\x14\n\nController\x1a\x9c\x03\n\x10LaunchJobRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x33\n\x15serialized_entrypoint\x18\x02 \x01(\x0cR\x14serializedEntrypoint\x12;\n\tresources\x18\x03 \x01(\x0b\x32\x1d.fluster.cluster.ResourceSpecR\tresources\x12\x44\n\x0b\x65nvironment\x18\x04 \x01(\x0b\x32\".fluster.cluster.EnvironmentConfigR\x0b\x65nvironment\x12&\n\x0f\x62undle_gcs_path\x18\x05 \x01(\tR\rbundleGcsPath\x12\x1f\n\x0b\x62undle_hash\x18\x06 \x01(\tR\nbundleHash\x12\x1f\n\x0b\x62undle_blob\x18\x07 \x01(\x0cR\nbundleBlob\x12<\n\x1ascheduling_timeout_seconds\x18\x08 \x01(\x05R\x18schedulingTimeoutSeconds\x12\x14\n\x05ports\x18\t \x03(\tR\x05ports\x1a*\n\x11LaunchJobResponse\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x1aS\n\x13GetJobStatusRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12%\n\x0einclude_result\x18\x02 \x01(\x08R\rincludeResult\x1a\x44\n\x14GetJobStatusResponse\x12,\n\x03job\x18\x01 \x01(\x0b\x32\x1a.fluster.cluster.JobStatusR\x03job\x1a,\n\x13TerminateJobRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x1a/\n\x0fListJobsRequest\x12\x1c\n\tnamespace\x18\x01 \x01(\tR\tnamespace\x1a\x42\n\x10ListJobsResponse\x12.\n\x04jobs\x18\x01 \x03(\x0b\x32\x1a.fluster.cluster.JobStatusR\x04jobs\x1a\xaa\x01\n\nWorkerInfo\x12\x1b\n\tworker_id\x18\x01 \x01(\tR\x08workerId\x12\x18\n\x07\x61\x64\x64ress\x18\x02 \x01(\tR\x07\x61\x64\x64ress\x12;\n\tresources\x18\x03 \x01(\x0b\x32\x1d.fluster.cluster.ResourceSpecR\tresources\x12(\n\x10registered_at_ms\x18\x04 \x01(\x03R\x0eregisteredAtMs\x1a\x8b\x01\n\x15RegisterWorkerRequest\x12\x1b\n\tworker_id\x18\x01 \x01(\tR\x08workerId\x12\x18\n\x07\x61\x64\x64ress\x18\x02 \x01(\tR\x07\x61\x64\x64ress\x12;\n\tresources\x18\x03 \x01(\x0b\x32\x1d.fluster.cluster.ResourceSpecR\tresources\x1a\x63\n\x16RegisterWorkerResponse\x12\x1a\n\x08\x61\x63\x63\x65pted\x18\x01 \x01(\x08R\x08\x61\x63\x63\x65pted\x12-\n\x12\x63ontroller_address\x18\x02 \x01(\tR\x11\x63ontrollerAddress\x1a\xd2\x01\n\x12WorkerHealthStatus\x12\x1b\n\tworker_id\x18\x01 \x01(\tR\x08workerId\x12\x18\n\x07healthy\x18\x02 \x01(\x08R\x07healthy\x12\x31\n\x14\x63onsecutive_failures\x18\x03 \x01(\x05R\x13\x63onsecutiveFailures\x12*\n\x11last_heartbeat_ms\x18\x04 \x01(\x03R\x0flastHeartbeatMs\x12&\n\x0frunning_job_ids\x18\x05 \x03(\tR\rrunningJobIds\x1a\x14\n\x12ListWorkersRequest\x1a_\n\x13ListWorkersResponse\x12H\n\x07workers\x18\x01 \x03(\x0b\x32..fluster.cluster.Controller.WorkerHealthStatusR\x07workers\x1a\x9b\x02\n\x08\x45ndpoint\x12\x1f\n\x0b\x65ndpoint_id\x18\x01 \x01(\tR\nendpointId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x18\n\x07\x61\x64\x64ress\x18\x03 \x01(\tR\x07\x61\x64\x64ress\x12\x15\n\x06job_id\x18\x04 \x01(\tR\x05jobId\x12\x1c\n\tnamespace\x18\x05 \x01(\tR\tnamespace\x12N\n\x08metadata\x18\x06 \x03(\x0b\x32\x32.fluster.cluster.Controller.Endpoint.MetadataEntryR\x08metadata\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x98\x02\n\x17RegisterEndpointRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x18\n\x07\x61\x64\x64ress\x18\x02 \x01(\tR\x07\x61\x64\x64ress\x12\x15\n\x06job_id\x18\x03 \x01(\tR\x05jobId\x12\x1c\n\tnamespace\x18\x04 \x01(\tR\tnamespace\x12]\n\x08metadata\x18\x05 \x03(\x0b\x32\x41.fluster.cluster.Controller.RegisterEndpointRequest.MetadataEntryR\x08metadata\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a;\n\x18RegisterEndpointResponse\x12\x1f\n\x0b\x65ndpoint_id\x18\x01 \x01(\tR\nendpointId\x1a<\n\x19UnregisterEndpointRequest\x12\x1f\n\x0b\x65ndpoint_id\x18\x01 \x01(\tR\nendpointId\x1aI\n\x15LookupEndpointRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\tnamespace\x18\x02 \x01(\tR\tnamespace\x1aZ\n\x16LookupEndpointResponse\x12@\n\x08\x65ndpoint\x18\x01 \x01(\x0b\x32$.fluster.cluster.Controller.EndpointR\x08\x65ndpoint\x1aL\n\x14ListEndpointsRequest\x12\x16\n\x06prefix\x18\x01 \x01(\tR\x06prefix\x12\x1c\n\tnamespace\x18\x02 \x01(\tR\tnamespace\x1a[\n\x15ListEndpointsResponse\x12\x42\n\tendpoints\x18\x01 \x03(\x0b\x32$.fluster.cluster.Controller.EndpointR\tendpoints\"\xdb\t\n\x06Worker\x1a\xc5\x02\n\rRunJobRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12\x33\n\x15serialized_entrypoint\x18\x02 \x01(\x0cR\x14serializedEntrypoint\x12\x44\n\x0b\x65nvironment\x18\x03 \x01(\x0b\x32\".fluster.cluster.EnvironmentConfigR\x0b\x65nvironment\x12&\n\x0f\x62undle_gcs_path\x18\x04 \x01(\tR\rbundleGcsPath\x12;\n\tresources\x18\x06 \x01(\x0b\x32\x1d.fluster.cluster.ResourceSpecR\tresources\x12\'\n\x0ftimeout_seconds\x18\x08 \x01(\x05R\x0etimeoutSeconds\x12\x14\n\x05ports\x18\t \x03(\tR\x05ports\x1aX\n\x0eRunJobResponse\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12/\n\x05state\x18\x02 \x01(\x0e\x32\x19.fluster.cluster.JobStateR\x05state\x1aS\n\x13GetJobStatusRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12%\n\x0einclude_result\x18\x02 \x01(\x08R\rincludeResult\x1a/\n\x0fListJobsRequest\x12\x1c\n\tnamespace\x18\x01 \x01(\tR\tnamespace\x1a\x42\n\x10ListJobsResponse\x12.\n\x04jobs\x18\x01 \x03(\x0b\x32\x1a.fluster.cluster.JobStatusR\x04jobs\x1aY\n\x08LogEntry\x12!\n\x0ctimestamp_ms\x18\x01 \x01(\x03R\x0btimestampMs\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\tR\x04\x64\x61ta\x1a\x95\x01\n\x0f\x46\x65tchLogsFilter\x12\x14\n\x05regex\x18\x01 \x01(\tR\x05regex\x12\x1d\n\nstart_line\x18\x02 \x01(\x03R\tstartLine\x12\x19\n\x08start_ms\x18\x03 \x01(\x03R\x07startMs\x12\x15\n\x06\x65nd_ms\x18\x04 \x01(\x03R\x05\x65ndMs\x12\x1b\n\tmax_lines\x18\x05 \x01(\x03R\x08maxLines\x1aj\n\x10\x46\x65tchLogsRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12?\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\'.fluster.cluster.Worker.FetchLogsFilterR\x06\x66ilter\x1aI\n\x11\x46\x65tchLogsResponse\x12\x34\n\x04logs\x18\x01 \x03(\x0b\x32 .fluster.cluster.Worker.LogEntryR\x04logs\x1aO\n\x0eKillJobRequest\x12\x15\n\x06job_id\x18\x01 \x01(\tR\x05jobId\x12&\n\x0fterm_timeout_ms\x18\x02 \x01(\x05R\rtermTimeoutMs\x1aj\n\x0eHealthResponse\x12\x18\n\x07healthy\x18\x01 \x01(\x08R\x07healthy\x12\x1b\n\tuptime_ms\x18\x02 \x01(\x03R\x08uptimeMs\x12!\n\x0crunning_jobs\x18\x03 \x01(\x05R\x0brunningJobs*\xea\x01\n\x08JobState\x12\x19\n\x15JOB_STATE_UNSPECIFIED\x10\x00\x12\x15\n\x11JOB_STATE_PENDING\x10\x01\x12\x16\n\x12JOB_STATE_BUILDING\x10\x02\x12\x15\n\x11JOB_STATE_RUNNING\x10\x03\x12\x17\n\x13JOB_STATE_SUCCEEDED\x10\x04\x12\x14\n\x10JOB_STATE_FAILED\x10\x05\x12\x14\n\x10JOB_STATE_KILLED\x10\x06\x12\x1b\n\x17JOB_STATE_WORKER_FAILED\x10\x07\x12\x1b\n\x17JOB_STATE_UNSCHEDULABLE\x10\x08\x32\xec\x08\n\x11\x43ontrollerService\x12h\n\tLaunchJob\x12,.fluster.cluster.Controller.LaunchJobRequest\x1a-.fluster.cluster.Controller.LaunchJobResponse\x12q\n\x0cGetJobStatus\x12/.fluster.cluster.Controller.GetJobStatusRequest\x1a\x30.fluster.cluster.Controller.GetJobStatusResponse\x12W\n\x0cTerminateJob\x12/.fluster.cluster.Controller.TerminateJobRequest\x1a\x16.fluster.cluster.Empty\x12\x65\n\x08ListJobs\x12+.fluster.cluster.Controller.ListJobsRequest\x1a,.fluster.cluster.Controller.ListJobsResponse\x12w\n\x0eRegisterWorker\x12\x31.fluster.cluster.Controller.RegisterWorkerRequest\x1a\x32.fluster.cluster.Controller.RegisterWorkerResponse\x12n\n\x0bListWorkers\x12..fluster.cluster.Controller.ListWorkersRequest\x1a/.fluster.cluster.Controller.ListWorkersResponse\x12}\n\x10RegisterEndpoint\x12\x33.fluster.cluster.Controller.RegisterEndpointRequest\x1a\x34.fluster.cluster.Controller.RegisterEndpointResponse\x12\x63\n\x12UnregisterEndpoint\x12\x35.fluster.cluster.Controller.UnregisterEndpointRequest\x1a\x16.fluster.cluster.Empty\x12w\n\x0eLookupEndpoint\x12\x31.fluster.cluster.Controller.LookupEndpointRequest\x1a\x32.fluster.cluster.Controller.LookupEndpointResponse\x12t\n\rListEndpoints\x12\x30.fluster.cluster.Controller.ListEndpointsRequest\x1a\x31.fluster.cluster.Controller.ListEndpointsResponse2\x9c\x04\n\rWorkerService\x12W\n\x06RunJob\x12%.fluster.cluster.Worker.RunJobRequest\x1a&.fluster.cluster.Worker.RunJobResponse\x12W\n\x0cGetJobStatus\x12+.fluster.cluster.Worker.GetJobStatusRequest\x1a\x1a.fluster.cluster.JobStatus\x12]\n\x08ListJobs\x12\'.fluster.cluster.Worker.ListJobsRequest\x1a(.fluster.cluster.Worker.ListJobsResponse\x12`\n\tFetchLogs\x12(.fluster.cluster.Worker.FetchLogsRequest\x1a).fluster.cluster.Worker.FetchLogsResponse\x12I\n\x07KillJob\x12&.fluster.cluster.Worker.KillJobRequest\x1a\x16.fluster.cluster.Empty\x12M\n\x0bHealthCheck\x12\x16.fluster.cluster.Empty\x1a&.fluster.cluster.Worker.HealthResponseB\x80\x01\n\x13\x63om.fluster.clusterB\x0c\x43lusterProtoP\x01\xa2\x02\x03\x46\x43X\xaa\x02\x0f\x46luster.Cluster\xca\x02\x0f\x46luster\\Cluster\xe2\x02\x1b\x46luster\\Cluster\\GPBMetadata\xea\x02\x10\x46luster::Clusterb\x08\x65\x64itionsp\xe8\x07') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cluster_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\023com.fluster.clusterB\014ClusterProtoP\001\242\002\003FCX\252\002\017Fluster.Cluster\312\002\017Fluster\\Cluster\342\002\033Fluster\\Cluster\\GPBMetadata\352\002\020Fluster::Cluster' + _globals['_JOBSTATUS_PORTSENTRY']._loaded_options = None + _globals['_JOBSTATUS_PORTSENTRY']._serialized_options = b'8\001' + _globals['_ENVIRONMENTCONFIG_ENVVARSENTRY']._loaded_options = None + _globals['_ENVIRONMENTCONFIG_ENVVARSENTRY']._serialized_options = b'8\001' + _globals['_CONTROLLER_ENDPOINT_METADATAENTRY']._loaded_options = None + _globals['_CONTROLLER_ENDPOINT_METADATAENTRY']._serialized_options = b'8\001' + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST_METADATAENTRY']._loaded_options = None + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST_METADATAENTRY']._serialized_options = b'8\001' + _globals['_JOBSTATE']._serialized_start=5682 + _globals['_JOBSTATE']._serialized_end=5916 + _globals['_EMPTY']._serialized_start=34 + _globals['_EMPTY']._serialized_end=41 + _globals['_RESOURCEUSAGE']._serialized_start=44 + _globals['_RESOURCEUSAGE']._serialized_end=260 + _globals['_BUILDMETRICS']._serialized_start=263 + _globals['_BUILDMETRICS']._serialized_end=423 + _globals['_JOBSTATUS']._serialized_start=426 + _globals['_JOBSTATUS']._serialized_end=1044 + _globals['_JOBSTATUS_PORTSENTRY']._serialized_start=988 + _globals['_JOBSTATUS_PORTSENTRY']._serialized_end=1044 + _globals['_DEVICECONFIG']._serialized_start=1047 + _globals['_DEVICECONFIG']._serialized_end=1215 + _globals['_CPUDEVICE']._serialized_start=1217 + _globals['_CPUDEVICE']._serialized_end=1254 + _globals['_GPUDEVICE']._serialized_start=1256 + _globals['_GPUDEVICE']._serialized_end=1315 + _globals['_TPUDEVICE']._serialized_start=1317 + _globals['_TPUDEVICE']._serialized_end=1382 + _globals['_RESOURCESPEC']._serialized_start=1385 + _globals['_RESOURCESPEC']._serialized_end=1604 + _globals['_ENVIRONMENTCONFIG']._serialized_start=1607 + _globals['_ENVIRONMENTCONFIG']._serialized_end=1863 + _globals['_ENVIRONMENTCONFIG_ENVVARSENTRY']._serialized_start=1795 + _globals['_ENVIRONMENTCONFIG_ENVVARSENTRY']._serialized_end=1853 + _globals['_CONTROLLER']._serialized_start=1866 + _globals['_CONTROLLER']._serialized_end=4433 + _globals['_CONTROLLER_LAUNCHJOBREQUEST']._serialized_start=1881 + _globals['_CONTROLLER_LAUNCHJOBREQUEST']._serialized_end=2293 + _globals['_CONTROLLER_LAUNCHJOBRESPONSE']._serialized_start=2295 + _globals['_CONTROLLER_LAUNCHJOBRESPONSE']._serialized_end=2337 + _globals['_CONTROLLER_GETJOBSTATUSREQUEST']._serialized_start=2339 + _globals['_CONTROLLER_GETJOBSTATUSREQUEST']._serialized_end=2422 + _globals['_CONTROLLER_GETJOBSTATUSRESPONSE']._serialized_start=2424 + _globals['_CONTROLLER_GETJOBSTATUSRESPONSE']._serialized_end=2492 + _globals['_CONTROLLER_TERMINATEJOBREQUEST']._serialized_start=2494 + _globals['_CONTROLLER_TERMINATEJOBREQUEST']._serialized_end=2538 + _globals['_CONTROLLER_LISTJOBSREQUEST']._serialized_start=2540 + _globals['_CONTROLLER_LISTJOBSREQUEST']._serialized_end=2587 + _globals['_CONTROLLER_LISTJOBSRESPONSE']._serialized_start=2589 + _globals['_CONTROLLER_LISTJOBSRESPONSE']._serialized_end=2655 + _globals['_CONTROLLER_WORKERINFO']._serialized_start=2658 + _globals['_CONTROLLER_WORKERINFO']._serialized_end=2828 + _globals['_CONTROLLER_REGISTERWORKERREQUEST']._serialized_start=2831 + _globals['_CONTROLLER_REGISTERWORKERREQUEST']._serialized_end=2970 + _globals['_CONTROLLER_REGISTERWORKERRESPONSE']._serialized_start=2972 + _globals['_CONTROLLER_REGISTERWORKERRESPONSE']._serialized_end=3071 + _globals['_CONTROLLER_WORKERHEALTHSTATUS']._serialized_start=3074 + _globals['_CONTROLLER_WORKERHEALTHSTATUS']._serialized_end=3284 + _globals['_CONTROLLER_LISTWORKERSREQUEST']._serialized_start=3286 + _globals['_CONTROLLER_LISTWORKERSREQUEST']._serialized_end=3306 + _globals['_CONTROLLER_LISTWORKERSRESPONSE']._serialized_start=3308 + _globals['_CONTROLLER_LISTWORKERSRESPONSE']._serialized_end=3403 + _globals['_CONTROLLER_ENDPOINT']._serialized_start=3406 + _globals['_CONTROLLER_ENDPOINT']._serialized_end=3689 + _globals['_CONTROLLER_ENDPOINT_METADATAENTRY']._serialized_start=3630 + _globals['_CONTROLLER_ENDPOINT_METADATAENTRY']._serialized_end=3689 + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST']._serialized_start=3692 + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST']._serialized_end=3972 + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST_METADATAENTRY']._serialized_start=3630 + _globals['_CONTROLLER_REGISTERENDPOINTREQUEST_METADATAENTRY']._serialized_end=3689 + _globals['_CONTROLLER_REGISTERENDPOINTRESPONSE']._serialized_start=3974 + _globals['_CONTROLLER_REGISTERENDPOINTRESPONSE']._serialized_end=4033 + _globals['_CONTROLLER_UNREGISTERENDPOINTREQUEST']._serialized_start=4035 + _globals['_CONTROLLER_UNREGISTERENDPOINTREQUEST']._serialized_end=4095 + _globals['_CONTROLLER_LOOKUPENDPOINTREQUEST']._serialized_start=4097 + _globals['_CONTROLLER_LOOKUPENDPOINTREQUEST']._serialized_end=4170 + _globals['_CONTROLLER_LOOKUPENDPOINTRESPONSE']._serialized_start=4172 + _globals['_CONTROLLER_LOOKUPENDPOINTRESPONSE']._serialized_end=4262 + _globals['_CONTROLLER_LISTENDPOINTSREQUEST']._serialized_start=4264 + _globals['_CONTROLLER_LISTENDPOINTSREQUEST']._serialized_end=4340 + _globals['_CONTROLLER_LISTENDPOINTSRESPONSE']._serialized_start=4342 + _globals['_CONTROLLER_LISTENDPOINTSRESPONSE']._serialized_end=4433 + _globals['_WORKER']._serialized_start=4436 + _globals['_WORKER']._serialized_end=5679 + _globals['_WORKER_RUNJOBREQUEST']._serialized_start=4447 + _globals['_WORKER_RUNJOBREQUEST']._serialized_end=4772 + _globals['_WORKER_RUNJOBRESPONSE']._serialized_start=4774 + _globals['_WORKER_RUNJOBRESPONSE']._serialized_end=4862 + _globals['_WORKER_GETJOBSTATUSREQUEST']._serialized_start=2339 + _globals['_WORKER_GETJOBSTATUSREQUEST']._serialized_end=2422 + _globals['_WORKER_LISTJOBSREQUEST']._serialized_start=2540 + _globals['_WORKER_LISTJOBSREQUEST']._serialized_end=2587 + _globals['_WORKER_LISTJOBSRESPONSE']._serialized_start=2589 + _globals['_WORKER_LISTJOBSRESPONSE']._serialized_end=2655 + _globals['_WORKER_LOGENTRY']._serialized_start=5066 + _globals['_WORKER_LOGENTRY']._serialized_end=5155 + _globals['_WORKER_FETCHLOGSFILTER']._serialized_start=5158 + _globals['_WORKER_FETCHLOGSFILTER']._serialized_end=5307 + _globals['_WORKER_FETCHLOGSREQUEST']._serialized_start=5309 + _globals['_WORKER_FETCHLOGSREQUEST']._serialized_end=5415 + _globals['_WORKER_FETCHLOGSRESPONSE']._serialized_start=5417 + _globals['_WORKER_FETCHLOGSRESPONSE']._serialized_end=5490 + _globals['_WORKER_KILLJOBREQUEST']._serialized_start=5492 + _globals['_WORKER_KILLJOBREQUEST']._serialized_end=5571 + _globals['_WORKER_HEALTHRESPONSE']._serialized_start=5573 + _globals['_WORKER_HEALTHRESPONSE']._serialized_end=5679 + _globals['_CONTROLLERSERVICE']._serialized_start=5919 + _globals['_CONTROLLERSERVICE']._serialized_end=7051 + _globals['_WORKERSERVICE']._serialized_start=7054 + _globals['_WORKERSERVICE']._serialized_end=7594 +# @@protoc_insertion_point(module_scope) diff --git a/lib/fluster/src/fluster/cluster_pb2.pyi b/lib/fluster/src/fluster/cluster_pb2.pyi new file mode 100644 index 0000000000..6c5b6f0c2e --- /dev/null +++ b/lib/fluster/src/fluster/cluster_pb2.pyi @@ -0,0 +1,443 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from collections.abc import Iterable as _Iterable, Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class JobState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + JOB_STATE_UNSPECIFIED: _ClassVar[JobState] + JOB_STATE_PENDING: _ClassVar[JobState] + JOB_STATE_BUILDING: _ClassVar[JobState] + JOB_STATE_RUNNING: _ClassVar[JobState] + JOB_STATE_SUCCEEDED: _ClassVar[JobState] + JOB_STATE_FAILED: _ClassVar[JobState] + JOB_STATE_KILLED: _ClassVar[JobState] + JOB_STATE_WORKER_FAILED: _ClassVar[JobState] + JOB_STATE_UNSCHEDULABLE: _ClassVar[JobState] +JOB_STATE_UNSPECIFIED: JobState +JOB_STATE_PENDING: JobState +JOB_STATE_BUILDING: JobState +JOB_STATE_RUNNING: JobState +JOB_STATE_SUCCEEDED: JobState +JOB_STATE_FAILED: JobState +JOB_STATE_KILLED: JobState +JOB_STATE_WORKER_FAILED: JobState +JOB_STATE_UNSCHEDULABLE: JobState + +class Empty(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ResourceUsage(_message.Message): + __slots__ = ("memory_mb", "disk_mb", "cpu_millicores", "memory_peak_mb", "cpu_percent", "process_count") + MEMORY_MB_FIELD_NUMBER: _ClassVar[int] + DISK_MB_FIELD_NUMBER: _ClassVar[int] + CPU_MILLICORES_FIELD_NUMBER: _ClassVar[int] + MEMORY_PEAK_MB_FIELD_NUMBER: _ClassVar[int] + CPU_PERCENT_FIELD_NUMBER: _ClassVar[int] + PROCESS_COUNT_FIELD_NUMBER: _ClassVar[int] + memory_mb: int + disk_mb: int + cpu_millicores: int + memory_peak_mb: int + cpu_percent: int + process_count: int + def __init__(self, memory_mb: _Optional[int] = ..., disk_mb: _Optional[int] = ..., cpu_millicores: _Optional[int] = ..., memory_peak_mb: _Optional[int] = ..., cpu_percent: _Optional[int] = ..., process_count: _Optional[int] = ...) -> None: ... + +class BuildMetrics(_message.Message): + __slots__ = ("build_started_ms", "build_finished_ms", "from_cache", "image_tag") + BUILD_STARTED_MS_FIELD_NUMBER: _ClassVar[int] + BUILD_FINISHED_MS_FIELD_NUMBER: _ClassVar[int] + FROM_CACHE_FIELD_NUMBER: _ClassVar[int] + IMAGE_TAG_FIELD_NUMBER: _ClassVar[int] + build_started_ms: int + build_finished_ms: int + from_cache: bool + image_tag: str + def __init__(self, build_started_ms: _Optional[int] = ..., build_finished_ms: _Optional[int] = ..., from_cache: _Optional[bool] = ..., image_tag: _Optional[str] = ...) -> None: ... + +class JobStatus(_message.Message): + __slots__ = ("job_id", "state", "exit_code", "error", "started_at_ms", "finished_at_ms", "ports", "resource_usage", "status_message", "build_metrics", "worker_id", "worker_address", "serialized_result") + class PortsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: int + def __init__(self, key: _Optional[str] = ..., value: _Optional[int] = ...) -> None: ... + JOB_ID_FIELD_NUMBER: _ClassVar[int] + STATE_FIELD_NUMBER: _ClassVar[int] + EXIT_CODE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + STARTED_AT_MS_FIELD_NUMBER: _ClassVar[int] + FINISHED_AT_MS_FIELD_NUMBER: _ClassVar[int] + PORTS_FIELD_NUMBER: _ClassVar[int] + RESOURCE_USAGE_FIELD_NUMBER: _ClassVar[int] + STATUS_MESSAGE_FIELD_NUMBER: _ClassVar[int] + BUILD_METRICS_FIELD_NUMBER: _ClassVar[int] + WORKER_ID_FIELD_NUMBER: _ClassVar[int] + WORKER_ADDRESS_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_RESULT_FIELD_NUMBER: _ClassVar[int] + job_id: str + state: JobState + exit_code: int + error: str + started_at_ms: int + finished_at_ms: int + ports: _containers.ScalarMap[str, int] + resource_usage: ResourceUsage + status_message: str + build_metrics: BuildMetrics + worker_id: str + worker_address: str + serialized_result: bytes + def __init__(self, job_id: _Optional[str] = ..., state: _Optional[_Union[JobState, str]] = ..., exit_code: _Optional[int] = ..., error: _Optional[str] = ..., started_at_ms: _Optional[int] = ..., finished_at_ms: _Optional[int] = ..., ports: _Optional[_Mapping[str, int]] = ..., resource_usage: _Optional[_Union[ResourceUsage, _Mapping]] = ..., status_message: _Optional[str] = ..., build_metrics: _Optional[_Union[BuildMetrics, _Mapping]] = ..., worker_id: _Optional[str] = ..., worker_address: _Optional[str] = ..., serialized_result: _Optional[bytes] = ...) -> None: ... + +class DeviceConfig(_message.Message): + __slots__ = ("cpu", "gpu", "tpu") + CPU_FIELD_NUMBER: _ClassVar[int] + GPU_FIELD_NUMBER: _ClassVar[int] + TPU_FIELD_NUMBER: _ClassVar[int] + cpu: CpuDevice + gpu: GpuDevice + tpu: TpuDevice + def __init__(self, cpu: _Optional[_Union[CpuDevice, _Mapping]] = ..., gpu: _Optional[_Union[GpuDevice, _Mapping]] = ..., tpu: _Optional[_Union[TpuDevice, _Mapping]] = ...) -> None: ... + +class CpuDevice(_message.Message): + __slots__ = ("variant",) + VARIANT_FIELD_NUMBER: _ClassVar[int] + variant: str + def __init__(self, variant: _Optional[str] = ...) -> None: ... + +class GpuDevice(_message.Message): + __slots__ = ("variant", "count") + VARIANT_FIELD_NUMBER: _ClassVar[int] + COUNT_FIELD_NUMBER: _ClassVar[int] + variant: str + count: int + def __init__(self, variant: _Optional[str] = ..., count: _Optional[int] = ...) -> None: ... + +class TpuDevice(_message.Message): + __slots__ = ("variant", "topology") + VARIANT_FIELD_NUMBER: _ClassVar[int] + TOPOLOGY_FIELD_NUMBER: _ClassVar[int] + variant: str + topology: str + def __init__(self, variant: _Optional[str] = ..., topology: _Optional[str] = ...) -> None: ... + +class ResourceSpec(_message.Message): + __slots__ = ("cpu", "memory", "disk", "device", "replicas", "preemptible", "regions") + CPU_FIELD_NUMBER: _ClassVar[int] + MEMORY_FIELD_NUMBER: _ClassVar[int] + DISK_FIELD_NUMBER: _ClassVar[int] + DEVICE_FIELD_NUMBER: _ClassVar[int] + REPLICAS_FIELD_NUMBER: _ClassVar[int] + PREEMPTIBLE_FIELD_NUMBER: _ClassVar[int] + REGIONS_FIELD_NUMBER: _ClassVar[int] + cpu: int + memory: str + disk: str + device: DeviceConfig + replicas: int + preemptible: bool + regions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, cpu: _Optional[int] = ..., memory: _Optional[str] = ..., disk: _Optional[str] = ..., device: _Optional[_Union[DeviceConfig, _Mapping]] = ..., replicas: _Optional[int] = ..., preemptible: _Optional[bool] = ..., regions: _Optional[_Iterable[str]] = ...) -> None: ... + +class EnvironmentConfig(_message.Message): + __slots__ = ("workspace", "pip_packages", "env_vars", "extras") + class EnvVarsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + WORKSPACE_FIELD_NUMBER: _ClassVar[int] + PIP_PACKAGES_FIELD_NUMBER: _ClassVar[int] + ENV_VARS_FIELD_NUMBER: _ClassVar[int] + EXTRAS_FIELD_NUMBER: _ClassVar[int] + workspace: str + pip_packages: _containers.RepeatedScalarFieldContainer[str] + env_vars: _containers.ScalarMap[str, str] + extras: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, workspace: _Optional[str] = ..., pip_packages: _Optional[_Iterable[str]] = ..., env_vars: _Optional[_Mapping[str, str]] = ..., extras: _Optional[_Iterable[str]] = ...) -> None: ... + +class Controller(_message.Message): + __slots__ = () + class LaunchJobRequest(_message.Message): + __slots__ = ("name", "serialized_entrypoint", "resources", "environment", "bundle_gcs_path", "bundle_hash", "bundle_blob", "scheduling_timeout_seconds", "ports") + NAME_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_ENTRYPOINT_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] + ENVIRONMENT_FIELD_NUMBER: _ClassVar[int] + BUNDLE_GCS_PATH_FIELD_NUMBER: _ClassVar[int] + BUNDLE_HASH_FIELD_NUMBER: _ClassVar[int] + BUNDLE_BLOB_FIELD_NUMBER: _ClassVar[int] + SCHEDULING_TIMEOUT_SECONDS_FIELD_NUMBER: _ClassVar[int] + PORTS_FIELD_NUMBER: _ClassVar[int] + name: str + serialized_entrypoint: bytes + resources: ResourceSpec + environment: EnvironmentConfig + bundle_gcs_path: str + bundle_hash: str + bundle_blob: bytes + scheduling_timeout_seconds: int + ports: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, name: _Optional[str] = ..., serialized_entrypoint: _Optional[bytes] = ..., resources: _Optional[_Union[ResourceSpec, _Mapping]] = ..., environment: _Optional[_Union[EnvironmentConfig, _Mapping]] = ..., bundle_gcs_path: _Optional[str] = ..., bundle_hash: _Optional[str] = ..., bundle_blob: _Optional[bytes] = ..., scheduling_timeout_seconds: _Optional[int] = ..., ports: _Optional[_Iterable[str]] = ...) -> None: ... + class LaunchJobResponse(_message.Message): + __slots__ = ("job_id",) + JOB_ID_FIELD_NUMBER: _ClassVar[int] + job_id: str + def __init__(self, job_id: _Optional[str] = ...) -> None: ... + class GetJobStatusRequest(_message.Message): + __slots__ = ("job_id", "include_result") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + INCLUDE_RESULT_FIELD_NUMBER: _ClassVar[int] + job_id: str + include_result: bool + def __init__(self, job_id: _Optional[str] = ..., include_result: _Optional[bool] = ...) -> None: ... + class GetJobStatusResponse(_message.Message): + __slots__ = ("job",) + JOB_FIELD_NUMBER: _ClassVar[int] + job: JobStatus + def __init__(self, job: _Optional[_Union[JobStatus, _Mapping]] = ...) -> None: ... + class TerminateJobRequest(_message.Message): + __slots__ = ("job_id",) + JOB_ID_FIELD_NUMBER: _ClassVar[int] + job_id: str + def __init__(self, job_id: _Optional[str] = ...) -> None: ... + class ListJobsRequest(_message.Message): + __slots__ = ("namespace",) + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + namespace: str + def __init__(self, namespace: _Optional[str] = ...) -> None: ... + class ListJobsResponse(_message.Message): + __slots__ = ("jobs",) + JOBS_FIELD_NUMBER: _ClassVar[int] + jobs: _containers.RepeatedCompositeFieldContainer[JobStatus] + def __init__(self, jobs: _Optional[_Iterable[_Union[JobStatus, _Mapping]]] = ...) -> None: ... + class WorkerInfo(_message.Message): + __slots__ = ("worker_id", "address", "resources", "registered_at_ms") + WORKER_ID_FIELD_NUMBER: _ClassVar[int] + ADDRESS_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] + REGISTERED_AT_MS_FIELD_NUMBER: _ClassVar[int] + worker_id: str + address: str + resources: ResourceSpec + registered_at_ms: int + def __init__(self, worker_id: _Optional[str] = ..., address: _Optional[str] = ..., resources: _Optional[_Union[ResourceSpec, _Mapping]] = ..., registered_at_ms: _Optional[int] = ...) -> None: ... + class RegisterWorkerRequest(_message.Message): + __slots__ = ("worker_id", "address", "resources") + WORKER_ID_FIELD_NUMBER: _ClassVar[int] + ADDRESS_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] + worker_id: str + address: str + resources: ResourceSpec + def __init__(self, worker_id: _Optional[str] = ..., address: _Optional[str] = ..., resources: _Optional[_Union[ResourceSpec, _Mapping]] = ...) -> None: ... + class RegisterWorkerResponse(_message.Message): + __slots__ = ("accepted", "controller_address") + ACCEPTED_FIELD_NUMBER: _ClassVar[int] + CONTROLLER_ADDRESS_FIELD_NUMBER: _ClassVar[int] + accepted: bool + controller_address: str + def __init__(self, accepted: _Optional[bool] = ..., controller_address: _Optional[str] = ...) -> None: ... + class WorkerHealthStatus(_message.Message): + __slots__ = ("worker_id", "healthy", "consecutive_failures", "last_heartbeat_ms", "running_job_ids") + WORKER_ID_FIELD_NUMBER: _ClassVar[int] + HEALTHY_FIELD_NUMBER: _ClassVar[int] + CONSECUTIVE_FAILURES_FIELD_NUMBER: _ClassVar[int] + LAST_HEARTBEAT_MS_FIELD_NUMBER: _ClassVar[int] + RUNNING_JOB_IDS_FIELD_NUMBER: _ClassVar[int] + worker_id: str + healthy: bool + consecutive_failures: int + last_heartbeat_ms: int + running_job_ids: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, worker_id: _Optional[str] = ..., healthy: _Optional[bool] = ..., consecutive_failures: _Optional[int] = ..., last_heartbeat_ms: _Optional[int] = ..., running_job_ids: _Optional[_Iterable[str]] = ...) -> None: ... + class ListWorkersRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + class ListWorkersResponse(_message.Message): + __slots__ = ("workers",) + WORKERS_FIELD_NUMBER: _ClassVar[int] + workers: _containers.RepeatedCompositeFieldContainer[Controller.WorkerHealthStatus] + def __init__(self, workers: _Optional[_Iterable[_Union[Controller.WorkerHealthStatus, _Mapping]]] = ...) -> None: ... + class Endpoint(_message.Message): + __slots__ = ("endpoint_id", "name", "address", "job_id", "namespace", "metadata") + class MetadataEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + ENDPOINT_ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + ADDRESS_FIELD_NUMBER: _ClassVar[int] + JOB_ID_FIELD_NUMBER: _ClassVar[int] + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + endpoint_id: str + name: str + address: str + job_id: str + namespace: str + metadata: _containers.ScalarMap[str, str] + def __init__(self, endpoint_id: _Optional[str] = ..., name: _Optional[str] = ..., address: _Optional[str] = ..., job_id: _Optional[str] = ..., namespace: _Optional[str] = ..., metadata: _Optional[_Mapping[str, str]] = ...) -> None: ... + class RegisterEndpointRequest(_message.Message): + __slots__ = ("name", "address", "job_id", "namespace", "metadata") + class MetadataEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + ADDRESS_FIELD_NUMBER: _ClassVar[int] + JOB_ID_FIELD_NUMBER: _ClassVar[int] + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + name: str + address: str + job_id: str + namespace: str + metadata: _containers.ScalarMap[str, str] + def __init__(self, name: _Optional[str] = ..., address: _Optional[str] = ..., job_id: _Optional[str] = ..., namespace: _Optional[str] = ..., metadata: _Optional[_Mapping[str, str]] = ...) -> None: ... + class RegisterEndpointResponse(_message.Message): + __slots__ = ("endpoint_id",) + ENDPOINT_ID_FIELD_NUMBER: _ClassVar[int] + endpoint_id: str + def __init__(self, endpoint_id: _Optional[str] = ...) -> None: ... + class UnregisterEndpointRequest(_message.Message): + __slots__ = ("endpoint_id",) + ENDPOINT_ID_FIELD_NUMBER: _ClassVar[int] + endpoint_id: str + def __init__(self, endpoint_id: _Optional[str] = ...) -> None: ... + class LookupEndpointRequest(_message.Message): + __slots__ = ("name", "namespace") + NAME_FIELD_NUMBER: _ClassVar[int] + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + name: str + namespace: str + def __init__(self, name: _Optional[str] = ..., namespace: _Optional[str] = ...) -> None: ... + class LookupEndpointResponse(_message.Message): + __slots__ = ("endpoint",) + ENDPOINT_FIELD_NUMBER: _ClassVar[int] + endpoint: Controller.Endpoint + def __init__(self, endpoint: _Optional[_Union[Controller.Endpoint, _Mapping]] = ...) -> None: ... + class ListEndpointsRequest(_message.Message): + __slots__ = ("prefix", "namespace") + PREFIX_FIELD_NUMBER: _ClassVar[int] + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + prefix: str + namespace: str + def __init__(self, prefix: _Optional[str] = ..., namespace: _Optional[str] = ...) -> None: ... + class ListEndpointsResponse(_message.Message): + __slots__ = ("endpoints",) + ENDPOINTS_FIELD_NUMBER: _ClassVar[int] + endpoints: _containers.RepeatedCompositeFieldContainer[Controller.Endpoint] + def __init__(self, endpoints: _Optional[_Iterable[_Union[Controller.Endpoint, _Mapping]]] = ...) -> None: ... + def __init__(self) -> None: ... + +class Worker(_message.Message): + __slots__ = () + class RunJobRequest(_message.Message): + __slots__ = ("job_id", "serialized_entrypoint", "environment", "bundle_gcs_path", "resources", "timeout_seconds", "ports") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + SERIALIZED_ENTRYPOINT_FIELD_NUMBER: _ClassVar[int] + ENVIRONMENT_FIELD_NUMBER: _ClassVar[int] + BUNDLE_GCS_PATH_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] + TIMEOUT_SECONDS_FIELD_NUMBER: _ClassVar[int] + PORTS_FIELD_NUMBER: _ClassVar[int] + job_id: str + serialized_entrypoint: bytes + environment: EnvironmentConfig + bundle_gcs_path: str + resources: ResourceSpec + timeout_seconds: int + ports: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, job_id: _Optional[str] = ..., serialized_entrypoint: _Optional[bytes] = ..., environment: _Optional[_Union[EnvironmentConfig, _Mapping]] = ..., bundle_gcs_path: _Optional[str] = ..., resources: _Optional[_Union[ResourceSpec, _Mapping]] = ..., timeout_seconds: _Optional[int] = ..., ports: _Optional[_Iterable[str]] = ...) -> None: ... + class RunJobResponse(_message.Message): + __slots__ = ("job_id", "state") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + STATE_FIELD_NUMBER: _ClassVar[int] + job_id: str + state: JobState + def __init__(self, job_id: _Optional[str] = ..., state: _Optional[_Union[JobState, str]] = ...) -> None: ... + class GetJobStatusRequest(_message.Message): + __slots__ = ("job_id", "include_result") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + INCLUDE_RESULT_FIELD_NUMBER: _ClassVar[int] + job_id: str + include_result: bool + def __init__(self, job_id: _Optional[str] = ..., include_result: _Optional[bool] = ...) -> None: ... + class ListJobsRequest(_message.Message): + __slots__ = ("namespace",) + NAMESPACE_FIELD_NUMBER: _ClassVar[int] + namespace: str + def __init__(self, namespace: _Optional[str] = ...) -> None: ... + class ListJobsResponse(_message.Message): + __slots__ = ("jobs",) + JOBS_FIELD_NUMBER: _ClassVar[int] + jobs: _containers.RepeatedCompositeFieldContainer[JobStatus] + def __init__(self, jobs: _Optional[_Iterable[_Union[JobStatus, _Mapping]]] = ...) -> None: ... + class LogEntry(_message.Message): + __slots__ = ("timestamp_ms", "source", "data") + TIMESTAMP_MS_FIELD_NUMBER: _ClassVar[int] + SOURCE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + timestamp_ms: int + source: str + data: str + def __init__(self, timestamp_ms: _Optional[int] = ..., source: _Optional[str] = ..., data: _Optional[str] = ...) -> None: ... + class FetchLogsFilter(_message.Message): + __slots__ = ("regex", "start_line", "start_ms", "end_ms", "max_lines") + REGEX_FIELD_NUMBER: _ClassVar[int] + START_LINE_FIELD_NUMBER: _ClassVar[int] + START_MS_FIELD_NUMBER: _ClassVar[int] + END_MS_FIELD_NUMBER: _ClassVar[int] + MAX_LINES_FIELD_NUMBER: _ClassVar[int] + regex: str + start_line: int + start_ms: int + end_ms: int + max_lines: int + def __init__(self, regex: _Optional[str] = ..., start_line: _Optional[int] = ..., start_ms: _Optional[int] = ..., end_ms: _Optional[int] = ..., max_lines: _Optional[int] = ...) -> None: ... + class FetchLogsRequest(_message.Message): + __slots__ = ("job_id", "filter") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + FILTER_FIELD_NUMBER: _ClassVar[int] + job_id: str + filter: Worker.FetchLogsFilter + def __init__(self, job_id: _Optional[str] = ..., filter: _Optional[_Union[Worker.FetchLogsFilter, _Mapping]] = ...) -> None: ... + class FetchLogsResponse(_message.Message): + __slots__ = ("logs",) + LOGS_FIELD_NUMBER: _ClassVar[int] + logs: _containers.RepeatedCompositeFieldContainer[Worker.LogEntry] + def __init__(self, logs: _Optional[_Iterable[_Union[Worker.LogEntry, _Mapping]]] = ...) -> None: ... + class KillJobRequest(_message.Message): + __slots__ = ("job_id", "term_timeout_ms") + JOB_ID_FIELD_NUMBER: _ClassVar[int] + TERM_TIMEOUT_MS_FIELD_NUMBER: _ClassVar[int] + job_id: str + term_timeout_ms: int + def __init__(self, job_id: _Optional[str] = ..., term_timeout_ms: _Optional[int] = ...) -> None: ... + class HealthResponse(_message.Message): + __slots__ = ("healthy", "uptime_ms", "running_jobs") + HEALTHY_FIELD_NUMBER: _ClassVar[int] + UPTIME_MS_FIELD_NUMBER: _ClassVar[int] + RUNNING_JOBS_FIELD_NUMBER: _ClassVar[int] + healthy: bool + uptime_ms: int + running_jobs: int + def __init__(self, healthy: _Optional[bool] = ..., uptime_ms: _Optional[int] = ..., running_jobs: _Optional[int] = ...) -> None: ... + def __init__(self) -> None: ... diff --git a/lib/fluster/src/fluster/proto/cluster.proto b/lib/fluster/src/fluster/proto/cluster.proto new file mode 100644 index 0000000000..83fa4ede9b --- /dev/null +++ b/lib/fluster/src/fluster/proto/cluster.proto @@ -0,0 +1,372 @@ +// Copyright 2025 The Marin Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package fluster.cluster; + +// ============================================================================ +// SHARED TYPES +// ============================================================================ + +message Empty {} + +enum JobState { + JOB_STATE_UNSPECIFIED = 0; + JOB_STATE_PENDING = 1; + JOB_STATE_BUILDING = 2; + JOB_STATE_RUNNING = 3; + JOB_STATE_SUCCEEDED = 4; + JOB_STATE_FAILED = 5; + JOB_STATE_KILLED = 6; + JOB_STATE_WORKER_FAILED = 7; // Worker died, job may be retried + JOB_STATE_UNSCHEDULABLE = 8; // Couldn't be scheduled within timeout +} + +message ResourceUsage { + int64 memory_mb = 1; + int64 disk_mb = 2; + int32 cpu_millicores = 3; + int64 memory_peak_mb = 4; + int32 cpu_percent = 5; + int32 process_count = 6; +} + +message BuildMetrics { + int64 build_started_ms = 1; + int64 build_finished_ms = 2; + bool from_cache = 3; + string image_tag = 4; +} + +message JobStatus { + string job_id = 1; + JobState state = 2; + int32 exit_code = 3; + string error = 4; + int64 started_at_ms = 5; + int64 finished_at_ms = 6; + + // Port allocations + map ports = 7; + + ResourceUsage resource_usage = 8; + + // Status message for current phase (e.g., "downloading bundle", "populating uv cache") + string status_message = 9; + + BuildMetrics build_metrics = 10; + + // Worker ID (populated by controller, not worker) + string worker_id = 11; + string worker_address = 12; // host:port for direct worker connection + bytes serialized_result = 13; // cloudpickle serialized return value (if requested) +} + +// ============================================================================ +// RESOURCE SPECIFICATION +// ============================================================================ + +// Device configuration - used in ResourceSpec +message DeviceConfig { + oneof device { + CpuDevice cpu = 1; + GpuDevice gpu = 2; + TpuDevice tpu = 3; + } +} + +message CpuDevice { + string variant = 1; // Always "cpu" +} + +message GpuDevice { + string variant = 1; // e.g., "A100", "H100", "auto" + int32 count = 2; // Number of GPUs +} + +message TpuDevice { + string variant = 1; // e.g., "v5litepod-16", "v4-8" + string topology = 2; // topology spec (e.g., "2x2x1") +} + +// Unified resource specification for jobs +// Used by both controller (for scheduling) and worker (for enforcement) +message ResourceSpec { + // Compute resources + int32 cpu = 1; // Number of CPU cores + string memory = 2; // RAM (e.g., "8g", "16g", "128m") + string disk = 3; // Disk space (e.g., "1g", "100g") + + // Device configuration + DeviceConfig device = 5; + + // Multi-instance configuration + int32 replicas = 6; // Number of replicas/slices + + // Scheduling preferences + bool preemptible = 7; + repeated string regions = 8; // Preferred cloud regions +} + +// ============================================================================ +// ENVIRONMENT CONFIGURATION +// ============================================================================ + +// Job environment configuration +// Exactly one of workspace or docker_image must be set (enforced in client) +message EnvironmentConfig { + oneof source { + string workspace = 1; // Path to workspace root for uv-based execution + } + + repeated string pip_packages = 2; // Additional pip packages to install + map env_vars = 3; // Environment variables to set + repeated string extras = 4; // Extra dependency groups for uv (e.g., ["tpu", "eval"]) +} + +// ============================================================================ +// CONTROLLER SERVICE MESSAGES +// ============================================================================ + +message Controller { + // --- Job Lifecycle --- + message LaunchJobRequest { + string name = 1; + bytes serialized_entrypoint = 2; // cloudpickle(Entrypoint) + ResourceSpec resources = 3; // Full resource specification + EnvironmentConfig environment = 4; // Environment configuration + + // Bundle information - either provide gcs_path OR blob (not both) + string bundle_gcs_path = 5; // gs://bucket/path/bundle.zip + string bundle_hash = 6; // SHA256 hash for caching + bytes bundle_blob = 7; // Direct bundle upload (controller writes to bundle_dir) + + // Scheduling timeout - job fails with UNSCHEDULABLE if not scheduled within this time + // 0 means no timeout (wait forever) + int32 scheduling_timeout_seconds = 8; + + // Named ports to allocate (e.g., ["http", "grpc", "actor"]) + // Worker allocates ports and injects FLUSTER_PORT_ env vars into the container + repeated string ports = 9; + } + + message LaunchJobResponse { + string job_id = 1; + } + + message GetJobStatusRequest { + string job_id = 1; + bool include_result = 2; // If true, include serialized_result in response + } + + message GetJobStatusResponse { + JobStatus job = 1; + } + + message TerminateJobRequest { + string job_id = 1; + } + + message ListJobsRequest { + string namespace = 1; + } + + message ListJobsResponse { + repeated JobStatus jobs = 1; + } + + // --- Worker Management --- + message WorkerInfo { + string worker_id = 1; + string address = 2; // host:port for WorkerService + ResourceSpec resources = 3; // Worker capabilities + int64 registered_at_ms = 4; + } + + message RegisterWorkerRequest { + string worker_id = 1; + string address = 2; + ResourceSpec resources = 3; + } + + message RegisterWorkerResponse { + bool accepted = 1; + string controller_address = 2; // For callbacks + } + + message WorkerHealthStatus { + string worker_id = 1; + bool healthy = 2; + int32 consecutive_failures = 3; + int64 last_heartbeat_ms = 4; + repeated string running_job_ids = 5; + } + + message ListWorkersRequest {} + + message ListWorkersResponse { + repeated WorkerHealthStatus workers = 1; + } + + // --- Endpoint Registry --- + message Endpoint { + string endpoint_id = 1; + string name = 2; + string address = 3; // host:port + string job_id = 4; + string namespace = 5; + map metadata = 6; + } + + message RegisterEndpointRequest { + string name = 1; + string address = 2; + string job_id = 3; + string namespace = 4; + map metadata = 5; + } + + message RegisterEndpointResponse { + string endpoint_id = 1; + } + + message UnregisterEndpointRequest { + string endpoint_id = 1; + } + + message LookupEndpointRequest { + string name = 1; + string namespace = 2; + } + + message LookupEndpointResponse { + Endpoint endpoint = 1; + } + + message ListEndpointsRequest { + string prefix = 1; + string namespace = 2; + } + + message ListEndpointsResponse { + repeated Endpoint endpoints = 1; + } +} + +// ============================================================================ +// WORKER SERVICE MESSAGES +// ============================================================================ + +message Worker { + message RunJobRequest { + string job_id = 1; + + // Serialized Python objects and workspace bundle for execution + bytes serialized_entrypoint = 2; // cloudpickle(Entrypoint) + EnvironmentConfig environment = 3; // Environment configuration + string bundle_gcs_path = 4; // gs://bucket/path/bundle.zip + + // Resource specification + ResourceSpec resources = 6; + + int32 timeout_seconds = 8; + + // Port requests + repeated string ports = 9; + } + + message RunJobResponse { + string job_id = 1; + JobState state = 2; + } + + message GetJobStatusRequest { + string job_id = 1; + bool include_result = 2; + } + + message ListJobsRequest { + string namespace = 1; + } + + message ListJobsResponse { + repeated JobStatus jobs = 1; + } + + message LogEntry { + int64 timestamp_ms = 1; + string source = 2; // "stdout", "stderr", or "build" + string data = 3; // Log line content + } + + message FetchLogsFilter { + string regex = 1; + int64 start_line = 2; + int64 start_ms = 3; + int64 end_ms = 4; + int64 max_lines = 5; + } + + message FetchLogsRequest { + string job_id = 1; + FetchLogsFilter filter = 2; + } + + message FetchLogsResponse { + repeated LogEntry logs = 1; + } + + message KillJobRequest { + string job_id = 1; + int32 term_timeout_ms = 2; // Time to wait for graceful termination before SIGKILL + } + + message HealthResponse { + bool healthy = 1; + int64 uptime_ms = 2; + int32 running_jobs = 3; + } +} + +// ============================================================================ +// SERVICES +// ============================================================================ + +service ControllerService { + // Job lifecycle + rpc LaunchJob(Controller.LaunchJobRequest) returns (Controller.LaunchJobResponse); + rpc GetJobStatus(Controller.GetJobStatusRequest) returns (Controller.GetJobStatusResponse); + rpc TerminateJob(Controller.TerminateJobRequest) returns (Empty); + rpc ListJobs(Controller.ListJobsRequest) returns (Controller.ListJobsResponse); + + // Worker management (controller polls workers internally via Worker.HealthCheck) + rpc RegisterWorker(Controller.RegisterWorkerRequest) returns (Controller.RegisterWorkerResponse); + rpc ListWorkers(Controller.ListWorkersRequest) returns (Controller.ListWorkersResponse); + + // Endpoint registry (generic service discovery) + rpc RegisterEndpoint(Controller.RegisterEndpointRequest) returns (Controller.RegisterEndpointResponse); + rpc UnregisterEndpoint(Controller.UnregisterEndpointRequest) returns (Empty); + rpc LookupEndpoint(Controller.LookupEndpointRequest) returns (Controller.LookupEndpointResponse); + rpc ListEndpoints(Controller.ListEndpointsRequest) returns (Controller.ListEndpointsResponse); +} + +service WorkerService { + rpc RunJob(Worker.RunJobRequest) returns (Worker.RunJobResponse); + rpc GetJobStatus(Worker.GetJobStatusRequest) returns (JobStatus); + rpc ListJobs(Worker.ListJobsRequest) returns (Worker.ListJobsResponse); + rpc FetchLogs(Worker.FetchLogsRequest) returns (Worker.FetchLogsResponse); + rpc KillJob(Worker.KillJobRequest) returns (Empty); + rpc HealthCheck(Empty) returns (Worker.HealthResponse); +} diff --git a/lib/fluster/src/fluster/worker_pool.py b/lib/fluster/src/fluster/worker_pool.py new file mode 100644 index 0000000000..eb5c98d9d6 --- /dev/null +++ b/lib/fluster/src/fluster/worker_pool.py @@ -0,0 +1,757 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WorkerPool for task dispatch. + +WorkerPool provides a high-level interface for dispatching arbitrary callables +to a pool of stateless workers. Unlike ActorPool (which load-balances calls to +pre-existing actors with known methods), WorkerPool creates and manages worker +jobs that can execute any callable. + +Example: + from fluster.worker_pool import WorkerPool, WorkerPoolConfig + from fluster.cluster.client import RpcClusterClient, BundleCreator + + bundle = BundleCreator().create_bundle() + client = RpcClusterClient("http://controller:8080", bundle) + + config = WorkerPoolConfig( + num_workers=3, + resources=cluster_pb2.ResourceSpec(cpu=1, memory="512m"), + ) + + with WorkerPool(client, config) as pool: + # Submit tasks + futures = [pool.submit(expensive_fn, i) for i in range(10)] + results = [f.result() for f in futures] +""" + +import os +import threading +import time +import uuid +from collections.abc import Callable, Sequence +from concurrent.futures import Future +from dataclasses import dataclass +from enum import Enum, auto +from queue import Empty, Queue +from typing import Any, Generic, TypeVar + +import cloudpickle + +from fluster import actor_pb2, cluster_pb2 +from fluster.actor import ActorServer +from fluster.actor.resolver import ClusterResolver, Resolver +from fluster.actor_connect import ActorServiceClientSync +from fluster.cluster.client import ClusterClient +from fluster.cluster.types import Entrypoint, JobId, Namespace +from fluster.cluster_connect import ControllerServiceClientSync + +T = TypeVar("T") + + +class WorkerStatus(Enum): + """Status of a worker in the pool.""" + + PENDING = auto() # Worker job launched, not yet registered + IDLE = auto() # Ready to accept tasks + BUSY = auto() # Currently executing a task + FAILED = auto() # Worker has failed/disconnected + + +class UserException(Exception): + """Wrapper for exceptions raised by user code (not infrastructure failures).""" + + def __init__(self, inner: BaseException): + self.inner = inner + super().__init__(str(inner)) + + +@dataclass +class WorkerState: + """Client-side state for a single worker.""" + + worker_id: str + worker_name: str + endpoint_url: str | None = None + status: WorkerStatus = WorkerStatus.PENDING + current_task_id: str | None = None + tasks_completed: int = 0 + tasks_failed: int = 0 + + +@dataclass +class PendingTask: + """A task waiting to be dispatched to a worker.""" + + task_id: str + serialized_fn: bytes + serialized_args: bytes + serialized_kwargs: bytes + future: Future + fn_name: str + submitted_at: float + retries_remaining: int = 0 + + +@dataclass +class PoolStatus: + """Status snapshot of the worker pool.""" + + pool_id: str + num_workers: int + workers_idle: int + workers_busy: int + workers_pending: int + workers_failed: int + tasks_queued: int + tasks_completed: int + tasks_failed: int + worker_details: list[dict] + + +class TaskExecutorActor: + """Actor that executes arbitrary callables. + + This is the server-side component of WorkerPool. Each worker job runs + one of these actors to execute tasks dispatched by the pool. + + The callable and arguments are received as cloudpickle-serialized bytes. + The return value is returned raw - ActorServer handles serialization. + """ + + def execute( + self, + serialized_callable: bytes, + serialized_args: bytes, + serialized_kwargs: bytes, + ) -> Any: + """Execute a pickled callable and return the result. + + Args: + serialized_callable: cloudpickle-serialized callable + serialized_args: cloudpickle-serialized tuple of positional args + serialized_kwargs: cloudpickle-serialized dict of keyword args + + Returns: + The return value of calling fn(*args, **kwargs). + ActorServer handles serialization of this value. + + Raises: + Any exception raised by the callable (propagated to client). + """ + fn = cloudpickle.loads(serialized_callable) + args = cloudpickle.loads(serialized_args) + kwargs = cloudpickle.loads(serialized_kwargs) + return fn(*args, **kwargs) + + +def register_endpoint( + controller_url: str, + name: str, + address: str, + job_id: str, + namespace: str, +) -> str: + """Register an endpoint with the cluster controller. + + Args: + controller_url: Controller URL (e.g., "http://localhost:8080") + name: Endpoint name for discovery + address: Address where the endpoint is listening (host:port) + job_id: Job ID that owns this endpoint + namespace: Namespace for isolation + + Returns: + Endpoint ID assigned by the controller + """ + client = ControllerServiceClientSync(address=controller_url, timeout_ms=10000) + request = cluster_pb2.Controller.RegisterEndpointRequest( + name=name, + address=address, + job_id=job_id, + namespace=namespace, + ) + response = client.register_endpoint(request) + return response.endpoint_id + + +def worker_job_entrypoint(pool_id: str, worker_index: int) -> None: + """Job entrypoint that starts a TaskExecutor actor with a unique name. + + This function is called when a worker job starts. It: + 1. Reads cluster configuration from environment variables + 2. Starts an ActorServer with a TaskExecutorActor + 3. Registers the endpoint with the controller under a unique name + 4. Runs forever, serving requests + + Each worker registers with a unique name so the client can target + specific idle workers when dispatching tasks. + + Environment variables (injected by the cluster): + FLUSTER_JOB_ID: Unique job identifier + FLUSTER_NAMESPACE: Namespace for actor isolation + FLUSTER_PORT_ACTOR: Port allocated for the actor server + FLUSTER_CONTROLLER_ADDRESS: Controller URL for registration + + Args: + pool_id: Unique identifier for the worker pool + worker_index: Index of this worker (0, 1, 2, ...) + """ + job_id = os.environ["FLUSTER_JOB_ID"] + namespace = os.environ["FLUSTER_NAMESPACE"] + port = int(os.environ["FLUSTER_PORT_ACTOR"]) + controller_url = os.environ["FLUSTER_CONTROLLER_ADDRESS"] + + # Unique name per worker + worker_name = f"_workerpool_{pool_id}:worker-{worker_index}" + + print(f"Worker starting: pool_id={pool_id}, worker_index={worker_index}") + print(f"Worker name: {worker_name}") + + # Start actor server + server = ActorServer(host="0.0.0.0", port=port) + server.register(worker_name, TaskExecutorActor()) + actual_port = server.serve_background() + print(f"ActorServer started on port {actual_port}") + + # Register with controller + # Use localhost since the port is mapped from host to container via Docker -p + endpoint_address = f"localhost:{actual_port}" + print(f"Registering endpoint: {worker_name} at {endpoint_address}") + + endpoint_id = register_endpoint( + controller_url, + worker_name, + endpoint_address, + job_id, + namespace, + ) + print(f"Endpoint registered: {endpoint_id}") + + # Serve forever + print("Worker ready, waiting for tasks...") + while True: + time.sleep(1) + + +class WorkerDispatcher: + """Dispatch thread for a single worker. + + Handles endpoint discovery and task dispatch in a dedicated thread. + State transitions: + - PENDING: Poll resolver for endpoint registration + - IDLE: Wait for task from queue + - BUSY: Execute task on worker endpoint + - FAILED: Worker has failed, thread exits + + On infrastructure failure (connection error), the task is re-queued for + another worker if retries remain. User exceptions propagate immediately. + """ + + def __init__( + self, + state: WorkerState, + task_queue: "Queue[PendingTask]", + resolver: Resolver, + timeout: float, + ): + self.state = state + self._task_queue = task_queue + self._resolver = resolver + self._timeout = timeout + self._shutdown = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + """Start the dispatch thread.""" + self._thread = threading.Thread( + target=self._run, + daemon=True, + name=f"dispatch-{self.state.worker_id}", + ) + self._thread.start() + + def stop(self) -> None: + """Signal the dispatch thread to stop.""" + self._shutdown.set() + + def join(self, timeout: float | None = None) -> None: + """Wait for the dispatch thread to finish.""" + if self._thread: + self._thread.join(timeout=timeout) + + def _run(self) -> None: + """Main dispatch loop.""" + while not self._shutdown.is_set(): + if self.state.status == WorkerStatus.PENDING: + self._discover_endpoint() + continue + + if self.state.status == WorkerStatus.FAILED: + break + + task = self._get_task() + if task: + self._execute_task(task) + + def _discover_endpoint(self) -> None: + """Poll resolver for endpoint registration.""" + result = self._resolver.resolve(self.state.worker_name) + if not result.is_empty: + endpoint = result.first() + self.state.endpoint_url = endpoint.url + self.state.status = WorkerStatus.IDLE + print(f"Worker {self.state.worker_id} discovered at {endpoint.url}") + else: + time.sleep(0.1) + + def _get_task(self) -> PendingTask | None: + """Try to get a task from the queue.""" + try: + return self._task_queue.get(timeout=0.5) + except Empty: + return None + + def _execute_task(self, task: PendingTask) -> None: + """Execute a task on the worker endpoint.""" + self.state.status = WorkerStatus.BUSY + self.state.current_task_id = task.task_id + + try: + result = _call_worker_endpoint( + endpoint_url=self.state.endpoint_url, + actor_name=self.state.worker_name, + task=task, + timeout=self._timeout, + ) + task.future.set_result(result) + self.state.tasks_completed += 1 + except UserException as e: + task.future.set_exception(e.inner) + self.state.tasks_failed += 1 + except Exception as e: + if task.retries_remaining > 0: + task.retries_remaining -= 1 + self._task_queue.put(task) + print( + f"Worker {self.state.worker_id} failed, re-queuing task {task.task_id} " + f"({task.retries_remaining} retries left)" + ) + self.state.status = WorkerStatus.FAILED + self.state.current_task_id = None + self._task_queue.task_done() + return + else: + task.future.set_exception(e) + self.state.tasks_failed += 1 + finally: + if self.state.status == WorkerStatus.BUSY: + self.state.status = WorkerStatus.IDLE + self.state.current_task_id = None + self._task_queue.task_done() + + +def _call_worker_endpoint( + endpoint_url: str, + actor_name: str, + task: PendingTask, + timeout: float, +) -> Any: + """Make a direct RPC call to a specific worker endpoint.""" + client = ActorServiceClientSync( + address=endpoint_url, + timeout_ms=int(timeout * 1000), + ) + + call = actor_pb2.ActorCall( + method_name="execute", + actor_name=actor_name, + serialized_args=cloudpickle.dumps( + ( + task.serialized_fn, + task.serialized_args, + task.serialized_kwargs, + ) + ), + serialized_kwargs=cloudpickle.dumps({}), + ) + + resp = client.call(call) + + if resp.HasField("error"): + if resp.error.serialized_exception: + # User exception - wrap it so we know not to retry + raise UserException(cloudpickle.loads(resp.error.serialized_exception)) + raise RuntimeError(f"{resp.error.error_type}: {resp.error.message}") + + return cloudpickle.loads(resp.serialized_value) + + +@dataclass +class WorkerPoolConfig: + """Configuration for a WorkerPool. + + Attributes: + num_workers: Number of worker jobs to launch + resources: Resource requirements per worker + environment: Optional environment configuration + name_prefix: Prefix for worker job names + max_retries: Number of retries for failed tasks (worker failures only) + """ + + num_workers: int + resources: cluster_pb2.ResourceSpec + environment: cluster_pb2.EnvironmentConfig | None = None + name_prefix: str = "worker" + max_retries: int = 0 + + +@dataclass +class WorkerFuture(Generic[T]): + """Future representing an in-flight task. + + Wraps a concurrent.futures.Future with a simpler interface. + ActorClient handles cloudpickle deserialization, so result() returns + the value directly. + """ + + _future: Future + _fn_name: str + + def result(self, timeout: float | None = None) -> T: + """Block until result is available. + + Args: + timeout: Maximum time to wait in seconds + + Returns: + The return value of the submitted callable + + Raises: + TimeoutError: If result not available within timeout + Exception: Any exception raised by the callable + """ + return self._future.result(timeout=timeout) + + def done(self) -> bool: + """Check if the task has completed.""" + return self._future.done() + + def exception(self) -> BaseException | None: + """Get the exception if the task failed, None otherwise.""" + if not self._future.done(): + return None + return self._future.exception() + + +class WorkerPool: + """Pool of stateless workers for task dispatch with idle worker targeting. + + WorkerPool manages a set of worker jobs that execute arbitrary callables. + Each worker registers with a unique name, and tasks are dispatched only + to idle workers. Tasks queue internally when all workers are busy. + + Usage: + with WorkerPool(client, config) as pool: + # Submit single task + future = pool.submit(fn, arg1, arg2) + result = future.result() + + # Map over items + futures = pool.map(fn, items) + results = [f.result() for f in futures] + + # Check pool status + pool.print_status() + """ + + def __init__( + self, + client: ClusterClient, + config: WorkerPoolConfig, + timeout: float = 30.0, + resolver: Resolver | None = None, + ): + """Create a worker pool. + + Args: + client: ClusterClient for launching worker jobs + config: Pool configuration (workers, resources, etc.) + timeout: RPC timeout in seconds for worker calls + resolver: Optional resolver override (for testing) + """ + self._client = client + self._config = config + self._timeout = timeout + self._pool_id = uuid.uuid4().hex[:8] + + # Worker management + self._workers: dict[str, WorkerState] = {} + self._job_ids: list[JobId] = [] + + # Task queue and dispatch + self._task_queue: Queue[PendingTask] = Queue() + self._dispatchers: list[WorkerDispatcher] = [] + self._shutdown = False + + # Resolver for endpoint discovery (injectable for testing) + self._resolver: Resolver | None = resolver + self._namespace = Namespace("") + + def __enter__(self) -> "WorkerPool": + """Start workers and wait for at least one to register.""" + self._launch_workers() + self._wait_for_workers(min_workers=1) + return self + + def __exit__(self, *_): + """Shutdown all workers.""" + self.shutdown(wait=False) + + @property + def pool_id(self) -> str: + """Unique identifier for this pool.""" + return self._pool_id + + @property + def size(self) -> int: + """Number of workers that have registered (IDLE or BUSY).""" + return sum(1 for w in self._workers.values() if w.status in (WorkerStatus.IDLE, WorkerStatus.BUSY)) + + @property + def idle_count(self) -> int: + """Number of idle workers ready for tasks.""" + return sum(1 for w in self._workers.values() if w.status == WorkerStatus.IDLE) + + @property + def job_ids(self) -> list[JobId]: + """List of worker job IDs.""" + return list(self._job_ids) + + def _launch_workers(self) -> None: + """Launch worker jobs and start dispatch threads.""" + # Create resolver for worker discovery if not injected + if self._resolver is None: + self._resolver = ClusterResolver( + self._client.controller_address, + namespace=self._namespace, + ) + + # Initialize worker state and launch jobs + for i in range(self._config.num_workers): + worker_id = f"worker-{i}" + worker_name = f"_workerpool_{self._pool_id}:{worker_id}" + self._workers[worker_id] = WorkerState( + worker_id=worker_id, + worker_name=worker_name, + status=WorkerStatus.PENDING, + ) + + entrypoint = Entrypoint( + callable=worker_job_entrypoint, + args=(self._pool_id, i), + ) + + job_id = self._client.submit( + entrypoint=entrypoint, + name=f"{self._config.name_prefix}-{self._pool_id}-{i}", + resources=self._config.resources, + environment=self._config.environment, + namespace=self._namespace, + ports=["actor"], + ) + self._job_ids.append(job_id) + + # Start dispatchers (one per worker) + for worker_state in self._workers.values(): + dispatcher = WorkerDispatcher( + state=worker_state, + task_queue=self._task_queue, + resolver=self._resolver, + timeout=self._timeout, + ) + dispatcher.start() + self._dispatchers.append(dispatcher) + + def _wait_for_workers( + self, + min_workers: int = 1, + timeout: float = 60.0, + ) -> None: + """Wait for workers to register. + + Args: + min_workers: Minimum number of workers required + timeout: Maximum time to wait in seconds + + Raises: + TimeoutError: If min_workers not available within timeout + """ + start = time.time() + while time.time() - start < timeout: + if self.size >= min_workers: + return + time.sleep(0.5) + + raise TimeoutError(f"Only {self.size} of {min_workers} workers registered within {timeout}s") + + def wait_for_workers( + self, + min_workers: int | None = None, + timeout: float = 60.0, + ) -> None: + """Wait for workers to become available. + + Args: + min_workers: Minimum workers required (default: all workers) + timeout: Maximum time to wait in seconds + """ + if min_workers is None: + min_workers = self._config.num_workers + self._wait_for_workers(min_workers=min_workers, timeout=timeout) + + def submit( + self, + fn: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> WorkerFuture[T]: + """Submit a task for execution. + + Tasks are queued internally and dispatched to idle workers. + Returns immediately with a Future that resolves when the task completes. + + Args: + fn: Callable to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Future that resolves to the function's return value + """ + if self._shutdown: + raise RuntimeError("WorkerPool has been shutdown") + + if not self._workers: + raise RuntimeError("No workers available") + + task = PendingTask( + task_id=uuid.uuid4().hex[:8], + serialized_fn=cloudpickle.dumps(fn), + serialized_args=cloudpickle.dumps(args), + serialized_kwargs=cloudpickle.dumps(kwargs), + future=Future(), + fn_name=getattr(fn, "__name__", "lambda"), + submitted_at=time.monotonic(), + retries_remaining=self._config.max_retries, + ) + + self._task_queue.put(task) + return WorkerFuture(_future=task.future, _fn_name=task.fn_name) + + def map( + self, + fn: Callable[[Any], T], + items: Sequence[Any], + ) -> list[WorkerFuture[T]]: + """Map a function over items in parallel. + + Args: + fn: Function to apply to each item + items: Items to process + + Returns: + List of futures, one per item + """ + return [self.submit(fn, item) for item in items] + + def status(self) -> PoolStatus: + """Get current pool status.""" + workers_by_status = {s: 0 for s in WorkerStatus} + total_completed = 0 + total_failed = 0 + worker_details = [] + + for worker in self._workers.values(): + workers_by_status[worker.status] += 1 + total_completed += worker.tasks_completed + total_failed += worker.tasks_failed + worker_details.append( + { + "worker_id": worker.worker_id, + "worker_name": worker.worker_name, + "status": worker.status.name, + "endpoint_url": worker.endpoint_url, + "current_task_id": worker.current_task_id, + "tasks_completed": worker.tasks_completed, + "tasks_failed": worker.tasks_failed, + } + ) + + return PoolStatus( + pool_id=self._pool_id, + num_workers=len(self._workers), + workers_idle=workers_by_status[WorkerStatus.IDLE], + workers_busy=workers_by_status[WorkerStatus.BUSY], + workers_pending=workers_by_status[WorkerStatus.PENDING], + workers_failed=workers_by_status[WorkerStatus.FAILED], + tasks_queued=self._task_queue.qsize(), + tasks_completed=total_completed, + tasks_failed=total_failed, + worker_details=worker_details, + ) + + def print_status(self) -> None: + """Print current pool status to stdout.""" + s = self.status() + print(f"WorkerPool[{s.pool_id}]") + print( + f" Workers: {s.num_workers} total " + f"({s.workers_idle} idle, {s.workers_busy} busy, " + f"{s.workers_pending} pending, {s.workers_failed} failed)" + ) + print(f" Tasks: {s.tasks_queued} queued, " f"{s.tasks_completed} completed, {s.tasks_failed} failed") + print(" Worker details:") + for w in s.worker_details: + task_info = f", task={w['current_task_id']}" if w["current_task_id"] else "" + print( + f" {w['worker_id']}: {w['status']}{task_info} " + f"(done={w['tasks_completed']}, err={w['tasks_failed']})" + ) + + def shutdown(self, wait: bool = True) -> None: + """Shutdown the worker pool. + + Args: + wait: If True, wait for pending tasks to complete before + terminating workers + """ + self._shutdown = True + + # Stop all dispatchers + for dispatcher in self._dispatchers: + dispatcher.stop() + + if wait: + self._task_queue.join() + for dispatcher in self._dispatchers: + dispatcher.join(timeout=5.0) + + # Terminate worker jobs + for job_id in self._job_ids: + try: + self._client.terminate(job_id) + except Exception: + pass diff --git a/lib/fluster/tests/__init__.py b/lib/fluster/tests/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/tests/actor/__init__.py b/lib/fluster/tests/actor/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/tests/actor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/tests/actor/test_actor_e2e.py b/lib/fluster/tests/actor/test_actor_e2e.py new file mode 100644 index 0000000000..096dafd7e8 --- /dev/null +++ b/lib/fluster/tests/actor/test_actor_e2e.py @@ -0,0 +1,172 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for actor server and client.""" + +import pytest + +from fluster.actor.client import ActorClient +from fluster.actor.resolver import FixedResolver +from fluster.actor.server import ActorServer +from fluster.actor.types import ActorContext, current_ctx + + +class Calculator: + """Test actor with basic arithmetic operations.""" + + def add(self, a: int, b: int) -> int: + return a + b + + def multiply(self, a: int, b: int) -> int: + return a * b + + def divide(self, a: int, b: int) -> float: + return a / b # May raise ZeroDivisionError + + +class ContextAwareActor: + """Test actor that accesses the injected context.""" + + def get_job_id(self) -> str: + return current_ctx().job_id + + +def test_basic_actor_call(): + """Test basic actor method calls work correctly.""" + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + port = server.serve_background() + + resolver = FixedResolver({"calc": f"http://127.0.0.1:{port}"}) + client = ActorClient(resolver, "calc") + assert client.add(2, 3) == 5 + assert client.multiply(4, 5) == 20 + + +def test_actor_exception_propagation(): + """Test that exceptions from actor methods propagate to the client.""" + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + port = server.serve_background() + + resolver = FixedResolver({"calc": f"http://127.0.0.1:{port}"}) + client = ActorClient(resolver, "calc") + with pytest.raises(ZeroDivisionError): + client.divide(1, 0) + + +def test_actor_context_injection(): + """Test that ActorContext is properly injected and accessible.""" + server = ActorServer(host="127.0.0.1") + server.register("ctx_actor", ContextAwareActor()) + + ctx = ActorContext(cluster=None, resolver=None, job_id="test-job-123", namespace="") + port = server.serve_background(context=ctx) + + resolver = FixedResolver({"ctx_actor": f"http://127.0.0.1:{port}"}) + client = ActorClient(resolver, "ctx_actor") + assert client.get_job_id() == "test-job-123" + + +@pytest.mark.asyncio +async def test_list_actors(): + """Test that list_actors returns registered actors.""" + from fluster import actor_pb2 + + server = ActorServer(host="127.0.0.1") + actor_id1 = server.register("calc", Calculator()) + actor_id2 = server.register("ctx", ContextAwareActor()) + server.serve_background() + + request = actor_pb2.ListActorsRequest() + response = await server.list_actors(request, None) + + assert len(response.actors) == 2 + + actor_names = {a.name for a in response.actors} + assert "calc" in actor_names + assert "ctx" in actor_names + + actor_ids = {a.actor_id for a in response.actors} + assert actor_id1 in actor_ids + assert actor_id2 in actor_ids + + for actor in response.actors: + assert actor.registered_at_ms > 0 + + +@pytest.mark.asyncio +async def test_list_methods(): + """Test that list_methods returns method info for an actor.""" + from fluster import actor_pb2 + + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + server.serve_background() + + request = actor_pb2.ListMethodsRequest(actor_name="calc") + response = await server.list_methods(request, None) + + method_names = {m.name for m in response.methods} + assert "add" in method_names + assert "multiply" in method_names + assert "divide" in method_names + + for method in response.methods: + assert method.signature + assert "(" in method.signature + + +@pytest.mark.asyncio +async def test_list_methods_with_docstring(): + """Test that list_methods includes docstrings when present.""" + from fluster import actor_pb2 + + class DocumentedActor: + def documented_method(self) -> str: + """This method has documentation.""" + return "result" + + def undocumented_method(self) -> str: + return "result" + + server = ActorServer(host="127.0.0.1") + server.register("doc", DocumentedActor()) + server.serve_background() + + request = actor_pb2.ListMethodsRequest(actor_name="doc") + response = await server.list_methods(request, None) + + methods_by_name = {m.name: m for m in response.methods} + + assert "documented_method" in methods_by_name + assert "This method has documentation" in methods_by_name["documented_method"].docstring + + assert "undocumented_method" in methods_by_name + assert methods_by_name["undocumented_method"].docstring == "" + + +@pytest.mark.asyncio +async def test_list_methods_missing_actor(): + """Test that list_methods returns empty response for missing actor.""" + from fluster import actor_pb2 + + server = ActorServer(host="127.0.0.1") + server.register("calc", Calculator()) + server.serve_background() + + request = actor_pb2.ListMethodsRequest(actor_name="nonexistent") + response = await server.list_methods(request, None) + + assert len(response.methods) == 0 diff --git a/lib/fluster/tests/actor/test_actor_pool.py b/lib/fluster/tests/actor/test_actor_pool.py new file mode 100644 index 0000000000..48afd6fdbc --- /dev/null +++ b/lib/fluster/tests/actor/test_actor_pool.py @@ -0,0 +1,84 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ActorPool round-robin and broadcast functionality.""" + +from fluster.actor.pool import ActorPool +from fluster.actor.resolver import FixedResolver +from fluster.actor.server import ActorServer + + +class Counter: + """Test actor with stateful counter.""" + + def __init__(self, start: int = 0): + self._value = start + + def get(self) -> int: + """Get current counter value.""" + return self._value + + def increment(self) -> int: + """Increment and return new value.""" + self._value += 1 + return self._value + + +def test_pool_round_robin(): + """Test that pool.call() cycles through endpoints in round-robin fashion.""" + servers = [] + urls = [] + + # Create 3 servers with different starting counters + for i in range(3): + server = ActorServer(host="127.0.0.1") + server.register("counter", Counter(start=i * 100)) + port = server.serve_background() + servers.append(server) + urls.append(f"http://127.0.0.1:{port}") + + resolver = FixedResolver({"counter": urls}) + pool = ActorPool(resolver, "counter") + + assert pool.size == 3 + + # Round-robin should cycle through servers + results = [pool.call().get() for _ in range(6)] + # Should see values from all three servers (0, 100, 200, 0, 100, 200) + assert set(results) == {0, 100, 200} + + +def test_pool_broadcast(): + """Test that pool.broadcast() sends to all endpoints.""" + servers = [] + urls = [] + + # Create 3 servers with different starting counters + for i in range(3): + server = ActorServer(host="127.0.0.1") + server.register("counter", Counter(start=i)) + port = server.serve_background() + servers.append(server) + urls.append(f"http://127.0.0.1:{port}") + + resolver = FixedResolver({"counter": urls}) + pool = ActorPool(resolver, "counter") + + # Broadcast get() to all endpoints + broadcast = pool.broadcast().get() + results = broadcast.wait_all() + + assert len(results) == 3 + assert all(r.success for r in results) + assert {r.value for r in results} == {0, 1, 2} diff --git a/lib/fluster/tests/actor/test_cluster_resolver.py b/lib/fluster/tests/actor/test_cluster_resolver.py new file mode 100644 index 0000000000..a400f217c2 --- /dev/null +++ b/lib/fluster/tests/actor/test_cluster_resolver.py @@ -0,0 +1,322 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ClusterResolver integration with controller.""" + +import socket +import threading +import time + +import pytest +import uvicorn + +from connectrpc.request import RequestContext + +from fluster import cluster_pb2 +from fluster.actor.resolver import ClusterResolver +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.controller.state import ControllerEndpoint, ControllerJob, ControllerState +from fluster.cluster.types import JobId, Namespace +from fluster.cluster_connect import ControllerServiceASGIApplication + + +class AsyncControllerServiceWrapper: + """Async wrapper around synchronous ControllerServiceImpl for testing.""" + + def __init__(self, sync_service: ControllerServiceImpl): + self._service = sync_service + + async def launch_job( + self, request: cluster_pb2.Controller.LaunchJobRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.LaunchJobResponse: + return self._service.launch_job(request, ctx) + + async def get_job_status( + self, request: cluster_pb2.Controller.GetJobStatusRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.GetJobStatusResponse: + return self._service.get_job_status(request, ctx) + + async def terminate_job( + self, request: cluster_pb2.Controller.TerminateJobRequest, ctx: RequestContext + ) -> cluster_pb2.Empty: + return self._service.terminate_job(request, ctx) + + async def list_jobs( + self, request: cluster_pb2.Controller.ListJobsRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.ListJobsResponse: + return self._service.list_jobs(request, ctx) + + async def register_worker( + self, request: cluster_pb2.Controller.RegisterWorkerRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.RegisterWorkerResponse: + return self._service.register_worker(request, ctx) + + async def list_workers( + self, request: cluster_pb2.Controller.ListWorkersRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.ListWorkersResponse: + return self._service.list_workers(request, ctx) + + async def register_endpoint( + self, request: cluster_pb2.Controller.RegisterEndpointRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.RegisterEndpointResponse: + return self._service.register_endpoint(request, ctx) + + async def unregister_endpoint( + self, request: cluster_pb2.Controller.UnregisterEndpointRequest, ctx: RequestContext + ) -> cluster_pb2.Empty: + return self._service.unregister_endpoint(request, ctx) + + async def lookup_endpoint( + self, request: cluster_pb2.Controller.LookupEndpointRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.LookupEndpointResponse: + return self._service.lookup_endpoint(request, ctx) + + async def list_endpoints( + self, request: cluster_pb2.Controller.ListEndpointsRequest, ctx: RequestContext + ) -> cluster_pb2.Controller.ListEndpointsResponse: + return self._service.list_endpoints(request, ctx) + + +class MockSchedulerWake: + """Mock object for scheduler wake interface.""" + + def wake(self): + pass + + +def create_controller_app(state: ControllerState) -> ControllerServiceASGIApplication: + """Create a minimal controller app with ListEndpoints handler.""" + mock_scheduler = MockSchedulerWake() + service = ControllerServiceImpl(state, mock_scheduler) + async_service = AsyncControllerServiceWrapper(service) + + return ControllerServiceASGIApplication(service=async_service) + + +@pytest.fixture +def controller_with_endpoint(): + """Start a controller with a registered endpoint.""" + state = ControllerState() + + # Add a running job + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + # Add an endpoint + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="inference", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Find free port + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + app = create_controller_app(state) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + for _ in range(50): + if server.started: + break + time.sleep(0.1) + + yield f"http://127.0.0.1:{port}", state + + +def test_cluster_resolver_finds_endpoint(controller_with_endpoint): + """Test that ClusterResolver successfully resolves a registered endpoint.""" + address, _state = controller_with_endpoint + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert "10.0.0.1:8080" in result.first().url + assert result.first().actor_id == "ep-1" + + +def test_cluster_resolver_missing_endpoint(controller_with_endpoint): + """Test that ClusterResolver returns empty result for non-existent actor.""" + address, _state = controller_with_endpoint + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("nonexistent") + + assert result.is_empty + + +def test_cluster_resolver_multiple_endpoints(controller_with_endpoint): + """Test that ClusterResolver returns all matching endpoints.""" + address, state = controller_with_endpoint + + # Add another job and endpoint with the same name + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job2) + + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="", + ) + state.add_endpoint(ep2) + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 2 + addresses = {ep.url for ep in result.endpoints} + assert "http://10.0.0.1:8080" in addresses + assert "http://10.0.0.2:8080" in addresses + + +def test_cluster_resolver_namespace_isolation(controller_with_endpoint): + """Test that ClusterResolver respects namespace boundaries.""" + address, state = controller_with_endpoint + + # Add endpoint in different namespace + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job2) + + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="other-namespace", + ) + state.add_endpoint(ep2) + + # Resolve in namespace should only find ep-1 + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert result.first().url == "http://10.0.0.1:8080" + + # Resolve in other-namespace should only find ep-2 + result_other = resolver.resolve("inference", namespace=Namespace("other-namespace")) + assert len(result_other.endpoints) == 1 + assert result_other.first().url == "http://10.0.0.2:8080" + + +def test_cluster_resolver_filters_exact_name_match(controller_with_endpoint): + """Test that ClusterResolver filters to exact name matches despite prefix API.""" + address, state = controller_with_endpoint + + # Add endpoint with similar but different name + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job2) + + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference-v2", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="", + ) + state.add_endpoint(ep2) + + # Resolve "inference" should not return "inference-v2" + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert result.first().url == "http://10.0.0.1:8080" + + +def test_cluster_resolver_only_running_jobs(controller_with_endpoint): + """Test that ClusterResolver only returns endpoints for RUNNING jobs.""" + address, state = controller_with_endpoint + + # Add a completed job with endpoint + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_SUCCEEDED, + ) + state.add_job(job2) + + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="", + ) + state.add_endpoint(ep2) + + # Should only find the running job's endpoint + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert result.first().url == "http://10.0.0.1:8080" + + +def test_cluster_resolver_metadata(controller_with_endpoint): + """Test that ClusterResolver preserves endpoint metadata.""" + address, state = controller_with_endpoint + + # Add endpoint with metadata + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job2) + + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="tagged-actor", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="", + metadata={"model": "gpt-4", "version": "1.0"}, + ) + state.add_endpoint(ep2) + + resolver = ClusterResolver(address, namespace=Namespace("")) + result = resolver.resolve("tagged-actor") + + assert len(result.endpoints) == 1 + assert result.first().metadata["model"] == "gpt-4" + assert result.first().metadata["version"] == "1.0" diff --git a/lib/fluster/tests/actor/test_gcs_resolver.py b/lib/fluster/tests/actor/test_gcs_resolver.py new file mode 100644 index 0000000000..dbe6845054 --- /dev/null +++ b/lib/fluster/tests/actor/test_gcs_resolver.py @@ -0,0 +1,205 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GcsResolver.""" + +from fluster.actor.resolver import GcsResolver, MockGcsApi +from fluster.cluster.types import Namespace + + +def test_gcs_resolver_finds_actors(): + """Test that GcsResolver finds actors via metadata tags.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 1 + assert "10.0.0.1:8080" in result.first().url + assert result.first().actor_id == "gcs-worker-1-inference" + assert result.first().metadata == {"instance": "worker-1"} + + +def test_gcs_resolver_filters_namespace(): + """Test that GcsResolver filters by namespace.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "other-ns", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", namespace=Namespace(""), api=api) + result = resolver.resolve("inference") + + assert result.is_empty + + +def test_gcs_resolver_ignores_non_running(): + """Test that GcsResolver only considers RUNNING instances.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "TERMINATED", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert result.is_empty + + +def test_gcs_resolver_multiple_instances(): + """Test that GcsResolver finds actors across multiple instances.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + { + "name": "worker-2", + "internal_ip": "10.0.0.2", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert len(result.endpoints) == 2 + urls = {ep.url for ep in result.endpoints} + assert "http://10.0.0.1:8080" in urls + assert "http://10.0.0.2:8080" in urls + + +def test_gcs_resolver_no_matching_actor(): + """Test that GcsResolver returns empty when no actor matches.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_training": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert result.is_empty + + +def test_gcs_resolver_missing_internal_ip(): + """Test that GcsResolver skips instances without internal IP.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": None, + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", api=api) + result = resolver.resolve("inference") + + assert result.is_empty + + +def test_gcs_resolver_default_namespace(): + """Test that GcsResolver uses default namespace correctly.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", namespace=Namespace(""), api=api) + + assert resolver.default_namespace == Namespace("") + + +def test_gcs_resolver_custom_namespace(): + """Test that GcsResolver can override namespace.""" + api = MockGcsApi( + [ + { + "name": "worker-1", + "internal_ip": "10.0.0.1", + "status": "RUNNING", + "metadata": { + "fluster_namespace": "custom-ns", + "fluster_actor_inference": "8080", + }, + }, + ] + ) + resolver = GcsResolver("project", "zone", namespace=Namespace(""), api=api) + + # Should not find with default namespace + result = resolver.resolve("inference") + assert result.is_empty + + # Should find with custom namespace + result = resolver.resolve("inference", namespace=Namespace("custom-ns")) + assert len(result.endpoints) == 1 diff --git a/lib/fluster/tests/actor/test_resolver.py b/lib/fluster/tests/actor/test_resolver.py new file mode 100644 index 0000000000..8baf3d50a6 --- /dev/null +++ b/lib/fluster/tests/actor/test_resolver.py @@ -0,0 +1,54 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for resolver functionality.""" + +from fluster.actor.client import ActorClient +from fluster.actor.resolver import FixedResolver +from fluster.actor.server import ActorServer + + +class Echo: + def echo(self, msg: str) -> str: + return f"echo: {msg}" + + +def test_fixed_resolver_single(): + resolver = FixedResolver({"svc": "http://localhost:8080"}) + result = resolver.resolve("svc") + assert len(result.endpoints) == 1 + assert result.first().url == "http://localhost:8080" + + +def test_fixed_resolver_multiple(): + resolver = FixedResolver({"svc": ["http://h1:8080", "http://h2:8080"]}) + result = resolver.resolve("svc") + assert len(result.endpoints) == 2 + + +def test_fixed_resolver_missing(): + resolver = FixedResolver({}) + result = resolver.resolve("missing") + assert result.is_empty + + +def test_client_with_resolver(): + server = ActorServer(host="127.0.0.1") + server.register("echo", Echo()) + port = server.serve_background() + + resolver = FixedResolver({"echo": f"http://127.0.0.1:{port}"}) + client = ActorClient(resolver, "echo") + + assert client.echo("hello") == "echo: hello" diff --git a/lib/fluster/tests/cluster/__init__.py b/lib/fluster/tests/cluster/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/tests/cluster/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/tests/cluster/controller/__init__.py b/lib/fluster/tests/cluster/controller/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/tests/cluster/controller/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/tests/cluster/controller/test_dashboard.py b/lib/fluster/tests/cluster/controller/test_dashboard.py new file mode 100644 index 0000000000..5d0b2c1bee --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_dashboard.py @@ -0,0 +1,209 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for controller dashboard behavioral logic.""" + +from unittest.mock import Mock + +import pytest +from starlette.testclient import TestClient + +from fluster import cluster_pb2 +from fluster.cluster.controller.dashboard import ControllerDashboard +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.controller.state import ( + ControllerEndpoint, + ControllerJob, + ControllerState, + ControllerWorker, +) +from fluster.cluster.types import JobId, WorkerId + + +@pytest.fixture +def state(): + return ControllerState() + + +@pytest.fixture +def service(state): + scheduler = Mock() + scheduler.wake = Mock() + return ControllerServiceImpl(state, scheduler) + + +@pytest.fixture +def client(service): + dashboard = ControllerDashboard(service) + return TestClient(dashboard._app) + + +@pytest.fixture +def job_request(): + return cluster_pb2.Controller.LaunchJobRequest( + name="test-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=2, memory="4g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + +@pytest.fixture +def resource_spec(): + return cluster_pb2.ResourceSpec(cpu=4, memory="8g", disk="100g") + + +def test_stats_counts_building_separately_from_running(client, state, job_request): + """Building jobs should be counted separately, not as running or pending.""" + state.add_job(ControllerJob(job_id=JobId("pending"), request=job_request, state=cluster_pb2.JOB_STATE_PENDING)) + state.add_job(ControllerJob(job_id=JobId("building"), request=job_request, state=cluster_pb2.JOB_STATE_BUILDING)) + state.add_job(ControllerJob(job_id=JobId("running"), request=job_request, state=cluster_pb2.JOB_STATE_RUNNING)) + + stats = client.get("/api/stats").json() + + assert stats["jobs_pending"] == 1 + assert stats["jobs_building"] == 1 + assert stats["jobs_running"] == 1 + + +def test_stats_groups_terminal_states_as_completed(client, state, job_request): + """Succeeded, failed, killed, and worker_failed all count as completed.""" + for job_state in [ + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + cluster_pb2.JOB_STATE_WORKER_FAILED, + ]: + state.add_job(ControllerJob(job_id=JobId(f"job-{job_state}"), request=job_request, state=job_state)) + + stats = client.get("/api/stats").json() + + assert stats["jobs_completed"] == 4 + assert stats["jobs_pending"] == 0 + assert stats["jobs_running"] == 0 + + +def test_stats_counts_only_healthy_workers(client, state, resource_spec): + """Healthy worker count excludes unhealthy workers.""" + state.add_worker( + ControllerWorker(worker_id=WorkerId("healthy1"), address="h1:8080", resources=resource_spec, healthy=True) + ) + state.add_worker( + ControllerWorker(worker_id=WorkerId("healthy2"), address="h2:8080", resources=resource_spec, healthy=True) + ) + state.add_worker( + ControllerWorker(worker_id=WorkerId("unhealthy"), address="h3:8080", resources=resource_spec, healthy=False) + ) + + stats = client.get("/api/stats").json() + + assert stats["workers_healthy"] == 2 + assert stats["workers_total"] == 3 + + +def test_stats_counts_endpoints_for_running_jobs_only(client, state, job_request): + """Endpoint count only includes endpoints for RUNNING jobs.""" + # Running job with endpoint + state.add_job(ControllerJob(job_id=JobId("running"), request=job_request, state=cluster_pb2.JOB_STATE_RUNNING)) + state.add_endpoint( + ControllerEndpoint( + endpoint_id="ep1", name="svc", address="host:80", job_id=JobId("running"), namespace="default" + ) + ) + + # Pending job with endpoint (should not count) + state.add_job(ControllerJob(job_id=JobId("pending"), request=job_request, state=cluster_pb2.JOB_STATE_PENDING)) + state.add_endpoint( + ControllerEndpoint( + endpoint_id="ep2", name="svc2", address="host:81", job_id=JobId("pending"), namespace="default" + ) + ) + + stats = client.get("/api/stats").json() + + assert stats["endpoints_count"] == 1 + + +def test_endpoints_only_returned_for_running_jobs(client, state, job_request): + """ListEndpoints filters out endpoints for non-running jobs.""" + # Create jobs in various states + state.add_job(ControllerJob(job_id=JobId("pending"), request=job_request, state=cluster_pb2.JOB_STATE_PENDING)) + state.add_job(ControllerJob(job_id=JobId("running"), request=job_request, state=cluster_pb2.JOB_STATE_RUNNING)) + state.add_job(ControllerJob(job_id=JobId("succeeded"), request=job_request, state=cluster_pb2.JOB_STATE_SUCCEEDED)) + + # Add endpoints for each + state.add_endpoint( + ControllerEndpoint(endpoint_id="ep1", name="pending-svc", address="h:1", job_id=JobId("pending"), namespace="") + ) + state.add_endpoint( + ControllerEndpoint(endpoint_id="ep2", name="running-svc", address="h:2", job_id=JobId("running"), namespace="") + ) + state.add_endpoint( + ControllerEndpoint(endpoint_id="ep3", name="done-svc", address="h:3", job_id=JobId("succeeded"), namespace="") + ) + + endpoints = client.get("/api/endpoints").json() + + assert len(endpoints) == 1 + assert endpoints[0]["name"] == "running-svc" + + +def test_job_detail_page_includes_worker_address(client, state, job_request, resource_spec): + """Job detail page injects worker address for client-side fetch.""" + state.add_worker( + ControllerWorker(worker_id=WorkerId("w1"), address="worker-host:9000", resources=resource_spec, healthy=True) + ) + state.add_job( + ControllerJob( + job_id=JobId("j1"), request=job_request, state=cluster_pb2.JOB_STATE_RUNNING, worker_id=WorkerId("w1") + ) + ) + + response = client.get("/job/j1") + + assert response.status_code == 200 + assert "worker-host:9000" in response.text + + +def test_job_detail_page_empty_worker_for_pending_job(client, state, job_request): + """Job detail page has empty worker address for unassigned jobs.""" + state.add_job(ControllerJob(job_id=JobId("pending-job"), request=job_request, state=cluster_pb2.JOB_STATE_PENDING)) + + response = client.get("/job/pending-job") + + assert response.status_code == 200 + # Worker address placeholder should be empty + assert "const workerAddress = '';" in response.text + + +def test_jobs_state_names_mapped_correctly(client, state, job_request): + """Proto state enums map to expected string names.""" + state_mapping = [ + (cluster_pb2.JOB_STATE_PENDING, "pending"), + (cluster_pb2.JOB_STATE_BUILDING, "building"), + (cluster_pb2.JOB_STATE_RUNNING, "running"), + (cluster_pb2.JOB_STATE_SUCCEEDED, "succeeded"), + (cluster_pb2.JOB_STATE_FAILED, "failed"), + (cluster_pb2.JOB_STATE_KILLED, "killed"), + (cluster_pb2.JOB_STATE_WORKER_FAILED, "worker_failed"), + ] + + for proto_state, _ in state_mapping: + state.add_job(ControllerJob(job_id=JobId(f"j-{proto_state}"), request=job_request, state=proto_state)) + + jobs = client.get("/api/jobs").json() + job_by_id = {j["job_id"]: j["state"] for j in jobs} + + for proto_state, expected_name in state_mapping: + assert job_by_id[f"j-{proto_state}"] == expected_name diff --git a/lib/fluster/tests/cluster/controller/test_endpoint_registry.py b/lib/fluster/tests/cluster/controller/test_endpoint_registry.py new file mode 100644 index 0000000000..ba9bfa1119 --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_endpoint_registry.py @@ -0,0 +1,299 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for endpoint registry in controller state.""" + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerEndpoint, ControllerJob, ControllerState +from fluster.cluster.types import JobId + + +@pytest.fixture +def state() -> ControllerState: + return ControllerState() + + +def test_add_and_lookup_endpoint(state: ControllerState): + """Test basic endpoint registration and lookup.""" + # Create a running job first + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + # Register endpoint + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Lookup + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 1 + assert results[0].address == "10.0.0.1:8080" + assert results[0].endpoint_id == "ep-1" + + +def test_endpoint_not_returned_for_non_running_job(state: ControllerState): + """Test that endpoints for non-RUNNING jobs are filtered out.""" + # Create a completed job + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_SUCCEEDED, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Should not return endpoint because job is not running + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 + + +def test_remove_endpoints_on_job_termination(state: ControllerState): + """Test that endpoints are removed when a job terminates.""" + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Verify endpoint is visible + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 1 + + # Simulate job termination + removed = state.remove_endpoints_for_job(JobId("job-1")) + assert len(removed) == 1 + assert removed[0].endpoint_id == "ep-1" + + # Endpoint should be gone + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 + + +def test_lookup_filters_by_namespace(state: ControllerState): + """Test that lookup respects namespace boundaries.""" + job1 = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test1"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job1) + state.add_job(job2) + + # Same name, different namespaces + ep1 = ControllerEndpoint( + endpoint_id="ep-1", + name="actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="ns-1", + ) + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="actor", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="ns-2", + ) + state.add_endpoint(ep1) + state.add_endpoint(ep2) + + # Each namespace should only see its own endpoint + results_ns1 = state.lookup_endpoints("actor", "ns-1") + assert len(results_ns1) == 1 + assert results_ns1[0].address == "10.0.0.1:8080" + + results_ns2 = state.lookup_endpoints("actor", "ns-2") + assert len(results_ns2) == 1 + assert results_ns2[0].address == "10.0.0.2:8080" + + +def test_list_endpoints_by_prefix(state: ControllerState): + """Test prefix-based endpoint listing.""" + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + # Register multiple endpoints with shared prefix + ep1 = ControllerEndpoint( + endpoint_id="ep-1", + name="inference/model-a", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference/model-b", + address="10.0.0.2:8080", + job_id=JobId("job-1"), + namespace="", + ) + ep3 = ControllerEndpoint( + endpoint_id="ep-3", + name="training/main", + address="10.0.0.3:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep1) + state.add_endpoint(ep2) + state.add_endpoint(ep3) + + # Lookup with prefix + results = state.list_endpoints_by_prefix("inference/", "") + assert len(results) == 2 + names = {r.name for r in results} + assert names == {"inference/model-a", "inference/model-b"} + + results_training = state.list_endpoints_by_prefix("training/", "") + assert len(results_training) == 1 + assert results_training[0].name == "training/main" + + +def test_multiple_endpoints_for_same_name(state: ControllerState): + """Test that multiple endpoints can be registered for the same name.""" + job1 = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test1"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + job2 = ControllerJob( + job_id=JobId("job-2"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job1) + state.add_job(job2) + + # Register multiple endpoints with same name (for load balancing) + ep1 = ControllerEndpoint( + endpoint_id="ep-1", + name="inference", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + ep2 = ControllerEndpoint( + endpoint_id="ep-2", + name="inference", + address="10.0.0.2:8080", + job_id=JobId("job-2"), + namespace="", + ) + state.add_endpoint(ep1) + state.add_endpoint(ep2) + + results = state.lookup_endpoints("inference", "") + assert len(results) == 2 + addresses = {r.address for r in results} + assert addresses == {"10.0.0.1:8080", "10.0.0.2:8080"} + + +def test_remove_endpoint_by_id(state: ControllerState): + """Test explicit endpoint removal by ID.""" + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Remove by ID + removed = state.remove_endpoint("ep-1") + assert removed is not None + assert removed.endpoint_id == "ep-1" + + # Should no longer be found + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 + + # Removing again should be idempotent + removed_again = state.remove_endpoint("ep-1") + assert removed_again is None + + +def test_pending_job_endpoints_not_returned(state: ControllerState): + """Test that endpoints for PENDING jobs are not returned.""" + job = ControllerJob( + job_id=JobId("job-1"), + request=cluster_pb2.Controller.LaunchJobRequest(name="test"), + state=cluster_pb2.JOB_STATE_PENDING, + ) + state.add_job(job) + + ep = ControllerEndpoint( + endpoint_id="ep-1", + name="my-actor", + address="10.0.0.1:8080", + job_id=JobId("job-1"), + namespace="", + ) + state.add_endpoint(ep) + + # Should not return because job is pending + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 0 + + # Transition to running + job.state = cluster_pb2.JOB_STATE_RUNNING + + # Now should be visible + results = state.lookup_endpoints("my-actor", "") + assert len(results) == 1 diff --git a/lib/fluster/tests/cluster/controller/test_integration.py b/lib/fluster/tests/cluster/controller/test_integration.py new file mode 100644 index 0000000000..34b1da785e --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_integration.py @@ -0,0 +1,394 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for the Controller. + +These tests exercise the complete job lifecycle through the Controller, +using a mock WorkerStubFactory to simulate worker responses. + +Tests call Controller methods directly (no background threads) for +deterministic, synchronous testing. +""" + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.controller import Controller, ControllerConfig +from fluster.cluster.controller.retry import handle_job_failure +from fluster.cluster.controller.workers import WorkerConfig, load_workers_from_config +from fluster.cluster.types import JobId + + +class MockWorkerStub: + """Mock worker stub that returns configured responses.""" + + def __init__(self): + self.job_statuses: dict[str, cluster_pb2.JobStatus] = {} + self.run_job_calls: list[cluster_pb2.Worker.RunJobRequest] = [] + self.healthy = True + + def run_job(self, request: cluster_pb2.Worker.RunJobRequest) -> cluster_pb2.Worker.RunJobResponse: + self.run_job_calls.append(request) + return cluster_pb2.Worker.RunJobResponse(job_id=request.job_id, state=cluster_pb2.JOB_STATE_RUNNING) + + def get_job_status(self, request: cluster_pb2.Worker.GetJobStatusRequest) -> cluster_pb2.JobStatus: + return self.job_statuses.get(request.job_id, cluster_pb2.JobStatus()) + + def list_jobs(self, request: cluster_pb2.Worker.ListJobsRequest) -> cluster_pb2.Worker.ListJobsResponse: + return cluster_pb2.Worker.ListJobsResponse(jobs=list(self.job_statuses.values())) + + def health_check(self, request: cluster_pb2.Empty) -> cluster_pb2.Worker.HealthResponse: + if not self.healthy: + raise ConnectionError("Worker unavailable") + return cluster_pb2.Worker.HealthResponse(healthy=True) + + def set_job_completed(self, job_id: str, state: int, exit_code: int = 0, error: str = ""): + self.job_statuses[job_id] = cluster_pb2.JobStatus( + job_id=job_id, + state=state, + exit_code=exit_code, + error=error, + finished_at_ms=1000, + ) + + +class MockWorkerStubFactory: + """Factory that returns mock stubs for testing.""" + + def __init__(self): + self.stubs: dict[str, MockWorkerStub] = {} + + def get_stub(self, address: str) -> MockWorkerStub: + if address not in self.stubs: + self.stubs[address] = MockWorkerStub() + return self.stubs[address] + + def get_stub_for_worker(self, worker_id: str) -> MockWorkerStub: + """Helper to get stub by worker_id (assumes host{N}:8080 format).""" + address = f"host{worker_id[1:]}:8080" if worker_id.startswith("w") else f"{worker_id}:8080" + return self.get_stub(address) + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make(name: str = "test-job", cpu: int = 1) -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=cpu, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + return _make + + +@pytest.fixture +def make_resource_spec(): + """Create a minimal ResourceSpec for testing.""" + + def _make() -> cluster_pb2.ResourceSpec: + return cluster_pb2.ResourceSpec(cpu=4, memory="8g", disk="10g") + + return _make + + +@pytest.fixture +def controller_with_workers(make_resource_spec): + """Create a Controller with mock stub factory and register workers.""" + + def _make(worker_ids: list[str], resources: cluster_pb2.ResourceSpec | None = None): + stub_factory = MockWorkerStubFactory() + config = ControllerConfig(port=0) + controller = Controller(config, stub_factory) + + res = resources or make_resource_spec() + workers = [WorkerConfig(wid, f"host{wid[1:]}:8080", res) for wid in worker_ids] + load_workers_from_config(controller.state, workers) + + return controller, stub_factory + + return _make + + +def test_full_job_lifecycle(make_job_request, controller_with_workers): + """Integration test: full job lifecycle from submission to completion.""" + controller, stub_factory = controller_with_workers(["w1"]) + + # Submit job + response = controller.launch_job(make_job_request("test-job")) + job_id = response.job_id + + # Verify job is PENDING + status = controller.get_job_status(job_id) + assert status.job.state == cluster_pb2.JOB_STATE_PENDING + + # Run scheduling + controller._run_scheduling() + + # Verify job is RUNNING + status = controller.get_job_status(job_id) + assert status.job.state == cluster_pb2.JOB_STATE_RUNNING + + # Verify dispatch RPC was called + stub = stub_factory.get_stub("host1:8080") + assert len(stub.run_job_calls) == 1 + assert stub.run_job_calls[0].job_id == job_id + + # Configure mock to report job succeeded + stub.set_job_completed(job_id, cluster_pb2.JOB_STATE_SUCCEEDED) + + # Run heartbeat + controller._run_heartbeats() + + # Verify job succeeded + status = controller.get_job_status(job_id) + assert status.job.state == cluster_pb2.JOB_STATE_SUCCEEDED + assert status.job.exit_code == 0 + + +def test_job_failure_and_retry(make_job_request, controller_with_workers): + """Job fails on first attempt, succeeds after retry.""" + controller, stub_factory = controller_with_workers(["w1"]) + + # Submit job with retries enabled + response = controller.launch_job(make_job_request("test-job")) + job_id = response.job_id + + job = controller.state.get_job(JobId(job_id)) + job.max_retries_failure = 1 + + # Run scheduling - job starts + controller._run_scheduling() + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + + # Configure mock to report failure + stub = stub_factory.get_stub("host1:8080") + stub.set_job_completed(job_id, cluster_pb2.JOB_STATE_FAILED, exit_code=1, error="Simulated failure") + + # Run heartbeat - picks up failure + controller._run_heartbeats() + + # Job should be FAILED + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_FAILED + + # Trigger retry + handle_job_failure(controller.state, JobId(job_id), is_worker_failure=False) + + # Job should be back to PENDING + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_PENDING + assert job.failure_count == 1 + + # Clear the failure status and run scheduling again + stub.job_statuses.clear() + controller._run_scheduling() + + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + + # This time configure success + stub.set_job_completed(job_id, cluster_pb2.JOB_STATE_SUCCEEDED) + controller._run_heartbeats() + + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_SUCCEEDED + assert job.failure_count == 1 # Retained from retry + + +def test_worker_failure_triggers_retry(make_job_request, controller_with_workers): + """Worker dies, job is retried on another worker.""" + controller, stub_factory = controller_with_workers(["w1", "w2"]) + + # Submit job + response = controller.launch_job(make_job_request("test-job")) + job_id = response.job_id + + job = controller.state.get_job(JobId(job_id)) + job.max_retries_preemption = 10 + + # Run scheduling - job goes to w1 + controller._run_scheduling() + + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + assert job.worker_id == "w1" + + # Make worker 1 unhealthy (heartbeat fails) + stub1 = stub_factory.get_stub("host1:8080") + stub1.healthy = False + + # Run heartbeats 3 times to trigger failure threshold + for _ in range(3): + controller._run_heartbeats() + + # Worker should be marked unhealthy, job should be pending for retry + worker1 = controller.state.get_worker("w1") + assert worker1.healthy is False + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_PENDING + assert job.preemption_count == 1 + + # Run scheduling again - job goes to w2 + controller._run_scheduling() + + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + assert job.worker_id == "w2" + + # Worker 2 reports success + stub2 = stub_factory.get_stub("host2:8080") + stub2.set_job_completed(job_id, cluster_pb2.JOB_STATE_SUCCEEDED) + controller._run_heartbeats() + + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_SUCCEEDED + assert job.preemption_count == 1 + + +def test_multiple_jobs_scheduled(make_job_request, controller_with_workers): + """Submit multiple jobs, all complete successfully.""" + controller, stub_factory = controller_with_workers( + ["w1"], + resources=cluster_pb2.ResourceSpec(cpu=10, memory="8g"), + ) + + # Submit multiple jobs + job_ids = [] + for i in range(5): + response = controller.launch_job(make_job_request(f"job-{i}")) + job_ids.append(response.job_id) + + # Run scheduling - all jobs should be assigned + controller._run_scheduling() + + # Verify all jobs are running + for job_id in job_ids: + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + + # Configure mock to report all jobs succeeded + stub = stub_factory.get_stub("host1:8080") + for job_id in job_ids: + stub.set_job_completed(job_id, cluster_pb2.JOB_STATE_SUCCEEDED) + + controller._run_heartbeats() + + # Verify all jobs succeeded + for job_id in job_ids: + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_SUCCEEDED + + +def test_job_terminated_during_execution(make_job_request, controller_with_workers): + """Terminate a running job mid-execution.""" + controller, _ = controller_with_workers(["w1"]) + + # Submit job + response = controller.launch_job(make_job_request("test-job")) + job_id = response.job_id + + # Run scheduling + controller._run_scheduling() + assert controller.get_job_status(job_id).job.state == cluster_pb2.JOB_STATE_RUNNING + + # Terminate the job + controller.terminate_job(job_id) + + # Verify job is killed + status = controller.get_job_status(job_id) + assert status.job.state == cluster_pb2.JOB_STATE_KILLED + + +def test_concurrent_job_execution_on_multiple_workers(make_job_request, controller_with_workers): + """Multiple workers can run jobs concurrently.""" + controller, stub_factory = controller_with_workers( + ["w1", "w2", "w3"], + resources=cluster_pb2.ResourceSpec(cpu=2, memory="8g"), + ) + + # Submit 6 jobs, each needing 2 CPUs + job_ids = [] + for i in range(6): + response = controller.launch_job(make_job_request(f"job-{i}", cpu=2)) + job_ids.append(response.job_id) + + # Run scheduling - each worker gets 1 job (2 CPU each) + controller._run_scheduling() + + # Count running jobs - should be 3 (one per worker) + running = [jid for jid in job_ids if controller.get_job_status(jid).job.state == cluster_pb2.JOB_STATE_RUNNING] + assert len(running) == 3 + + # Complete all running jobs + for worker_id in ["w1", "w2", "w3"]: + stub = stub_factory.get_stub(f"host{worker_id[1:]}:8080") + worker = controller.state.get_worker(worker_id) + for job_id in list(worker.running_jobs): + stub.set_job_completed(str(job_id), cluster_pb2.JOB_STATE_SUCCEEDED) + + controller._run_heartbeats() + + # Run scheduling again to pick up remaining jobs + controller._run_scheduling() + + # Remaining 3 jobs should now be running + running = [jid for jid in job_ids if controller.get_job_status(jid).job.state == cluster_pb2.JOB_STATE_RUNNING] + assert len(running) == 3 + + +def test_scheduler_respects_resource_limits(make_job_request, controller_with_workers): + """Scheduler doesn't over-commit worker resources.""" + controller, _ = controller_with_workers( + ["w1"], + resources=cluster_pb2.ResourceSpec(cpu=4, memory="8g"), + ) + + # Submit 3 jobs needing 2 CPUs each (total 6 CPUs, but only 4 available) + job_ids = [] + for i in range(3): + response = controller.launch_job(make_job_request(f"job-{i}", cpu=2)) + job_ids.append(response.job_id) + + # Run scheduling + controller._run_scheduling() + + # Only 2 jobs should be running (4 CPUs / 2 CPUs per job = 2 jobs) + running = [jid for jid in job_ids if controller.get_job_status(jid).job.state == cluster_pb2.JOB_STATE_RUNNING] + pending = [jid for jid in job_ids if controller.get_job_status(jid).job.state == cluster_pb2.JOB_STATE_PENDING] + + assert len(running) == 2 + assert len(pending) == 1 + + +def test_ports_forwarded_from_launch_to_run_request(controller_with_workers): + """Verify port names are forwarded from LaunchJobRequest to RunJobRequest.""" + controller, stub_factory = controller_with_workers(["w1"]) + + # Submit job with port requests + request = cluster_pb2.Controller.LaunchJobRequest( + name="port-test-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ports=["http", "grpc", "actor"], + ) + response = controller.launch_job(request) + job_id = response.job_id + + # Run scheduling + controller._run_scheduling() + + # Verify job is running + status = controller.get_job_status(job_id) + assert status.job.state == cluster_pb2.JOB_STATE_RUNNING + + # Verify the RunJobRequest sent to worker includes ports + stub = stub_factory.get_stub("host1:8080") + assert len(stub.run_job_calls) == 1 + + run_request = stub.run_job_calls[0] + assert list(run_request.ports) == ["http", "grpc", "actor"] diff --git a/lib/fluster/tests/cluster/controller/test_resources.py b/lib/fluster/tests/cluster/controller/test_resources.py new file mode 100644 index 0000000000..562221c11a --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_resources.py @@ -0,0 +1,58 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for resource parsing and comparison utilities.""" + +import pytest + +from fluster.cluster.controller.resources import ( + parse_memory_string, +) + + +@pytest.mark.parametrize( + "memory_str,expected_bytes", + [ + ("1g", 1024**3), + ("8g", 8 * 1024**3), + ("16gb", 16 * 1024**3), + ("512m", 512 * 1024**2), + ("1024mb", 1024 * 1024**2), + ("1024k", 1024 * 1024), + ("1024kb", 1024 * 1024), + ("1024b", 1024), + ("1024", 1024), # No unit defaults to bytes + ("", 0), + ("0g", 0), + ], +) +def test_parse_memory_string(memory_str, expected_bytes): + assert parse_memory_string(memory_str) == expected_bytes + + +def test_parse_memory_string_case_insensitive(): + assert parse_memory_string("8G") == parse_memory_string("8g") + assert parse_memory_string("16GB") == parse_memory_string("16gb") + assert parse_memory_string("512M") == parse_memory_string("512m") + + +def test_parse_memory_string_with_whitespace(): + assert parse_memory_string(" 8g ") == 8 * 1024**3 + + +def test_parse_memory_string_invalid(): + with pytest.raises(ValueError): + parse_memory_string("invalid") + with pytest.raises(ValueError): + parse_memory_string("8x") diff --git a/lib/fluster/tests/cluster/controller/test_retry.py b/lib/fluster/tests/cluster/controller/test_retry.py new file mode 100644 index 0000000000..280e87d21f --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_retry.py @@ -0,0 +1,406 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for job failure and retry logic.""" + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.retry import handle_gang_failure, handle_job_failure +from fluster.cluster.controller.state import ControllerJob, ControllerState +from fluster.cluster.types import JobId + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make(name: str = "test-job") -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + return _make + + +def test_job_retry_on_worker_failure(make_job_request): + """Worker failure increments preemption_count and retries if under limit.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + max_retries_preemption=2, + ) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + state.add_job(job) + + # First worker failure - should retry + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.state == cluster_pb2.JOB_STATE_PENDING + assert job.preemption_count == 1 + assert job.failure_count == 0 # Should not increment job failure count + + +def test_job_retry_on_job_failure(make_job_request): + """Job failure increments failure_count and retries if under limit.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + max_retries_failure=2, + ) + job.state = cluster_pb2.JOB_STATE_FAILED + state.add_job(job) + + # First job failure - should retry + assert handle_job_failure(state, JobId("j1"), is_worker_failure=False) + assert job.state == cluster_pb2.JOB_STATE_PENDING + assert job.failure_count == 1 + assert job.preemption_count == 0 # Should not increment preemption count + + +def test_job_exceeds_worker_failure_limit(make_job_request): + """Stop retrying when preemption_count exceeds max_retries_preemption.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + max_retries_preemption=2, + ) + state.add_job(job) + + # First failure - should retry (count 1 <= limit 2) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == 1 + + # Second failure - should retry (count 2 <= limit 2) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == 2 + + # Third failure - should NOT retry (count 3 > limit 2) + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert not handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == 3 + assert job.state == cluster_pb2.JOB_STATE_WORKER_FAILED # State unchanged + + +def test_job_exceeds_job_failure_limit(make_job_request): + """Stop retrying when failure_count exceeds max_retries_failure.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + max_retries_failure=1, + ) + state.add_job(job) + + # First failure - should retry (count 1 <= limit 1) + job.state = cluster_pb2.JOB_STATE_FAILED + assert handle_job_failure(state, JobId("j1"), is_worker_failure=False) + assert job.failure_count == 1 + + # Second failure - should NOT retry (count 2 > limit 1) + job.state = cluster_pb2.JOB_STATE_FAILED + assert not handle_job_failure(state, JobId("j1"), is_worker_failure=False) + assert job.failure_count == 2 + assert job.state == cluster_pb2.JOB_STATE_FAILED # State unchanged + + +def test_job_retry_resets_state(make_job_request): + """Verify job state/worker_id/timestamps cleared on retry.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + max_retries_failure=1, + ) + job.state = cluster_pb2.JOB_STATE_FAILED + job.worker_id = "w1" + job.started_at_ms = 12345 + job.finished_at_ms = 67890 + job.error = "Something went wrong" + state.add_job(job) + + # Retry should reset all state + assert handle_job_failure(state, JobId("j1"), is_worker_failure=False) + assert job.state == cluster_pb2.JOB_STATE_PENDING + assert job.worker_id is None + assert job.started_at_ms is None + assert job.finished_at_ms is None + assert job.error is None + # exit_code is not cleared (intentional - keeps historical info) + + +def test_handle_job_failure_nonexistent_job(): + """Returns False for unknown job_id.""" + state = ControllerState() + assert not handle_job_failure(state, JobId("nonexistent"), is_worker_failure=True) + + +def test_gang_all_or_nothing_retry(make_job_request): + """Gang retry fails if any job has no retries left.""" + state = ControllerState() + job1 = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + gang_id="g1", + max_retries_failure=1, + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=make_job_request("job2"), + gang_id="g1", + max_retries_failure=0, # No retries + ) + + state.add_job(job1) + state.add_job(job2) + + # Mark both as running + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + + # Gang fails - j2 has 0 retries, so entire gang cannot retry + retried = handle_gang_failure(state, "g1", is_worker_failure=False) + assert retried == [] + + # Both jobs should be marked KILLED + assert job1.state == cluster_pb2.JOB_STATE_KILLED + assert job2.state == cluster_pb2.JOB_STATE_KILLED + assert "Gang g1 failed" in job1.error + assert "Gang g1 failed" in job2.error + + # Failure counts should not be incremented since gang couldn't retry + assert job1.failure_count == 0 + assert job2.failure_count == 0 + + +def test_gang_retry_success(make_job_request): + """All jobs in gang are retried together.""" + state = ControllerState() + job1 = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + gang_id="g1", + max_retries_preemption=2, + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=make_job_request("job2"), + gang_id="g1", + max_retries_preemption=2, + ) + + state.add_job(job1) + state.add_job(job2) + + # Mark both as running + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + + # Gang fails due to worker failure - both have retries left + retried = handle_gang_failure(state, "g1", is_worker_failure=True) + + assert set(retried) == {"j1", "j2"} + + # Both jobs should be reset to PENDING + assert job1.state == cluster_pb2.JOB_STATE_PENDING + assert job2.state == cluster_pb2.JOB_STATE_PENDING + + # Both jobs should have preemption_count incremented + assert job1.preemption_count == 1 + assert job2.preemption_count == 1 + + # State should be cleared + assert job1.worker_id is None + assert job2.worker_id is None + assert job1.error is None + assert job2.error is None + + +def test_gang_marks_running_jobs_as_killed(make_job_request): + """Running jobs in gang marked KILLED on failure.""" + state = ControllerState() + job1 = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + gang_id="g1", + max_retries_failure=0, # No retries + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=make_job_request("job2"), + gang_id="g1", + max_retries_failure=0, # No retries + ) + job3 = ControllerJob( + job_id=JobId("j3"), + request=make_job_request("job3"), + gang_id="g1", + max_retries_failure=0, # No retries + ) + + state.add_job(job1) + state.add_job(job2) + state.add_job(job3) + + # Mix of states + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_PENDING # Not started yet + job3.state = cluster_pb2.JOB_STATE_RUNNING + + # Gang fails - no retries available + retried = handle_gang_failure(state, "g1", is_worker_failure=False) + assert retried == [] + + # Only running jobs should be marked KILLED + assert job1.state == cluster_pb2.JOB_STATE_KILLED + assert job2.state == cluster_pb2.JOB_STATE_PENDING # Not running, so not killed + assert job3.state == cluster_pb2.JOB_STATE_KILLED + + # All running jobs should have error message + assert job1.error == "Gang g1 failed" + assert job3.error == "Gang g1 failed" + + +def test_gang_failure_with_nonexistent_gang(): + """Returns empty list for unknown gang_id.""" + state = ControllerState() + retried = handle_gang_failure(state, "nonexistent", is_worker_failure=True) + assert retried == [] + + +def test_gang_retry_tracks_correct_failure_type(make_job_request): + """Gang retry increments correct counter based on failure type.""" + state = ControllerState() + job1 = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + gang_id="g1", + max_retries_failure=2, + max_retries_preemption=2, + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=make_job_request("job2"), + gang_id="g1", + max_retries_failure=2, + max_retries_preemption=2, + ) + + state.add_job(job1) + state.add_job(job2) + + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + + # Test worker failure (preemption) + retried = handle_gang_failure(state, "g1", is_worker_failure=True) + assert len(retried) == 2 + assert job1.preemption_count == 1 + assert job2.preemption_count == 1 + assert job1.failure_count == 0 + assert job2.failure_count == 0 + + # Reset for next test + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + + # Test job failure (internal) + retried = handle_gang_failure(state, "g1", is_worker_failure=False) + assert len(retried) == 2 + assert job1.preemption_count == 1 # Unchanged + assert job2.preemption_count == 1 # Unchanged + assert job1.failure_count == 1 + assert job2.failure_count == 1 + + +def test_job_with_zero_retries_default(make_job_request): + """Default max_retries_failure=0 means one try, no retries.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + # max_retries_failure defaults to 0 + ) + job.state = cluster_pb2.JOB_STATE_FAILED + state.add_job(job) + + # First failure should NOT retry (count 1 > limit 0) + assert not handle_job_failure(state, JobId("j1"), is_worker_failure=False) + assert job.failure_count == 1 + assert job.state == cluster_pb2.JOB_STATE_FAILED + + +def test_job_with_high_preemption_retries_default(make_job_request): + """Default max_retries_preemption=100 allows many retries.""" + state = ControllerState() + job = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + # max_retries_preemption defaults to 100 + ) + state.add_job(job) + + # Should be able to retry many times + for i in range(100): + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == i + 1 + + # 101st failure should not retry + job.state = cluster_pb2.JOB_STATE_WORKER_FAILED + assert not handle_job_failure(state, JobId("j1"), is_worker_failure=True) + assert job.preemption_count == 101 + + +def test_gang_retry_checks_all_jobs_for_retries(make_job_request): + """Gang retry requires ALL jobs to have retries, not just majority.""" + state = ControllerState() + # Create gang with 3 jobs, only 1 has no retries + job1 = ControllerJob( + job_id=JobId("j1"), + request=make_job_request("job1"), + gang_id="g1", + max_retries_failure=5, + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=make_job_request("job2"), + gang_id="g1", + max_retries_failure=5, + ) + job3 = ControllerJob( + job_id=JobId("j3"), + request=make_job_request("job3"), + gang_id="g1", + max_retries_failure=0, # Only this one has no retries + ) + + state.add_job(job1) + state.add_job(job2) + state.add_job(job3) + + job1.state = cluster_pb2.JOB_STATE_RUNNING + job2.state = cluster_pb2.JOB_STATE_RUNNING + job3.state = cluster_pb2.JOB_STATE_RUNNING + + # Should not retry because j3 has no retries (even though 2/3 do) + retried = handle_gang_failure(state, "g1", is_worker_failure=False) + assert retried == [] diff --git a/lib/fluster/tests/cluster/controller/test_scheduler.py b/lib/fluster/tests/cluster/controller/test_scheduler.py new file mode 100644 index 0000000000..79e735c6ad --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_scheduler.py @@ -0,0 +1,323 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for job scheduler. + +The scheduler is a shallow interface that takes inputs (pending jobs, workers, +current time) and returns outputs (assignments, timed-out jobs). It does not +dispatch jobs, modify state, or run threads. +""" + +import time + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.scheduler import Scheduler +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.types import JobId, WorkerId + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make( + name: str = "test-job", + cpu: int = 1, + memory: str = "1g", + scheduling_timeout_seconds: int = 0, + ) -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=cpu, memory=memory), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + scheduling_timeout_seconds=scheduling_timeout_seconds, + ) + + return _make + + +@pytest.fixture +def make_resource_spec(): + """Create a ResourceSpec for testing with enough capacity for multiple jobs.""" + + def _make(cpu: int = 10, memory: str = "10g") -> cluster_pb2.ResourceSpec: + return cluster_pb2.ResourceSpec(cpu=cpu, memory=memory, disk="10g") + + return _make + + +@pytest.fixture +def state(): + """Create a fresh ControllerState for each test.""" + return ControllerState() + + +@pytest.fixture +def scheduler(state): + """Create a Scheduler instance.""" + return Scheduler(state) + + +def test_scheduler_finds_assignment_for_job(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler assigns job to available worker.""" + worker = ControllerWorker(WorkerId("w1"), "addr", make_resource_spec()) + state.add_worker(worker) + + job = ControllerJob(JobId("j1"), request=make_job_request()) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 1 + assert result.assignments[0] == (job, worker) + assert len(result.timed_out_jobs) == 0 + + +def test_scheduler_returns_empty_when_no_workers(scheduler, state, make_job_request): + """Verify scheduler returns empty result when no workers available.""" + job = ControllerJob(JobId("j1"), request=make_job_request()) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() # Empty + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 0 + assert len(result.timed_out_jobs) == 0 + + +def test_scheduler_assigns_multiple_jobs_to_worker(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler can assign multiple jobs to one worker.""" + worker = ControllerWorker( + WorkerId("w1"), + "addr", + make_resource_spec(cpu=10, memory="10g"), + ) + state.add_worker(worker) + + job1 = ControllerJob(JobId("j1"), request=make_job_request(cpu=2)) + job2 = ControllerJob(JobId("j2"), request=make_job_request(cpu=2)) + job3 = ControllerJob(JobId("j3"), request=make_job_request(cpu=2)) + state.add_job(job1) + state.add_job(job2) + state.add_job(job3) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 3 + assigned_job_ids = {job.job_id for job, _ in result.assignments} + assert assigned_job_ids == {job1.job_id, job2.job_id, job3.job_id} + + +def test_scheduler_skips_jobs_that_dont_fit(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler skips jobs that don't fit and continues to next.""" + # Worker with 4 CPUs + worker = ControllerWorker( + WorkerId("w1"), + "addr", + cluster_pb2.ResourceSpec(cpu=4, memory="16g"), + ) + state.add_worker(worker) + + # Job 1: needs 8 CPUs (won't fit on 4 CPU worker) + job1 = ControllerJob(JobId("j1"), request=make_job_request(cpu=8)) + # Job 2: needs 2 CPUs (will fit) + job2 = ControllerJob(JobId("j2"), request=make_job_request(cpu=2)) + state.add_job(job1) + state.add_job(job2) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + # Only job2 should be assigned + assert len(result.assignments) == 1 + assert result.assignments[0][0] == job2 + + +def test_scheduler_detects_timed_out_jobs(scheduler, state, make_resource_spec): + """Verify scheduler identifies jobs that exceeded scheduling timeout.""" + worker = ControllerWorker(WorkerId("w1"), "addr", make_resource_spec(cpu=2)) + state.add_worker(worker) + + # Job that requires 100 CPUs (will never fit) with 1 second timeout + # Submitted 2 seconds ago, so it should be timed out + job_request = cluster_pb2.Controller.LaunchJobRequest( + name="impossible-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=100, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + scheduling_timeout_seconds=1, + ) + job = ControllerJob( + JobId("j1"), + request=job_request, + submitted_at_ms=int(time.time() * 1000) - 2000, # Submitted 2s ago + ) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 0 + assert len(result.timed_out_jobs) == 1 + assert result.timed_out_jobs[0] == job + + +def test_scheduler_no_timeout_when_zero(scheduler, state, make_resource_spec): + """Verify job with scheduling_timeout_seconds=0 never times out.""" + worker = ControllerWorker(WorkerId("w1"), "addr", make_resource_spec(cpu=2)) + state.add_worker(worker) + + # Job that can't fit but has no timeout (0) + job_request = cluster_pb2.Controller.LaunchJobRequest( + name="no-timeout-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=100, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + scheduling_timeout_seconds=0, # No timeout + ) + job = ControllerJob( + JobId("j1"), + request=job_request, + submitted_at_ms=int(time.time() * 1000) - 10000, # Submitted 10s ago + ) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + # Job should not be in timed_out_jobs (just skipped, no assignment) + assert len(result.assignments) == 0 + assert len(result.timed_out_jobs) == 0 + + +def test_scheduler_respects_worker_capacity_across_assignments(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler tracks capacity used by earlier assignments in same cycle.""" + # Worker with 4 CPUs + worker = ControllerWorker(WorkerId("w1"), "addr", make_resource_spec(cpu=4)) + state.add_worker(worker) + + # Submit 4 jobs, each requiring 2 CPUs + # Only 2 should fit at a time + for i in range(4): + job = ControllerJob(JobId(f"j{i}"), request=make_job_request(cpu=2)) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + # Only first 2 jobs should be assigned (using all 4 CPUs) + assert len(result.assignments) == 2 + + +def test_scheduler_skips_unhealthy_workers(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler ignores unhealthy workers.""" + healthy_worker = ControllerWorker(WorkerId("w1"), "addr1", make_resource_spec()) + unhealthy_worker = ControllerWorker(WorkerId("w2"), "addr2", make_resource_spec()) + unhealthy_worker.healthy = False + + state.add_worker(healthy_worker) + state.add_worker(unhealthy_worker) + + job = ControllerJob(JobId("j1"), request=make_job_request()) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + # get_available_workers() already filters unhealthy workers + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 1 + assert result.assignments[0][1] == healthy_worker + + +def test_scheduler_considers_running_jobs_for_capacity(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler accounts for jobs already running on workers.""" + # Worker with 4 CPUs + worker = ControllerWorker(WorkerId("w1"), "addr", make_resource_spec(cpu=4)) + state.add_worker(worker) + + # Add a running job that uses 3 CPUs + running_job = ControllerJob( + JobId("running"), + request=make_job_request(cpu=3), + state=cluster_pb2.JOB_STATE_RUNNING, + worker_id=worker.worker_id, + ) + state._jobs[running_job.job_id] = running_job + worker.running_jobs.add(running_job.job_id) + + # Try to schedule a job that needs 2 CPUs (won't fit, only 1 CPU available) + job = ControllerJob(JobId("j1"), request=make_job_request(cpu=2)) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 0 + + +def test_scheduler_assigns_to_multiple_workers(scheduler, state, make_job_request, make_resource_spec): + """Verify scheduler can assign jobs across multiple workers.""" + # Two workers with 2 CPUs each + worker1 = ControllerWorker(WorkerId("w1"), "addr1", make_resource_spec(cpu=2)) + worker2 = ControllerWorker(WorkerId("w2"), "addr2", make_resource_spec(cpu=2)) + state.add_worker(worker1) + state.add_worker(worker2) + + # Three jobs needing 2 CPUs each + # Two should fit (one on each worker), third won't fit + for i in range(3): + job = ControllerJob(JobId(f"j{i}"), request=make_job_request(cpu=2)) + state.add_job(job) + + pending_jobs = state.peek_pending_jobs() + workers = state.get_available_workers() + now_ms = int(time.time() * 1000) + + result = scheduler.find_assignments(pending_jobs, workers, now_ms) + + assert len(result.assignments) == 2 + assigned_workers = {w.worker_id for _, w in result.assignments} + assert assigned_workers == {"w1", "w2"} diff --git a/lib/fluster/tests/cluster/controller/test_service.py b/lib/fluster/tests/cluster/controller/test_service.py new file mode 100644 index 0000000000..e11e2f3211 --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_service.py @@ -0,0 +1,368 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for controller RPC service implementation.""" + +from unittest.mock import Mock + +import pytest +from connectrpc.code import Code +from connectrpc.errors import ConnectError + +from fluster import cluster_pb2 +from fluster.cluster.controller.service import ControllerServiceImpl +from fluster.cluster.controller.state import ControllerJob, ControllerState +from fluster.cluster.types import JobId, WorkerId + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make(name: str = "test-job") -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + return _make + + +@pytest.fixture +def make_resource_spec(): + """Create a minimal ResourceSpec for testing.""" + + def _make() -> cluster_pb2.ResourceSpec: + return cluster_pb2.ResourceSpec(cpu=1, memory="1g", disk="10g") + + return _make + + +@pytest.fixture +def state(): + """Create a fresh ControllerState for each test.""" + return ControllerState() + + +class MockSchedulerWake: + """Mock object that just tracks wake() calls.""" + + def __init__(self): + self.wake = Mock() + + +@pytest.fixture +def mock_scheduler(): + """Create a mock scheduler with wake() method.""" + return MockSchedulerWake() + + +@pytest.fixture +def service(state, mock_scheduler): + """Create a ControllerServiceImpl for testing.""" + return ControllerServiceImpl(state, mock_scheduler) + + +def test_launch_job_returns_job_id(service, make_job_request): + """Verify launch_job returns a job_id and adds job to state.""" + request = make_job_request("test-job") + + response = service.launch_job(request, None) + + # Should return a job_id + assert response.job_id + assert len(response.job_id) > 0 + + # Job should be in state + job = service._state.get_job(JobId(response.job_id)) + assert job is not None + assert job.state == cluster_pb2.JOB_STATE_PENDING + assert job.request.name == "test-job" + + +def test_launch_job_wakes_scheduler(service, mock_scheduler, make_job_request): + """Verify launch_job wakes the scheduler.""" + request = make_job_request("test-job") + service.launch_job(request, None) + + # Should have called wake() once + mock_scheduler.wake.assert_called_once() + + +def test_get_job_status_returns_status(service, state, make_job_request): + """Verify get_job_status returns status for existing job.""" + # Add a job directly to state + job = ControllerJob( + job_id=JobId("test-job-id"), + request=make_job_request("test-job"), + state=cluster_pb2.JOB_STATE_RUNNING, + submitted_at_ms=12345, + started_at_ms=12350, + ) + job.worker_id = "worker-1" + state.add_job(job) + + # Get status + request = cluster_pb2.Controller.GetJobStatusRequest(job_id="test-job-id") + response = service.get_job_status(request, None) + + # Verify response + assert response.job.job_id == "test-job-id" + assert response.job.state == cluster_pb2.JOB_STATE_RUNNING + assert response.job.started_at_ms == 12350 + assert response.job.worker_id == "worker-1" + + +def test_get_job_status_not_found(service): + """Verify get_job_status raises ConnectError for unknown job.""" + request = cluster_pb2.Controller.GetJobStatusRequest(job_id="nonexistent") + + with pytest.raises(ConnectError) as exc_info: + service.get_job_status(request, None) + + assert exc_info.value.code == Code.NOT_FOUND + assert "nonexistent" in exc_info.value.message + + +def test_terminate_job_marks_as_killed(service, state, make_job_request): + """Verify terminate_job sets job state to KILLED.""" + # Add a running job + job = ControllerJob( + job_id=JobId("test-job-id"), + request=make_job_request("test-job"), + state=cluster_pb2.JOB_STATE_RUNNING, + submitted_at_ms=12345, + started_at_ms=12350, + ) + state.add_job(job) + + # Terminate it + request = cluster_pb2.Controller.TerminateJobRequest(job_id="test-job-id") + response = service.terminate_job(request, None) + + # Should return empty response + assert isinstance(response, cluster_pb2.Empty) + + # Job should be marked KILLED + assert job.state == cluster_pb2.JOB_STATE_KILLED + assert job.finished_at_ms is not None + assert job.finished_at_ms > job.started_at_ms + + +def test_terminate_job_not_found(service): + """Verify terminate_job raises ConnectError for unknown job.""" + request = cluster_pb2.Controller.TerminateJobRequest(job_id="nonexistent") + + with pytest.raises(ConnectError) as exc_info: + service.terminate_job(request, None) + + assert exc_info.value.code == Code.NOT_FOUND + assert "nonexistent" in exc_info.value.message + + +def test_list_jobs_returns_all_jobs(service, state, make_job_request): + """Verify list_jobs returns all jobs in state.""" + # Add multiple jobs with different states + job1 = ControllerJob( + job_id=JobId("job-1"), + request=make_job_request("job1"), + state=cluster_pb2.JOB_STATE_PENDING, + ) + job2 = ControllerJob( + job_id=JobId("job-2"), + request=make_job_request("job2"), + state=cluster_pb2.JOB_STATE_RUNNING, + ) + job3 = ControllerJob( + job_id=JobId("job-3"), + request=make_job_request("job3"), + state=cluster_pb2.JOB_STATE_SUCCEEDED, + ) + state.add_job(job1) + state.add_job(job2) + state.add_job(job3) + + # List jobs + request = cluster_pb2.Controller.ListJobsRequest() + response = service.list_jobs(request, None) + + # Should return all jobs + assert len(response.jobs) == 3 + job_ids = {j.job_id for j in response.jobs} + assert job_ids == {"job-1", "job-2", "job-3"} + + # Verify states are correct + states_by_id = {j.job_id: j.state for j in response.jobs} + assert states_by_id["job-1"] == cluster_pb2.JOB_STATE_PENDING + assert states_by_id["job-2"] == cluster_pb2.JOB_STATE_RUNNING + assert states_by_id["job-3"] == cluster_pb2.JOB_STATE_SUCCEEDED + + +def test_get_job_status_includes_all_fields(service, state, make_job_request): + """Verify get_job_status includes all JobStatus fields.""" + # Add a completed job with all fields populated + job = ControllerJob( + job_id=JobId("test-job-id"), + request=make_job_request("test-job"), + state=cluster_pb2.JOB_STATE_FAILED, + submitted_at_ms=12345, + started_at_ms=12350, + finished_at_ms=12400, + ) + job.worker_id = "worker-1" + job.error = "Something went wrong" + job.exit_code = 42 + state.add_job(job) + + # Get status + request = cluster_pb2.Controller.GetJobStatusRequest(job_id="test-job-id") + response = service.get_job_status(request, None) + + # Verify all fields + assert response.job.job_id == "test-job-id" + assert response.job.state == cluster_pb2.JOB_STATE_FAILED + assert response.job.started_at_ms == 12350 + assert response.job.finished_at_ms == 12400 + assert response.job.worker_id == "worker-1" + assert response.job.error == "Something went wrong" + assert response.job.exit_code == 42 + + +def test_launch_job_generates_unique_ids(service, make_job_request): + """Verify each launch_job call generates a unique job_id.""" + request = make_job_request("test-job") + + # Launch multiple jobs + response1 = service.launch_job(request, None) + response2 = service.launch_job(request, None) + response3 = service.launch_job(request, None) + + # All IDs should be unique + job_ids = {response1.job_id, response2.job_id, response3.job_id} + assert len(job_ids) == 3 + + +def test_terminate_pending_job(service, state, make_job_request): + """Verify terminate_job works on pending jobs (not just running).""" + # Add a pending job + job = ControllerJob( + job_id=JobId("test-job-id"), + request=make_job_request("test-job"), + state=cluster_pb2.JOB_STATE_PENDING, + submitted_at_ms=12345, + ) + state.add_job(job) + + # Terminate it + request = cluster_pb2.Controller.TerminateJobRequest(job_id="test-job-id") + service.terminate_job(request, None) + + # Job should be marked KILLED even though it was never running + assert job.state == cluster_pb2.JOB_STATE_KILLED + assert job.finished_at_ms is not None + + +def test_register_worker(service, state, make_resource_spec): + """Verify register_worker adds worker to state.""" + request = cluster_pb2.Controller.RegisterWorkerRequest( + worker_id="w1", + address="host1:8080", + resources=make_resource_spec(), + ) + + response = service.register_worker(request, None) + + assert response.accepted is True + worker = state.get_worker(WorkerId("w1")) + assert worker is not None + assert worker.address == "host1:8080" + assert worker.healthy is True + + +def test_register_worker_logs_action(service, state, make_resource_spec): + """Verify register_worker logs an action.""" + request = cluster_pb2.Controller.RegisterWorkerRequest( + worker_id="w1", + address="host1:8080", + resources=make_resource_spec(), + ) + + service.register_worker(request, None) + + actions = state.get_recent_actions() + assert len(actions) == 1 + assert actions[0].action == "worker_registered" + assert actions[0].worker_id == "w1" + + +def test_list_workers_returns_all(service, state, make_resource_spec): + """Verify list_workers returns all workers.""" + from fluster.cluster.controller.state import ControllerWorker + + # Add multiple workers + for i in range(3): + worker = ControllerWorker( + worker_id=WorkerId(f"w{i}"), + address=f"host{i}:8080", + resources=make_resource_spec(), + healthy=(i != 1), # w1 is unhealthy + ) + state.add_worker(worker) + + request = cluster_pb2.Controller.ListWorkersRequest() + response = service.list_workers(request, None) + + assert len(response.workers) == 3 + worker_ids = {w.worker_id for w in response.workers} + assert worker_ids == {"w0", "w1", "w2"} + + # Check healthy status + workers_by_id = {w.worker_id: w for w in response.workers} + assert workers_by_id["w0"].healthy is True + assert workers_by_id["w1"].healthy is False + assert workers_by_id["w2"].healthy is True + + +def test_launch_job_logs_action(service, state, make_job_request): + """Verify launch_job logs an action.""" + request = make_job_request("test-job") + response = service.launch_job(request, None) + + actions = state.get_recent_actions() + assert len(actions) == 1 + assert actions[0].action == "job_submitted" + assert actions[0].job_id == response.job_id + assert actions[0].details == "test-job" + + +def test_terminate_job_logs_action(service, state, make_job_request): + """Verify terminate_job logs an action.""" + # Add a running job + job = ControllerJob( + job_id=JobId("test-job-id"), + request=make_job_request("test-job"), + state=cluster_pb2.JOB_STATE_RUNNING, + submitted_at_ms=12345, + ) + state.add_job(job) + + request = cluster_pb2.Controller.TerminateJobRequest(job_id="test-job-id") + service.terminate_job(request, None) + + actions = state.get_recent_actions() + assert len(actions) == 1 + assert actions[0].action == "job_killed" + assert actions[0].job_id == "test-job-id" diff --git a/lib/fluster/tests/cluster/controller/test_state.py b/lib/fluster/tests/cluster/controller/test_state.py new file mode 100644 index 0000000000..20d4f31539 --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_state.py @@ -0,0 +1,307 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for controller core data structures.""" + +import threading + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.types import JobId, WorkerId + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make(name: str = "test-job") -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + return _make + + +@pytest.fixture +def make_resource_spec(): + """Create a minimal ResourceSpec for testing.""" + + def _make() -> cluster_pb2.ResourceSpec: + return cluster_pb2.ResourceSpec(cpu=1, memory="1g", disk="10g") + + return _make + + +def test_controller_state_fifo_order(make_job_request): + """Verify jobs are returned in FIFO order.""" + state = ControllerState() + job1 = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1"), submitted_at_ms=100) + job2 = ControllerJob(job_id=JobId("j2"), request=make_job_request("job2"), submitted_at_ms=200) + state.add_job(job1) + state.add_job(job2) + + # Jobs should be popped in the order they were added + popped1 = state.pop_next_pending() + assert popped1 is not None + assert popped1.job_id == "j1" + + popped2 = state.pop_next_pending() + assert popped2 is not None + assert popped2.job_id == "j2" + + # Queue should be empty now + assert state.pop_next_pending() is None + + +def test_controller_state_skip_non_pending(make_job_request): + """Verify pop_next_pending skips jobs that are not in PENDING state.""" + state = ControllerState() + job1 = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + job1.state = cluster_pb2.JOB_STATE_RUNNING # Already started + job2 = ControllerJob(job_id=JobId("j2"), request=make_job_request("job2")) + state.add_job(job1) + state.add_job(job2) + + # Should skip j1 since it's not PENDING + popped = state.pop_next_pending() + assert popped is not None + assert popped.job_id == "j2" + + # Queue should be empty now + assert state.pop_next_pending() is None + + +def test_controller_state_worker_operations(make_resource_spec): + """Test add/get/list workers.""" + state = ControllerState() + worker1 = ControllerWorker(worker_id=WorkerId("w1"), address="host1:8080", resources=make_resource_spec()) + worker2 = ControllerWorker(worker_id=WorkerId("w2"), address="host2:8080", resources=make_resource_spec()) + + # Add workers + state.add_worker(worker1) + state.add_worker(worker2) + + # Get individual worker + retrieved = state.get_worker(WorkerId("w1")) + assert retrieved is not None + assert retrieved.address == "host1:8080" + assert retrieved.healthy is True + + # Get all available workers + available = state.get_available_workers() + assert len(available) == 2 + assert {w.worker_id for w in available} == {"w1", "w2"} + + # Mark one worker unhealthy + worker1.healthy = False + available = state.get_available_workers() + assert len(available) == 1 + assert available[0].worker_id == "w2" + + +def test_controller_state_gang_tracking(make_job_request): + """Verify gang jobs are tracked correctly.""" + state = ControllerState() + job1 = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1"), gang_id="gang1") + job2 = ControllerJob(job_id=JobId("j2"), request=make_job_request("job2"), gang_id="gang1") + job3 = ControllerJob(job_id=JobId("j3"), request=make_job_request("job3"), gang_id="gang2") + + state.add_job(job1) + state.add_job(job2) + state.add_job(job3) + + # Get jobs in gang1 + gang1_jobs = state.get_gang_jobs("gang1") + assert len(gang1_jobs) == 2 + assert {j.job_id for j in gang1_jobs} == {"j1", "j2"} + + # Get jobs in gang2 + gang2_jobs = state.get_gang_jobs("gang2") + assert len(gang2_jobs) == 1 + assert gang2_jobs[0].job_id == "j3" + + # Non-existent gang returns empty list + assert state.get_gang_jobs("nonexistent") == [] + + +def test_controller_state_thread_safety(make_job_request): + """Verify concurrent access doesn't corrupt state.""" + state = ControllerState() + num_threads = 10 + jobs_per_thread = 50 + barrier = threading.Barrier(num_threads) + errors = [] + + def add_jobs(thread_id: int): + try: + # Wait for all threads to be ready + barrier.wait() + + # Add jobs + for i in range(jobs_per_thread): + job_id = f"t{thread_id}_j{i}" + job = ControllerJob(job_id=JobId(job_id), request=make_job_request(f"job-{job_id}")) + state.add_job(job) + except Exception as e: + errors.append(e) + + # Start threads + threads = [threading.Thread(target=add_jobs, args=(i,)) for i in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Check no errors occurred + assert not errors, f"Errors during concurrent execution: {errors}" + + # Verify all jobs were added + expected_count = num_threads * jobs_per_thread + popped_count = 0 + while state.pop_next_pending() is not None: + popped_count += 1 + + assert popped_count == expected_count, f"Expected {expected_count} jobs, got {popped_count}" + + +def test_controller_state_job_retrieval(make_job_request): + """Test job retrieval by ID.""" + state = ControllerState() + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1"), submitted_at_ms=12345) + state.add_job(job) + + # Retrieve by ID + retrieved = state.get_job(JobId("j1")) + assert retrieved is not None + assert retrieved.job_id == "j1" + assert retrieved.submitted_at_ms == 12345 + + # Non-existent job returns None + assert state.get_job(JobId("nonexistent")) is None + + +def test_controller_state_multiple_gangs(make_job_request): + """Test tracking multiple gangs simultaneously.""" + state = ControllerState() + + # Create multiple gangs with different sizes + for gang_num in range(5): + gang_id = f"gang{gang_num}" + for job_num in range(gang_num + 1): # gang0 has 1 job, gang1 has 2, etc. + job_id = JobId(f"g{gang_num}_j{job_num}") + job = ControllerJob(job_id=job_id, request=make_job_request(f"job-{job_id}"), gang_id=gang_id) + state.add_job(job) + + # Verify each gang has correct number of jobs + for gang_num in range(5): + gang_id = f"gang{gang_num}" + gang_jobs = state.get_gang_jobs(gang_id) + expected_count = gang_num + 1 + assert ( + len(gang_jobs) == expected_count + ), f"Gang {gang_id} should have {expected_count} jobs, got {len(gang_jobs)}" + + +def test_controller_state_requeue_job(make_job_request): + """Test that jobs can be re-queued by calling add_job again.""" + state = ControllerState() + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # Add job + state.add_job(job) + + # Pop it + popped = state.pop_next_pending() + assert popped is not None + assert popped.job_id == "j1" + + # Queue should be empty + assert state.pop_next_pending() is None + + # Re-queue the same job + state.add_job(job) + + # Should be available again + popped = state.pop_next_pending() + assert popped is not None + assert popped.job_id == "j1" + + +def test_controller_state_action_log(): + """Test action log functionality.""" + state = ControllerState() + + # Initially empty + assert state.get_recent_actions() == [] + + # Log some actions + state.log_action("job_submitted", job_id=JobId("j1"), details="Test job") + state.log_action("worker_registered", worker_id=WorkerId("w1")) + state.log_action("job_started", job_id=JobId("j1"), worker_id=WorkerId("w1")) + + # Should have 3 actions + actions = state.get_recent_actions() + assert len(actions) == 3 + + # Check order (oldest first) + assert actions[0].action == "job_submitted" + assert actions[0].job_id == "j1" + assert actions[0].details == "Test job" + assert actions[1].action == "worker_registered" + assert actions[1].worker_id == "w1" + assert actions[2].action == "job_started" + + # Check timestamps are set + for action in actions: + assert action.timestamp_ms > 0 + + +def test_controller_state_action_log_limit(): + """Test action log respects limit parameter.""" + state = ControllerState() + + # Log many actions + for i in range(10): + state.log_action(f"action_{i}") + + # Get with limit + actions = state.get_recent_actions(limit=3) + assert len(actions) == 3 + + # Should be most recent 3 + assert actions[0].action == "action_7" + assert actions[1].action == "action_8" + assert actions[2].action == "action_9" + + +def test_controller_state_action_log_bounded(): + """Test action log deque is bounded to 100 entries.""" + state = ControllerState() + + # Log more than 100 actions + for i in range(150): + state.log_action(f"action_{i}") + + # Should only have 100 + actions = state.get_recent_actions(limit=200) + assert len(actions) == 100 + + # Oldest should be action_50 (first 50 were evicted) + assert actions[0].action == "action_50" + assert actions[-1].action == "action_149" diff --git a/lib/fluster/tests/cluster/controller/test_workers.py b/lib/fluster/tests/cluster/controller/test_workers.py new file mode 100644 index 0000000000..2c8654c776 --- /dev/null +++ b/lib/fluster/tests/cluster/controller/test_workers.py @@ -0,0 +1,499 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for worker registry and scheduling.""" + +import time + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.controller.state import ControllerJob, ControllerState, ControllerWorker +from fluster.cluster.controller.workers import ( + WorkerConfig, + find_worker_for_job, + get_committed_resources, + load_workers_from_config, + worker_can_fit_job, +) +from fluster.cluster.types import JobId, WorkerId + + +@pytest.fixture +def make_resource_spec(): + """Create a minimal ResourceSpec for testing.""" + + def _make(cpu: int = 8, memory: str = "32g") -> cluster_pb2.ResourceSpec: + return cluster_pb2.ResourceSpec(cpu=cpu, memory=memory, disk="100g") + + return _make + + +@pytest.fixture +def make_job_request(): + """Create a minimal LaunchJobRequest for testing.""" + + def _make(name: str = "test-job") -> cluster_pb2.Controller.LaunchJobRequest: + return cluster_pb2.Controller.LaunchJobRequest( + name=name, + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="1g"), + environment=cluster_pb2.EnvironmentConfig(workspace="/tmp"), + ) + + return _make + + +def test_load_workers_from_config(make_resource_spec): + """Verify workers are added to state correctly.""" + state = ControllerState() + workers = [ + WorkerConfig("w1", "host1:8080", make_resource_spec()), + WorkerConfig("w2", "host2:8080", make_resource_spec()), + ] + + before_ms = int(time.time() * 1000) + load_workers_from_config(state, workers) + after_ms = int(time.time() * 1000) + + # Verify workers were added + assert len(state.get_available_workers()) == 2 + + # Verify worker details + worker1 = state.get_worker(WorkerId("w1")) + assert worker1 is not None + assert worker1.address == "host1:8080" + assert worker1.healthy is True + assert worker1.resources.cpu == 8 + assert worker1.resources.memory == "32g" + + # Verify last_heartbeat_ms was set to current time + assert before_ms <= worker1.last_heartbeat_ms <= after_ms + + worker2 = state.get_worker(WorkerId("w2")) + assert worker2 is not None + assert worker2.address == "host2:8080" + + +def test_find_worker_for_job_returns_healthy_worker(make_resource_spec, make_job_request): + """Verify healthy worker is returned.""" + state = ControllerState() + worker = ControllerWorker(worker_id=WorkerId("w1"), address="host1:8080", resources=make_resource_spec()) + state.add_worker(worker) + + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # Should return the healthy worker + result = find_worker_for_job(state, job) + assert result is not None + assert result.worker_id == "w1" + assert result.address == "host1:8080" + + +def test_find_worker_for_job_skips_unhealthy(make_resource_spec, make_job_request): + """Verify unhealthy workers are skipped.""" + state = ControllerState() + + # Add unhealthy worker + worker1 = ControllerWorker(worker_id=WorkerId("w1"), address="host1:8080", resources=make_resource_spec()) + worker1.healthy = False + state.add_worker(worker1) + + # Add healthy worker + worker2 = ControllerWorker(worker_id=WorkerId("w2"), address="host2:8080", resources=make_resource_spec()) + state.add_worker(worker2) + + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # Should skip w1 and return w2 + result = find_worker_for_job(state, job) + assert result is not None + assert result.worker_id == "w2" + + +def test_find_worker_for_job_no_workers_returns_none(make_job_request): + """Verify None when no workers available.""" + state = ControllerState() + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # No workers registered + result = find_worker_for_job(state, job) + assert result is None + + +def test_find_worker_for_job_all_unhealthy_returns_none(make_resource_spec, make_job_request): + """Verify None when all workers are unhealthy.""" + state = ControllerState() + + # Add multiple unhealthy workers + for i in range(3): + worker = ControllerWorker(worker_id=WorkerId(f"w{i}"), address=f"host{i}:8080", resources=make_resource_spec()) + worker.healthy = False + state.add_worker(worker) + + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # Should return None since all workers are unhealthy + result = find_worker_for_job(state, job) + assert result is None + + +def test_find_worker_for_job_returns_first_available(make_resource_spec, make_job_request): + """Verify first-fit behavior - returns first healthy worker.""" + state = ControllerState() + + # Add multiple healthy workers + for i in range(5): + worker = ControllerWorker(worker_id=WorkerId(f"w{i}"), address=f"host{i}:8080", resources=make_resource_spec()) + state.add_worker(worker) + + job = ControllerJob(job_id=JobId("j1"), request=make_job_request("job1")) + + # Should return the first worker (order may vary due to dict iteration) + result = find_worker_for_job(state, job) + assert result is not None + assert result.worker_id in {f"w{i}" for i in range(5)} + + +# ============================================================================= +# Resource Matching Tests +# ============================================================================= + + +def test_worker_can_fit_job_cpu_constraint(): + """Job requiring more CPU than available should not fit.""" + state = ControllerState() + + # Worker with 4 CPUs total, running a job using 2 CPUs + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec(cpu=4, memory="32g"), + ) + state.add_worker(worker) + + # Job already running on worker (uses 2 CPUs) + running_job = ControllerJob( + job_id=JobId("running"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="running", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=2, memory="1g"), + ), + ) + state._jobs[running_job.job_id] = running_job + worker.running_jobs.add(running_job.job_id) + + # New job requiring 4 CPUs (only 2 available) + new_job = ControllerJob( + job_id=JobId("new"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="new", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=4, memory="1g"), + ), + ) + + assert not worker_can_fit_job(state, worker, new_job) + + +def test_worker_can_fit_job_memory_constraint(): + """Job requiring more memory than available should not fit.""" + state = ControllerState() + + # Worker with 16g memory, 12g already committed + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec(cpu=8, memory="16g"), + ) + state.add_worker(worker) + + # Running job uses 12g + running_job = ControllerJob( + job_id=JobId("running"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="running", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="12g"), + ), + ) + state._jobs[running_job.job_id] = running_job + worker.running_jobs.add(running_job.job_id) + + # New job requiring 8g (only 4g available) + new_job = ControllerJob( + job_id=JobId("new"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="new", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=1, memory="8g"), + ), + ) + + assert not worker_can_fit_job(state, worker, new_job) + + +def test_worker_can_fit_job_device_type_mismatch(): + """GPU job should not fit on CPU-only worker.""" + state = ControllerState() + + # CPU-only worker + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec( + cpu=8, + memory="32g", + device=cluster_pb2.DeviceConfig(cpu=cluster_pb2.CpuDevice()), + ), + ) + state.add_worker(worker) + + # GPU job + gpu_job = ControllerJob( + job_id=JobId("gpu-job"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="gpu-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=1, + memory="8g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=1)), + ), + ), + ) + + assert not worker_can_fit_job(state, worker, gpu_job) + + +def test_worker_can_fit_job_gpu_variant_match(): + """Job specifying GPU variant should match worker with same variant.""" + state = ControllerState() + + # H100 GPU worker + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec( + cpu=32, + memory="256g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="H100", count=8)), + ), + ) + state.add_worker(worker) + + # Job requiring H100 + h100_job = ControllerJob( + job_id=JobId("h100-job"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="h100-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="32g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="H100", count=2)), + ), + ), + ) + + assert worker_can_fit_job(state, worker, h100_job) + + +def test_worker_can_fit_job_gpu_variant_mismatch(): + """Job specifying specific variant should not match different variant.""" + state = ControllerState() + + # A100 GPU worker + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec( + cpu=32, + memory="256g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=8)), + ), + ) + state.add_worker(worker) + + # Job requiring H100 + h100_job = ControllerJob( + job_id=JobId("h100-job"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="h100-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="32g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="H100", count=2)), + ), + ), + ) + + assert not worker_can_fit_job(state, worker, h100_job) + + +def test_worker_can_fit_job_gpu_variant_auto(): + """Job with variant='auto' should match any GPU worker.""" + state = ControllerState() + + # A100 GPU worker + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec( + cpu=32, + memory="256g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=8)), + ), + ) + state.add_worker(worker) + + # Job with auto variant + auto_job = ControllerJob( + job_id=JobId("auto-job"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="auto-job", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="32g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="auto", count=1)), + ), + ), + ) + + assert worker_can_fit_job(state, worker, auto_job) + + +def test_worker_can_fit_job_gpu_count_constraint(): + """Job requiring more GPUs than available should not fit.""" + state = ControllerState() + + # Worker with 8 GPUs, 6 already in use + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec( + cpu=32, + memory="256g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=8)), + ), + ) + state.add_worker(worker) + + # Running job uses 6 GPUs + running_job = ControllerJob( + job_id=JobId("running"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="running", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="32g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=6)), + ), + ), + ) + state._jobs[running_job.job_id] = running_job + worker.running_jobs.add(running_job.job_id) + + # New job requiring 4 GPUs (only 2 available) + new_job = ControllerJob( + job_id=JobId("new"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="new", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec( + cpu=4, + memory="32g", + device=cluster_pb2.DeviceConfig(gpu=cluster_pb2.GpuDevice(variant="A100", count=4)), + ), + ), + ) + + assert not worker_can_fit_job(state, worker, new_job) + + +def test_get_committed_resources(): + """Verify committed resources are computed from running jobs.""" + state = ControllerState() + + worker = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr", + resources=cluster_pb2.ResourceSpec(cpu=16, memory="64g"), + ) + state.add_worker(worker) + + # Add two running jobs + job1 = ControllerJob( + job_id=JobId("j1"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="j1", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=2, memory="8g"), + ), + ) + job2 = ControllerJob( + job_id=JobId("j2"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="j2", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=4, memory="16g"), + ), + ) + state._jobs[job1.job_id] = job1 + state._jobs[job2.job_id] = job2 + worker.running_jobs.add(job1.job_id) + worker.running_jobs.add(job2.job_id) + + cpu, memory, _gpu = get_committed_resources(state, worker) + assert cpu == 6 # 2 + 4 + assert memory == 24 * 1024**3 # 8g + 16g + + +def test_find_worker_for_job_respects_capacity(): + """Verify find_worker_for_job skips workers without capacity.""" + state = ControllerState() + + # Worker 1: only 2 CPUs total + worker1 = ControllerWorker( + worker_id=WorkerId("w1"), + address="addr1", + resources=cluster_pb2.ResourceSpec(cpu=2, memory="16g"), + ) + state.add_worker(worker1) + + # Worker 2: has 8 CPUs + worker2 = ControllerWorker( + worker_id=WorkerId("w2"), + address="addr2", + resources=cluster_pb2.ResourceSpec(cpu=8, memory="32g"), + ) + state.add_worker(worker2) + + # Job requiring 4 CPUs + job = ControllerJob( + job_id=JobId("j1"), + request=cluster_pb2.Controller.LaunchJobRequest( + name="j1", + serialized_entrypoint=b"test", + resources=cluster_pb2.ResourceSpec(cpu=4, memory="1g"), + ), + ) + + result = find_worker_for_job(state, job) + assert result is not None + assert result.worker_id == "w2" # Should skip w1 (only 2 CPUs) diff --git a/lib/fluster/tests/cluster/worker/__init__.py b/lib/fluster/tests/cluster/worker/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lib/fluster/tests/cluster/worker/test_builder.py b/lib/fluster/tests/cluster/worker/test_builder.py new file mode 100644 index 0000000000..b4cb45a4e7 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_builder.py @@ -0,0 +1,416 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for VenvCache and ImageCache.""" + +import subprocess + +import pytest +from fluster.cluster.worker.builder import ImageCache, VenvCache + + +@pytest.fixture +def test_bundle(tmp_path): + """Create a test bundle with pyproject.toml and uv.lock.""" + bundle_dir = tmp_path / "test_bundle" + bundle_dir.mkdir() + + # Create minimal pyproject.toml + pyproject = """[project] +name = "test-package" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" +""" + (bundle_dir / "pyproject.toml").write_text(pyproject) + + # Create uv.lock file + # This is a minimal lock file - in real usage it would be generated by uv + uv_lock = """version = 1 +requires-python = ">=3.11" + +[[package]] +name = "test-package" +version = "0.1.0" +source = { editable = "." } +""" + (bundle_dir / "uv.lock").write_text(uv_lock) + + return bundle_dir + + +@pytest.fixture +def test_bundle_with_deps(tmp_path): + """Create a test bundle with actual dependencies for real uv testing.""" + bundle_dir = tmp_path / "test_bundle_deps" + bundle_dir.mkdir() + + # Create package directory structure + src_dir = bundle_dir / "src" / "test_package" + src_dir.mkdir(parents=True) + (src_dir / "__init__.py").write_text("# Test package\n") + + # Create pyproject.toml with a simple dependency + pyproject = """[project] +name = "test-package" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [ + "httpx>=0.28.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/test_package"] +""" + (bundle_dir / "pyproject.toml").write_text(pyproject) + + # Run uv lock to generate real lock file + try: + subprocess.run( + ["uv", "lock"], + cwd=bundle_dir, + check=True, + capture_output=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + pytest.skip("uv not available or failed to create lock file") + + return bundle_dir + + +def test_compute_deps_hash(test_bundle): + """Test that deps hash is computed from pyproject.toml and uv.lock.""" + cache = VenvCache() + + hash1 = cache.compute_deps_hash(test_bundle) + + # Should be consistent + hash2 = cache.compute_deps_hash(test_bundle) + assert hash1 == hash2 + + # Modify pyproject.toml + (test_bundle / "pyproject.toml").write_text("[project]\nname = 'changed'\n") + hash3 = cache.compute_deps_hash(test_bundle) + + # Hash should change + assert hash3 != hash1 + + +# ImageCache Tests + + +@pytest.fixture +def docker_bundle(tmp_path): + """Create a minimal test bundle for Docker builds.""" + bundle_dir = tmp_path / "docker_bundle" + bundle_dir.mkdir() + + # Create package directory structure + src_dir = bundle_dir / "src" / "test_app" + src_dir.mkdir(parents=True) + (src_dir / "__init__.py").write_text('"""Test app."""\n') + (src_dir / "main.py").write_text( + 'def main():\n print("Hello from Docker!")\n\nif __name__ == "__main__":\n main()\n' + ) + + # Create pyproject.toml + pyproject = """[project] +name = "test-app" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/test_app"] +""" + (bundle_dir / "pyproject.toml").write_text(pyproject) + + # Create uv.lock + uv_lock = """version = 1 +requires-python = ">=3.11" + +[[package]] +name = "test-app" +version = "0.1.0" +source = { editable = "." } +""" + (bundle_dir / "uv.lock").write_text(uv_lock) + + return bundle_dir + + +def check_docker_available(): + """Check if Docker is available and running.""" + try: + result = subprocess.run( + ["docker", "info"], + check=True, + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return False + + +@pytest.mark.slow +def test_image_cache_initialization(tmp_path): + """Test ImageCache initialization creates cache directory.""" + cache_dir = tmp_path / "cache" + ImageCache(cache_dir, registry="localhost:5000", max_images=10) + + assert (cache_dir / "images").exists() + + +@pytest.mark.slow +def test_image_caching(tmp_path, docker_bundle): + """Test that subsequent builds with same deps_hash use cached image.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + cache_dir = tmp_path / "cache" + builder = ImageCache(cache_dir, registry="localhost:5000") + + job_id = "cache-test-456" + deps_hash = "cachedep1234567890" + base_image = "python:3.11-slim" + + # First build - not from cache + result1 = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=job_id, + deps_hash=deps_hash, + ) + + assert result1.from_cache is False + assert result1.build_time_ms > 0 + + # Second build - should be from cache + result2 = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=job_id, + deps_hash=deps_hash, + ) + + assert result2.from_cache is True + assert result2.build_time_ms == 0 + assert result2.image_tag == result1.image_tag + + # Cleanup + subprocess.run(["docker", "rmi", result1.image_tag], stdout=subprocess.DEVNULL, check=False) + + +@pytest.mark.slow +def test_deps_hash_change_triggers_rebuild(tmp_path, docker_bundle): + """Test that changing deps_hash triggers a rebuild.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + cache_dir = tmp_path / "cache" + builder = ImageCache(cache_dir, registry="localhost:5000") + + job_id = "rebuild-test-789" + base_image = "python:3.11-slim" + + # Build with first deps_hash + deps_hash1 = "oldhash1234567890" + result1 = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=job_id, + deps_hash=deps_hash1, + ) + + assert result1.from_cache is False + expected_tag1 = f"localhost:5000/fluster-job-{job_id}:{deps_hash1[:8]}" + assert result1.image_tag == expected_tag1 + + # Build with different deps_hash - should rebuild + deps_hash2 = "newhash0987654321" + result2 = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=job_id, + deps_hash=deps_hash2, + ) + + assert result2.from_cache is False + expected_tag2 = f"localhost:5000/fluster-job-{job_id}:{deps_hash2[:8]}" + assert result2.image_tag == expected_tag2 + assert result2.image_tag != result1.image_tag + + # Both images should exist + exists1 = builder._docker.exists(expected_tag1) + exists2 = builder._docker.exists(expected_tag2) + assert exists1 is True + assert exists2 is True + + # Cleanup + subprocess.run(["docker", "rmi", expected_tag1], stdout=subprocess.DEVNULL, check=False) + subprocess.run(["docker", "rmi", expected_tag2], stdout=subprocess.DEVNULL, check=False) + + +@pytest.mark.slow +def test_buildkit_cache_mounts(tmp_path, docker_bundle): + """Test that BuildKit cache mounts are used for UV cache.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + cache_dir = tmp_path / "cache" + builder = ImageCache(cache_dir, registry="localhost:5000") + + # Verify DOCKER_BUILDKIT is set in environment during build + job_id = "buildkit-test-abc" + deps_hash = "buildkit1234567890" + base_image = "python:3.11-slim" + + # Build image - BuildKit should be enabled + result = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=job_id, + deps_hash=deps_hash, + ) + + # Verify image was built (BuildKit enabled by default in _docker_build) + assert result.from_cache is False + exists = builder._docker.exists(result.image_tag) + assert exists is True + + # Cleanup + subprocess.run(["docker", "rmi", result.image_tag], stdout=subprocess.DEVNULL, check=False) + + +@pytest.mark.slow +def test_image_build_with_extras(tmp_path): + """Test building image with extras.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Create bundle with extras + bundle_dir = tmp_path / "extras_bundle" + bundle_dir.mkdir() + + src_dir = bundle_dir / "src" / "test_app" + src_dir.mkdir(parents=True) + (src_dir / "__init__.py").write_text('"""Test app."""\n') + + pyproject = """[project] +name = "test-app" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [] + +[project.optional-dependencies] +dev = [] +test = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/test_app"] +""" + (bundle_dir / "pyproject.toml").write_text(pyproject) + + # Generate proper lock file + try: + subprocess.run( + ["uv", "lock"], + cwd=bundle_dir, + check=True, + capture_output=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + pytest.skip("uv not available or failed to create lock file") + + cache_dir = tmp_path / "cache" + builder = ImageCache(cache_dir, registry="localhost:5000") + + # Build with extras + result = builder.build( + bundle_path=bundle_dir, + base_image="python:3.11-slim", + extras=["dev", "test"], + job_id="extras-test", + deps_hash="extrahash123", + ) + + assert result.from_cache is False + + # Cleanup + subprocess.run(["docker", "rmi", result.image_tag], stdout=subprocess.DEVNULL, check=False) + + +@pytest.mark.slow +def test_lru_eviction_of_images(tmp_path, docker_bundle): + """Test LRU eviction removes old images when over limit.""" + import time + + if not check_docker_available(): + pytest.skip("Docker not available") + + cache_dir = tmp_path / "cache" + builder = ImageCache(cache_dir, registry="localhost:5000", max_images=2) + + base_image = "python:3.11-slim" + images_built = [] + + # Build 3 images to trigger eviction + for i in range(3): + result = builder.build( + bundle_path=docker_bundle, + base_image=base_image, + extras=[], + job_id=f"eviction-test-{i}", + deps_hash=f"evict{i:016d}", + ) + images_built.append(result.image_tag) + + # Small delay to ensure different creation times + time.sleep(0.1) + + # After building 3 images with max_images=2, oldest should be evicted + # Note: _evict_old_images is called after each build, but only when count > max_images + + # At least the newest image should exist + exists_2 = builder._docker.exists(images_built[2]) + assert exists_2 is True + + # Cleanup remaining images + for tag in images_built: + try: + subprocess.run(["docker", "rmi", tag], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) + except Exception: + pass # Image may already be evicted diff --git a/lib/fluster/tests/cluster/worker/test_bundle.py b/lib/fluster/tests/cluster/worker/test_bundle.py new file mode 100644 index 0000000000..fc60be57f7 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_bundle.py @@ -0,0 +1,177 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BundleCache.""" + +import hashlib +import zipfile + +import pytest +from fluster.cluster.worker.bundle import BundleCache + + +@pytest.fixture +def temp_cache_dir(tmp_path): + """Create a temporary cache directory.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + return cache_dir + + +@pytest.fixture +def test_bundle(tmp_path): + """Create a test bundle zip file.""" + bundle_dir = tmp_path / "test_bundle" + bundle_dir.mkdir() + + # Create some test files + (bundle_dir / "pyproject.toml").write_text("[project]\nname = 'test'\n") + (bundle_dir / "main.py").write_text("print('hello')\n") + + src_dir = bundle_dir / "src" + src_dir.mkdir() + (src_dir / "module.py").write_text("def foo(): pass\n") + + # Create zip file + zip_path = tmp_path / "test_bundle.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + for file in bundle_dir.rglob("*"): + if file.is_file(): + zf.write(file, file.relative_to(bundle_dir)) + + return zip_path + + +@pytest.fixture +def test_bundle_hash(test_bundle): + """Compute hash of test bundle.""" + h = hashlib.sha256() + with open(test_bundle, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +def test_download_local_bundle(temp_cache_dir, test_bundle): + """Test downloading a local bundle using file:// path.""" + cache = BundleCache(temp_cache_dir) + + # Use file:// protocol + file_url = f"file://{test_bundle}" + + # Get bundle + extract_path = cache.get_bundle(file_url) + + # Verify extraction + assert extract_path.exists() + assert extract_path.is_dir() + assert (extract_path / "pyproject.toml").exists() + assert (extract_path / "main.py").exists() + assert (extract_path / "src" / "module.py").exists() + + +def test_caching_behavior(temp_cache_dir, test_bundle): + """Test that bundles are cached and not re-downloaded.""" + cache = BundleCache(temp_cache_dir) + + file_url = f"file://{test_bundle}" + + # First download + extract_path1 = cache.get_bundle(file_url) + + # Second request - should use cache and return same path + extract_path2 = cache.get_bundle(file_url) + + assert extract_path1 == extract_path2 + assert extract_path2.exists() + + +def test_hash_verification_success(temp_cache_dir, test_bundle, test_bundle_hash): + """Test that hash verification passes with correct hash.""" + cache = BundleCache(temp_cache_dir) + + file_url = f"file://{test_bundle}" + + # Get bundle with correct hash + extract_path = cache.get_bundle(file_url, expected_hash=test_bundle_hash) + + # Should succeed + assert extract_path.exists() + + +def test_hash_verification_failure(temp_cache_dir, test_bundle): + """Test that hash verification fails with incorrect hash.""" + cache = BundleCache(temp_cache_dir) + + file_url = f"file://{test_bundle}" + + # Use wrong hash + wrong_hash = "a" * 64 + + # Should raise ValueError + with pytest.raises(ValueError, match="Bundle hash mismatch"): + cache.get_bundle(file_url, expected_hash=wrong_hash) + + +def test_lru_eviction(temp_cache_dir, tmp_path): + """Test LRU eviction when cache exceeds max_bundles.""" + # Create cache with max 2 bundles + cache = BundleCache(temp_cache_dir, max_bundles=2) + + # Create 3 test bundles + bundles = [] + for i in range(3): + bundle_dir = tmp_path / f"bundle_{i}" + bundle_dir.mkdir() + (bundle_dir / "test.txt").write_text(f"bundle {i}") + + zip_path = tmp_path / f"bundle_{i}.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.write(bundle_dir / "test.txt", "test.txt") + + bundles.append(zip_path) + + # Download all 3 bundles + paths = [] + for bundle in bundles: + file_url = f"file://{bundle}" + path = cache.get_bundle(file_url) + paths.append(path) + + # First bundle should be evicted (only 2 should remain) + assert not paths[0].exists(), "First bundle should be evicted" + assert paths[1].exists(), "Second bundle should still exist" + assert paths[2].exists(), "Third bundle should still exist" + + # Verify only 2 extracts exist + extracts = list((temp_cache_dir / "extracts").iterdir()) + assert len(extracts) == 2 + + +def test_concurrent_downloads(temp_cache_dir, test_bundle): + """Test that concurrent downloads work correctly.""" + import concurrent.futures + + cache = BundleCache(temp_cache_dir) + + file_url = f"file://{test_bundle}" + + # Request same bundle multiple times concurrently using threads + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(cache.get_bundle, file_url) for _ in range(5)] + paths = [f.result() for f in futures] + + # All should return the same path + assert all(p == paths[0] for p in paths) + assert paths[0].exists() diff --git a/lib/fluster/tests/cluster/worker/test_dashboard.py b/lib/fluster/tests/cluster/worker/test_dashboard.py new file mode 100644 index 0000000000..8cbef2b330 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_dashboard.py @@ -0,0 +1,712 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for WorkerDashboard HTTP/RPC endpoints and WorkerService implementation.""" + +import asyncio +from pathlib import Path +from unittest.mock import Mock + +import cloudpickle +import httpx +import pytest +from connectrpc.code import Code +from connectrpc.errors import ConnectError +from connectrpc.request import RequestContext +from starlette.testclient import TestClient + +from fluster import cluster_pb2 +from fluster.cluster.worker.builder import BuildResult, ImageCache, VenvCache +from fluster.cluster.worker.bundle import BundleCache +from fluster.cluster.worker.dashboard import WorkerDashboard +from fluster.cluster.worker.docker import ContainerStats, ContainerStatus, DockerRuntime +from fluster.cluster.worker.service import WorkerServiceImpl +from fluster.cluster.worker.worker import Worker, WorkerConfig +from fluster.cluster_connect import WorkerServiceClient + +# ============================================================================ +# Shared fixtures +# ============================================================================ + + +@pytest.fixture +def mock_bundle_cache(): + """Create mock BundleCache.""" + cache = Mock(spec=BundleCache) + cache.get_bundle = Mock(return_value=Path("/tmp/bundle")) + return cache + + +@pytest.fixture +def mock_venv_cache(): + """Create mock VenvCache.""" + cache = Mock(spec=VenvCache) + cache.compute_deps_hash = Mock(return_value="abc123") + return cache + + +@pytest.fixture +def mock_image_cache(): + """Create mock ImageCache.""" + cache = Mock(spec=ImageCache) + cache.build = Mock( + return_value=BuildResult( + image_tag="test-image:latest", + deps_hash="abc123", + build_time_ms=1000, + from_cache=False, + ) + ) + return cache + + +@pytest.fixture +def mock_runtime(): + """Create mock DockerRuntime for sync model. + + The sync model uses create_container/start_container/inspect pattern + instead of the blocking run() method. + """ + runtime = Mock(spec=DockerRuntime) + + # Container lifecycle methods + runtime.create_container = Mock(return_value="container123") + runtime.start_container = Mock() + + # Inspect returns not-running so jobs complete immediately + runtime.inspect = Mock(return_value=ContainerStatus(running=False, exit_code=0)) + + runtime.kill = Mock() + runtime.remove = Mock() + runtime.get_stats = Mock( + return_value=ContainerStats( + memory_mb=100, + cpu_percent=50, + process_count=1, + available=True, + ) + ) + runtime.get_logs = Mock(return_value=[]) + return runtime + + +@pytest.fixture +def worker(mock_bundle_cache, mock_venv_cache, mock_image_cache, mock_runtime): + """Create Worker with mocked dependencies.""" + config = WorkerConfig( + port=0, + max_concurrent_jobs=5, + port_range=(50000, 50100), + ) + return Worker( + config, + bundle_provider=mock_bundle_cache, + image_provider=mock_image_cache, + container_runtime=mock_runtime, + ) + + +def create_test_entrypoint(): + """Create a simple test entrypoint.""" + from dataclasses import dataclass + + @dataclass + class Entrypoint: + callable: object + args: tuple = () + kwargs: dict | None = None + + def __post_init__(self): + if self.kwargs is None: + self.kwargs = {} + + def test_fn(): + print("Hello from test") + + return Entrypoint(callable=test_fn) + + +def create_run_job_request(job_id: str = "test-job-1", ports: list[str] | None = None): + """Create a RunJobRequest for testing.""" + entrypoint = create_test_entrypoint() + serialized_entrypoint = cloudpickle.dumps(entrypoint) + + env_config = cluster_pb2.EnvironmentConfig( + workspace="/workspace", + env_vars={ + "TEST_VAR": "value", + "JOB_VAR": "job_value", + }, + extras=["dev"], + ) + + resources = cluster_pb2.ResourceSpec( + cpu=2, + memory="4g", + ) + + return cluster_pb2.Worker.RunJobRequest( + job_id=job_id, + serialized_entrypoint=serialized_entrypoint, + environment=env_config, + bundle_gcs_path="gs://bucket/bundle.zip", + resources=resources, + timeout_seconds=300, + ports=ports or [], + ) + + +@pytest.fixture +def service(worker): + """Create WorkerServiceImpl.""" + return WorkerServiceImpl(provider=worker) + + +@pytest.fixture +def server(service): + """Create WorkerDashboard.""" + return WorkerDashboard(service=service, host="127.0.0.1", port=0) + + +@pytest.fixture +def client(server): + """Create test client for HTTP requests.""" + return TestClient(server._app) + + +@pytest.fixture +def request_context(): + """Create a mock RequestContext for RPC calls.""" + return Mock(spec=RequestContext) + + +# ============================================================================ +# Dashboard tests +# ============================================================================ + + +def test_dashboard_loads(client): + """Test dashboard HTML loads successfully.""" + response = client.get("/") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/html; charset=utf-8" + assert "Fluster Worker Dashboard" in response.text + + +# ============================================================================ +# Stats API tests +# ============================================================================ + + +def test_stats_empty(client, service): + """Test /api/stats with no jobs.""" + response = client.get("/api/stats") + assert response.status_code == 200 + data = response.json() + assert data == {"running": 0, "pending": 0, "building": 0, "completed": 0} + + +def test_stats_with_jobs(client, service): + """Test /api/stats with various job states.""" + # Submit jobs and let them complete + for i in range(5): + request = create_run_job_request(job_id=f"job-{i}") + service.run_job(request, Mock()) + + # Wait for all job threads to complete (mock makes them finish immediately) + jobs = service._provider.list_jobs() + for job in jobs: + if job.thread: + job.thread.join(timeout=5.0) + + # Now manually set states (threads are done, we control the state) + jobs[0].status = cluster_pb2.JOB_STATE_RUNNING + jobs[1].status = cluster_pb2.JOB_STATE_PENDING + jobs[2].status = cluster_pb2.JOB_STATE_BUILDING + jobs[3].status = cluster_pb2.JOB_STATE_SUCCEEDED + jobs[4].status = cluster_pb2.JOB_STATE_FAILED + + response = client.get("/api/stats") + assert response.status_code == 200 + data = response.json() + assert data["running"] == 1 + assert data["pending"] == 1 + assert data["building"] == 1 + assert data["completed"] == 2 # succeeded + failed + + +# ============================================================================ +# List jobs API tests +# ============================================================================ + + +def test_list_jobs_empty(client): + """Test /api/jobs with no jobs.""" + response = client.get("/api/jobs") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_jobs_with_data(client, service): + """Test /api/jobs returns all jobs.""" + for i in range(3): + request = create_run_job_request(job_id=f"job-{i}") + service.run_job(request, Mock()) + + response = client.get("/api/jobs") + assert response.status_code == 200 + jobs = response.json() + assert len(jobs) == 3 + + job_ids = {j["job_id"] for j in jobs} + assert job_ids == {"job-0", "job-1", "job-2"} + + +# ============================================================================ +# Get job API tests +# ============================================================================ + + +def test_get_job_not_found(client): + """Test /api/jobs/{job_id} with nonexistent job.""" + response = client.get("/api/jobs/nonexistent") + assert response.status_code == 404 + assert response.json() == {"error": "Not found"} + + +def test_get_job_success(client, service): + """Test /api/jobs/{job_id} returns job details.""" + request = create_run_job_request(job_id="job-details", ports=["http", "grpc"]) + service.run_job(request, Mock()) + + # Wait for job to complete (mock runtime returns running=False immediately) + job = service._provider.get_job("job-details") + job.thread.join(timeout=5.0) + + response = client.get("/api/jobs/job-details") + assert response.status_code == 200 + data = response.json() + + assert data["job_id"] == "job-details" + assert data["status"] == "succeeded" # Job completes immediately with mock runtime + assert data["exit_code"] == 0 + assert "http" in data["ports"] + assert "grpc" in data["ports"] + + +# ============================================================================ +# Get logs API tests +# ============================================================================ + + +def test_get_logs_with_tail_parameter(client, service): + """Test /api/jobs/{job_id}/logs?tail=N returns last N lines.""" + request = create_run_job_request(job_id="job-tail") + service.run_job(request, Mock()) + + # Add logs directly to job.logs (since we no longer use file-based logging) + job = service._provider.get_job("job-tail") + for i in range(100): + job.logs.add("stdout", f"Log line {i}") + + response = client.get("/api/jobs/job-tail/logs?tail=5") + assert response.status_code == 200 + logs = response.json() + + assert len(logs) == 5 + assert logs[0]["data"] == "Log line 95" + assert logs[4]["data"] == "Log line 99" + + +def test_get_logs_with_source_filter(client, service): + """Test /api/jobs/{job_id}/logs?source=stdout filters by source.""" + import time + + request = create_run_job_request(job_id="job-source-filter") + service.run_job(request, Mock()) + + # Stop the job thread so it doesn't add more logs + job = service._provider.get_job("job-source-filter") + time.sleep(0.05) + job.should_stop = True + if job.thread: + job.thread.join(timeout=1.0) + + # Clear any existing logs and add test logs + job.logs.lines.clear() + job.logs.add("stdout", "stdout line 1") + job.logs.add("stdout", "stdout line 2") + job.logs.add("stderr", "stderr line 1") + job.logs.add("stderr", "stderr line 2") + + # Test stdout filter + response = client.get("/api/jobs/job-source-filter/logs?source=stdout") + assert response.status_code == 200 + logs = response.json() + assert len(logs) == 2 + assert all(log["source"] == "stdout" for log in logs) + + # Test stderr filter + response = client.get("/api/jobs/job-source-filter/logs?source=stderr") + assert response.status_code == 200 + logs = response.json() + assert len(logs) == 2 + assert all(log["source"] == "stderr" for log in logs) + + # Test without filter - should get all logs + response = client.get("/api/jobs/job-source-filter/logs") + assert response.status_code == 200 + logs = response.json() + assert len(logs) == 4 # 2 stdout + 2 stderr + + +def test_list_jobs_includes_resource_and_build_metrics(client, service): + """Test /api/jobs includes resource and build metrics.""" + request = create_run_job_request(job_id="job-with-metrics") + service.run_job(request, Mock()) + + # Set some resource and build metrics + job = service._provider.get_job("job-with-metrics") + job.current_memory_mb = 256 + job.peak_memory_mb = 512 + job.current_cpu_percent = 45 + job.process_count = 3 + job.disk_mb = 1024 + job.build_from_cache = True + job.image_tag = "test-image:v1.0" + + response = client.get("/api/jobs") + assert response.status_code == 200 + jobs = response.json() + + job_data = next(j for j in jobs if j["job_id"] == "job-with-metrics") + assert job_data["memory_mb"] == 256 + assert job_data["memory_peak_mb"] == 512 + assert job_data["cpu_percent"] == 45 + assert job_data["process_count"] == 3 + assert job_data["disk_mb"] == 1024 + assert job_data["build_from_cache"] is True + assert job_data["image_tag"] == "test-image:v1.0" + + +def test_get_job_includes_nested_resources_and_build(client, service): + """Test /api/jobs/{job_id} includes nested resources and build objects.""" + request = create_run_job_request(job_id="job-nested-metrics") + service.run_job(request, Mock()) + + # Set some resource and build metrics + job = service._provider.get_job("job-nested-metrics") + job.current_memory_mb = 128 + job.peak_memory_mb = 256 + job.current_cpu_percent = 30 + job.process_count = 2 + job.disk_mb = 512 + job.build_started_ms = 1000 + job.build_finished_ms = 2500 + job.build_from_cache = False + job.image_tag = "test-image:v2.0" + + response = client.get("/api/jobs/job-nested-metrics") + assert response.status_code == 200 + data = response.json() + + # Check resources nested object + assert "resources" in data + resources = data["resources"] + assert resources["memory_mb"] == 128 + assert resources["memory_peak_mb"] == 256 + assert resources["cpu_percent"] == 30 + assert resources["process_count"] == 2 + assert resources["disk_mb"] == 512 + + # Check build nested object + assert "build" in data + build = data["build"] + assert build["started_ms"] == 1000 + assert build["finished_ms"] == 2500 + assert build["duration_ms"] == 1500 + assert build["from_cache"] is False + assert build["image_tag"] == "test-image:v2.0" + + +def test_job_detail_page_loads(client): + """Test /job/{job_id} page loads successfully.""" + response = client.get("/job/test-job-123") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/html; charset=utf-8" + assert "Job: test-job-123" in response.text + assert "Back to Dashboard" in response.text + + +# ============================================================================ +# RPC Service tests (WorkerServiceImpl) +# ============================================================================ + + +def test_run_job_generates_job_id_if_missing(service, request_context): + """Test run_job generates job_id when not provided.""" + request = create_run_job_request(job_id="") + response = service.run_job(request, request_context) + + assert response.job_id # Should have a generated ID + assert len(response.job_id) > 0 + # Job may have already transitioned from PENDING since threads start immediately + assert response.state in ( + cluster_pb2.JOB_STATE_PENDING, + cluster_pb2.JOB_STATE_BUILDING, + cluster_pb2.JOB_STATE_RUNNING, + cluster_pb2.JOB_STATE_SUCCEEDED, + ) + + +def test_run_job_with_ports(service, request_context): + """Test run_job allocates ports correctly.""" + request = create_run_job_request(job_id="job-with-ports", ports=["http", "grpc"]) + response = service.run_job(request, request_context) + + assert response.job_id == "job-with-ports" + + # Verify ports were allocated + job = service._provider.get_job("job-with-ports") + assert len(job.ports) == 2 + assert "http" in job.ports + assert "grpc" in job.ports + + +def test_get_job_status_not_found(service, request_context): + """Test get_job_status raises NOT_FOUND for nonexistent job.""" + status_request = cluster_pb2.Worker.GetJobStatusRequest(job_id="nonexistent") + + with pytest.raises(ConnectError) as exc_info: + service.get_job_status(status_request, request_context) + + assert exc_info.value.code == Code.NOT_FOUND + assert "nonexistent" in str(exc_info.value) + + +def test_get_job_status_completed_job(service, request_context): + """Test get_job_status for completed job includes timing info.""" + request = create_run_job_request(job_id="job-completed") + service.run_job(request, request_context) + + # Wait for job to complete + job = service._provider.get_job("job-completed") + job.thread.join(timeout=5.0) + + status_request = cluster_pb2.Worker.GetJobStatusRequest(job_id="job-completed") + status = service.get_job_status(status_request, request_context) + + assert status.job_id == "job-completed" + assert status.state == cluster_pb2.JOB_STATE_SUCCEEDED + assert status.exit_code == 0 + assert status.started_at_ms > 0 + assert status.finished_at_ms > 0 + + +def test_fetch_logs_tail_with_negative_start_line(service, request_context): + """Test fetch_logs with negative start_line for tailing.""" + request = create_run_job_request(job_id="job-logs-tail") + service.run_job(request, request_context) + + # Add logs directly to job.logs + job = service._provider.get_job("job-logs-tail") + for i in range(10): + job.logs.add("stdout", f"Log line {i}") + + log_filter = cluster_pb2.Worker.FetchLogsFilter(start_line=-3) + logs_request = cluster_pb2.Worker.FetchLogsRequest(job_id="job-logs-tail", filter=log_filter) + response = service.fetch_logs(logs_request, request_context) + + assert len(response.logs) == 3 + assert response.logs[0].data == "Log line 7" + assert response.logs[1].data == "Log line 8" + assert response.logs[2].data == "Log line 9" + + +def test_fetch_logs_with_regex_filter(service, request_context): + """Test fetch_logs with regex content filter.""" + request = create_run_job_request(job_id="job-logs-regex") + service.run_job(request, request_context) + + # Add logs with different patterns + job = service._provider.get_job("job-logs-regex") + job.logs.add("stdout", "ERROR: something bad") + job.logs.add("stdout", "INFO: normal log") + job.logs.add("stdout", "ERROR: another error") + job.logs.add("stdout", "DEBUG: details") + + log_filter = cluster_pb2.Worker.FetchLogsFilter(regex="ERROR") + logs_request = cluster_pb2.Worker.FetchLogsRequest(job_id="job-logs-regex", filter=log_filter) + response = service.fetch_logs(logs_request, request_context) + + assert len(response.logs) == 2 + assert "ERROR" in response.logs[0].data + assert "ERROR" in response.logs[1].data + + +def test_fetch_logs_combined_filters(service, request_context): + """Test fetch_logs with multiple filters combined.""" + request = create_run_job_request(job_id="job-logs-combined") + service.run_job(request, request_context) + + # Add logs + job = service._provider.get_job("job-logs-combined") + job.logs.add("stdout", "ERROR: first error") + job.logs.add("stdout", "INFO: normal") + job.logs.add("stdout", "ERROR: second error") + job.logs.add("stdout", "ERROR: third error") + job.logs.add("stdout", "ERROR: fourth error") + job.logs.add("stdout", "ERROR: fifth error") + + # Use regex to filter ERRORs, then limit to 2 + log_filter = cluster_pb2.Worker.FetchLogsFilter(regex="ERROR", max_lines=2) + logs_request = cluster_pb2.Worker.FetchLogsRequest(job_id="job-logs-combined", filter=log_filter) + response = service.fetch_logs(logs_request, request_context) + + assert len(response.logs) == 2 + assert "ERROR" in response.logs[0].data + assert "ERROR" in response.logs[1].data + + +def test_kill_job_not_found(service, request_context): + """Test kill_job raises NOT_FOUND for nonexistent job.""" + kill_request = cluster_pb2.Worker.KillJobRequest(job_id="nonexistent") + + with pytest.raises(ConnectError) as exc_info: + service.kill_job(kill_request, request_context) + + assert exc_info.value.code == Code.NOT_FOUND + assert "nonexistent" in str(exc_info.value) + + +def test_kill_job_already_completed(service, request_context): + """Test kill_job fails for already completed job.""" + request = create_run_job_request(job_id="job-completed") + service.run_job(request, request_context) + + # Wait for job to complete + job = service._provider.get_job("job-completed") + job.thread.join(timeout=5.0) + + # Try to kill completed job + kill_request = cluster_pb2.Worker.KillJobRequest(job_id="job-completed") + + with pytest.raises(ConnectError) as exc_info: + service.kill_job(kill_request, request_context) + + assert exc_info.value.code == Code.FAILED_PRECONDITION + assert "already completed" in str(exc_info.value) + + +def test_kill_job_with_custom_timeout(service, request_context): + """Test kill_job accepts custom term_timeout_ms and attempts termination. + + Note: With mocks, the job thread completes immediately. This test verifies + the API works and runtime.kill is called, not the actual kill behavior. + """ + request = create_run_job_request(job_id="job-kill") + service.run_job(request, request_context) + + # Wait for job thread to finish (mock makes it complete immediately) + job = service._provider.get_job("job-kill") + if job.thread: + job.thread.join(timeout=5.0) + + # Manually set job to RUNNING to simulate mid-execution state + job.status = cluster_pb2.JOB_STATE_RUNNING + job.container_id = "container123" + + kill_request = cluster_pb2.Worker.KillJobRequest(job_id="job-kill", term_timeout_ms=100) + response = service.kill_job(kill_request, request_context) + + # Verify API response and that should_stop was set + assert isinstance(response, cluster_pb2.Empty) + assert job.should_stop is True + # The runtime.kill should have been called (may be called twice: SIGTERM then SIGKILL) + assert service._provider._runtime.kill.called + + +# ============================================================================ +# Connect RPC integration tests +# ============================================================================ + + +def test_rpc_endpoint_mounted_correctly(server): + """Test Connect RPC is mounted at correct path.""" + # Check that the RPC path is included in routes + route_paths = [route.path for route in server._app.routes] + assert "/fluster.cluster.WorkerService" in route_paths + + +@pytest.mark.asyncio +async def test_rpc_run_job_via_connect_client(service): + """Test calling run_job via Connect RPC client.""" + # Create server on ephemeral port + server = WorkerDashboard(service=service, host="127.0.0.1", port=0) + + # Run server in background + async def run_server(): + import uvicorn + + config = uvicorn.Config(server._app, host="127.0.0.1", port=18080) + server_obj = uvicorn.Server(config) + await server_obj.serve() + + server_task = asyncio.create_task(run_server()) + + try: + # Give server time to start + await asyncio.sleep(0.5) + + # Create Connect client (async client talks to WSGI server via WSGIMiddleware) + async with httpx.AsyncClient() as http_client: + client = WorkerServiceClient(address="http://127.0.0.1:18080", session=http_client) + + # Submit job via RPC + request = create_run_job_request(job_id="rpc-test-job") + response = await client.run_job(request) + + assert response.job_id == "rpc-test-job" + # Job may have already transitioned from PENDING since threads start immediately + assert response.state in ( + cluster_pb2.JOB_STATE_PENDING, + cluster_pb2.JOB_STATE_BUILDING, + cluster_pb2.JOB_STATE_RUNNING, + cluster_pb2.JOB_STATE_SUCCEEDED, + ) + + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + +# ============================================================================ +# Server properties tests +# ============================================================================ + + +def test_server_port_property(service): + """Test server port property returns configured port.""" + server = WorkerDashboard(service=service, host="127.0.0.1", port=9999) + assert server.port == 9999 + + +def test_server_default_host_and_port(service): + """Test server uses default host and port.""" + server = WorkerDashboard(service=service) + assert server._host == "0.0.0.0" + assert server._port == 8080 diff --git a/lib/fluster/tests/cluster/worker/test_main.py b/lib/fluster/tests/cluster/worker/test_main.py new file mode 100644 index 0000000000..a0e1d1bb6e --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_main.py @@ -0,0 +1,90 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Fluster worker CLI.""" + +import tempfile +from pathlib import Path + +from click.testing import CliRunner + +from fluster.cluster.worker.main import cli + + +def test_cli_help(): + """Test CLI help message.""" + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + assert result.exit_code == 0 + assert "Fluster Worker" in result.output + assert "serve" in result.output + assert "cleanup" in result.output + + +def test_serve_help(): + """Test serve command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["serve", "--help"]) + assert result.exit_code == 0 + assert "Start the Fluster worker service" in result.output + assert "--host" in result.output + assert "--port" in result.output + assert "--cache-dir" in result.output + assert "--registry" in result.output + + +def test_cleanup_help(): + """Test cleanup command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["cleanup", "--help"]) + assert result.exit_code == 0 + assert "Clean up cached bundles" in result.output + assert "--cache-dir" in result.output + + +def test_cleanup_removes_cache_directory(): + """Test cleanup command removes cache directory.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "fluster-cache" + cache_dir.mkdir() + test_file = cache_dir / "test.txt" + test_file.write_text("test data") + + assert cache_dir.exists() + + result = runner.invoke(cli, ["cleanup", "--cache-dir", str(cache_dir)]) + assert result.exit_code == 0 + assert "Removed cache directory" in result.output + assert not cache_dir.exists() + + +def test_cleanup_handles_missing_directory(): + """Test cleanup command handles missing cache directory.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + cache_dir = Path(tmpdir) / "nonexistent" + assert not cache_dir.exists() + + result = runner.invoke(cli, ["cleanup", "--cache-dir", str(cache_dir)]) + assert result.exit_code == 0 + assert "does not exist" in result.output + + +def test_serve_requires_registry(): + """Test serve command requires --registry argument.""" + runner = CliRunner() + result = runner.invoke(cli, ["serve"]) + assert result.exit_code != 0 + assert "registry" in result.output.lower() or "required" in result.output.lower() diff --git a/lib/fluster/tests/cluster/worker/test_runtime.py b/lib/fluster/tests/cluster/worker/test_runtime.py new file mode 100644 index 0000000000..2d38acec72 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_runtime.py @@ -0,0 +1,726 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DockerRuntime.""" + +import subprocess +import time + +import pytest + +from fluster import cluster_pb2 +from fluster.cluster.worker.docker import ContainerConfig, DockerRuntime + + +def check_docker_available(): + """Check if Docker is available and running.""" + try: + result = subprocess.run( + ["docker", "info"], + check=True, + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return False + + +@pytest.fixture +def runtime(): + """Create DockerRuntime instance.""" + return DockerRuntime() + + +@pytest.mark.slow +def test_create_and_start_container(runtime): + """Test creating and starting a simple container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "Hello World"], + env={}, + ) + + # Create container + container_id = runtime.create_container(config) + assert container_id is not None + assert len(container_id) > 0 + + # Start container + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + # Check final status + status = runtime.inspect(container_id) + assert not status.running + assert status.exit_code == 0 + assert status.error is None + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_container_with_failure_exit_code(runtime): + """Test container that exits with non-zero code.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "exit 42"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + # Check exit code + status = runtime.inspect(container_id) + assert status.exit_code == 42 + assert status.error is None + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_resource_limits_cpu(runtime): + """Test that CPU limits are applied to container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Set CPU limit to 1 core (1000 millicores) + resources = cluster_pb2.ResourceSpec(cpu=1) + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + resources=resources, + ) + + container_id = runtime.create_container(config) + + # Inspect container to verify CPU limit was set + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{.HostConfig.NanoCpus}}", + ], + capture_output=True, + text=True, + check=True, + ) + nano_cpus = int(result.stdout.strip()) + + # 1 core = 1000 millicores = 1000000000 nanocpus + expected_nano_cpus = 1_000_000_000 + assert nano_cpus == expected_nano_cpus + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_resource_limits_memory(runtime): + """Test that memory limits are applied to container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Set memory limit to 256 MB + resources = cluster_pb2.ResourceSpec(memory="256m") + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + resources=resources, + ) + + container_id = runtime.create_container(config) + + # Inspect container to verify memory limit was set + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{.HostConfig.Memory}}", + ], + capture_output=True, + text=True, + check=True, + ) + memory_bytes = int(result.stdout.strip()) + + # 256 MB = 268435456 bytes + expected_bytes = 256 * 1024 * 1024 + assert memory_bytes == expected_bytes + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_resource_limits_combined(runtime): + """Test that CPU and memory limits work together.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + resources = cluster_pb2.ResourceSpec(cpu=1, memory="512m") + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + resources=resources, + ) + + container_id = runtime.create_container(config) + + # Verify both limits + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{.HostConfig.NanoCpus}} {{.HostConfig.Memory}}", + ], + capture_output=True, + text=True, + check=True, + ) + parts = result.stdout.strip().split() + nano_cpus = int(parts[0]) + memory_bytes = int(parts[1]) + + assert nano_cpus == 1_000_000_000 + assert memory_bytes == 512 * 1024 * 1024 + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_environment_variables(runtime): + """Test that environment variables are passed to container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "echo $TEST_VAR"], + env={"TEST_VAR": "test_value_123"}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_multiple_environment_variables(runtime): + """Test multiple environment variables.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "echo $VAR1 $VAR2 $VAR3"], + env={ + "VAR1": "value1", + "VAR2": "value2", + "VAR3": "value3", + }, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_mounts(runtime, tmp_path): + """Test that volume mounts work correctly.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Create a test file on host + host_dir = tmp_path / "host_mount" + host_dir.mkdir() + test_file = host_dir / "test.txt" + test_file.write_text("test content from host") + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "cat /mnt/test.txt"], + env={}, + mounts=[(str(host_dir), "/mnt", "ro")], + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_mounts_writable(runtime, tmp_path): + """Test writable volume mounts.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Create a directory for writable mount + host_dir = tmp_path / "writable_mount" + host_dir.mkdir() + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "echo 'written from container' > /mnt/output.txt"], + env={}, + mounts=[(str(host_dir), "/mnt", "rw")], + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Verify file was written from container + output_file = host_dir / "output.txt" + assert output_file.exists() + assert "written from container" in output_file.read_text() + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_port_mapping(runtime): + """Test port mapping configuration.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Note: We can't easily test actual network connectivity in unit tests, + # but we can verify the port mapping configuration is applied + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + ports={"http": 8080, "metrics": 9090}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Inspect container to verify port mappings were configured + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{json .HostConfig.PortBindings}}", + ], + capture_output=True, + text=True, + check=True, + ) + port_bindings = result.stdout.strip() + + # Should have port mappings (exact format depends on Docker version) + # Just verify the ports appear in the output + assert "8080" in port_bindings + assert "9090" in port_bindings + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_workdir(runtime): + """Test custom working directory.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "pwd"], + env={}, + workdir="/custom/workdir", + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + status = runtime.inspect(container_id) + assert status.exit_code == 0 + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_container_cleanup_with_remove(runtime): + """Test that remove() properly cleans up containers.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + # Container should exist before removal + result = subprocess.run( + [ + "docker", + "ps", + "-a", + "--filter", + f"id={container_id}", + "--format", + "{{.ID}}", + ], + capture_output=True, + text=True, + check=True, + ) + assert container_id[:12] in result.stdout + + # Remove container + runtime.remove(container_id) + + # Wait a moment for removal to complete + time.sleep(0.1) + + # Container should not exist after removal + result = subprocess.run( + [ + "docker", + "ps", + "-a", + "--filter", + f"id={container_id}", + "--format", + "{{.ID}}", + ], + capture_output=True, + text=True, + check=True, + ) + assert result.stdout.strip() == "" + + +@pytest.mark.slow +def test_security_hardening_no_new_privileges(runtime): + """Test that no-new-privileges security option is applied.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + ) + + container_id = runtime.create_container(config) + + # Inspect container security options + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{json .HostConfig.SecurityOpt}}", + ], + capture_output=True, + text=True, + check=True, + ) + security_opts = result.stdout.strip() + + assert "no-new-privileges" in security_opts + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_security_hardening_cap_drop_all(runtime): + """Test that all capabilities are dropped.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + ) + + container_id = runtime.create_container(config) + + # Inspect container capability drops + result = subprocess.run( + [ + "docker", + "inspect", + container_id, + "--format", + "{{json .HostConfig.CapDrop}}", + ], + capture_output=True, + text=True, + check=True, + ) + cap_drop = result.stdout.strip() + + assert "ALL" in cap_drop + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_kill_with_sigterm(runtime): + """Test killing container with SIGTERM.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Start a long-running container with a trap to handle SIGTERM + config = ContainerConfig( + image="alpine:latest", + command=["sh", "-c", "trap 'exit 0' TERM; while true; do sleep 1; done"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait a moment for container to be running + time.sleep(0.5) + + # Kill with SIGTERM + runtime.kill(container_id, force=False) + + # Wait longer for graceful shutdown + time.sleep(1.0) + + # Verify container is stopped + status = runtime.inspect(container_id) + assert not status.running + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_kill_with_sigkill(runtime): + """Test killing container with SIGKILL (force).""" + if not check_docker_available(): + pytest.skip("Docker not available") + + # Start a long-running container + config = ContainerConfig( + image="alpine:latest", + command=["sleep", "30"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait a moment for container to be running + time.sleep(0.5) + + # Force kill with SIGKILL + runtime.kill(container_id, force=True) + + # Wait for container to stop + time.sleep(0.5) + + # Verify container is stopped + status = runtime.inspect(container_id) + assert not status.running + + # Cleanup + runtime.remove(container_id) + + +@pytest.mark.slow +def test_inspect_running_container(runtime): + """Test inspect() on a running container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sleep", "10"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Check status while running + time.sleep(0.5) + status = runtime.inspect(container_id) + assert status.running + assert status.exit_code is None + + # Kill and cleanup + runtime.kill(container_id, force=True) + runtime.remove(container_id) + + +@pytest.mark.slow +def test_inspect_stopped_container(runtime): + """Test inspect() on a stopped container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "test"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + # Wait for completion + max_wait = 5 + start = time.time() + while time.time() - start < max_wait: + status = runtime.inspect(container_id) + if not status.running: + break + time.sleep(0.1) + + # Check final status + status = runtime.inspect(container_id) + assert not status.running + assert status.exit_code == 0 + + # Cleanup + runtime.remove(container_id) diff --git a/lib/fluster/tests/cluster/worker/test_stats.py b/lib/fluster/tests/cluster/worker/test_stats.py new file mode 100644 index 0000000000..1e70967065 --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_stats.py @@ -0,0 +1,146 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Docker container statistics collection.""" + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import docker.errors +import pytest + +from fluster.cluster.worker.docker import ContainerStats, DockerRuntime +from fluster.cluster.worker.worker_types import collect_workdir_size_mb + + +def check_docker_available(): + """Check if Docker is available and running.""" + try: + result = subprocess.run( + ["docker", "info"], + check=True, + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return False + + +def test_collect_workdir_size_mb_with_temp_directory(tmp_path): + """Test workdir size calculation with a temporary directory.""" + # Create some files in temp directory + (tmp_path / "file1.txt").write_text("x" * 1024 * 100) # 100 KB + (tmp_path / "file2.txt").write_text("y" * 1024 * 100) # 100 KB + + size_mb = collect_workdir_size_mb(tmp_path) + + # Size should be at least 1 MB (200 KB rounded up) + assert size_mb >= 1 + + +def test_collect_workdir_size_mb_nonexistent_directory(): + """Test workdir size returns 0 for non-existent directory.""" + nonexistent = Path("/nonexistent/path/does/not/exist") + + size_mb = collect_workdir_size_mb(nonexistent) + + assert size_mb == 0 + + +def test_get_stats_invalid_container(): + """Test that get_stats returns available=False for invalid container ID.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + runtime = DockerRuntime() + invalid_container_id = "nonexistent_container_12345" + + stats = runtime.get_stats(invalid_container_id) + + assert isinstance(stats, ContainerStats) + assert stats.available is False + assert stats.memory_mb == 0 + assert stats.cpu_percent == 0 + assert stats.process_count == 0 + + +def test_get_stats_with_mock(): + """Test get_stats parsing with mocked Docker client.""" + runtime = DockerRuntime() + + # Mock stats response with realistic Docker stats format + mock_stats = { + "memory_stats": { + "usage": 100 * 1024 * 1024, # 100 MB in bytes + }, + "cpu_stats": { + "cpu_usage": {"total_usage": 2000000000}, + "system_cpu_usage": 10000000000, + "online_cpus": 4, + }, + "precpu_stats": { + "cpu_usage": {"total_usage": 1000000000}, + "system_cpu_usage": 9000000000, + }, + "pids_stats": { + "current": 5, + }, + } + + mock_container = MagicMock() + mock_container.stats.return_value = mock_stats + + mock_client = MagicMock() + mock_client.containers.get.return_value = mock_container + + with patch("fluster.cluster.worker.docker.docker") as mock_docker: + mock_docker.from_env.return_value = mock_client + stats = runtime.get_stats("test_container") + + assert stats.available is True + assert stats.memory_mb == 100 + assert stats.cpu_percent == 400 # (1000000000 / 1000000000) * 4 * 100 + assert stats.process_count == 5 + + +def test_get_stats_not_found_exception(): + """Test that NotFound exception returns available=False.""" + runtime = DockerRuntime() + + mock_client = MagicMock() + mock_client.containers.get.side_effect = docker.errors.NotFound("Container not found") + + with patch("fluster.cluster.worker.docker.docker") as mock_docker: + mock_docker.from_env.return_value = mock_client + mock_docker.errors = docker.errors + stats = runtime.get_stats("missing_container") + + assert stats.available is False + + +def test_get_stats_api_error_exception(): + """Test that APIError exception returns available=False.""" + runtime = DockerRuntime() + + mock_client = MagicMock() + mock_client.containers.get.side_effect = docker.errors.APIError("API error") + + with patch("fluster.cluster.worker.docker.docker") as mock_docker: + mock_docker.from_env.return_value = mock_client + mock_docker.errors = docker.errors + stats = runtime.get_stats("error_container") + + assert stats.available is False diff --git a/lib/fluster/tests/cluster/worker/test_worker.py b/lib/fluster/tests/cluster/worker/test_worker.py new file mode 100644 index 0000000000..441866edbd --- /dev/null +++ b/lib/fluster/tests/cluster/worker/test_worker.py @@ -0,0 +1,845 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Worker class (includes PortAllocator and job management).""" + +import socket +import subprocess +import time +import zipfile +from pathlib import Path +from unittest.mock import Mock + +import cloudpickle +import pytest +from connectrpc.request import RequestContext + +from fluster import cluster_pb2 +from fluster.cluster.types import Entrypoint +from fluster.cluster.worker.builder import BuildResult, ImageCache, VenvCache +from fluster.cluster.worker.bundle import BundleCache +from fluster.cluster.worker.docker import ContainerConfig, ContainerStats, ContainerStatus, DockerRuntime, ImageBuilder +from fluster.cluster.worker.service import WorkerServiceImpl +from fluster.cluster.worker.worker import PortAllocator, Worker, WorkerConfig + + +# ============================================================================ +# PortAllocator Tests +# ============================================================================ + + +@pytest.fixture +def allocator(): + """Create PortAllocator with small range for testing.""" + return PortAllocator(port_range=(40000, 40100)) + + +def test_allocate_single_port(allocator): + """Test allocating a single port.""" + ports = allocator.allocate(count=1) + assert len(ports) == 1 + assert 40000 <= ports[0] < 40100 + + +def test_allocate_multiple_ports(allocator): + """Test allocating multiple ports at once.""" + ports = allocator.allocate(count=5) + assert len(ports) == 5 + assert len(set(ports)) == 5 # All unique + for port in ports: + assert 40000 <= port < 40100 + + +def test_allocated_ports_are_usable(allocator): + """Test that allocated ports can actually be bound.""" + ports = allocator.allocate(count=3) + + for port in ports: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + + +def test_no_port_reuse_before_release(allocator): + """Test that allocated ports are not reused before release.""" + ports1 = allocator.allocate(count=5) + ports2 = allocator.allocate(count=5) + + assert len(set(ports1) & set(ports2)) == 0 + + +def test_ports_reused_after_release(): + """Test that ports can be reused after release.""" + allocator_small = PortAllocator(port_range=(40000, 40003)) + + ports1 = allocator_small.allocate(count=3) + assert len(ports1) == 3 + + allocator_small.release(ports1) + + ports2 = allocator_small.allocate(count=3) + assert len(ports2) == 3 + assert set(ports1) == set(ports2) + + +def test_release_partial_ports(allocator): + """Test releasing only some ports.""" + ports = allocator.allocate(count=5) + + allocator.release(ports[:3]) + + new_ports = allocator.allocate(count=2) + assert len(set(new_ports) & set(ports[:3])) > 0 + + +def test_exhausted_port_range(): + """Test behavior when port range is exhausted.""" + allocator_tiny = PortAllocator(port_range=(40000, 40002)) + + ports = allocator_tiny.allocate(count=2) + assert len(ports) == 2 + + with pytest.raises(RuntimeError, match="No free ports available"): + allocator_tiny.allocate(count=1) + + +def test_concurrent_allocations(allocator): + """Test concurrent port allocations are thread-safe.""" + import threading + + results = [] + + def allocate_ports(): + ports = allocator.allocate(count=5) + results.append(ports) + + threads = [threading.Thread(target=allocate_ports) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + all_ports = [] + for ports in results: + all_ports.extend(ports) + + assert len(all_ports) == len(set(all_ports)) + + +def test_release_nonexistent_port(allocator): + """Test that releasing a non-allocated port doesn't cause errors.""" + allocator.release([99999]) + + +def test_default_port_range(): + """Test default port range is 30000-40000.""" + allocator = PortAllocator() + ports = allocator.allocate(count=5) + + for port in ports: + assert 30000 <= port < 40000 + + +# ============================================================================ +# Worker Tests (with mocked dependencies) +# ============================================================================ + + +@pytest.fixture +def mock_bundle_cache(): + """Create mock BundleCache.""" + cache = Mock(spec=BundleCache) + cache.get_bundle = Mock(return_value=Path("/tmp/bundle")) + return cache + + +@pytest.fixture +def mock_venv_cache(): + """Create mock VenvCache.""" + cache = Mock(spec=VenvCache) + cache.compute_deps_hash = Mock(return_value="abc123") + return cache + + +@pytest.fixture +def mock_image_cache(): + """Create mock ImageBuilder.""" + builder = Mock(spec=ImageBuilder) + builder.build = Mock( + return_value=BuildResult( + image_tag="test-image:latest", + deps_hash="abc123", + build_time_ms=1000, + from_cache=False, + ) + ) + return builder + + +@pytest.fixture +def mock_runtime(): + """Create mock DockerRuntime. + + By default, simulates a container that runs and completes successfully. + """ + runtime = Mock(spec=DockerRuntime) + runtime.create_container = Mock(return_value="container123") + runtime.start_container = Mock() + + call_count = [0] + + def inspect_side_effect(container_id): + call_count[0] += 1 + if call_count[0] == 1: + return ContainerStatus(running=True) + return ContainerStatus(running=False, exit_code=0) + + runtime.inspect = Mock(side_effect=inspect_side_effect) + runtime.kill = Mock() + runtime.remove = Mock() + runtime.get_stats = Mock( + return_value=ContainerStats(memory_mb=100, cpu_percent=50, process_count=5, available=True) + ) + runtime.get_logs = Mock(return_value=[]) + return runtime + + +@pytest.fixture +def worker(mock_bundle_cache, mock_venv_cache, mock_image_cache, mock_runtime): + """Create Worker with mocked dependencies.""" + config = WorkerConfig( + port=0, + max_concurrent_jobs=5, + port_range=(50000, 50100), + ) + return Worker( + config, + bundle_provider=mock_bundle_cache, + image_provider=mock_image_cache, + container_runtime=mock_runtime, + ) + + +def create_test_entrypoint(): + """Create a simple test entrypoint.""" + from dataclasses import dataclass + + @dataclass + class TestEntrypoint: + callable: object + args: tuple = () + kwargs: dict | None = None + + def __post_init__(self): + if self.kwargs is None: + self.kwargs = {} + + def test_fn(): + print("Hello from test") + + return TestEntrypoint(callable=test_fn) + + +def create_run_job_request(job_id: str = "test-job-1", ports: list[str] | None = None): + """Create a RunJobRequest for testing.""" + entrypoint = create_test_entrypoint() + serialized_entrypoint = cloudpickle.dumps(entrypoint) + + env_config = cluster_pb2.EnvironmentConfig( + workspace="/workspace", + env_vars={ + "TEST_VAR": "value", + "JOB_VAR": "job_value", + }, + extras=["dev"], + ) + + resources = cluster_pb2.ResourceSpec( + cpu=2, + memory="4g", + ) + + return cluster_pb2.Worker.RunJobRequest( + job_id=job_id, + serialized_entrypoint=serialized_entrypoint, + environment=env_config, + bundle_gcs_path="gs://bucket/bundle.zip", + resources=resources, + timeout_seconds=300, + ports=ports or [], + ) + + +def test_submit_job_returns_job_id(worker): + """Test that submit_job returns job_id immediately.""" + request = create_run_job_request() + job_id = worker.submit_job(request) + + assert job_id == "test-job-1" + + job = worker.get_job(job_id) + assert job is not None + assert job.job_id == job_id + + +def test_job_lifecycle_phases(worker): + """Test job transitions through PENDING → BUILDING → RUNNING → SUCCEEDED.""" + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_SUCCEEDED + assert final_job.exit_code == 0 + + +def test_job_with_ports(worker): + """Test job with port allocation.""" + request = create_run_job_request(ports=["http", "grpc"]) + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + assert len(job.ports) == 2 + assert "http" in job.ports + assert "grpc" in job.ports + assert job.ports["http"] != job.ports["grpc"] + + job.thread.join(timeout=15.0) + + +def test_job_failure_on_nonzero_exit(worker, mock_runtime): + """Test job fails when container exits with non-zero code.""" + mock_runtime.inspect = Mock(return_value=ContainerStatus(running=False, exit_code=1)) + + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_FAILED + assert final_job.exit_code == 1 + assert "Exit code: 1" in final_job.error + + +def test_job_failure_on_error(worker, mock_runtime): + """Test job fails when container returns error.""" + call_count = [0] + + def inspect_side_effect(container_id): + call_count[0] += 1 + if call_count[0] == 1: + return ContainerStatus(running=True) + return ContainerStatus(running=False, exit_code=1, error="Container crashed") + + mock_runtime.inspect = Mock(side_effect=inspect_side_effect) + + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=10.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_FAILED + assert final_job.error == "Container crashed" + + +def test_job_exception_handling(worker, mock_bundle_cache): + """Test job handles exceptions during execution.""" + mock_bundle_cache.get_bundle = Mock(side_effect=Exception("Bundle download failed")) + + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_FAILED + assert "Bundle download failed" in final_job.error + + +def test_list_jobs(worker): + """Test listing all jobs.""" + requests = [create_run_job_request(job_id=f"job-{i}") for i in range(3)] + + for request in requests: + worker.submit_job(request) + + jobs = worker.list_jobs() + assert len(jobs) == 3 + assert {job.job_id for job in jobs} == {"job-0", "job-1", "job-2"} + + +def test_kill_running_job(worker, mock_runtime): + """Test killing a running job with graceful timeout.""" + mock_runtime.inspect = Mock(return_value=ContainerStatus(running=True)) + + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + for _ in range(20): + if job.status == cluster_pb2.JOB_STATE_RUNNING and job.container_id: + break + time.sleep(0.1) + + result = worker.kill_job(job_id, term_timeout_ms=100) + assert result is True + + job.thread.join(timeout=15.0) + + assert job.status == cluster_pb2.JOB_STATE_KILLED + mock_runtime.kill.assert_any_call("container123", force=False) + + +def test_kill_nonexistent_job(worker): + """Test killing a nonexistent job returns False.""" + result = worker.kill_job("nonexistent-job") + assert result is False + + +def test_get_logs_empty(worker): + """Test getting logs for job immediately after submission.""" + request = create_run_job_request() + job_id = worker.submit_job(request) + + logs = worker.get_logs(job_id) + assert isinstance(logs, list) + + +def test_get_logs_nonexistent_job(worker): + """Test getting logs for nonexistent job returns empty list.""" + logs = worker.get_logs("nonexistent-job") + assert logs == [] + + +def test_build_command_with_entrypoint(worker): + """Test _build_command creates correct cloudpickle command.""" + entrypoint = create_test_entrypoint() + command = worker._build_command(entrypoint) + + assert command[0] == "python" + assert command[1] == "-c" + assert "cloudpickle" in command[2] + assert "base64" in command[2] + + +def test_fray_port_mapping_env_var(worker, mock_runtime): + """Test that FRAY_PORT_MAPPING environment variable is set with port mappings.""" + request = create_run_job_request(ports=["web", "api", "metrics"]) + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + assert mock_runtime.create_container.called + call_args = mock_runtime.create_container.call_args + config = call_args[0][0] + + assert "FRAY_PORT_MAPPING" in config.env + + port_mapping = config.env["FRAY_PORT_MAPPING"] + mappings = {} + for pair in port_mapping.split(","): + name, port = pair.split(":") + mappings[name] = int(port) + + assert set(mappings.keys()) == {"web", "api", "metrics"} + assert len(set(mappings.values())) == 3 + + assert "FLUSTER_PORT_WEB" in config.env + assert "FLUSTER_PORT_API" in config.env + assert "FLUSTER_PORT_METRICS" in config.env + + +def test_fray_port_mapping_not_set_when_no_ports(worker, mock_runtime): + """Test that FRAY_PORT_MAPPING is not set when no ports are requested.""" + request = create_run_job_request(ports=[]) + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + assert mock_runtime.create_container.called + call_args = mock_runtime.create_container.call_args + config = call_args[0][0] + + assert "FRAY_PORT_MAPPING" not in config.env + + +def test_job_failure_error_appears_in_logs(worker, mock_bundle_cache): + """Test that job failure errors appear in logs.""" + mock_bundle_cache.get_bundle = Mock(side_effect=Exception("Bundle download failed")) + + request = create_run_job_request() + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_FAILED + assert "Bundle download failed" in final_job.error + + logs = worker.get_logs(job_id) + error_logs = [log for log in logs if log.source == "error"] + assert len(error_logs) >= 1 + assert any("Bundle download failed" in log.data for log in error_logs) + + +def test_port_retry_on_binding_failure(mock_bundle_cache, mock_venv_cache, mock_image_cache): + """Test that job retries with new ports when port binding fails.""" + runtime = Mock(spec=DockerRuntime) + runtime.create_container = Mock(return_value="container123") + + call_count = [0] + + def start_side_effect(container_id): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("failed to bind host port: address already in use") + return None + + runtime.start_container = Mock(side_effect=start_side_effect) + runtime.remove = Mock() + + inspect_call_count = [0] + + def inspect_side_effect(container_id): + inspect_call_count[0] += 1 + if inspect_call_count[0] == 1: + return ContainerStatus(running=True) + return ContainerStatus(running=False, exit_code=0) + + runtime.inspect = Mock(side_effect=inspect_side_effect) + runtime.get_stats = Mock( + return_value=ContainerStats(memory_mb=100, cpu_percent=50, process_count=5, available=True) + ) + runtime.get_logs = Mock(return_value=[]) + + config = WorkerConfig( + port=0, + max_concurrent_jobs=5, + port_range=(50000, 50100), + ) + worker = Worker( + config, + bundle_provider=mock_bundle_cache, + image_provider=mock_image_cache, + container_runtime=runtime, + ) + + request = create_run_job_request(ports=["actor"]) + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_SUCCEEDED + + assert runtime.start_container.call_count == 2 + assert runtime.remove.call_count == 1 + + logs = worker.get_logs(job_id) + build_logs = [log for log in logs if log.source == "build"] + assert any("Port conflict" in log.data for log in build_logs) + + +def test_port_retry_exhausted(mock_bundle_cache, mock_venv_cache, mock_image_cache): + """Test that job fails after max port retries are exhausted.""" + runtime = Mock(spec=DockerRuntime) + runtime.create_container = Mock(return_value="container123") + runtime.start_container = Mock( + side_effect=RuntimeError("failed to bind host port: address already in use") + ) + runtime.remove = Mock() + runtime.get_logs = Mock(return_value=[]) + + config = WorkerConfig( + port=0, + max_concurrent_jobs=5, + port_range=(50000, 50100), + ) + worker = Worker( + config, + bundle_provider=mock_bundle_cache, + image_provider=mock_image_cache, + container_runtime=runtime, + ) + + request = create_run_job_request(ports=["actor"]) + job_id = worker.submit_job(request) + + job = worker.get_job(job_id) + job.thread.join(timeout=15.0) + + final_job = worker.get_job(job_id) + assert final_job.status == cluster_pb2.JOB_STATE_FAILED + assert "address already in use" in final_job.error + + assert runtime.start_container.call_count == 3 + + +# ============================================================================ +# Integration Tests (with real Docker) +# ============================================================================ + + +def check_docker_available(): + """Check if Docker is available and running.""" + try: + result = subprocess.run( + ["docker", "info"], + check=True, + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return False + + +def create_test_bundle(tmp_path): + """Create a minimal test bundle with pyproject.toml.""" + bundle_dir = tmp_path / "bundle" + bundle_dir.mkdir() + + (bundle_dir / "pyproject.toml").write_text( + """[project] +name = "test-job" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [] +""" + ) + + zip_path = tmp_path / "bundle.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + for f in bundle_dir.rglob("*"): + if f.is_file(): + zf.write(f, f.relative_to(bundle_dir)) + + return f"file://{zip_path}" + + +def create_integration_entrypoint(): + """Create a simple test entrypoint for integration tests.""" + + def test_fn(): + print("Hello from test job!") + return 42 + + return Entrypoint(callable=test_fn, args=(), kwargs={}) + + +def create_integration_run_job_request(bundle_path: str, job_id: str): + """Create a RunJobRequest for integration testing.""" + entrypoint = create_integration_entrypoint() + + return cluster_pb2.Worker.RunJobRequest( + job_id=job_id, + serialized_entrypoint=cloudpickle.dumps(entrypoint), + bundle_gcs_path=bundle_path, + environment=cluster_pb2.EnvironmentConfig( + workspace="/app", + ), + resources=cluster_pb2.ResourceSpec( + cpu=1, + memory="512m", + ), + ) + + +@pytest.fixture +def cache_dir(tmp_path): + """Create a temporary cache directory.""" + cache = tmp_path / "cache" + cache.mkdir() + return cache + + +@pytest.fixture +def test_bundle(tmp_path): + """Create a test bundle and return file:// path.""" + return create_test_bundle(tmp_path) + + +@pytest.fixture +def real_worker(cache_dir): + """Create Worker with real components (not mocks).""" + config = WorkerConfig( + port=0, + cache_dir=cache_dir, + registry="localhost:5000", + max_concurrent_jobs=2, + port_range=(40000, 40100), + ) + return Worker(config) + + +@pytest.fixture +def real_service(real_worker): + """Create WorkerServiceImpl with real worker.""" + return WorkerServiceImpl(real_worker) + + +@pytest.fixture +def runtime(): + """Create DockerRuntime instance.""" + return DockerRuntime() + + +class TestDockerRuntimeIntegration: + """Integration tests for DockerRuntime with real containers.""" + + @pytest.mark.slow + def test_create_and_start_container(self, runtime): + """Create and start a simple container and verify it runs.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["echo", "hello"], + env={}, + ) + + container_id = runtime.create_container(config) + assert container_id is not None + + runtime.start_container(container_id) + + time.sleep(1) + + status = runtime.inspect(container_id) + assert not status.running + assert status.exit_code == 0 + + runtime.remove(container_id) + + @pytest.mark.slow + def test_kill_container(self, runtime): + """Test killing a running container.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + config = ContainerConfig( + image="alpine:latest", + command=["sleep", "60"], + env={}, + ) + + container_id = runtime.create_container(config) + runtime.start_container(container_id) + + time.sleep(1) + + runtime.kill(container_id, force=True) + + status = runtime.inspect(container_id) + assert not status.running + + runtime.remove(container_id) + + +class TestWorkerIntegration: + """Integration tests for Worker with real components.""" + + @pytest.mark.slow + def test_submit_job_lifecycle(self, real_worker, test_bundle): + """Test full job lifecycle from submission to completion.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + request = create_integration_run_job_request(test_bundle, "integration-test-1") + + job_id = real_worker.submit_job(request) + assert job_id == "integration-test-1" + + for _ in range(30): + time.sleep(1) + job = real_worker.get_job(job_id) + + if job.status in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + cluster_pb2.JOB_STATE_KILLED, + ): + break + + job = real_worker.get_job(job_id) + assert job.status in ( + cluster_pb2.JOB_STATE_SUCCEEDED, + cluster_pb2.JOB_STATE_FAILED, + ) + + @pytest.mark.slow + def test_concurrent_job_limit(self, real_worker, test_bundle): + """Test that max_concurrent_jobs is enforced.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + requests = [ + create_integration_run_job_request(test_bundle, f"concurrent-{i}") for i in range(4) + ] + + _job_ids = [real_worker.submit_job(r) for r in requests] + + time.sleep(1) + + jobs = real_worker.list_jobs() + running = sum(1 for j in jobs if j.status == cluster_pb2.JOB_STATE_RUNNING) + + assert running <= 2 + + +class TestWorkerServiceIntegration: + """Integration tests for WorkerService RPC implementation.""" + + @pytest.mark.slow + def test_health_check_rpc(self, real_service): + """Test HealthCheck RPC returns healthy status.""" + ctx = Mock(spec=RequestContext) + + response = real_service.health_check(cluster_pb2.Empty(), ctx) + + assert response.healthy + assert response.uptime_ms >= 0 + + @pytest.mark.slow + def test_fetch_logs_tail(self, real_service, test_bundle): + """Test FetchLogs with negative start_line for tailing.""" + if not check_docker_available(): + pytest.skip("Docker not available") + + ctx = Mock(spec=RequestContext) + + request = create_integration_run_job_request(test_bundle, "logs-test") + real_service.run_job(request, ctx) + + time.sleep(2) + + log_request = cluster_pb2.Worker.FetchLogsRequest( + job_id="logs-test", + filter=cluster_pb2.Worker.FetchLogsFilter(start_line=-10), + ) + + response = real_service.fetch_logs(log_request, ctx) + assert response.logs is not None + assert len(response.logs) >= 0 diff --git a/lib/fluster/tests/conftest.py b/lib/fluster/tests/conftest.py new file mode 100644 index 0000000000..a5a7a57323 --- /dev/null +++ b/lib/fluster/tests/conftest.py @@ -0,0 +1,15 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test configuration for fluster diff --git a/lib/fluster/tests/test_worker_pool.py b/lib/fluster/tests/test_worker_pool.py new file mode 100644 index 0000000000..9ea44818f7 --- /dev/null +++ b/lib/fluster/tests/test_worker_pool.py @@ -0,0 +1,751 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for WorkerPool behavior. + +These tests exercise the WorkerPool through its public interface, testing +observable behavior rather than implementation details. The test harness +uses real ActorServers with TaskExecutorActor instances, bypassing only +the job-launching infrastructure (ClusterClient). +""" + +import time +from concurrent.futures import Future +from dataclasses import dataclass +from queue import Queue + +import cloudpickle +import pytest +from connectrpc.errors import ConnectError +from fluster import cluster_pb2 +from fluster.actor import ActorServer +from fluster.actor.resolver import FixedResolver +from fluster.cluster.types import Entrypoint, JobId +from fluster.worker_pool import ( + PendingTask, + TaskExecutorActor, + WorkerDispatcher, + WorkerPool, + WorkerPoolConfig, + WorkerState, + WorkerStatus, +) + +# ============================================================================= +# Unit tests for TaskExecutorActor +# ============================================================================= + + +def test_execute_basic(): + """Test basic function execution.""" + executor = TaskExecutorActor() + + fn_bytes = cloudpickle.dumps(lambda x, y: x + y) + args_bytes = cloudpickle.dumps((1, 2)) + kwargs_bytes = cloudpickle.dumps({}) + + result = executor.execute(fn_bytes, args_bytes, kwargs_bytes) + assert result == 3 + + +def test_execute_with_kwargs(): + """Test execution with keyword arguments.""" + executor = TaskExecutorActor() + + def greet(name, greeting="Hello"): + return f"{greeting}, {name}!" + + fn_bytes = cloudpickle.dumps(greet) + args_bytes = cloudpickle.dumps(("World",)) + kwargs_bytes = cloudpickle.dumps({"greeting": "Hi"}) + + result = executor.execute(fn_bytes, args_bytes, kwargs_bytes) + assert result == "Hi, World!" + + +def test_execute_returns_complex_object(): + """Test that complex objects can be returned.""" + executor = TaskExecutorActor() + + def create_dict(): + return {"a": [1, 2, 3], "b": {"nested": True}} + + fn_bytes = cloudpickle.dumps(create_dict) + args_bytes = cloudpickle.dumps(()) + kwargs_bytes = cloudpickle.dumps({}) + + result = executor.execute(fn_bytes, args_bytes, kwargs_bytes) + assert result == {"a": [1, 2, 3], "b": {"nested": True}} + + +def test_execute_propagates_exception(): + """Test that exceptions are propagated.""" + executor = TaskExecutorActor() + + def raise_error(): + raise ValueError("test error") + + fn_bytes = cloudpickle.dumps(raise_error) + args_bytes = cloudpickle.dumps(()) + kwargs_bytes = cloudpickle.dumps({}) + + with pytest.raises(ValueError, match="test error"): + executor.execute(fn_bytes, args_bytes, kwargs_bytes) + + +def test_execute_with_closure(): + """Test execution of closures that capture variables.""" + executor = TaskExecutorActor() + + multiplier = 10 + + def multiply(x): + return x * multiplier + + fn_bytes = cloudpickle.dumps(multiply) + args_bytes = cloudpickle.dumps((5,)) + kwargs_bytes = cloudpickle.dumps({}) + + result = executor.execute(fn_bytes, args_bytes, kwargs_bytes) + assert result == 50 + + +# ============================================================================= +# WorkerDispatcher behavioral tests +# ============================================================================= + + +@pytest.fixture +def worker_server(): + """Start a real ActorServer with TaskExecutorActor for testing.""" + server = ActorServer(host="127.0.0.1", port=0) + actor_name = "_test_worker" + server.register(actor_name, TaskExecutorActor()) + port = server.serve_background() + yield f"http://127.0.0.1:{port}", actor_name + + +def test_dispatch_discovers_worker_endpoint(worker_server): + """Dispatcher transitions worker from PENDING to IDLE when endpoint is discovered.""" + url, actor_name = worker_server + + worker_state = WorkerState( + worker_id="w-0", + worker_name=actor_name, + status=WorkerStatus.PENDING, + ) + + resolver = FixedResolver({actor_name: url}) + task_queue: Queue[PendingTask] = Queue() + + dispatcher = WorkerDispatcher( + state=worker_state, + task_queue=task_queue, + resolver=resolver, + timeout=5.0, + ) + dispatcher.start() + + # Wait for discovery + deadline = time.time() + 2.0 + while worker_state.status == WorkerStatus.PENDING and time.time() < deadline: + time.sleep(0.05) + + dispatcher.stop() + dispatcher.join(timeout=1.0) + + assert worker_state.status == WorkerStatus.IDLE + assert worker_state.endpoint_url == url + + +def test_dispatch_executes_task_on_worker(worker_server): + """Dispatcher executes tasks and sets results on futures.""" + url, actor_name = worker_server + + worker_state = WorkerState( + worker_id="w-0", + worker_name=actor_name, + endpoint_url=url, + status=WorkerStatus.IDLE, + ) + + resolver = FixedResolver({actor_name: url}) + task_queue: Queue[PendingTask] = Queue() + + dispatcher = WorkerDispatcher( + state=worker_state, + task_queue=task_queue, + resolver=resolver, + timeout=5.0, + ) + dispatcher.start() + + # Submit a task + future: Future = Future() + task = PendingTask( + task_id="task-1", + serialized_fn=cloudpickle.dumps(lambda x: x * 2), + serialized_args=cloudpickle.dumps((21,)), + serialized_kwargs=cloudpickle.dumps({}), + future=future, + fn_name="multiply", + submitted_at=time.monotonic(), + retries_remaining=0, + ) + task_queue.put(task) + + # Wait for result + result = future.result(timeout=5.0) + assert result == 42 + + dispatcher.stop() + dispatcher.join(timeout=1.0) + + assert worker_state.tasks_completed == 1 + + +def test_dispatch_propagates_user_exceptions(worker_server): + """User exceptions are propagated without retry.""" + url, actor_name = worker_server + + worker_state = WorkerState( + worker_id="w-0", + worker_name=actor_name, + endpoint_url=url, + status=WorkerStatus.IDLE, + ) + + resolver = FixedResolver({actor_name: url}) + task_queue: Queue[PendingTask] = Queue() + + dispatcher = WorkerDispatcher( + state=worker_state, + task_queue=task_queue, + resolver=resolver, + timeout=5.0, + ) + dispatcher.start() + + # Submit a task that raises + def raise_error(): + raise ValueError("test error") + + future: Future = Future() + task = PendingTask( + task_id="task-err", + serialized_fn=cloudpickle.dumps(raise_error), + serialized_args=cloudpickle.dumps(()), + serialized_kwargs=cloudpickle.dumps({}), + future=future, + fn_name="raise_error", + submitted_at=time.monotonic(), + retries_remaining=3, # Should not retry user exceptions + ) + task_queue.put(task) + + # Should get the ValueError, not be retried + with pytest.raises(ValueError, match="test error"): + future.result(timeout=5.0) + + dispatcher.stop() + dispatcher.join(timeout=1.0) + + assert worker_state.tasks_failed == 1 + # Worker should still be IDLE (not FAILED) since it was a user exception + assert worker_state.status == WorkerStatus.IDLE + + +def test_dispatch_retries_on_infrastructure_failure(): + """Infrastructure failures cause re-queue if retries remain.""" + # Start a real server for the "good" worker + server = ActorServer(host="127.0.0.1", port=0) + actor_name = "_test_worker" + server.register(actor_name, TaskExecutorActor()) + port = server.serve_background() + good_url = f"http://127.0.0.1:{port}" + + # Worker 1 points to a non-existent endpoint (will fail) + worker_state_1 = WorkerState( + worker_id="w-0", + worker_name=actor_name, + endpoint_url="http://127.0.0.1:9999", + status=WorkerStatus.IDLE, + ) + + # Worker 2 points to the real server + worker_state_2 = WorkerState( + worker_id="w-1", + worker_name=actor_name, + endpoint_url=good_url, + status=WorkerStatus.IDLE, + ) + + resolver = FixedResolver({actor_name: good_url}) + task_queue: Queue[PendingTask] = Queue() + + # Start both dispatchers + dispatcher_1 = WorkerDispatcher( + state=worker_state_1, + task_queue=task_queue, + resolver=resolver, + timeout=1.0, # Short timeout for faster failure + ) + dispatcher_2 = WorkerDispatcher( + state=worker_state_2, + task_queue=task_queue, + resolver=resolver, + timeout=5.0, + ) + + dispatcher_1.start() + dispatcher_2.start() + + # Submit a task with retries + future: Future = Future() + task = PendingTask( + task_id="task-retry", + serialized_fn=cloudpickle.dumps(lambda: "success"), + serialized_args=cloudpickle.dumps(()), + serialized_kwargs=cloudpickle.dumps({}), + future=future, + fn_name="success_fn", + submitted_at=time.monotonic(), + retries_remaining=2, + ) + task_queue.put(task) + + # Worker 1 fails, task re-queued, worker 2 succeeds + result = future.result(timeout=10.0) + assert result == "success" + + dispatcher_1.stop() + dispatcher_2.stop() + dispatcher_1.join(timeout=1.0) + dispatcher_2.join(timeout=1.0) + + # Worker 1 should be FAILED after the connection error + assert worker_state_1.status == WorkerStatus.FAILED + # Worker 2 completed the task + assert worker_state_2.tasks_completed == 1 + + +# ============================================================================= +# E2E tests for WorkerPool +# ============================================================================= + + +@dataclass +class MockJob: + """Tracks a job submission.""" + + job_id: JobId + name: str + entrypoint: Entrypoint + resources: cluster_pb2.ResourceSpec + + +class MockClusterClient: + """Mock ClusterClient that tracks submissions without launching jobs. + + For E2E testing, we bypass the actual job infrastructure. Workers are + started directly as ActorServers and discovered via FixedResolver. + """ + + def __init__(self, controller_address: str = "http://mock-controller:8080"): + self._controller_address = controller_address + self._jobs: dict[JobId, MockJob] = {} + self._job_counter = 0 + + @property + def controller_address(self) -> str: + return self._controller_address + + def submit( + self, + entrypoint: Entrypoint, + name: str, + resources: cluster_pb2.ResourceSpec, + environment: cluster_pb2.EnvironmentConfig | None = None, + namespace: str = "", + ports: list[str] | None = None, + ) -> JobId: + job_id = JobId(f"mock-job-{self._job_counter}") + self._job_counter += 1 + self._jobs[job_id] = MockJob( + job_id=job_id, + name=name, + entrypoint=entrypoint, + resources=resources, + ) + return job_id + + def status(self, job_id: JobId) -> cluster_pb2.JobStatus: + return cluster_pb2.JobStatus( + job_id=str(job_id), + name=self._jobs[job_id].name if job_id in self._jobs else "", + state=cluster_pb2.JOB_STATE_RUNNING, + ) + + def wait( + self, + job_id: JobId, + timeout: float = 300.0, + poll_interval: float = 0.5, + ) -> cluster_pb2.JobStatus: + return self.status(job_id) + + def terminate(self, job_id: JobId) -> None: + pass + + @property + def submitted_jobs(self) -> list[MockJob]: + return list(self._jobs.values()) + + +@dataclass +class WorkerPoolTestHarness: + """Test harness that manages worker servers and provides a configured pool. + + Starts real ActorServers with TaskExecutorActor instances and configures + a FixedResolver to discover them. The WorkerPool uses these servers instead + of launching jobs via ClusterClient. + """ + + pool: WorkerPool + client: MockClusterClient + servers: list[ActorServer] + endpoints: dict[str, str] + + +@pytest.fixture +def worker_pool_harness(): + """Create a WorkerPool with 2 real worker servers.""" + num_workers = 2 + + # Start real ActorServers + servers = [] + endpoints = {} + for _ in range(num_workers): + server = ActorServer(host="127.0.0.1", port=0) + # Worker names are generated by WorkerPool as _workerpool_{pool_id}:worker-{i} + # We'll register with a placeholder name and update the resolver after pool creation + servers.append(server) + + # Create the pool with mock client + client = MockClusterClient() + config = WorkerPoolConfig( + num_workers=num_workers, + resources=cluster_pb2.ResourceSpec(cpu=1, memory="512m"), + max_retries=1, + ) + pool = WorkerPool(client, config, timeout=5.0) + + # Get the pool_id and set up workers with correct names + pool_id = pool.pool_id + for i, server in enumerate(servers): + worker_name = f"_workerpool_{pool_id}:worker-{i}" + server.register(worker_name, TaskExecutorActor()) + port = server.serve_background() + endpoints[worker_name] = f"http://127.0.0.1:{port}" + + # Create resolver and inject it + resolver = FixedResolver(endpoints) + pool._resolver = resolver + + yield WorkerPoolTestHarness( + pool=pool, + client=client, + servers=servers, + endpoints=endpoints, + ) + + # Cleanup + pool.shutdown(wait=False) + + +class TestWorkerPoolE2E: + """End-to-end tests for WorkerPool through its public interface.""" + + def test_pool_discovers_workers(self, worker_pool_harness): + """Workers transition from PENDING to IDLE when discovered.""" + harness = worker_pool_harness + pool = harness.pool + + # Manually trigger worker launch (normally done by __enter__) + pool._launch_workers() + + # Wait for workers to be discovered + pool._wait_for_workers(min_workers=2, timeout=5.0) + + assert pool.size == 2 + assert pool.idle_count == 2 + + def test_submit_executes_task(self, worker_pool_harness): + """submit() dispatches a task and returns correct result.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + def add(a, b): + return a + b + + future = pool.submit(add, 10, 20) + result = future.result(timeout=5.0) + + assert result == 30 + + def test_submit_with_kwargs(self, worker_pool_harness): + """submit() passes keyword arguments correctly.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + def greet(name, prefix="Hello"): + return f"{prefix}, {name}!" + + future = pool.submit(greet, "World", prefix="Hi") + result = future.result(timeout=5.0) + + assert result == "Hi, World!" + + def test_map_executes_in_parallel(self, worker_pool_harness): + """map() distributes work across workers.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=2, timeout=5.0) + + def square(x): + return x * x + + futures = pool.map(square, [1, 2, 3, 4, 5]) + results = [f.result(timeout=5.0) for f in futures] + + assert results == [1, 4, 9, 16, 25] + + def test_exception_propagates_to_caller(self, worker_pool_harness): + """Exceptions raised by user code propagate to the caller.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + def fail(): + raise ValueError("intentional error") + + future = pool.submit(fail) + + with pytest.raises(ValueError, match="intentional error"): + future.result(timeout=5.0) + + def test_complex_return_values(self, worker_pool_harness): + """Complex objects are properly serialized and returned.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + def create_complex(): + return { + "numbers": [1, 2, 3], + "nested": {"a": 1, "b": 2}, + "tuple": (1, "two", 3.0), + } + + future = pool.submit(create_complex) + result = future.result(timeout=5.0) + + assert result["numbers"] == [1, 2, 3] + assert result["nested"]["b"] == 2 + assert result["tuple"] == (1, "two", 3.0) + + def test_closures_work(self, worker_pool_harness): + """Functions that capture variables work correctly.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + multiplier = 7 + + def multiply(x): + return x * multiplier + + future = pool.submit(multiply, 6) + result = future.result(timeout=5.0) + + assert result == 42 + + def test_status_reflects_pool_state(self, worker_pool_harness): + """status() returns accurate information about pool state.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=2, timeout=5.0) + + status = pool.status() + + assert status.pool_id == pool.pool_id + assert status.num_workers == 2 + assert status.workers_idle == 2 + assert status.workers_pending == 0 + assert status.tasks_queued == 0 + + def test_future_done_and_exception(self, worker_pool_harness): + """WorkerFuture.done() and exception() work correctly.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + # Test successful completion + future_success = pool.submit(lambda: 42) + result = future_success.result(timeout=5.0) + assert result == 42 + assert future_success.done() + assert future_success.exception() is None + + # Test exception case + def fail(): + raise RuntimeError("expected") + + future_fail = pool.submit(fail) + with pytest.raises(RuntimeError): + future_fail.result(timeout=5.0) + + assert future_fail.done() + assert isinstance(future_fail.exception(), RuntimeError) + + def test_context_manager_waits_for_at_least_one_worker(self): + """__enter__ waits for at least one worker before returning.""" + # Set up a dedicated server for this test + server = ActorServer(host="127.0.0.1", port=0) + worker_name = "_workerpool_ctxtest:worker-0" + server.register(worker_name, TaskExecutorActor()) + port = server.serve_background() + + endpoints = {worker_name: f"http://127.0.0.1:{port}"} + + client = MockClusterClient() + config = WorkerPoolConfig( + num_workers=1, + resources=cluster_pb2.ResourceSpec(cpu=1), + ) + pool = WorkerPool(client, config, timeout=5.0, resolver=FixedResolver(endpoints)) + pool._pool_id = "ctxtest" + + with pool: + # By the time __enter__ returns, we should have at least 1 worker + assert pool.size >= 1 + + def test_shutdown_prevents_new_submissions(self, worker_pool_harness): + """After shutdown, submit() raises RuntimeError.""" + harness = worker_pool_harness + pool = harness.pool + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + pool.shutdown(wait=False) + + with pytest.raises(RuntimeError, match="shutdown"): + pool.submit(lambda: 42) + + +class TestWorkerPoolRetry: + """Tests for retry behavior on infrastructure failures.""" + + def test_task_retries_on_worker_failure(self): + """When a worker fails, the task is re-queued and picked up by another worker.""" + # Set up one bad worker and one good worker + good_server = ActorServer(host="127.0.0.1", port=0) + good_server.register("_workerpool_test:worker-1", TaskExecutorActor()) + good_port = good_server.serve_background() + + # Endpoints: worker-0 points to non-existent server, worker-1 is real + endpoints = { + "_workerpool_test:worker-0": "http://127.0.0.1:9999", # Will fail + "_workerpool_test:worker-1": f"http://127.0.0.1:{good_port}", + } + + client = MockClusterClient() + config = WorkerPoolConfig( + num_workers=2, + resources=cluster_pb2.ResourceSpec(cpu=1), + max_retries=2, + ) + + pool = WorkerPool( + client, + config, + timeout=2.0, + resolver=FixedResolver(endpoints), + ) + + # Override pool_id to match our endpoints + pool._pool_id = "test" + + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + # Submit task - may hit failed worker first but should succeed after retry + future = pool.submit(lambda: "success") + result = future.result(timeout=10.0) + + assert result == "success" + + pool.shutdown(wait=False) + + def test_task_fails_when_retries_exhausted(self): + """When all retries are exhausted, the error propagates to caller.""" + # All workers point to non-existent servers + endpoints = { + "_workerpool_noretry:worker-0": "http://127.0.0.1:9999", + } + + client = MockClusterClient() + config = WorkerPoolConfig( + num_workers=1, + resources=cluster_pb2.ResourceSpec(cpu=1), + max_retries=0, # No retries + ) + + pool = WorkerPool( + client, + config, + timeout=1.0, + resolver=FixedResolver(endpoints), + ) + + pool._pool_id = "noretry" + pool._launch_workers() + pool._wait_for_workers(min_workers=1, timeout=5.0) + + future = pool.submit(lambda: "should fail") + + # Should fail with connection error from failed RPC + with pytest.raises(ConnectError): + future.result(timeout=5.0) + + pool.shutdown(wait=False) diff --git a/lib/zephyr/src/zephyr/backends.py b/lib/zephyr/src/zephyr/backends.py index e8102f2b9d..65ba1ab83a 100644 --- a/lib/zephyr/src/zephyr/backends.py +++ b/lib/zephyr/src/zephyr/backends.py @@ -41,6 +41,7 @@ compute_plan, run_stage, ) +from zephyr.storage import DEFAULT_SPILL_THRESHOLD_BYTES, StorageManager logger = logging.getLogger(__name__) @@ -67,39 +68,22 @@ class Shard: idx: int # Shard index (e.g., 0 of 50) chunks: list[Chunk] - context: JobContext @property def count(self) -> int: """Total number of items across all chunks.""" return sum(c.count for c in self.chunks) - def iter_chunks(self) -> Iterator[list]: - """Iterate over chunks (each chunk is a list of items).""" + def iter_chunks(self) -> Iterator: + """Iterate over chunks (each chunk is an iterable of items).""" for chunk in self.chunks: - data = self.context.get(chunk.data) - yield data + yield chunk.data def __iter__(self): """Flat map over all chunks.""" for chunk_data in self.iter_chunks(): yield from chunk_data - @staticmethod - def from_single_ref(ref: Any, context: JobContext, idx: int, count: int) -> Shard: - """Wrap a single ref as a Shard. - - Args: - ref: Reference to wrap (type depends on context) - context: Execution context for get operations - idx: Shard index - count: Number of items in the ref - - Returns: - Shard containing the single ref - """ - return Shard(idx=idx, chunks=[Chunk(count=count, data=ref)], context=context) - def format_shard_path(pattern: str, shard_idx: int, total: int) -> str: """Format output path with shard information. @@ -134,29 +118,28 @@ def reshard_refs(shards: list[Shard], num_shards: int) -> list[Shard]: if not shards: return [] - context = shards[0].context all_chunks = [chunk for shard in shards for chunk in shard.chunks] if not all_chunks: return [] chunk_groups = np.array_split(all_chunks, num_shards) # type: ignore - return [ - Shard(idx=idx, chunks=list(group), context=context) for idx, group in enumerate(chunk_groups) if len(group) > 0 - ] + return [Shard(idx=idx, chunks=list(group)) for idx, group in enumerate(chunk_groups) if len(group) > 0] class Backend: - def __init__(self, context: JobContext, config: BackendConfig): + def __init__(self, context: JobContext, config: BackendConfig, storage: StorageManager): """Initialize backend with execution context and configuration. Args: context: Execution context providing put/get/run/wait primitives config: Backend configuration + storage: Storage manager for chunk serialization """ self.context = context self.config = config self.dry_run = config.dry_run + self.storage = storage @staticmethod def execute( @@ -166,16 +149,20 @@ def execute( verbose: bool = False, max_parallelism: int = 1024, dry_run: bool = False, + storage_path: str | None = None, + spill_threshold_bytes: int | None = None, ) -> Sequence[T]: """Execute a dataset and return results. Args: dataset: Dataset to execute context: JobContext to use for execution. If None, uses get_default_job_ctx() - hints: Execution hints (chunk_size, intra_shard_parallelism, etc.) + hints: Execution hints (chunk_size, etc.) verbose: Print additional logging and optimization stats max_parallelism: Maximum number of concurrent tasks dry_run: If True, show optimization plan without executing + storage_path: Base path for intermediate storage. Defaults to MARIN_PREFIX/tmp. + spill_threshold_bytes: Size threshold for spilling to storage. Defaults to 1MB. Returns: Sequence of results @@ -190,15 +177,20 @@ def execute( if context is None: context = get_default_job_ctx() config = BackendConfig(max_parallelism=max_parallelism, dry_run=dry_run) - backend = Backend(context, config) - plan = compute_plan(dataset, hints) - if verbose: - backend._print_plan(dataset.operations, plan) - if dry_run: - return [] + if spill_threshold_bytes is None: + spill_threshold_bytes = DEFAULT_SPILL_THRESHOLD_BYTES + + with StorageManager(storage_path, spill_threshold_bytes) as storage: + backend = Backend(context, config, storage) - return list(backend._execute_plan(plan, hints)) + plan = compute_plan(dataset, hints) + if verbose: + backend._print_plan(dataset.operations, plan) + if dry_run: + return [] + + return list(backend._execute_plan(plan, hints)) def _print_plan(self, original_ops: list, plan: PhysicalPlan) -> None: """Print the physical plan showing shard count and operation fusion. @@ -252,6 +244,8 @@ def _shards_from_source_items(self, source_items: list[SourceItem]) -> list[Shar Returns: List of Shards ready for processing, one per unique shard_idx """ + from zephyr.storage import InlineRef + # Group by shard_idx items_by_shard: dict[int, list[SourceItem]] = defaultdict(list) for item in source_items: @@ -264,10 +258,10 @@ def _shards_from_source_items(self, source_items: list[SourceItem]) -> list[Shar chunks = [] for item in items: - # Pass the data field directly to the first operation - chunks.append(Chunk(count=1, data=self.context.put([item.data]))) + # Use InlineRef so Shards are pickle-able for Ray workers + chunks.append(Chunk(count=1, data=InlineRef(data=[item.data]))) - shards.append(Shard(idx=shard_idx, chunks=chunks, context=self.context)) + shards.append(Shard(idx=shard_idx, chunks=chunks)) return shards @@ -301,8 +295,13 @@ def _execute_stage( # Compute aux shards for joins aux_shards_per_shard = self._compute_join_aux_shards(stage, shards, hints) - # Single execution path - ForkChunks handles parallelism internally - return self._execute_shard_parallel(stage.operations, shards, aux_shards_per_shard, hints) + # Execute stage + result = self._execute_shard_parallel(stage.operations, shards, aux_shards_per_shard, hints) + + # Advance storage stage counter for next stage + self.storage.next_stage() + + return result def _compute_join_aux_shards( self, @@ -366,45 +365,53 @@ def _run_tasks( ) -> dict[int, list[tuple[ChunkHeader, Any]]]: """Run stage tasks for contexts, return results grouped by output shard_idx. + Workers return list[tuple[ChunkHeader, ChunkRef]] directly. The controller + collects these refs without materializing data. + Args: contexts: List of StageContext to process operations: Physical operations to execute Returns: - Dict mapping shard_idx -> list of (header, data_ref) tuples. + Dict mapping shard_idx -> list of (header, ChunkRef) tuples. """ results_by_shard: dict[int, list[tuple[ChunkHeader, Any]]] = defaultdict(list) if not contexts: return results_by_shard - active_gens: list[tuple[Any, StageContext]] = [] + active: dict[int, Any] = {} queued = list(contexts) + task_counter = 0 # Start initial batch - while len(active_gens) < self.config.max_parallelism and queued: + while len(active) < self.config.max_parallelism and queued: ctx = queued.pop(0) - active_gens.append((self.context.run(run_stage, ctx, operations), ctx)) - - # Process results - while active_gens or queued: - gen_objs = [g for g, _ in active_gens] - ready, _ = self.context.wait(gen_objs, num_returns=1) - - for ready_gen in ready: - # Find matching entry - for g, ctx in active_gens: - if g is ready_gen: - try: - header = self.context.get(next(ready_gen)) - data_ref = next(ready_gen) - results_by_shard[header.shard_idx].append((header, data_ref)) - except StopIteration: - active_gens.remove((g, ctx)) - if queued: - next_ctx = queued.pop(0) - active_gens.append((self.context.run(run_stage, next_ctx, operations), next_ctx)) - break + future = self.context.run(run_stage, ctx, operations) + active[task_counter] = future + task_counter += 1 + + # Collect results + while active or queued: + futures = list(active.values()) + ready, _ = self.context.wait(futures, num_returns=1) + + for ready_future in ready: + # Find and remove matching entry + task_id = next(tid for tid, f in active.items() if f is ready_future) + del active[task_id] + + # Workers return list of (header, ChunkRef) tuples directly + result_pairs = self.context.get(ready_future) + for header, ref in result_pairs: + results_by_shard[header.shard_idx].append((header, ref)) + + # Start next task if available + if queued: + next_ctx = queued.pop(0) + next_future = self.context.run(run_stage, next_ctx, operations) + active[task_counter] = next_future + task_counter += 1 return results_by_shard @@ -417,6 +424,8 @@ def _execute_shard_parallel( ) -> list[Shard]: """Execute operations on shards with one task per shard. + Workers make spill decisions and return ChunkRefs (InlineRef or StorageRef). + Args: operations: Physical operations to execute shards: List of input Shards @@ -424,7 +433,7 @@ def _execute_shard_parallel( hints: Execution hints Returns: - List of output Shards assembled from streamed chunks + List of output Shards assembled from worker-returned ChunkRefs """ if aux_shards_per_shard is None: aux_shards_per_shard = [{} for _ in range(len(shards))] @@ -433,14 +442,18 @@ def _execute_shard_parallel( total = len(shards) + # Stage storage path for workers to spill chunks + stage_path = f"{self.storage.job_path}/stage_{self.storage._stage_idx}" + contexts = [ StageContext( shard=shard, shard_idx=shard_idx, total_shards=total, chunk_size=hints.chunk_size, + storage_path=stage_path, + spill_threshold_bytes=self.storage.spill_threshold_bytes, aux_shards=aux_shards, - execution_context=self.context, ) for shard_idx, (shard, aux_shards) in enumerate(zip(shards, aux_shards_per_shard, strict=True)) ] @@ -459,13 +472,12 @@ def _execute_shard_parallel( shards = [] for idx in range(num_output_shards): if idx not in results: - shards.append(Shard(idx=idx, chunks=[], context=self.context)) + shards.append(Shard(idx=idx, chunks=[])) else: shards.append( Shard( idx=idx, chunks=[Chunk(header.count, data_ref) for header, data_ref in results[idx]], - context=self.context, ) ) diff --git a/lib/zephyr/src/zephyr/plan.py b/lib/zephyr/src/zephyr/plan.py index 93f971f6eb..083d24c8ed 100644 --- a/lib/zephyr/src/zephyr/plan.py +++ b/lib/zephyr/src/zephyr/plan.py @@ -34,7 +34,6 @@ import fsspec import msgspec -from fray.job import JobContext from zephyr.dataset import ( FilterOp, @@ -61,12 +60,6 @@ # Default number of items per output chunk during streaming DEFAULT_CHUNK_SIZE = 100_000 -# Default number of parallel chunks when splitting files for intra-shard parallelism -DEFAULT_INTRA_SHARD_PARALLELISM = 1 - -# Size of micro-batches yielded from parallel chunk workers to reduce overhead -DEFAULT_MICRO_BATCH_SIZE = 1024 - @dataclass class SourceItem: @@ -146,18 +139,7 @@ class Join: right_plan: PhysicalPlan | None = None -@dataclass -class ForkChunks: - """Fork stream into N parallel chunk streams. - - Child operations are applied in parallel, and merged as available. - """ - - target_chunks: int = DEFAULT_INTRA_SHARD_PARALLELISM - parallel_ops: list = field(default_factory=list) # list[PhysicalOp] - - -PhysicalOp = Map | Write | Scatter | Reduce | Fold | Reshard | Join | ForkChunks +PhysicalOp = Map | Write | Scatter | Reduce | Fold | Reshard | Join class StageType(StrEnum): @@ -308,13 +290,9 @@ class ExecutionHint: Attributes: chunk_size: Number of items per output chunk during streaming. Use -1 for 1 chunk per shard. - intra_shard_parallelism: Controls parallel processing of chunks within - a shard. Set to -1 (default) for auto (parallel when chunks > 1), - 0 to disable, or N to limit max parallel chunks per shard. """ chunk_size: int = DEFAULT_CHUNK_SIZE - intra_shard_parallelism: int = -1 @dataclass @@ -326,40 +304,19 @@ class FusionState: pending_fusible: list = field(default_factory=list) output_shards: int | None = None stage_type: StageType = StageType.WORKER - hints: ExecutionHint = field(default_factory=ExecutionHint) def flush_pending(self) -> None: - """Convert pending fusible ops to a physical Map or ForkChunks. - - When the first op is LoadFileOp and parallelism is enabled, creates ForkChunks - for parallel chunk processing. - """ + """Convert pending fusible ops to a physical Map.""" if not self.pending_fusible: return - has_load_file = isinstance(self.pending_fusible[0], LoadFileOp) requires_full_shard = any(isinstance(op, MapShardOp) for op in self.pending_fusible) - - # Create ForkChunks for file pipelines with parallelism enabled - if has_load_file and self.hints.intra_shard_parallelism != 0 and not requires_full_shard: - user_ops = self.pending_fusible[1:] # Exclude LoadFileOp, ForkChunks handles file loading - target_chunks = ( - self.hints.intra_shard_parallelism - if self.hints.intra_shard_parallelism > 0 - else DEFAULT_INTRA_SHARD_PARALLELISM - ) - parallel_ops = [Map(fn=compose_map(user_ops), requires_full_shard=False)] if user_ops else [] - logger.info("Creating ForkChunks with %d parallel ops, %d target chunks", len(parallel_ops), target_chunks) - self.current_ops.append(ForkChunks(target_chunks=target_chunks, parallel_ops=parallel_ops)) - else: - # Regular Map - self.current_ops.append( - Map( - fn=compose_map(self.pending_fusible[:]), - requires_full_shard=requires_full_shard, - ) + self.current_ops.append( + Map( + fn=compose_map(self.pending_fusible[:]), + requires_full_shard=requires_full_shard, ) - + ) self.pending_fusible = [] def add_op( @@ -412,9 +369,6 @@ def _fuse_operations(operations: list, hints: ExecutionHint | None = None) -> li - ReshardOp → Reshard - JoinOp → Join (with pre-computed right_plan) - When a stage starts with LoadFileOp and parallelism is enabled, the leading Maps - are wrapped in ForkChunks for parallel chunk processing. - Args: operations: List of logical operations hints: Execution hints (used for pre-computing join right plans) @@ -428,7 +382,7 @@ def _fuse_operations(operations: list, hints: ExecutionHint | None = None) -> li if hints is None: hints = ExecutionHint() - state = FusionState(hints=hints) + state = FusionState() for op in operations: if isinstance(op, WriteOp): @@ -613,23 +567,6 @@ def make_windows( yield window -def _stream_chunks(items: Iterator, shard_idx: int, chunk_size: int) -> Iterator[ChunkHeader | list[Any]]: - """Stream chunks from an iterator, yielding header/data pairs.""" - chunk: list = [] - for item in items: - chunk.append(item) - if chunk_size > 0 and len(chunk) >= chunk_size: - header = ChunkHeader(shard_idx=shard_idx, count=len(chunk)) - yield header - yield chunk - chunk = [] - # Yield final partial chunk - if chunk: - header = ChunkHeader(shard_idx=shard_idx, count=len(chunk)) - yield header - yield chunk - - def _group_items_by_hash( items: Iterable, key_fn: Callable, @@ -690,111 +627,6 @@ def _merge_sorted_chunks(shard, key_fn: Callable) -> Iterator[tuple[object, Iter yield from groupby(merged_stream, key=key_fn) -def _compute_chunk_specs(spec, target_chunks: int) -> list: - """Compute chunk specs for a file.""" - from zephyr.readers import open_file - - if target_chunks <= 1 or not isinstance(spec, InputFileSpec) or not spec.path.endswith((".parquet", ".vortex")): - return [spec] - - if spec.path.endswith(".parquet"): - import pyarrow.parquet as pq - - with open_file(spec.path, "rb") as f: - parquet_file = pq.ParquetFile(f) - num_rows = parquet_file.metadata.num_rows - else: - import vortex - - f = vortex.open(spec.path) - num_rows = f.to_dataset().count_rows() - - row_ranges = [] - rows_per_chunk = num_rows // target_chunks - for i in range(target_chunks): - start = i * rows_per_chunk - end = (i + 1) * rows_per_chunk - row_ranges.append((start, end)) - - row_ranges[-1] = (row_ranges[-1][0], num_rows) - - return [ - InputFileSpec( - path=spec.path, - format=spec.format, - columns=spec.columns, - row_start=start, - row_end=end, - filter_expr=spec.filter_expr, - ) - for start, end in row_ranges - ] - - -def _merge_chunk_streams(exec_ctx, futures: list): - active = {id(f): f for f in futures} - - while active: - ready, _ = exec_ctx.wait(list(active.values()), num_returns=1) - for gen in ready: - try: - items = exec_ctx.get(next(gen)) - yield from items - except StopIteration: - del active[id(gen)] - - -def _execute_fork_join( - exec_ctx, - source_stream, - parallel_ops: list[PhysicalOp], - target_chunks: int, -): - """Execute ops in parallel across chunks, merging results.""" - from zephyr.readers import load_file - - source_items = list(source_stream) - - logger.info("Source items: %s", source_items) - - # For each source item, compute chunk specs - all_chunk_specs = [] - for item in source_items: - if isinstance(item, InputFileSpec): - chunk_specs = _compute_chunk_specs(item, target_chunks) - all_chunk_specs.extend(chunk_specs) - else: - all_chunk_specs.append(item) - - logger.info("All chunk specs: %s", all_chunk_specs) - - def process_chunk(chunk_spec): - if isinstance(chunk_spec, InputFileSpec): - stream = load_file(chunk_spec) - else: - stream = iter([chunk_spec]) - for op in parallel_ops: - assert isinstance(op, Map) - stream = op.fn(stream) - - # batch into micro-chunks to reduce overhead - micro_chunks = [] - for item in stream: - micro_chunks.append(item) - if len(micro_chunks) >= DEFAULT_MICRO_BATCH_SIZE: - yield micro_chunks - micro_chunks = [] - if micro_chunks: - yield micro_chunks - - if len(all_chunk_specs) == 1: - for batch in process_chunk(all_chunk_specs[0]): - yield from batch - else: - futures = [exec_ctx.run(process_chunk, spec) for spec in all_chunk_specs] - yield from _merge_chunk_streams(exec_ctx, futures) - - def _sorted_merge_join( left_stream: Iterable, right_stream: Iterable, @@ -856,16 +688,18 @@ class StageContext: shard_idx: Index of this shard total_shards: Total number of shards chunk_size: Number of items per output chunk + storage_path: Base path for spilling chunks (e.g., gs://bucket/tmp/job_xxx/stage_0) + spill_threshold_bytes: Size threshold for spilling to storage aux_shards: Auxiliary shards for joins, keyed by op index - execution_context: Execution context for put/get/run/wait operations (for ForkChunks) """ shard: Any # Shard object (avoids circular import) shard_idx: int total_shards: int chunk_size: int + storage_path: str + spill_threshold_bytes: int aux_shards: dict[int, list[Any]] = field(default_factory=dict) - execution_context: JobContext = None def get_right_shard(self, op_index: int) -> Any: """Get right shard for join at given op index. @@ -879,10 +713,54 @@ def get_right_shard(self, op_index: int) -> Any: return shards[0] +def _collect_chunks( + items: Iterator, + shard_idx: int, + chunk_size: int, + storage_path: str, + spill_threshold_bytes: int, +) -> list[tuple[ChunkHeader, Any]]: + """Collect items into chunks and write via ChunkWriter. + + Returns list of (header, ChunkRef) where ChunkRef is InlineRef or StorageRef. + """ + from zephyr.storage import ChunkWriter + + results: list[tuple[ChunkHeader, Any]] = [] + current_chunk: list = [] + chunk_idx = 0 + + for item in items: + current_chunk.append(item) + + if chunk_size > 0 and len(current_chunk) >= chunk_size: + # Flush current chunk + if current_chunk: + spill_path = f"{storage_path}/shard_{shard_idx:05d}_chunk_{chunk_idx:05d}.vortex" + writer = ChunkWriter(spill_path=spill_path, spill_threshold_bytes=spill_threshold_bytes) + for chunk_item in current_chunk: + writer.write(chunk_item) + ref = writer.finish() + results.append((ChunkHeader(shard_idx=shard_idx, count=len(current_chunk)), ref)) + current_chunk = [] + chunk_idx += 1 + + # Flush remaining items + if current_chunk: + spill_path = f"{storage_path}/shard_{shard_idx:05d}_chunk_{chunk_idx:05d}.vortex" + writer = ChunkWriter(spill_path=spill_path, spill_threshold_bytes=spill_threshold_bytes) + for chunk_item in current_chunk: + writer.write(chunk_item) + ref = writer.finish() + results.append((ChunkHeader(shard_idx=shard_idx, count=len(current_chunk)), ref)) + + return results + + def run_stage( ctx: StageContext, ops: list[PhysicalOp], -) -> Iterator[ChunkHeader | list[Any]]: +) -> list[tuple[ChunkHeader, Any]]: """Execute a stage's physical ops in a single pass. This is the single worker function that backends call to execute physical ops. @@ -892,8 +770,8 @@ def run_stage( ctx: Stage execution context providing shard data and metadata ops: List of physical operations to execute in sequence - Yields: - ChunkHeader followed by list of items for each chunk produced + Returns: + List of (ChunkHeader, ChunkRef) tuples where ChunkRef is InlineRef or StorageRef """ from zephyr.writers import write_binary_file, write_jsonl_file, write_levanter_cache, write_parquet_file @@ -906,15 +784,6 @@ def run_stage( if isinstance(op, Map): stream = op.fn(stream) op_index += 1 - elif isinstance(op, ForkChunks): - # Execute chunk parallelism with contained parallel_ops - stream = _execute_fork_join( - ctx.execution_context, - stream, - op.parallel_ops, - op.target_chunks, - ) - op_index += 1 elif isinstance(op, Write): output_path = op.output_pattern(ctx.shard_idx, ctx.total_shards) @@ -928,8 +797,9 @@ def run_stage( if fs.exists(test_path): logger.info(f"Skipping write, output exists: {output_path}") - yield from _stream_chunks(iter([output_path]), ctx.shard_idx, ctx.chunk_size) - return + return _collect_chunks( + iter([output_path]), ctx.shard_idx, ctx.chunk_size, ctx.storage_path, ctx.spill_threshold_bytes + ) # Write based on type if op.writer_type == "jsonl": @@ -948,27 +818,31 @@ def run_stage( else: raise ValueError(f"Unknown writer_type: {op.writer_type}") - yield from _stream_chunks(iter([result]), ctx.shard_idx, ctx.chunk_size) - return + return _collect_chunks( + iter([result]), ctx.shard_idx, ctx.chunk_size, ctx.storage_path, ctx.spill_threshold_bytes + ) elif isinstance(op, Scatter): # Hash items to output shards num_output_shards = op.num_output_shards if op.num_output_shards > 0 else ctx.total_shards output_chunks = _group_items_by_hash(stream, op.key_fn, num_output_shards, ctx.chunk_size) - # Yield chunks for each output shard - for shard_idx in range(num_output_shards): - if output_chunks[shard_idx]: - for chunk in output_chunks[shard_idx]: - header = ChunkHeader(shard_idx=shard_idx, count=chunk.count) - yield header - yield chunk.data + # Collect chunks for each output shard + results: list[tuple[ChunkHeader, Any]] = [] + for target_shard_idx in range(num_output_shards): + if output_chunks[target_shard_idx]: + for chunk in output_chunks[target_shard_idx]: + # Scatter chunks are already materialized, wrap as InlineRef + from zephyr.storage import InlineRef + + ref = InlineRef(data=chunk.data) + results.append((ChunkHeader(shard_idx=target_shard_idx, count=chunk.count), ref)) else: - # Yield empty chunk so controller knows this shard exists - header = ChunkHeader(shard_idx=shard_idx, count=0) - yield header - yield [] - return + # Empty chunk so controller knows this shard exists + from zephyr.storage import InlineRef + + results.append((ChunkHeader(shard_idx=target_shard_idx, count=0), InlineRef(data=[]))) + return results elif isinstance(op, Reduce): # Merge sorted chunks and reduce per key @@ -990,5 +864,5 @@ def run_stage( stream = op.fn(stream, iter(right_shard)) op_index += 1 - # Yield remaining items as chunks - yield from _stream_chunks(stream, ctx.shard_idx, ctx.chunk_size) + # Collect remaining items as chunks + return _collect_chunks(stream, ctx.shard_idx, ctx.chunk_size, ctx.storage_path, ctx.spill_threshold_bytes) diff --git a/lib/zephyr/src/zephyr/storage.py b/lib/zephyr/src/zephyr/storage.py new file mode 100644 index 0000000000..7764f8cadb --- /dev/null +++ b/lib/zephyr/src/zephyr/storage.py @@ -0,0 +1,169 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Storage abstractions for Zephyr chunk serialization. + +This module provides InlineRef and StorageRef types for representing chunk data +that is either kept in memory (small) or spilled to storage (large). +""" + +from __future__ import annotations + +import os +import sys +import uuid +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +# Threshold for inline vs storage decision (1MB) +DEFAULT_SPILL_THRESHOLD_BYTES = 1 * 1024 * 1024 + + +@dataclass +class InlineRef: + """Data kept inline in memory (small chunks).""" + + data: list[Any] + + @property + def count(self) -> int: + return len(self.data) + + def __iter__(self) -> Iterator[Any]: + return iter(self.data) + + +@dataclass +class StorageRef: + """Reference to data stored in GCS/S3/local filesystem.""" + + path: str + count: int + + def __iter__(self) -> Iterator[dict]: + from zephyr.readers import load_vortex + + return load_vortex(self.path) + + +# Union type for chunk references +ChunkRef = InlineRef | StorageRef + + +def _estimate_item_size(item: Any) -> int: + """Estimate serialized size of an item.""" + if isinstance(item, dict): + total = 0 + for k, v in item.items(): + total += _estimate_item_size(k) + _estimate_item_size(v) + return total + return sys.getsizeof(item) + + +class ChunkWriter: + """Writes chunk data, choosing inline vs storage based on size threshold. + + Small chunks (< spill_threshold_bytes) are kept inline in memory. + Large chunks are written to storage as Vortex files. + + Usage: + writer = ChunkWriter(spill_path="/tmp/chunk.vortex") + for item in items: + writer.write(item) + ref = writer.finish() # Returns InlineRef or StorageRef + """ + + def __init__( + self, + spill_path: str, + spill_threshold_bytes: int = DEFAULT_SPILL_THRESHOLD_BYTES, + ): + self.spill_path = spill_path + self.spill_threshold_bytes = spill_threshold_bytes + self._items: list[Any] = [] + self._size_estimate = 0 + + def write(self, item: Any) -> None: + """Add item to chunk.""" + self._items.append(item) + self._size_estimate += _estimate_item_size(item) + + def finish(self) -> ChunkRef: + """Finalize and return appropriate ref type.""" + if self._size_estimate < self.spill_threshold_bytes: + return InlineRef(data=self._items) + + from zephyr.writers import write_vortex_file + + result = write_vortex_file(self._items, self.spill_path) + return StorageRef(path=result["path"], count=result["count"]) + + +class StorageManager: + """Manages storage paths and cleanup for a job execution. + + Usage: + with StorageManager() as storage: + # ... use chunk_path/job_path to reference or write chunks + # Job directory cleaned up on exit + """ + + def __init__( + self, + base_path: str | None = None, + spill_threshold_bytes: int = DEFAULT_SPILL_THRESHOLD_BYTES, + ): + if base_path is None: + prefix = os.environ.get("MARIN_PREFIX", "/tmp") + base_path = f"{prefix}/zephyr/tmp/" + + self.base_path = base_path.rstrip("/") + self.job_id = str(uuid.uuid4())[:8] + self.spill_threshold_bytes = spill_threshold_bytes + self._stage_idx = 0 + + @property + def job_path(self) -> str: + return f"{self.base_path}/job_{self.job_id}" + + def chunk_path(self, shard_idx: int, chunk_idx: int) -> str: + return f"{self.job_path}/stage_{self._stage_idx}/shard_{shard_idx:05d}_chunk_{chunk_idx:05d}.vortex" + + def create_writer(self, shard_idx: int, chunk_idx: int) -> ChunkWriter: + """Create a ChunkWriter for the given shard/chunk.""" + path = self.chunk_path(shard_idx, chunk_idx) + return ChunkWriter(spill_path=path, spill_threshold_bytes=self.spill_threshold_bytes) + + def next_stage(self) -> None: + """Advance to next stage.""" + self._stage_idx += 1 + + def cleanup(self) -> None: + """Best-effort cleanup of job directory.""" + import fsspec + + try: + fs, _ = fsspec.core.url_to_fs(self.job_path) + if fs.exists(self.job_path): + fs.rm(self.job_path, recursive=True) + except Exception: + pass # ignore any failures during cleanup + + def __enter__(self) -> StorageManager: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + self.cleanup() + return False diff --git a/lib/zephyr/tests/conftest.py b/lib/zephyr/tests/conftest.py index 8941b19650..82225957d4 100644 --- a/lib/zephyr/tests/conftest.py +++ b/lib/zephyr/tests/conftest.py @@ -14,10 +14,15 @@ """Pytest fixtures for zephyr tests.""" +import threading +import time + import pytest import ray from fray.job import create_job_ctx +from fray.job.rpc.controller import FrayControllerServer +from fray.job.rpc.worker import FrayWorker from zephyr import load_file @@ -31,6 +36,41 @@ def ray_cluster(): # Don't shutdown - let pytest handle cleanup +@pytest.fixture(scope="module") +def rpc_infrastructure(): + """Start RPC controller and workers for Zephyr tests. + + Creates a controller server and 2 workers to match threadpool + parallelism (max_workers=2). + """ + # Start controller on random port + server = FrayControllerServer(port=0) + port = server.start() + + # Start 2 workers for parallel execution + workers = [] + threads = [] + + for i in range(2): + worker = FrayWorker(f"http://localhost:{port}", port=0) + workers.append(worker) + thread = threading.Thread(target=worker.run, daemon=True, name=f"zephyr-rpc-worker-{i}") + thread.start() + threads.append(thread) + + # Give workers time to register + time.sleep(0.3) + + yield port + + # Cleanup + for worker in workers: + worker.stop() + server.stop() + for thread in threads: + thread.join(timeout=2.0) + + @pytest.fixture def sample_data(): """Sample data for testing.""" @@ -39,14 +79,29 @@ def sample_data(): @pytest.fixture( params=[ - pytest.param(create_job_ctx("sync"), id="sync"), - pytest.param(create_job_ctx("threadpool", max_workers=2), id="thread"), - pytest.param(create_job_ctx("ray"), id="ray"), + pytest.param("sync", id="sync"), + pytest.param("threadpool", id="thread"), + pytest.param("ray", id="ray"), + pytest.param("rpc", id="rpc"), ] ) -def backend(request): - """Parametrized fixture providing all job contexts for testing.""" - return request.param +def backend(request, ray_cluster, rpc_infrastructure): + """Parametrized fixture providing all job contexts for testing. + + Tests run against all 4 backends: sync, threadpool, ray, and rpc. + """ + backend_type = request.param + + if backend_type == "sync": + return create_job_ctx("sync") + elif backend_type == "threadpool": + return create_job_ctx("threadpool", max_workers=2) + elif backend_type == "ray": + return create_job_ctx("ray") + elif backend_type == "rpc": + return create_job_ctx("fray", controller_address=f"http://localhost:{rpc_infrastructure}") + else: + raise ValueError(f"Unknown backend: {backend_type}") class CallCounter: diff --git a/lib/zephyr/tests/test_dataset.py b/lib/zephyr/tests/test_dataset.py index 088cf85072..0961a8e39d 100644 --- a/lib/zephyr/tests/test_dataset.py +++ b/lib/zephyr/tests/test_dataset.py @@ -29,6 +29,19 @@ from .conftest import CallCounter +def create_vortex_file(tmp_path): + """Create a test vortex file with sample data for pushdown tests. + + Creates 100 records with id (0-99), name, and score (id * 10). + """ + from zephyr.writers import write_vortex_file + + records = [{"id": i, "name": f"item_{i}", "score": i * 10} for i in range(100)] + path = tmp_path / "test.vortex" + write_vortex_file(records, str(path)) + return path + + @pytest.fixture def sample_data(): """Sample data for testing.""" @@ -41,20 +54,6 @@ def test_from_list(sample_data, backend): assert list(Backend.execute(ds, context=backend)) == sample_data -def test_dataclass_round_trip_preserves_type(backend): - """Ensure dataclass items are not downcast to dicts during execution.""" - items = [SampleDataclass("alpha", 1), SampleDataclass("beta", 2)] - - ds = Dataset.from_list(items) - result = list(Backend.execute(ds, context=backend)) - - assert [item.name for item in result] == ["alpha", "beta"] - assert all(isinstance(item, SampleDataclass) for item in result) - - doubled = Dataset.from_list(items).map(lambda x: x.value * 2) - assert list(Backend.execute(doubled, context=backend)) == [2, 4] - - def test_from_iterable(backend): """Test creating dataset from iterable.""" ds = Dataset.from_iterable(range(5)) @@ -1276,33 +1275,6 @@ def test_mixed_filter_expression_and_lambda(backend): assert results[0] == {"a": 3, "b": "x"} -# ============================================================================= -# InputFileSpec and Chunking Tests -# ============================================================================= - - -def test_input_file_spec_row_range_basic(tmp_path): - """Test InputFileSpec reads only the specified row range.""" - from zephyr.readers import InputFileSpec, load_parquet - - data = [{"id": i, "value": i * 10} for i in range(100)] - input_path = tmp_path / "data.parquet" - write_parquet_file(data, str(input_path)) - - spec = InputFileSpec( - path=str(input_path), - format="parquet", - row_start=10, - row_end=20, - ) - - records = list(load_parquet(spec)) - - assert len(records) == 10 - assert records[0]["id"] == 10 - assert records[-1]["id"] == 19 - - def test_input_file_spec_with_columns_and_row_range(tmp_path): """Test InputFileSpec with both columns and row_range.""" from zephyr.readers import InputFileSpec, load_parquet @@ -1327,109 +1299,53 @@ def test_input_file_spec_with_columns_and_row_range(tmp_path): assert records[-1]["id"] == 9 -def test_fork_chunks_inserted_with_intra_shard_parallelism(tmp_path): - """Test that ForkChunks is inserted when intra_shard_parallelism is enabled.""" - import pyarrow as pa - import pyarrow.parquet as pq - from zephyr.plan import ExecutionHint, ForkChunks, Map, compute_plan - - # Create a parquet file - data = [{"id": i, "value": i * 2} for i in range(100)] - input_path = tmp_path / "data.parquet" - table = pa.Table.from_pylist(data) - pq.write_table(table, str(input_path)) - - # Create pipeline: load_parquet -> map -> write - output_path = tmp_path / "output.parquet" - ds = ( - Dataset.from_files(str(input_path)) - .load_parquet() - .map(lambda x: {"id": x["id"], "doubled": x["value"]}) - .write_parquet(str(output_path)) - ) - - # Test with intra_shard_parallelism enabled - hints = ExecutionHint(intra_shard_parallelism=4) - plan = compute_plan(ds, hints) - - # Should have one stage - assert len(plan.stages) == 1 - stage = plan.stages[0] - - # Check that ForkChunks is the first op - assert len(stage.operations) > 0 - assert isinstance(stage.operations[0], ForkChunks) - assert stage.operations[0].target_chunks == 4 - - # Check that ForkChunks contains the user ops (map) - fork_chunks = stage.operations[0] - assert len(fork_chunks.parallel_ops) == 1 - assert isinstance(fork_chunks.parallel_ops[0], Map) +def test_vortex_load_file_auto_detects_format(backend, tmp_path): + """Test that load_file() auto-detects vortex format.""" + from zephyr.writers import write_vortex_file + # Create input vortex file + records = [{"id": i, "name": f"item_{i}"} for i in range(50)] + input_path = tmp_path / "input.vortex" + write_vortex_file(records, str(input_path)) -def test_fork_chunks_not_inserted_when_disabled(tmp_path): - """Test that ForkChunks is NOT inserted when intra_shard_parallelism is 0.""" - import pyarrow as pa - import pyarrow.parquet as pq - from zephyr.plan import ExecutionHint, ForkChunks, compute_plan + output_pattern = str(tmp_path / "output-{shard:05d}.jsonl.gz") - # Create a parquet file - data = [{"id": i, "value": i * 2} for i in range(100)] - input_path = tmp_path / "data.parquet" - table = pa.Table.from_pylist(data) - pq.write_table(table, str(input_path)) - - # Create pipeline: load_parquet -> map -> write - output_path = tmp_path / "output.parquet" ds = ( Dataset.from_files(str(input_path)) - .load_parquet() - .map(lambda x: {"id": x["id"], "doubled": x["value"]}) - .write_parquet(str(output_path)) + .load_file() # Should auto-detect vortex + .filter(lambda r: r["id"] < 10) + .write_jsonl(output_pattern) ) - # Test with intra_shard_parallelism disabled - hints = ExecutionHint(intra_shard_parallelism=0) - plan = compute_plan(ds, hints) + results = list(Backend.execute(ds, context=backend)) + assert len(results) == 1 - # Should have one stage - assert len(plan.stages) == 1 - stage = plan.stages[0] - # Check that ForkChunks is NOT present - fork_chunks_found = any(isinstance(op, ForkChunks) for op in stage.operations) - assert not fork_chunks_found +def test_expression_filter_pushdown(backend, tmp_path): + """Test filter pushdown with expression. + Verifies that vortex format supports predicate pushdown, + filtering at the I/O layer instead of in Python. + """ + from zephyr.expr import col -def test_fork_chunks_not_inserted_with_map_shard(tmp_path): - """Test that ForkChunks is NOT inserted when MapShardOp requires full shard context.""" - import pyarrow as pa - import pyarrow.parquet as pq - from zephyr.plan import ExecutionHint, ForkChunks, compute_plan + vortex_file = create_vortex_file(tmp_path) + ds = Dataset.from_files(str(vortex_file)).load_vortex().filter(col("score") > 500) - # Create a parquet file - data = [{"id": i, "value": i * 2} for i in range(100)] - input_path = tmp_path / "data.parquet" - table = pa.Table.from_pylist(data) - pq.write_table(table, str(input_path)) + results = list(Backend.execute(ds, context=backend)) + assert len(results) == 49 # scores 510, 520, ..., 990 + assert all(r["score"] > 500 for r in results) - # Create pipeline with map_shard (requires full shard context) - output_path = tmp_path / "output.parquet" - ds = ( - Dataset.from_files(str(input_path)) - .load_parquet() - .map_shard(lambda items: list(items)[:10]) - .write_parquet(str(output_path)) - ) - # Test with intra_shard_parallelism enabled - hints = ExecutionHint(intra_shard_parallelism=4) - plan = compute_plan(ds, hints) +def test_column_select_pushdown(backend, tmp_path): + """Test column selection pushdown. - # Should have one stage - assert len(plan.stages) == 1 - stage = plan.stages[0] + Verifies that vortex format supports projection pushdown, + loading only requested columns. + """ + vortex_file = create_vortex_file(tmp_path) + ds = Dataset.from_files(str(vortex_file)).load_vortex().select("id", "score") - # Check that ForkChunks is NOT present (because map_shard requires full shard) - fork_chunks_found = any(isinstance(op, ForkChunks) for op in stage.operations) - assert not fork_chunks_found + results = list(Backend.execute(ds, context=backend)) + assert len(results) == 100 + assert set(results[0].keys()) == {"id", "score"} diff --git a/lib/zephyr/tests/test_storage.py b/lib/zephyr/tests/test_storage.py new file mode 100644 index 0000000000..1e7b8ff802 --- /dev/null +++ b/lib/zephyr/tests/test_storage.py @@ -0,0 +1,143 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from zephyr.storage import ChunkWriter, InlineRef, StorageManager, StorageRef +from zephyr.writers import write_vortex_file + + +def test_inline_ref_load_roundtrip(): + """Test InlineRef iteration roundtrip.""" + data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}, {"id": 3, "name": "Charlie"}] + ref = InlineRef(data=data) + + assert ref.count == 3 + loaded = list(ref) + assert loaded == data + + +def test_storage_ref_load_roundtrip(tmp_path): + """Test StorageRef iteration roundtrip with Vortex file.""" + data = [{"id": 1, "value": 100}, {"id": 2, "value": 200}, {"id": 3, "value": 300}] + vortex_path = tmp_path / "test.vortex" + + result = write_vortex_file(data, str(vortex_path)) + ref = StorageRef(path=str(vortex_path), count=result["count"]) + + assert ref.count == 3 + loaded = list(ref) + assert len(loaded) == 3 + assert loaded[0]["id"] == 1 + assert loaded[1]["value"] == 200 + + +def test_chunk_writer_spills_over_threshold(tmp_path): + """Test ChunkWriter spills to storage when over threshold.""" + spill_path = tmp_path / "chunk.vortex" + # Use a low threshold to force spill + writer = ChunkWriter(spill_path=str(spill_path), spill_threshold_bytes=1000) + + # Write enough data to exceed threshold + items = [{"id": i, "data": "x" * 1000} for i in range(10)] + for item in items: + writer.write(item) + + ref = writer.finish() + + # Should be spilled to storage + assert isinstance(ref, StorageRef) + assert ref.count == 10 + assert ref.path == str(spill_path) + assert spill_path.exists() + + # Verify data roundtrip + loaded = list(ref) + assert len(loaded) == 10 + assert loaded[0]["id"] == 0 + assert loaded[9]["id"] == 9 + + +def test_storage_manager_path_generation(tmp_path): + """Test StorageManager generates correct paths.""" + storage = StorageManager(base_path=str(tmp_path)) + + # Check job_path + assert storage.job_path.startswith(str(tmp_path)) + assert "job_" in storage.job_path + + # Check chunk_path format + chunk_path = storage.chunk_path(shard_idx=0, chunk_idx=0) + assert "stage_0" in chunk_path + assert "shard_00000" in chunk_path + assert "chunk_00000.vortex" in chunk_path + + # Check with different indices + chunk_path = storage.chunk_path(shard_idx=123, chunk_idx=456) + assert "shard_00123" in chunk_path + assert "chunk_00456.vortex" in chunk_path + + +def test_storage_manager_cleanup_removes_directory(tmp_path): + """Test StorageManager.cleanup() removes job directory.""" + storage = StorageManager(base_path=str(tmp_path), spill_threshold_bytes=1000) + + # Write some data to create the directory + items = [{"id": i, "data": "x" * 1000} for i in range(5)] + writer = storage.create_writer(shard_idx=0, chunk_idx=0) + for item in items: + writer.write(item) + ref = writer.finish() + + assert isinstance(ref, StorageRef) + job_path = Path(storage.job_path) + assert job_path.exists() + + # Cleanup should remove the directory + storage.cleanup() + assert not job_path.exists() + + +def test_storage_manager_context_manager_calls_cleanup(tmp_path): + """Test StorageManager context manager calls cleanup on exit.""" + items = [{"id": i, "data": "x" * 1000} for i in range(5)] + job_path = None + + with StorageManager(base_path=str(tmp_path), spill_threshold_bytes=1000) as storage: + writer = storage.create_writer(shard_idx=0, chunk_idx=0) + for item in items: + writer.write(item) + ref = writer.finish() + assert isinstance(ref, StorageRef) + job_path = Path(storage.job_path) + assert job_path.exists() + + # After exiting context, cleanup should have run + assert not job_path.exists() + + +def test_storage_ref_with_complex_data(tmp_path): + """Test StorageRef with complex nested data structures.""" + data = [ + {"id": 1, "nested": {"key": "value"}, "array": [1, 2, 3]}, + {"id": 2, "nested": {"key": "other"}, "array": [4, 5, 6]}, + ] + vortex_path = tmp_path / "complex.vortex" + write_vortex_file(data, str(vortex_path)) + + ref = StorageRef(path=str(vortex_path), count=2) + loaded = list(ref) + + assert len(loaded) == 2 + assert loaded[0]["id"] == 1 diff --git a/lib/zephyr/tests/test_vortex.py b/lib/zephyr/tests/test_vortex.py deleted file mode 100644 index bf5aa3f15f..0000000000 --- a/lib/zephyr/tests/test_vortex.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2025 The Marin Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for vortex file format support.""" - -import pytest - -from fray.job import create_job_ctx -from zephyr import Backend, Dataset -from zephyr.expr import col -from zephyr.readers import InputFileSpec, load_vortex -from zephyr.writers import write_vortex_file - - -@pytest.fixture -def vortex_file(tmp_path): - """Create a test vortex file with sample data.""" - records = [{"id": i, "name": f"item_{i}", "score": i * 10} for i in range(100)] - path = tmp_path / "test.vortex" - write_vortex_file(records, str(path)) - return path - - -@pytest.fixture( - params=[ - pytest.param("sync", id="sync"), - pytest.param("threadpool", id="thread"), - ] -) -def sync_backend(request): - """Backend fixture for sync and threadpool backends.""" - return create_job_ctx(request.param, max_workers=2) - - -class TestVortexReader: - """Tests for load_vortex() function.""" - - def test_load_vortex_basic(self, vortex_file): - """Test basic vortex file reading.""" - records = list(load_vortex(str(vortex_file))) - assert len(records) == 100 - assert records[0]["id"] == 0 - assert records[0]["name"] == "item_0" - assert records[0]["score"] == 0 - - def test_load_vortex_column_projection(self, vortex_file): - """Test column selection (projection).""" - spec = InputFileSpec(path=str(vortex_file), columns=["id", "name"]) - records = list(load_vortex(spec)) - assert len(records) == 100 - assert set(records[0].keys()) == {"id", "name"} - - def test_load_vortex_empty_file(self, tmp_path): - """Test loading an empty vortex file.""" - empty_path = tmp_path / "empty.vortex" - write_vortex_file([], str(empty_path)) - - records = list(load_vortex(str(empty_path))) - assert records == [] - - -class TestVortexWriter: - """Tests for write_vortex_file() function.""" - - def test_write_vortex_basic(self, tmp_path): - """Test basic vortex file writing.""" - records = [{"id": i, "value": i * 2} for i in range(10)] - output_path = tmp_path / "output.vortex" - - result = write_vortex_file(records, str(output_path)) - - assert result["count"] == 10 - assert output_path.exists() - - # Verify roundtrip - loaded = list(load_vortex(str(output_path))) - assert loaded == records - - def test_write_vortex_empty(self, tmp_path): - """Test writing empty dataset.""" - output_path = tmp_path / "empty.vortex" - result = write_vortex_file([], str(output_path)) - - assert result["count"] == 0 - assert output_path.exists() - - def test_write_vortex_single_record(self, tmp_path): - """Test writing single record.""" - records = [{"key": "value", "number": 42}] - output_path = tmp_path / "single.vortex" - - result = write_vortex_file(records, str(output_path)) - assert result["count"] == 1 - - loaded = list(load_vortex(str(output_path))) - assert loaded == records - - -class TestVortexPipeline: - """Tests for vortex in Dataset pipelines.""" - - def test_read_write_pipeline(self, sync_backend, vortex_file, tmp_path): - """Test read -> filter -> write pipeline with vortex.""" - output_pattern = str(tmp_path / "output-{shard:05d}.vortex") - - ds = ( - Dataset.from_files(str(vortex_file)) - .load_vortex() - .filter(lambda r: r["score"] > 500) - .write_vortex(output_pattern) - ) - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 1 - - # Verify output - loaded = list(load_vortex(results[0])) - assert len(loaded) == 49 # scores 510, 520, ..., 990 - assert all(r["score"] > 500 for r in loaded) - - def test_load_file_auto_detects_vortex(self, sync_backend, vortex_file, tmp_path): - """Test that load_file() auto-detects vortex format.""" - output_pattern = str(tmp_path / "output-{shard:05d}.jsonl.gz") - - ds = Dataset.from_files(str(vortex_file)).load_file().filter(lambda r: r["id"] < 10).write_jsonl(output_pattern) - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 1 - - def test_vortex_to_parquet_conversion(self, sync_backend, vortex_file, tmp_path): - """Test converting vortex to parquet.""" - output_pattern = str(tmp_path / "output-{shard:05d}.parquet") - - ds = Dataset.from_files(str(vortex_file)).load_vortex().write_parquet(output_pattern) - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 1 - - # Verify parquet output - from zephyr.readers import load_parquet - - loaded = list(load_parquet(results[0])) - assert len(loaded) == 100 - - def test_parquet_to_vortex_conversion(self, sync_backend, tmp_path): - """Test converting parquet to vortex.""" - # Create parquet file - from zephyr.writers import write_parquet_file - - records = [{"a": i, "b": f"val_{i}"} for i in range(50)] - parquet_path = tmp_path / "input.parquet" - write_parquet_file(records, str(parquet_path)) - - output_pattern = str(tmp_path / "output-{shard:05d}.vortex") - - ds = Dataset.from_files(str(parquet_path)).load_parquet().write_vortex(output_pattern) - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 1 - - # Verify vortex output - loaded = list(load_vortex(results[0])) - assert loaded == records - - -class TestVortexFilterPushdown: - """Tests for filter pushdown to vortex reader.""" - - def test_expression_filter_pushdown(self, sync_backend, vortex_file): - """Test filter pushdown with expression.""" - ds = Dataset.from_files(str(vortex_file)).load_vortex().filter(col("score") > 500) - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 49 # scores 510, 520, ..., 990 - assert all(r["score"] > 500 for r in results) - - def test_column_select_pushdown(self, sync_backend, vortex_file): - """Test column selection pushdown.""" - ds = Dataset.from_files(str(vortex_file)).load_vortex().select("id", "score") - - results = list(Backend.execute(ds, context=sync_backend)) - assert len(results) == 100 - assert set(results[0].keys()) == {"id", "score"} diff --git a/pyproject.toml b/pyproject.toml index 9912f20351..bd2e77d632 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,11 +9,12 @@ description = "Marin workspace root and experiments" license = { file = "LICENSE" } requires-python = ">=3.11" dependencies = [ - "marin", + "fluster", + "fray", + "haliax", "levanter", + "marin", "zephyr", - "haliax", - "fray", ] [tool.uv] @@ -25,14 +26,16 @@ override-dependencies = [ [tool.uv.workspace] members = [ - "lib/marin", + "lib/fluster", + "lib/fray", + "lib/haliax", "lib/levanter", + "lib/marin", "lib/zephyr", - "lib/haliax", - "lib/fray", ] [tool.uv.sources] +fluster = { workspace = true } fray = { workspace = true } haliax = { workspace = true } levanter = { workspace = true } @@ -67,6 +70,7 @@ ignore = ["F722", "B008", "UP015", "A005", "I001"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401"] +"lib/fluster/src/fluster/cluster/worker/dashboard.py" = ["E501"] [tool.mypy] diff --git a/uv.lock b/uv.lock index b94a494ba6..a5123ece7b 100644 --- a/uv.lock +++ b/uv.lock @@ -97,6 +97,7 @@ fork-strategy = "fewest" [manifest] members = [ + "fluster", "fray", "haliax", "levanter", @@ -681,6 +682,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/11/25cdf9d5fc21efd30134fc74c43702c6f7ef09ebae8ed927f1283403ad8d/colorful-0.5.8-py2.py3-none-any.whl", hash = "sha256:a9381fdda3337fbaba5771991020abc69676afa102646650b759927892875992", size = 201334, upload-time = "2025-10-29T11:53:20.251Z" }, ] +[[package]] +name = "connect-python" +version = "0.7.0" +source = { git = "https://github.com/connectrpc/connect-python.git?rev=5342eacecef85e52717604ee5ac7e03a1e16c7ac#5342eacecef85e52717604ee5ac7e03a1e16c7ac" } +dependencies = [ + { name = "httpx" }, + { name = "protobuf" }, +] + [[package]] name = "contourpy" version = "1.3.3" @@ -931,6 +941,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32' or (extra == 'extra-5-marin-cpu' and extra == 'extra-5-marin-gpu') or (extra == 'extra-5-marin-gpu' and extra == 'extra-5-marin-tpu') or (extra == 'extra-8-levanter-gpu' and extra == 'extra-8-levanter-tpu')" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "draccus" version = "0.11.5" @@ -1130,6 +1154,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/cf/d118be069d0fbbfbb6c5591069233678b573e49d81e2a387afb0297528a0/flax-0.12.0-py3-none-any.whl", hash = "sha256:13dee5f2658e8b51e22cef54b788ddbad1b6e14a9d7c5b6d1ae9c3a0110d645e", size = 466282, upload-time = "2025-09-25T23:58:58.631Z" }, ] +[[package]] +name = "fluster" +version = "0.1.0" +source = { editable = "lib/fluster" } +dependencies = [ + { name = "cloudpickle" }, + { name = "connect-python" }, + { name = "docker" }, + { name = "fsspec" }, + { name = "httpx" }, + { name = "humanfriendly" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "uvicorn", extra = ["standard"] }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-timeout" }, +] + +[package.metadata] +requires-dist = [ + { name = "cloudpickle", specifier = ">=3.1.2" }, + { name = "connect-python", git = "https://github.com/connectrpc/connect-python.git?rev=5342eacecef85e52717604ee5ac7e03a1e16c7ac" }, + { name = "docker", specifier = ">=7.0.0" }, + { name = "fsspec", specifier = ">=2024.0.0" }, + { name = "httpx", specifier = ">=0.28.1" }, + { name = "humanfriendly", specifier = ">=10.0" }, + { name = "pydantic", specifier = ">=2.12.5" }, + { name = "starlette", specifier = ">=0.50.0" }, + { name = "typing-extensions", specifier = ">=4.0" }, + { name = "uvicorn", extras = ["standard"], specifier = ">=0.23.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.2" }, + { name = "pytest-asyncio" }, + { name = "pytest-timeout" }, +] + [[package]] name = "fonttools" version = "4.61.1" @@ -2996,6 +3065,7 @@ name = "marin-root" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "fluster" }, { name = "fray" }, { name = "haliax" }, { name = "levanter" }, @@ -3005,6 +3075,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "fluster", editable = "lib/fluster" }, { name = "fray", editable = "lib/fray" }, { name = "haliax", editable = "lib/haliax" }, { name = "levanter", editable = "lib/levanter" }, @@ -4466,17 +4537,17 @@ wheels = [ [[package]] name = "protobuf" -version = "6.33.2" +version = "6.33.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/34/44/e49ecff446afeec9d1a66d6bbf9adc21e3c7cea7803a920ca3773379d4f6/protobuf-6.33.2.tar.gz", hash = "sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4", size = 444296, upload-time = "2025-12-06T00:17:53.311Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/b8/cda15d9d46d03d4aa3a67cb6bffe05173440ccf86a9541afaf7ac59a1b6b/protobuf-6.33.4.tar.gz", hash = "sha256:dc2e61bca3b10470c1912d166fe0af67bfc20eb55971dcef8dfa48ce14f0ed91", size = 444346, upload-time = "2026-01-12T18:33:40.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/91/1e3a34881a88697a7354ffd177e8746e97a722e5e8db101544b47e84afb1/protobuf-6.33.2-cp310-abi3-win32.whl", hash = "sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d", size = 425603, upload-time = "2025-12-06T00:17:41.114Z" }, - { url = "https://files.pythonhosted.org/packages/64/20/4d50191997e917ae13ad0a235c8b42d8c1ab9c3e6fd455ca16d416944355/protobuf-6.33.2-cp310-abi3-win_amd64.whl", hash = "sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4", size = 436930, upload-time = "2025-12-06T00:17:43.278Z" }, - { url = "https://files.pythonhosted.org/packages/b2/ca/7e485da88ba45c920fb3f50ae78de29ab925d9e54ef0de678306abfbb497/protobuf-6.33.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43", size = 427621, upload-time = "2025-12-06T00:17:44.445Z" }, - { url = "https://files.pythonhosted.org/packages/7d/4f/f743761e41d3b2b2566748eb76bbff2b43e14d5fcab694f494a16458b05f/protobuf-6.33.2-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e", size = 324460, upload-time = "2025-12-06T00:17:45.678Z" }, - { url = "https://files.pythonhosted.org/packages/b1/fa/26468d00a92824020f6f2090d827078c09c9c587e34cbfd2d0c7911221f8/protobuf-6.33.2-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872", size = 339168, upload-time = "2025-12-06T00:17:46.813Z" }, - { url = "https://files.pythonhosted.org/packages/56/13/333b8f421738f149d4fe5e49553bc2a2ab75235486259f689b4b91f96cec/protobuf-6.33.2-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f", size = 323270, upload-time = "2025-12-06T00:17:48.253Z" }, - { url = "https://files.pythonhosted.org/packages/0e/15/4f02896cc3df04fc465010a4c6a0cd89810f54617a32a70ef531ed75d61c/protobuf-6.33.2-py3-none-any.whl", hash = "sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c", size = 170501, upload-time = "2025-12-06T00:17:52.211Z" }, + { url = "https://files.pythonhosted.org/packages/e0/be/24ef9f3095bacdf95b458543334d0c4908ccdaee5130420bf064492c325f/protobuf-6.33.4-cp310-abi3-win32.whl", hash = "sha256:918966612c8232fc6c24c78e1cd89784307f5814ad7506c308ee3cf86662850d", size = 425612, upload-time = "2026-01-12T18:33:29.656Z" }, + { url = "https://files.pythonhosted.org/packages/31/ad/e5693e1974a28869e7cd244302911955c1cebc0161eb32dfa2b25b6e96f0/protobuf-6.33.4-cp310-abi3-win_amd64.whl", hash = "sha256:8f11ffae31ec67fc2554c2ef891dcb561dae9a2a3ed941f9e134c2db06657dbc", size = 436962, upload-time = "2026-01-12T18:33:31.345Z" }, + { url = "https://files.pythonhosted.org/packages/66/15/6ee23553b6bfd82670207ead921f4d8ef14c107e5e11443b04caeb5ab5ec/protobuf-6.33.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2fe67f6c014c84f655ee06f6f66213f9254b3a8b6bda6cda0ccd4232c73c06f0", size = 427612, upload-time = "2026-01-12T18:33:32.646Z" }, + { url = "https://files.pythonhosted.org/packages/2b/48/d301907ce6d0db75f959ca74f44b475a9caa8fcba102d098d3c3dd0f2d3f/protobuf-6.33.4-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:757c978f82e74d75cba88eddec479df9b99a42b31193313b75e492c06a51764e", size = 324484, upload-time = "2026-01-12T18:33:33.789Z" }, + { url = "https://files.pythonhosted.org/packages/92/1c/e53078d3f7fe710572ab2dcffd993e1e3b438ae71cfc031b71bae44fcb2d/protobuf-6.33.4-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c7c64f259c618f0bef7bee042075e390debbf9682334be2b67408ec7c1c09ee6", size = 339256, upload-time = "2026-01-12T18:33:35.231Z" }, + { url = "https://files.pythonhosted.org/packages/e8/8e/971c0edd084914f7ee7c23aa70ba89e8903918adca179319ee94403701d5/protobuf-6.33.4-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:3df850c2f8db9934de4cf8f9152f8dc2558f49f298f37f90c517e8e5c84c30e9", size = 323311, upload-time = "2026-01-12T18:33:36.305Z" }, + { url = "https://files.pythonhosted.org/packages/75/b1/1dc83c2c661b4c62d56cc081706ee33a4fc2835bd90f965baa2663ef7676/protobuf-6.33.4-py3-none-any.whl", hash = "sha256:1fe3730068fcf2e595816a6c34fe66eeedd37d51d0400b72fabc848811fdc1bc", size = 170532, upload-time = "2026-01-12T18:33:39.199Z" }, ] [[package]] @@ -6735,4 +6806,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" }, { url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" }, { url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" }, -] \ No newline at end of file +]