diff --git a/integration_tests/benchmarks_old/dist.py b/integration_tests/benchmarks_old/dist.py index f3a8d5d..edcd09c 100644 --- a/integration_tests/benchmarks_old/dist.py +++ b/integration_tests/benchmarks_old/dist.py @@ -15,7 +15,7 @@ import os from cascade.low.builders import JobBuilder, TaskBuilder -from cascade.low.core import JobInstance, SchedulingConstraint +from cascade.low.core import JobInstance, SchedulingConstraint, TaskId def source_func() -> int: @@ -112,6 +112,6 @@ def get_job() -> JobInstance: job.nodes["sink"].definition.input_schema[f"v{i}"] = "int" # TODO put some allow_kw into TaskDefinition instead to allow this job = job.build().get_or_raise() - job.ext_outputs = list(job.outputs_of("sink")) - job.constraints = [SchedulingConstraint(gang=[f"proc{i}" for i in range(L)])] + job.ext_outputs = list(job.outputs_of(TaskId("sink"))) + job.constraints = [SchedulingConstraint(gang=[TaskId(f"proc{i}") for i in range(L)])] return job diff --git a/integration_tests/benchmarks_old/matmul.py b/integration_tests/benchmarks_old/matmul.py index dc1eb6f..d1795d8 100644 --- a/integration_tests/benchmarks_old/matmul.py +++ b/integration_tests/benchmarks_old/matmul.py @@ -6,7 +6,7 @@ import jax.random as jr # ty: ignore[unresolved-import] from cascade.low.builders import JobBuilder, TaskBuilder -from cascade.low.core import JobInstance +from cascade.low.core import JobInstance, TaskId def get_funcs(): @@ -47,7 +47,7 @@ def get_job() -> JobInstance: prv = cur job = job.build().get_or_raise() - job.ext_outputs = list(job.outputs_of(cur)) + job.ext_outputs = list(job.outputs_of(TaskId(cur))) return job diff --git a/src/cascade/controller/notify.py b/src/cascade/controller/notify.py index 6185fa5..f4790fe 100644 --- a/src/cascade/controller/notify.py +++ b/src/cascade/controller/notify.py @@ -90,7 +90,7 @@ def notify( if isinstance(event, DatasetPublished): logger.debug(f"received {event=}") # NOTE here we'll need to distinguish memory-only and host-wide (shm) publications, currently all events mean shm - host = event.origin if isinstance(event.origin, HostId) else event.origin.host + host = cast(HostId, event.origin) if not isinstance(event.origin, WorkerId) else event.origin.host context.host2ds[host][event.ds] = DatasetStatus.available context.ds2host[event.ds][host] = DatasetStatus.available state.consider_fetch(event.ds, host) diff --git a/src/cascade/controller/report.py b/src/cascade/controller/report.py index 337e2d5..3fb4252 100644 --- a/src/cascade/controller/report.py +++ b/src/cascade/controller/report.py @@ -12,6 +12,7 @@ import pickle from dataclasses import dataclass from time import monotonic_ns +from typing import NewType import zmq from typing_extensions import Self @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) -JobId = str +JobId = NewType("JobId", str) @dataclass @@ -83,7 +84,7 @@ def __init__(self, report_address: str | None) -> None: return address, job_id = report_address.split(",", 1) logger.debug(f"initialising reporter with {address=} and {job_id=}") - self.job_id = job_id + self.job_id = JobId(job_id) self.socket = get_context().socket(zmq.PUSH) self.socket.connect(address) diff --git a/src/cascade/executor/bridge.py b/src/cascade/executor/bridge.py index 13cb3a3..528c903 100644 --- a/src/cascade/executor/bridge.py +++ b/src/cascade/executor/bridge.py @@ -10,6 +10,7 @@ import logging import time +from typing import cast from cascade.executor.checkpoints import build_persist_command, build_retrieve_command, serialize_params from cascade.executor.comms import GraceWatcher, Listener, ReliableSender @@ -72,7 +73,7 @@ def __init__(self, controller_url: str, expected_executors: int, checkpoint_spec logger.warning(f"double registration of {message.host}, suggesting network congestion") continue self.sender.add_host(message.host, message.maddress) - self.sender.add_host("data." + message.host, message.daddress) + self.sender.add_host(HostId("data." + message.host), message.daddress) for worker in message.workers: self.environment.workers[worker.worker_id] = Worker(cpu=worker.cpu, gpu=worker.gpu, memory_mb=worker.memory_mb) self.environment.host_url_base[message.host] = message.url_base @@ -97,9 +98,11 @@ def recv_events(self) -> list[Event]: while (not events) and (not shutdown_reason): # timeout ms matches for message in self.mlistener.recv_messages(timeout_ms=resend_grace_ms): - if hasattr(message, "host") and isinstance((host := message.host), HostId): + if hasattr(message, "host") and isinstance(message.host, str): + host = cast(HostId, message.host) self.heartbeat_checker[host].step() - if hasattr(message, "worker") and isinstance((worker := message.worker), WorkerId): + if hasattr(message, "worker") and isinstance(message.worker, WorkerId): + worker = message.worker self.heartbeat_checker[worker.host].step() if isinstance(message, Event): events.append(message) @@ -111,7 +114,7 @@ def recv_events(self) -> list[Event]: logger.critical(f"received failure {message=}, proceeding with a shutdown") if isinstance(message, ExecutorExit | ExecutorFailure) and message.host in self.sender.hosts: self.sender.hosts.pop(message.host) - self.sender.hosts.pop("data." + message.host) + self.sender.hosts.pop(HostId("data." + message.host)) shutdown_reason = message elif isinstance(message, Unsupported): logger.critical(f"received unexpected {message=}, proceeding with a shutdown") @@ -150,32 +153,32 @@ def purge(self, host: HostId, ds: DatasetId) -> None: def transmit(self, ds: DatasetId, source: HostId, target: HostId) -> None: if source == VirtualCheckpointHost: command = build_retrieve_command(self.checkpoint_spec, ds, target) - self.sender.send("data." + target, command) + self.sender.send(HostId("data." + target), command) else: m = DatasetTransmitCommand( source=source, target=target, - daddress=self.sender.hosts["data." + target][1], + daddress=self.sender.hosts[HostId("data." + target)][1], ds=ds, idx=self.transmit_idx_counter, ) self.transmit_idx_counter += 1 - self.sender.send("data." + source, m) + self.sender.send(HostId("data." + source), m) def persist(self, ds: DatasetId, source: HostId) -> None: command = build_persist_command(self.checkpoint_spec, ds, source) - self.sender.send("data." + source, command) + self.sender.send(HostId("data." + source), command) def fetch(self, ds: DatasetId, source: HostId) -> None: m = DatasetTransmitCommand( source=source, - target="controller", + target=HostId("controller"), daddress=self.mlistener.address, ds=ds, idx=self.transmit_idx_counter, ) self.transmit_idx_counter += 1 - self.sender.send("data." + source, m) + self.sender.send(HostId("data." + source), m) def shutdown(self) -> None: m = ExecutorShutdown() @@ -189,7 +192,7 @@ def shutdown(self) -> None: if isinstance(message, ExecutorExit | ExecutorFailure): if message.host in self.sender.hosts: self.sender.hosts.pop(message.host) - self.sender.hosts.pop("data." + message.host) + self.sender.hosts.pop(HostId("data." + message.host)) else: logger.warning(f"ignoring {type(message)}") if self.sender.hosts: diff --git a/src/cascade/executor/data_server.py b/src/cascade/executor/data_server.py index 2d38c34..c641573 100644 --- a/src/cascade/executor/data_server.py +++ b/src/cascade/executor/data_server.py @@ -41,7 +41,7 @@ Syn, ) from cascade.executor.runner.memory import ds2shmid -from cascade.low.core import DatasetId +from cascade.low.core import DatasetId, HostId from cascade.low.exceptions import CascadeError, CascadeInfrastructureError, CascadeInternalError, ser from cascade.low.func import assert_never from cascade.low.tracing import TransmitLifecycle, label, mark @@ -54,7 +54,7 @@ def __init__( self, maddress: BackboneAddress, daddress: BackboneAddress, - host: str, + host: HostId, logging_config: dict, ): logging.config.dictConfig(logging_config) @@ -347,7 +347,7 @@ def recv_loop(self) -> None: def start_data_server( maddress: BackboneAddress, daddress: BackboneAddress, - host: str, + host: HostId, logging_config: dict, ): server = DataServer(maddress, daddress, host, logging_config) diff --git a/src/cascade/executor/executor.py b/src/cascade/executor/executor.py index 0729f93..3120570 100644 --- a/src/cascade/executor/executor.py +++ b/src/cascade/executor/executor.py @@ -111,7 +111,7 @@ def __init__( # NOTE following inits are with potential side effects self.mlistener = Listener(address_of(portBase)) self.sender = ReliableSender(self.mlistener.address, resend_grace_ms) - self.sender.add_host("controller", controller_address) + self.sender.add_host(HostId("controller"), controller_address) # TODO make the shm server params configurable shm_port = f"/tmp/cascShmSock-{uuid.uuid4()}" # portBase + 2 shm_api.publish_socket_addr(shm_port) @@ -191,7 +191,7 @@ def terminate(self) -> None: def to_controller(self, m: Message) -> None: self.heartbeat_watcher.step() - self.sender.send("controller", m) + self.sender.send(HostId("controller"), m) def _start_worker(self, worker: WorkerId, attempt_cnt: int, seq: None | TaskSequence) -> WorkerHandle: ctx = platform.get_mp_ctx("worker") diff --git a/src/cascade/gateway/router.py b/src/cascade/gateway/router.py index c1c22a0..8516221 100644 --- a/src/cascade/gateway/router.py +++ b/src/cascade/gateway/router.py @@ -56,11 +56,11 @@ def __init__( max_jobs: int | None, ): self.poller = poller - self.jobs: dict[str, Job] = {} + self.jobs: dict[JobId, Job] = {} self.active_jobs = 0 self.max_jobs = max_jobs self.jobs_queue: OrderedDict[JobId, JobSpec] = OrderedDict() - self.procs: dict[str, subprocess.Popen] = {} + self.procs: dict[JobId, subprocess.Popen] = {} self.loggingConfig = loggingConfig self.troika_config = troika_config @@ -85,7 +85,7 @@ def maybe_spawn(self) -> None: def enqueue_job(self, job_spec: JobSpec) -> JobId: job_id = next_uuid( set(self.jobs.keys()).union(self.jobs_queue.keys()), - lambda: str(uuid.uuid4()), + lambda: JobId(str(uuid.uuid4())), ) self.jobs_queue[job_id] = job_spec self.maybe_spawn() diff --git a/src/cascade/gateway/server.py b/src/cascade/gateway/server.py index e39ac77..ed1754a 100644 --- a/src/cascade/gateway/server.py +++ b/src/cascade/gateway/server.py @@ -18,7 +18,7 @@ import zmq import cascade.gateway.api as api -from cascade.controller.report import JobProgress, deserialize +from cascade.controller.report import JobId, JobProgress, deserialize from cascade.deployment.logging import LoggingConfig, init_from_obj from cascade.executor.comms import get_context from cascade.gateway.client import parse_request, serialize_response @@ -44,7 +44,7 @@ def handle_fe(socket: zmq.Socket, jobs: JobRouter) -> bool: try: progresses, datasets, queue_length = jobs.progress_of(m.job_ids) rv = api.JobProgressResponse( - progresses=cast(dict[str, JobProgress | None], progresses), + progresses=cast(dict[JobId, JobProgress | None], progresses), datasets=datasets, error=None, queue_length=queue_length, diff --git a/src/cascade/low/builders.py b/src/cascade/low/builders.py index a3b8d15..424a36b 100644 --- a/src/cascade/low/builders.py +++ b/src/cascade/low/builders.py @@ -21,6 +21,7 @@ JobInstance, Task2TaskEdge, TaskDefinition, + TaskId, TaskInstance, ) from cascade.low.func import Either @@ -100,12 +101,12 @@ def with_node(self, name: str, task: TaskInstance) -> Self: return replace(self, nodes=self.nodes.set(name, task)) def with_output(self, task: str, output: str = DefaultTaskOutput) -> Self: - return replace(self, outputs=self.outputs.append(DatasetId(task, output))) + return replace(self, outputs=self.outputs.append(DatasetId(TaskId(task), output))) def with_edge(self, source: str, sink: str, into: str | int, frum: str = DefaultTaskOutput) -> Self: new_edge = Task2TaskEdge( - source=DatasetId(source, frum), - sink_task=sink, + source=DatasetId(TaskId(source), frum), + sink_task=TaskId(sink), sink_input_kw=into if isinstance(into, str) else None, sink_input_ps=into if isinstance(into, int) else None, ) @@ -193,7 +194,7 @@ def get_edge_errors(edge: Task2TaskEdge) -> Iterator[str]: else: return Either.ok( JobInstance( - tasks=cast(dict[str, TaskInstance], pyrsistent.thaw(self.nodes)), + tasks=cast(dict[TaskId, TaskInstance], pyrsistent.thaw(self.nodes)), edges=pyrsistent.thaw(self.edges), ext_outputs=pyrsistent.thaw(self.outputs), ) diff --git a/src/cascade/low/core.py b/src/cascade/low/core.py index 8643284..b5f18bb 100644 --- a/src/cascade/low/core.py +++ b/src/cascade/low/core.py @@ -12,7 +12,7 @@ from base64 import b64decode, b64encode from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional, Type, cast +from typing import Any, Callable, Literal, NewType, Optional, Type, cast import cloudpickle from pydantic import BaseModel, Field @@ -58,7 +58,7 @@ def func_enc(f: Callable) -> str: return b64encode(cloudpickle.dumps(f)).decode("ascii") -TaskId = str +TaskId = NewType("TaskId", str) @dataclass(frozen=True) @@ -80,7 +80,7 @@ def ser(self) -> str: def des(cls, value: str) -> Self: pref, suf = value.split(".", 1) task_len = int.from_bytes(bytes.fromhex(pref), byteorder="big") - return cls(task=suf[:task_len], output=suf[task_len:]) + return cls(task=TaskId(suf[:task_len]), output=suf[task_len:]) class Task2TaskEdge(BaseModel): @@ -138,7 +138,7 @@ def outputs_of(self, task_id: TaskId) -> set[DatasetId]: return {DatasetId(task_id, output) for output, _ in self.tasks[task_id].definition.output_schema} -HostId = str +HostId = NewType("HostId", str) @dataclass(frozen=True) @@ -152,7 +152,7 @@ def __repr__(self) -> str: @classmethod def from_repr(cls, value: str) -> Self: host, worker = value.split(".", 1) - return cls(host=host, worker=worker) + return cls(host=HostId(host), worker=worker) def worker_num(self) -> int: """Used eg for gpu allocation""" @@ -193,7 +193,7 @@ class JobExecutionRecord(BaseModel): CheckpointStorageType = Literal["fs"] -StorageId = str +StorageId = NewType("StorageId", str) class CheckpointSpec(BaseModel): diff --git a/src/cascade/low/dask.py b/src/cascade/low/dask.py index f9d08fc..714f606 100644 --- a/src/cascade/low/dask.py +++ b/src/cascade/low/dask.py @@ -21,7 +21,7 @@ from dask._task_spec import Alias, DataNode, Task, TaskRef from cascade.low.builders import TaskBuilder -from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskInstance +from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskId, TaskInstance from cascade.low.exceptions import CascadeUserError logger = logging.getLogger(__name__) @@ -38,8 +38,8 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]: for i, v in enumerate(task.args): if isinstance(v, Alias | TaskRef): edge = Task2TaskEdge( - source=DatasetId(task=daskKeyRepr(v.key), output=DefaultTaskOutput), - sink_task=key, + source=DatasetId(task=TaskId(daskKeyRepr(v.key)), output=DefaultTaskOutput), + sink_task=TaskId(key), sink_input_ps=i, sink_input_kw=None, ) @@ -52,8 +52,8 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]: for k, v in task.kwargs.items(): if isinstance(v, Alias | TaskRef): edge = Task2TaskEdge( - source=DatasetId(task=daskKeyRepr(v.key), output=DefaultTaskOutput), - sink_task=key, + source=DatasetId(task=TaskId(daskKeyRepr(v.key)), output=DefaultTaskOutput), + sink_task=TaskId(key), sink_input_kw=k, sink_input_ps=None, ) @@ -68,7 +68,7 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]: def graph2job(dask: dict) -> JobInstance: - task_nodes = {} + task_nodes: dict[TaskId, TaskInstance] = {} edges = [] for node, value in dask.items(): @@ -78,10 +78,10 @@ def graph2job(dask: dict) -> JobInstance: def provider() -> Any: return value.value - task_nodes[key] = TaskBuilder.from_callable(provider) + task_nodes[TaskId(key)] = TaskBuilder.from_callable(provider) elif isinstance(value, Task): node, _edges = task2task(key, value) - task_nodes[key] = node + task_nodes[TaskId(key)] = node edges.extend(_edges) elif isinstance(value, list | tuple | set): # TODO implement, consult further: diff --git a/src/cascade/low/execution_context.py b/src/cascade/low/execution_context.py index 1cf84c4..f9a8c7d 100644 --- a/src/cascade/low/execution_context.py +++ b/src/cascade/low/execution_context.py @@ -42,7 +42,7 @@ class TaskStatus(int, Enum): failed = 3 # set by executor -VirtualCheckpointHost: HostId = "virtualCheckpointHost" +VirtualCheckpointHost: HostId = HostId("virtualCheckpointHost") @dataclass diff --git a/src/cascade/low/into.py b/src/cascade/low/into.py index d522f8b..04a75e8 100644 --- a/src/cascade/low/into.py +++ b/src/cascade/low/into.py @@ -11,7 +11,7 @@ import logging from typing import Any, Callable, cast -from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskDefinition, TaskInstance +from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskDefinition, TaskId, TaskInstance logger = logging.getLogger(__name__) @@ -55,13 +55,13 @@ def node2task(name: str, node: dict) -> tuple[TaskInstance, list[Task2TaskEdge]] edges = [] for param, other in node["inputs"].items(): if isinstance(other, str): - source = DatasetId(other, DefaultTaskOutput) + source = DatasetId(TaskId(other), DefaultTaskOutput) else: - source = DatasetId(other[0], other[1]) + source = DatasetId(TaskId(other[0]), other[1]) edges.append( Task2TaskEdge( source=source, - sink_task=name, + sink_task=TaskId(name), sink_input_ps=rev_lookup[param], sink_input_kw=None, ) @@ -89,9 +89,9 @@ def node2task(name: str, node: dict) -> tuple[TaskInstance, list[Task2TaskEdge]] def graph2job(graph: dict) -> JobInstance: # graph assumed to be ekw.graph.serialise(ekw.graph.Graph) edges = [] - tasks = {} + tasks: dict[TaskId, TaskInstance] = {} for node_name, node_val in graph.items(): task, task_edges = node2task(node_name, node_val) edges += task_edges - tasks[node_name] = task + tasks[TaskId(node_name)] = task return JobInstance(tasks=tasks, edges=edges) diff --git a/src/cascade/main.py b/src/cascade/main.py index 425af14..e5ae8fc 100644 --- a/src/cascade/main.py +++ b/src/cascade/main.py @@ -28,7 +28,7 @@ from cascade.executor.config import logging_config, logging_config_filehandler from cascade.executor.executor import Executor from cascade.executor.msg import BackboneAddress, ExecutorShutdown -from cascade.low.core import DatasetId, JobInstance, JobInstanceRich +from cascade.low.core import DatasetId, HostId, JobInstance, JobInstanceRich from cascade.low.exceptions import CascadeError, CascadeInfrastructureError from cascade.low.func import msum from cascade.scheduler.precompute import precompute @@ -85,7 +85,7 @@ def launch_executor( job.jobInstance, controller_address, workers_per_host, - f"h{i}", + HostId(f"h{i}"), portBase, shm_vol_gb, loggingConfig, diff --git a/src/cascade/scheduler/api.py b/src/cascade/scheduler/api.py index f5b323d..9d2408e 100644 --- a/src/cascade/scheduler/api.py +++ b/src/cascade/scheduler/api.py @@ -93,7 +93,7 @@ def init_schedule(preschedule: Preschedule, context: JobExecutionContext) -> Sch components.append(component) computable += len(precomponent.sources) for task in precomponent.nodes: - ts2component[task] = componentId + ts2component[task] = ComponentId(componentId) if gangs: for gang in gangs: @@ -133,7 +133,9 @@ def assign(schedule: Schedule, context: JobExecutionContext) -> Iterator[Assignm return # step II: assign remaining workers to new components - components = [(component.weight, component_id) for component_id, component in enumerate(schedule.components) if component.weight > 0] + components = [ + (component.weight, ComponentId(component_id)) for component_id, component in enumerate(schedule.components) if component.weight > 0 + ] if not components: return diff --git a/src/cascade/scheduler/core.py b/src/cascade/scheduler/core.py index 20b52ca..2ebe109 100644 --- a/src/cascade/scheduler/core.py +++ b/src/cascade/scheduler/core.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. from dataclasses import dataclass +from typing import NewType from cascade.low.core import DatasetId, HostId, TaskId, WorkerId @@ -39,7 +40,7 @@ class Preschedule: Worker2TaskDistance = dict[WorkerId, dict[TaskId, int]] -ComponentId = int +ComponentId = NewType("ComponentId", int) @dataclass diff --git a/tests/cascade/controller/test_run.py b/tests/cascade/controller/test_run.py index efa6a06..933035e 100644 --- a/tests/cascade/controller/test_run.py +++ b/tests/cascade/controller/test_run.py @@ -30,7 +30,7 @@ from cascade.executor.executor import Executor from cascade.executor.msg import BackboneAddress, ExecutorShutdown from cascade.low.builders import JobBuilder, TaskBuilder -from cascade.low.core import CheckpointSpec, DatasetId, JobInstance, JobInstanceRich +from cascade.low.core import CheckpointSpec, DatasetId, HostId, JobInstance, JobInstanceRich, StorageId, TaskId from cascade.scheduler.core import Preschedule from cascade.scheduler.precompute import precompute @@ -53,7 +53,7 @@ def launch_executor( job_instance, controller_address, 2, - f"test_executor{i}", + HostId(f"test_executor{i}"), portBase, None, DefaultLoggingConfig, @@ -169,8 +169,8 @@ def test_para1_persist(): storage_type="fs", storage_params=td, retrieve_id=None, - persist_id="run1", - to_persist=[DatasetId(task="c2i1", output="0")], + persist_id=StorageId("run1"), + to_persist=[DatasetId(task=TaskId("c2i1"), output="0")], ) job.checkpointSpec = spec run_cluster(job, 12600, 1) @@ -210,13 +210,13 @@ def test_fusing(): .get_or_raise() ) preschedule = precompute(job) - assert preschedule.components[0].fusing_opportunities["source"] == [ - "source", - "m1", - "m2", - "m3", - "m4", - "sink", + assert preschedule.components[0].fusing_opportunities[TaskId("source")] == [ + TaskId("source"), + TaskId("m1"), + TaskId("m2"), + TaskId("m3"), + TaskId("m4"), + TaskId("sink"), ] # TODO we currently dont check that those actually *got fused* -- fix jobInstanceRich = JobInstanceRich(jobInstance=job, checkpointSpec=None) @@ -249,9 +249,9 @@ def test_checkpoints(): checkpointSpec = CheckpointSpec( storage_type="fs", storage_params=ckpt_root, - retrieve_id="f1", - persist_id="f1", - to_persist=[DatasetId(task="source", output="0")], + retrieve_id=StorageId("f1"), + persist_id=StorageId("f1"), + to_persist=[DatasetId(task=TaskId("source"), output="0")], ) jobInstanceRich = JobInstanceRich( diff --git a/tests/cascade/executor/test_checkpoints.py b/tests/cascade/executor/test_checkpoints.py index c4ca48f..ab2f4b5 100644 --- a/tests/cascade/executor/test_checkpoints.py +++ b/tests/cascade/executor/test_checkpoints.py @@ -10,19 +10,19 @@ retrieve_dataset, ) from cascade.executor.serde import DefaultSerde -from cascade.low.core import CheckpointSpec, DatasetId +from cascade.low.core import CheckpointSpec, DatasetId, StorageId, TaskId from cascade.low.execution_context import VirtualCheckpointHost from cascade.shm.client import AllocatedBuffer def test_rw(): with tempfile.TemporaryDirectory() as td: - ds1 = DatasetId(task="1", output="0") + ds1 = DatasetId(task=TaskId("1"), output="0") spec = CheckpointSpec( storage_type="fs", storage_params=td, - retrieve_id="subfolder", - persist_id="subfolder", + retrieve_id=StorageId("subfolder"), + persist_id=StorageId("subfolder"), to_persist=[ds1], ) diff --git a/tests/cascade/executor/test_executor.py b/tests/cascade/executor/test_executor.py index 2c47e1c..924c708 100644 --- a/tests/cascade/executor/test_executor.py +++ b/tests/cascade/executor/test_executor.py @@ -39,9 +39,11 @@ ) from cascade.low.core import ( DatasetId, + HostId, JobInstance, Task2TaskEdge, TaskDefinition, + TaskId, TaskInstance, WorkerId, ) @@ -55,7 +57,7 @@ def launch_executor(job_instance: JobInstance, controller_address: BackboneAddre job_instance, controller_address, 4, - "test_executor", + HostId("test_executor"), portBase, None, DefaultLoggingConfig, @@ -81,16 +83,16 @@ def test_func(x: np.ndarray) -> np.ndarray: static_input_kw={"x": np.array([1.0])}, static_input_ps={}, ) - source_o = DatasetId("source", "o") + source_o = DatasetId(TaskId("source"), "o") sink = TaskInstance( definition=task_definition, static_input_kw={}, static_input_ps={}, ) - sink_o = DatasetId("sink", "o") + sink_o = DatasetId(TaskId("sink"), "o") job = JobInstance( - tasks={"source": source, "sink": sink}, - edges=[Task2TaskEdge(source=source_o, sink_task="sink", sink_input_kw="x", sink_input_ps=None)], + tasks={TaskId("source"): source, TaskId("sink"): sink}, + edges=[Task2TaskEdge(source=source_o, sink_task=TaskId("sink"), sink_input_kw="x", sink_input_ps=None)], ) # cluster setup @@ -106,12 +108,12 @@ def test_func(x: np.ndarray) -> np.ndarray: # register ms = l.recv_messages(None) expected_registration = ExecutorRegistration( - host="test_executor", + host=HostId("test_executor"), maddress=m1, daddress=d1, workers=[ Worker( - worker_id=WorkerId("test_executor", f"w{i}"), + worker_id=WorkerId(HostId("test_executor"), f"w{i}"), cpu=1, gpu=0, memory_mb=1024, @@ -126,10 +128,10 @@ def test_func(x: np.ndarray) -> np.ndarray: assert m == expected_registration # submit graph - w0 = WorkerId("test_executor", "w0") + w0 = WorkerId(HostId("test_executor"), "w0") callback( m1, - TaskSequence(worker=w0, tasks=["source", "sink"], publish={sink_o}, extra_env=[]), + TaskSequence(worker=w0, tasks=[TaskId("source"), TaskId("sink")], publish={sink_o}, extra_env=[]), ) # NOTE we need to expect source_o dataset too, because of no finegraining for host-wide and worker-only expected = { @@ -148,13 +150,13 @@ def test_func(x: np.ndarray) -> np.ndarray: DatasetTransmitCommand( ds=sink_o, idx=0, - source="test_executor", - target="controller", + source=HostId("test_executor"), + target=HostId("controller"), daddress=c1, ), ) ms = l.recv_messages() - assert len(ms) == 1 and isinstance(ms[0], DatasetTransmitPayload) and ms[0].header.ds == DatasetId(task="sink", output="o") + assert len(ms) == 1 and isinstance(ms[0], DatasetTransmitPayload) and ms[0].header.ds == DatasetId(task=TaskId("sink"), output="o") assert serde.des_output(ms[0].value, "int", ms[0].header.deser_fun)[0] == 3.0 # purge, store, run partial and fetch again @@ -168,14 +170,14 @@ def test_func(x: np.ndarray) -> np.ndarray: send_data(d1, payload, syn) expected = { Ack(idx=1), - DatasetPublished(origin="test_executor", ds=source_o, transmit_idx=1), + DatasetPublished(origin=HostId("test_executor"), ds=source_o, transmit_idx=1), } while expected: ms = l.recv_messages() for m in ms: logger.debug(f"about to remove received message {m}") expected.remove(m) - callback(m1, TaskSequence(worker=w0, tasks=["sink"], publish={sink_o}, extra_env=[])) + callback(m1, TaskSequence(worker=w0, tasks=[TaskId("sink")], publish={sink_o}, extra_env=[])) expected = [ DatasetPublished(w0, ds=sink_o, transmit_idx=None), ] @@ -189,8 +191,8 @@ def test_func(x: np.ndarray) -> np.ndarray: DatasetTransmitCommand( ds=sink_o, idx=2, - source="test_executor", - target="controller", + source=HostId("test_executor"), + target=HostId("controller"), daddress=c1, ), ) @@ -202,7 +204,7 @@ def test_func(x: np.ndarray) -> np.ndarray: # shutdown callback(m1, ExecutorShutdown()) ms = l.recv_messages() - assert ExecutorExit(host="test_executor") in ms + assert ExecutorExit(host=HostId("test_executor")) in ms p.join() except: if p.is_alive(): diff --git a/tests/cascade/executor/test_runner.py b/tests/cascade/executor/test_runner.py index 11e9f14..64da340 100644 --- a/tests/cascade/executor/test_runner.py +++ b/tests/cascade/executor/test_runner.py @@ -21,9 +21,11 @@ from cascade.executor.runner.packages import PackagesEnv from cascade.low.core import ( DatasetId, + HostId, JobInstance, Task2TaskEdge, TaskDefinition, + TaskId, TaskInstance, WorkerId, ) @@ -31,7 +33,7 @@ def test_runner(monkeypatch): - worker = WorkerId("h0", "w0") + worker = WorkerId(HostId("h0"), "w0") # monkeypatching test_address = "zmq:test" @@ -112,14 +114,14 @@ def test_func(x): static_input_kw={"x": 1}, static_input_ps={}, ) - t2ds = DatasetId("t2", "o") + t2ds = DatasetId(TaskId("t2"), "o") oneTaskTs = TaskSequence( worker=worker, - tasks=["t2"], + tasks=[TaskId("t2")], publish={t2ds}, extra_env=[], ) - oneTaskJob = JobInstance(tasks={"t2": t2}, edges=[]) + oneTaskJob = JobInstance(tasks={TaskId("t2"): t2}, edges=[]) oneTaskRc = entrypoint.RunnerContext( workerId=worker, workerAttemptCnt=0, @@ -151,17 +153,17 @@ def test_func(x): static_input_kw={}, static_input_ps={}, ) - t3i = DatasetId("t3a", "o") - t3o = DatasetId("t3b", "o") + t3i = DatasetId(TaskId("t3a"), "o") + t3o = DatasetId(TaskId("t3b"), "o") twoTaskTs = TaskSequence( worker=worker, - tasks=["t3a", "t3b"], + tasks=[TaskId("t3a"), TaskId("t3b")], publish={t3o}, extra_env=[], ) twoTaskJob = JobInstance( - tasks={"t3a": t3a, "t3b": t3b}, - edges=[Task2TaskEdge(source=t3i, sink_task="t3b", sink_input_kw="x", sink_input_ps=None)], + tasks={TaskId("t3a"): t3a, TaskId("t3b"): t3b}, + edges=[Task2TaskEdge(source=t3i, sink_task=TaskId("t3b"), sink_input_kw="x", sink_input_ps=None)], ) twoTaskRc = entrypoint.RunnerContext( workerId=worker, @@ -207,25 +209,25 @@ def gen_func(): static_input_kw={}, static_input_ps={}, ) - t4gOutputs = [DatasetId("t4g", k) for k, _ in gen_definition.output_schema] + t4gOutputs = [DatasetId(TaskId("t4g"), k) for k, _ in gen_definition.output_schema] t4c = TaskInstance( definition=task_definition, static_input_kw={}, static_input_ps={}, ) - t4pOutputs = [DatasetId(f"t4c{i}", "o") for i in range(N)] + t4pOutputs = [DatasetId(TaskId(f"t4c{i}"), "o") for i in range(N)] t4TaskTs = TaskSequence( worker=worker, - tasks=["t4g"] + [f"t4c{i}" for i in range(N)], + tasks=[TaskId("t4g")] + [TaskId(f"t4c{i}") for i in range(N)], publish=set(t4pOutputs), extra_env=[], ) t4Job = JobInstance( - tasks={**{"t4g": t4g}, **{f"t4c{i}": t4c for i in range(N)}}, + tasks={TaskId("t4g"): t4g, **{TaskId(f"t4c{i}"): t4c for i in range(N)}}, edges=[ Task2TaskEdge( source=t4gOutputs[i], - sink_task=f"t4c{i}", + sink_task=TaskId(f"t4c{i}"), sink_input_kw="x", sink_input_ps=None, ) diff --git a/tests/cascade/executor/util.py b/tests/cascade/executor/util.py index 38e88c9..b11a0db 100644 --- a/tests/cascade/executor/util.py +++ b/tests/cascade/executor/util.py @@ -30,7 +30,7 @@ from cascade.executor.runner.packages import PackagesEnv from cascade.executor.runner.runner import ExecutionContext, run from cascade.low.builders import TaskBuilder -from cascade.low.core import DatasetId, WorkerId +from cascade.low.core import DatasetId, HostId, TaskId, WorkerId from cascade.shm.server import entrypoint as shm_server logger = logging.getLogger(__name__) @@ -82,7 +82,7 @@ def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext) raise ValueError(f"expected 1 task, gotten {len(tasks)}") taskId = tasks[0] taskInstance = executionContext.tasks[taskId] - with Memory(callback, WorkerId(host="testHost", worker="testWorker")) as memory, PackagesEnv() as pckg: + with Memory(callback, WorkerId(host=HostId("testHost"), worker="testWorker")) as memory, PackagesEnv() as pckg: # for key, value in taskSequence.extra_env.items(): # os.environ[key] = value @@ -94,8 +94,8 @@ def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext) def callable2ctx(callableInstance: CallableInstance, callback: BackboneAddress) -> ExecutionContext: taskInstance = TaskBuilder.from_callable(callableInstance.func, callableInstance.env) param_source = {} - params = [(key, DatasetId("taskId", f"kwarg.{key}"), value) for key, value in callableInstance.kwargs.items()] + [ - (key, DatasetId("taskId", f"pos.{key}"), value) for key, value in callableInstance.args + params = [(key, DatasetId(TaskId("taskId"), f"kwarg.{key}"), value) for key, value in callableInstance.kwargs.items()] + [ + (key, DatasetId(TaskId("taskId"), f"pos.{key}"), value) for key, value in callableInstance.args ] for key, ds_key, value in params: raw = cloudpickle.dumps(value) @@ -106,10 +106,10 @@ def callable2ctx(callableInstance: CallableInstance, callback: BackboneAddress) param_source[key] = (ds_key, "Any") return ExecutionContext( - tasks={"taskId": taskInstance}, - param_source={"taskId": param_source}, + tasks={TaskId("taskId"): taskInstance}, + param_source={TaskId("taskId"): param_source}, callback=callback, - publish={DatasetId("taskId", output) for output, _ in taskInstance.definition.output_schema}, + publish={DatasetId(TaskId("taskId"), output) for output, _ in taskInstance.definition.output_schema}, ) @@ -121,12 +121,12 @@ def run_test(callableInstance: CallableInstance, testId: str, max_runtime_sec: i mp_ctx = platform.get_mp_ctx("executor-aux") runner = mp_ctx.Process(target=simple_runner, args=(addr, ec_ctx)) runner.start() - output = DatasetId("taskId", "0") + output = DatasetId(TaskId("taskId"), "0") end = perf_counter_ns() + max_runtime_sec * int(1e9) while perf_counter_ns() < end: mess = listener.recv_messages() - if mess == [DatasetPublished(origin=WorkerId(host="testHost", worker="testWorker"), ds=output, transmit_idx=None)]: + if mess == [DatasetPublished(origin=WorkerId(host=HostId("testHost"), worker="testWorker"), ds=output, transmit_idx=None)]: break elif not mess: continue diff --git a/tests/cascade/gateway/test_run.py b/tests/cascade/gateway/test_run.py index bdbba1c..1470f3b 100644 --- a/tests/cascade/gateway/test_run.py +++ b/tests/cascade/gateway/test_run.py @@ -5,7 +5,7 @@ import cascade.gateway.client as client from cascade.gateway.__main__ import main_cli from cascade.low.builders import JobBuilder -from cascade.low.core import DatasetId, JobInstanceRich, TaskDefinition, TaskInstance +from cascade.low.core import DatasetId, JobInstanceRich, TaskDefinition, TaskId, TaskInstance init_value = 10 job_func = lambda i: i * 2 @@ -30,7 +30,7 @@ def get_job_succ() -> JobInstanceRich: sii = TaskInstance(definition=sid, static_input_kw={}, static_input_ps={}) ji = JobBuilder().with_node("so", soi).with_node("si", sii).with_edge("so", "si", 0, "o").build().get_or_raise() - ji.ext_outputs = [DatasetId("si", "o")] + ji.ext_outputs = [DatasetId(TaskId("si"), "o")] return JobInstanceRich(jobInstance=ji, checkpointSpec=None) diff --git a/tests/cascade/low/test_core.py b/tests/cascade/low/test_core.py index 7d69820..7cf9fa5 100644 --- a/tests/cascade/low/test_core.py +++ b/tests/cascade/low/test_core.py @@ -1,11 +1,11 @@ -from cascade.low.core import DatasetId +from cascade.low.core import DatasetId, TaskId def test_datasetid_serde(): cases = [ - DatasetId(task="basic", output="0"), + DatasetId(task=TaskId("basic"), output="0"), DatasetId( - task="whoa_tricky!()(??!@#!$34--thisWouldBeABadFileNameReally\n\n\1\0\t", + task=TaskId("whoa_tricky!()(??!@#!$34--thisWouldBeABadFileNameReally\n\n\1\0\t"), output="thiscanbetrickytoo", ), ] diff --git a/tests/cascade/scheduler/test_api.py b/tests/cascade/scheduler/test_api.py index ec48f5b..b9127de 100644 --- a/tests/cascade/scheduler/test_api.py +++ b/tests/cascade/scheduler/test_api.py @@ -8,7 +8,7 @@ """Tests calculation of preschedule, state initialize & first assign and plan""" -from cascade.low.core import DatasetId, WorkerId +from cascade.low.core import DatasetId, HostId, TaskId, WorkerId from cascade.low.execution_context import TaskStatus, init_context from cascade.scheduler.api import assign, init_schedule, plan from cascade.scheduler.core import Assignment @@ -27,22 +27,22 @@ def test_job0(): preschedule.components[0].fusing_opportunities = {} h1w1 = get_env(1, 1) - h1w1_w = WorkerId("h0", "w0") + h1w1_w = WorkerId(HostId("h0"), "w0") context = init_context(h1w1, job0, preschedule.edge_o, preschedule.edge_i) schedule = init_schedule(preschedule, context) assignment = list(assign(schedule, context)) assert assignment == [ Assignment( worker=h1w1_w, - tasks=["source"], + tasks=[TaskId("source")], prep=[], - outputs={DatasetId(task="source", output="0")}, + outputs={DatasetId(task=TaskId("source"), output="0")}, extra_env=[], ) ] plan(schedule, context, assignment) - assert context.worker2ts == {h1w1_w: {"source": TaskStatus.enqueued}} + assert context.worker2ts == {h1w1_w: {TaskId("source"): TaskStatus.enqueued}} def test_job1(): @@ -52,22 +52,22 @@ def test_job1(): preschedule.components[0].fusing_opportunities = {} h1w1 = get_env(1, 1) - h1w1_w = WorkerId("h0", "w0") + h1w1_w = WorkerId(HostId("h0"), "w0") context = init_context(h1w1, job1, preschedule.edge_o, preschedule.edge_i) schedule = init_schedule(preschedule, context) assignment = list(assign(schedule, context)) assert assignment == [ Assignment( worker=h1w1_w, - tasks=["source"], + tasks=[TaskId("source")], prep=[], - outputs={DatasetId(task="source", output="0")}, + outputs={DatasetId(task=TaskId("source"), output="0")}, extra_env=[], ) ] plan(schedule, context, assignment) - assert context.worker2ts == {h1w1_w: {"source": TaskStatus.enqueued}} + assert context.worker2ts == {h1w1_w: {TaskId("source"): TaskStatus.enqueued}} # TODO add some multi-source or multi-component job diff --git a/tests/cascade/scheduler/test_checkpoints.py b/tests/cascade/scheduler/test_checkpoints.py index d3115b2..b6c00f9 100644 --- a/tests/cascade/scheduler/test_checkpoints.py +++ b/tests/cascade/scheduler/test_checkpoints.py @@ -1,5 +1,5 @@ from cascade.low.builders import JobBuilder, TaskBuilder -from cascade.low.core import DatasetId, JobInstanceRich +from cascade.low.core import DatasetId, JobInstanceRich, TaskId from cascade.scheduler.checkpoints import trim_with_persisted from cascade.scheduler.precompute import precompute @@ -51,8 +51,14 @@ def test_trim_with_checkpoints(): ) jobRich = JobInstanceRich(jobInstance=jobInstanceOrig, checkpointSpec=None) preschedule = precompute(jobRich.jobInstance) - persisted = {DatasetId(task="transform", output="0")} + persisted = {DatasetId(task=TaskId("transform"), output="0")} jobInstanceNew, preschedule, persisted_valid = trim_with_persisted(jobRich, preschedule, persisted) assert persisted_valid == persisted - assert set(jobInstanceNew.tasks.keys()) == {"source2", "transform", "product1", "product2", "sink"} + assert set(jobInstanceNew.tasks.keys()) == { + TaskId("source2"), + TaskId("transform"), + TaskId("product1"), + TaskId("product2"), + TaskId("sink"), + } diff --git a/tests/cascade/scheduler/test_graph.py b/tests/cascade/scheduler/test_graph.py index 70d1527..ffc929a 100644 --- a/tests/cascade/scheduler/test_graph.py +++ b/tests/cascade/scheduler/test_graph.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. from collections import defaultdict +from typing import cast from cascade.low.core import TaskId from cascade.scheduler.precompute import _decompose, _enrich @@ -23,21 +24,21 @@ def _oedge2iedge(edge_o: dict[TaskId, set[TaskId]]) -> dict[TaskId, set[TaskId]] def test_decompose(): # comp1: v0 -> v1 -> v2 + v3 -> v1 # comp2: v4 -> v5, v4 -> v6 - nodes = [f"v{i}" for i in range(7)] - edge_o = defaultdict( - set, - **{ - "v0": {"v1"}, - "v1": {"v2"}, - "v3": {"v1"}, - "v4": {"v5", "v6"}, - }, + nodes = [TaskId(f"v{i}") for i in range(7)] + edge_o: dict[TaskId, set[TaskId]] = defaultdict(set) + edge_o.update( + { + TaskId("v0"): {TaskId("v1")}, + TaskId("v1"): {TaskId("v2")}, + TaskId("v3"): {TaskId("v1")}, + TaskId("v4"): {TaskId("v5"), TaskId("v6")}, + } ) edge_i = _oedge2iedge(edge_o) expected = { - (frozenset({"v0", "v1", "v2", "v3"}), frozenset({"v0", "v3"})), - (frozenset({"v4", "v5", "v6"}), frozenset({"v4"})), + (frozenset({TaskId("v0"), TaskId("v1"), TaskId("v2"), TaskId("v3")}), frozenset({TaskId("v0"), TaskId("v3")})), + (frozenset({TaskId("v4"), TaskId("v5"), TaskId("v6")}), frozenset({TaskId("v4")})), } for component in _decompose(nodes, edge_i, edge_o): e = (frozenset(component[0]), frozenset(component[1])) @@ -51,41 +52,97 @@ def test_enrich(): # v3 -> v1 # v4 -> v5 -> v2 # v4 -> v6 - edge_o = defaultdict( - set, - **{ - "v0": {"v1"}, - "v1": {"v2"}, - "v3": {"v1"}, - "v4": {"v5", "v6"}, - "v5": {"v2"}, - }, + edge_o: dict[TaskId, set[TaskId]] = defaultdict(set) + edge_o.update( + { + TaskId("v0"): {TaskId("v1")}, + TaskId("v1"): {TaskId("v2")}, + TaskId("v3"): {TaskId("v1")}, + TaskId("v4"): {TaskId("v5"), TaskId("v6")}, + TaskId("v5"): {TaskId("v2")}, + } ) edge_i = _oedge2iedge(edge_o) - component = (list(set(edge_o.keys()).union(set(edge_i.keys()))), ["v0", "v3", "v4"]) + component = (list(set(edge_o.keys()).union(set(edge_i.keys()))), [TaskId("v0"), TaskId("v3"), TaskId("v4")]) res = _enrich(component, edge_i, edge_o, set(), set()) assert res.nodes == component[0] assert res.sources == component[1] assert res.weight() == len(component[0]) - value = { - "v0": 1, - "v1": 2, - "v2": 3, - "v3": 1, - "v4": 2, - "v5": 2, - "v6": 3, + value: dict[TaskId, int] = { + TaskId("v0"): 1, + TaskId("v1"): 2, + TaskId("v2"): 3, + TaskId("v3"): 1, + TaskId("v4"): 2, + TaskId("v5"): 2, + TaskId("v6"): 3, } assert res.value == value - distance_matrix = { - "v0": {"v0": 0, "v1": 1, "v2": 2, "v3": 1, "v4": 2, "v5": 2, "v6": 3}, - "v1": {"v0": 1, "v1": 0, "v2": 1, "v3": 1, "v4": 2, "v5": 1, "v6": 3}, - "v2": {"v0": 2, "v1": 1, "v2": 0, "v3": 2, "v4": 2, "v5": 1, "v6": 3}, - "v3": {"v0": 1, "v1": 1, "v2": 2, "v3": 0, "v4": 2, "v5": 2, "v6": 3}, - "v4": {"v0": 2, "v1": 2, "v2": 2, "v3": 2, "v4": 0, "v5": 1, "v6": 1}, - "v5": {"v0": 2, "v1": 1, "v2": 1, "v3": 2, "v4": 1, "v5": 0, "v6": 3}, - "v6": {"v0": 3, "v1": 3, "v2": 3, "v3": 3, "v4": 1, "v5": 3, "v6": 0}, + distance_matrix: dict[TaskId, dict[TaskId, int]] = { + TaskId("v0"): { + TaskId("v0"): 0, + TaskId("v1"): 1, + TaskId("v2"): 2, + TaskId("v3"): 1, + TaskId("v4"): 2, + TaskId("v5"): 2, + TaskId("v6"): 3, + }, + TaskId("v1"): { + TaskId("v0"): 1, + TaskId("v1"): 0, + TaskId("v2"): 1, + TaskId("v3"): 1, + TaskId("v4"): 2, + TaskId("v5"): 1, + TaskId("v6"): 3, + }, + TaskId("v2"): { + TaskId("v0"): 2, + TaskId("v1"): 1, + TaskId("v2"): 0, + TaskId("v3"): 2, + TaskId("v4"): 2, + TaskId("v5"): 1, + TaskId("v6"): 3, + }, + TaskId("v3"): { + TaskId("v0"): 1, + TaskId("v1"): 1, + TaskId("v2"): 2, + TaskId("v3"): 0, + TaskId("v4"): 2, + TaskId("v5"): 2, + TaskId("v6"): 3, + }, + TaskId("v4"): { + TaskId("v0"): 2, + TaskId("v1"): 2, + TaskId("v2"): 2, + TaskId("v3"): 2, + TaskId("v4"): 0, + TaskId("v5"): 1, + TaskId("v6"): 1, + }, + TaskId("v5"): { + TaskId("v0"): 2, + TaskId("v1"): 1, + TaskId("v2"): 1, + TaskId("v3"): 2, + TaskId("v4"): 1, + TaskId("v5"): 0, + TaskId("v6"): 3, + }, + TaskId("v6"): { + TaskId("v0"): 3, + TaskId("v1"): 3, + TaskId("v2"): 3, + TaskId("v3"): 3, + TaskId("v4"): 1, + TaskId("v5"): 3, + TaskId("v6"): 0, + }, } assert res.distance_matrix == distance_matrix diff --git a/tests/cascade/scheduler/util.py b/tests/cascade/scheduler/util.py index 3c6f5c2..26dd079 100644 --- a/tests/cascade/scheduler/util.py +++ b/tests/cascade/scheduler/util.py @@ -22,10 +22,12 @@ from cascade.low.core import ( DatasetId, Environment, + HostId, JobExecutionRecord, JobInstance, JobInstanceRich, TaskExecutionRecord, + TaskId, Worker, WorkerId, ) @@ -61,8 +63,8 @@ class BuilderGroup: def add_large_source(builder: BuilderGroup, runtime: int, runmem: int, outsize: int) -> None: builder.job = builder.job.with_node("source", TaskBuilder.from_callable(sourceFunc)) - builder.record.tasks["source"] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) - builder.record.datasets_mb[DatasetId("source", Node.DEFAULT_OUTPUT)] = outsize + builder.record.tasks[TaskId("source")] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) + builder.record.datasets_mb[DatasetId(TaskId("source"), Node.DEFAULT_OUTPUT)] = outsize builder.layers = [1] @@ -86,8 +88,8 @@ def add_postproc( builder.job = builder.job.with_edge(e1, node, "a") builder.job = builder.job.with_edge(e2, node, "b") # print(f"adding {node} with edges {e1}, {e2}") - builder.record.tasks[node] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) - builder.record.datasets_mb[DatasetId(node, Node.DEFAULT_OUTPUT)] = outsize + builder.record.tasks[TaskId(node)] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) + builder.record.datasets_mb[DatasetId(TaskId(node), Node.DEFAULT_OUTPUT)] = outsize builder.layers.append(n) @@ -104,7 +106,7 @@ def add_sink( for i in range(builder.layers[from_layer] // frac): source = ((i * frac) + 157) % builder.layers[from_layer] builder.job = builder.job.with_edge(f"pproc{from_layer}-{source}", node, i) - builder.record.tasks[node] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) + builder.record.tasks[TaskId(node)] = TaskExecutionRecord(cpuseconds=runtime, memory_mb=runmem) def enrich_instance(job: JobInstance) -> JobInstanceRich: @@ -149,6 +151,8 @@ def get_job1() -> tuple[JobInstanceRich, JobExecutionRecord]: ## *** environment builders *** def get_env(hosts: int, workers_per_host: int) -> Environment: return Environment( - workers={WorkerId(f"h{h}", f"w{w}"): Worker(cpu=1, gpu=0, memory_mb=1000) for h in range(hosts) for w in range(workers_per_host)}, - host_url_base={f"h{h}": "tcp://localhost" for h in range(hosts)}, + workers={ + WorkerId(HostId(f"h{h}"), f"w{w}"): Worker(cpu=1, gpu=0, memory_mb=1000) for h in range(hosts) for w in range(workers_per_host) + }, + host_url_base={HostId(f"h{h}"): "tcp://localhost" for h in range(hosts)}, )