Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions integration_tests/benchmarks_old/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

from cascade.low.builders import JobBuilder, TaskBuilder
from cascade.low.core import JobInstance, SchedulingConstraint
from cascade.low.core import JobInstance, SchedulingConstraint, TaskId


def source_func() -> int:
Expand Down Expand Up @@ -112,6 +112,6 @@ def get_job() -> JobInstance:
job.nodes["sink"].definition.input_schema[f"v{i}"] = "int" # TODO put some allow_kw into TaskDefinition instead to allow this

job = job.build().get_or_raise()
job.ext_outputs = list(job.outputs_of("sink"))
job.constraints = [SchedulingConstraint(gang=[f"proc{i}" for i in range(L)])]
job.ext_outputs = list(job.outputs_of(TaskId("sink")))
job.constraints = [SchedulingConstraint(gang=[TaskId(f"proc{i}") for i in range(L)])]
return job
4 changes: 2 additions & 2 deletions integration_tests/benchmarks_old/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.random as jr # ty: ignore[unresolved-import]

from cascade.low.builders import JobBuilder, TaskBuilder
from cascade.low.core import JobInstance
from cascade.low.core import JobInstance, TaskId


def get_funcs():
Expand Down Expand Up @@ -47,7 +47,7 @@ def get_job() -> JobInstance:
prv = cur

job = job.build().get_or_raise()
job.ext_outputs = list(job.outputs_of(cur))
job.ext_outputs = list(job.outputs_of(TaskId(cur)))
return job


Expand Down
2 changes: 1 addition & 1 deletion src/cascade/controller/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def notify(
if isinstance(event, DatasetPublished):
logger.debug(f"received {event=}")
# NOTE here we'll need to distinguish memory-only and host-wide (shm) publications, currently all events mean shm
host = event.origin if isinstance(event.origin, HostId) else event.origin.host
host = cast(HostId, event.origin) if not isinstance(event.origin, WorkerId) else event.origin.host
context.host2ds[host][event.ds] = DatasetStatus.available
context.ds2host[event.ds][host] = DatasetStatus.available
state.consider_fetch(event.ds, host)
Expand Down
5 changes: 3 additions & 2 deletions src/cascade/controller/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pickle
from dataclasses import dataclass
from time import monotonic_ns
from typing import NewType

import zmq
from typing_extensions import Self
Expand All @@ -23,7 +24,7 @@

logger = logging.getLogger(__name__)

JobId = str
JobId = NewType("JobId", str)


@dataclass
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, report_address: str | None) -> None:
return
address, job_id = report_address.split(",", 1)
logger.debug(f"initialising reporter with {address=} and {job_id=}")
self.job_id = job_id
self.job_id = JobId(job_id)
self.socket = get_context().socket(zmq.PUSH)
self.socket.connect(address)

Expand Down
25 changes: 14 additions & 11 deletions src/cascade/executor/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging
import time
from typing import cast

from cascade.executor.checkpoints import build_persist_command, build_retrieve_command, serialize_params
from cascade.executor.comms import GraceWatcher, Listener, ReliableSender
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(self, controller_url: str, expected_executors: int, checkpoint_spec
logger.warning(f"double registration of {message.host}, suggesting network congestion")
continue
self.sender.add_host(message.host, message.maddress)
self.sender.add_host("data." + message.host, message.daddress)
self.sender.add_host(HostId("data." + message.host), message.daddress)
for worker in message.workers:
self.environment.workers[worker.worker_id] = Worker(cpu=worker.cpu, gpu=worker.gpu, memory_mb=worker.memory_mb)
self.environment.host_url_base[message.host] = message.url_base
Expand All @@ -97,9 +98,11 @@ def recv_events(self) -> list[Event]:
while (not events) and (not shutdown_reason):
# timeout ms matches
for message in self.mlistener.recv_messages(timeout_ms=resend_grace_ms):
if hasattr(message, "host") and isinstance((host := message.host), HostId):
if hasattr(message, "host") and isinstance(message.host, str):
host = cast(HostId, message.host)
self.heartbeat_checker[host].step()
if hasattr(message, "worker") and isinstance((worker := message.worker), WorkerId):
if hasattr(message, "worker") and isinstance(message.worker, WorkerId):
worker = message.worker
self.heartbeat_checker[worker.host].step()
if isinstance(message, Event):
events.append(message)
Expand All @@ -111,7 +114,7 @@ def recv_events(self) -> list[Event]:
logger.critical(f"received failure {message=}, proceeding with a shutdown")
if isinstance(message, ExecutorExit | ExecutorFailure) and message.host in self.sender.hosts:
self.sender.hosts.pop(message.host)
self.sender.hosts.pop("data." + message.host)
self.sender.hosts.pop(HostId("data." + message.host))
shutdown_reason = message
elif isinstance(message, Unsupported):
logger.critical(f"received unexpected {message=}, proceeding with a shutdown")
Expand Down Expand Up @@ -150,32 +153,32 @@ def purge(self, host: HostId, ds: DatasetId) -> None:
def transmit(self, ds: DatasetId, source: HostId, target: HostId) -> None:
if source == VirtualCheckpointHost:
command = build_retrieve_command(self.checkpoint_spec, ds, target)
self.sender.send("data." + target, command)
self.sender.send(HostId("data." + target), command)
else:
m = DatasetTransmitCommand(
source=source,
target=target,
daddress=self.sender.hosts["data." + target][1],
daddress=self.sender.hosts[HostId("data." + target)][1],
ds=ds,
idx=self.transmit_idx_counter,
)
self.transmit_idx_counter += 1
self.sender.send("data." + source, m)
self.sender.send(HostId("data." + source), m)

def persist(self, ds: DatasetId, source: HostId) -> None:
command = build_persist_command(self.checkpoint_spec, ds, source)
self.sender.send("data." + source, command)
self.sender.send(HostId("data." + source), command)

def fetch(self, ds: DatasetId, source: HostId) -> None:
m = DatasetTransmitCommand(
source=source,
target="controller",
target=HostId("controller"),
daddress=self.mlistener.address,
ds=ds,
idx=self.transmit_idx_counter,
)
self.transmit_idx_counter += 1
self.sender.send("data." + source, m)
self.sender.send(HostId("data." + source), m)

def shutdown(self) -> None:
m = ExecutorShutdown()
Expand All @@ -189,7 +192,7 @@ def shutdown(self) -> None:
if isinstance(message, ExecutorExit | ExecutorFailure):
if message.host in self.sender.hosts:
self.sender.hosts.pop(message.host)
self.sender.hosts.pop("data." + message.host)
self.sender.hosts.pop(HostId("data." + message.host))
else:
logger.warning(f"ignoring {type(message)}")
if self.sender.hosts:
Expand Down
6 changes: 3 additions & 3 deletions src/cascade/executor/data_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
Syn,
)
from cascade.executor.runner.memory import ds2shmid
from cascade.low.core import DatasetId
from cascade.low.core import DatasetId, HostId
from cascade.low.exceptions import CascadeError, CascadeInfrastructureError, CascadeInternalError, ser
from cascade.low.func import assert_never
from cascade.low.tracing import TransmitLifecycle, label, mark
Expand All @@ -54,7 +54,7 @@ def __init__(
self,
maddress: BackboneAddress,
daddress: BackboneAddress,
host: str,
host: HostId,
logging_config: dict,
):
logging.config.dictConfig(logging_config)
Expand Down Expand Up @@ -347,7 +347,7 @@ def recv_loop(self) -> None:
def start_data_server(
maddress: BackboneAddress,
daddress: BackboneAddress,
host: str,
host: HostId,
logging_config: dict,
):
server = DataServer(maddress, daddress, host, logging_config)
Expand Down
4 changes: 2 additions & 2 deletions src/cascade/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
# NOTE following inits are with potential side effects
self.mlistener = Listener(address_of(portBase))
self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
self.sender.add_host("controller", controller_address)
self.sender.add_host(HostId("controller"), controller_address)
# TODO make the shm server params configurable
shm_port = f"/tmp/cascShmSock-{uuid.uuid4()}" # portBase + 2
shm_api.publish_socket_addr(shm_port)
Expand Down Expand Up @@ -191,7 +191,7 @@ def terminate(self) -> None:

def to_controller(self, m: Message) -> None:
self.heartbeat_watcher.step()
self.sender.send("controller", m)
self.sender.send(HostId("controller"), m)

def _start_worker(self, worker: WorkerId, attempt_cnt: int, seq: None | TaskSequence) -> WorkerHandle:
ctx = platform.get_mp_ctx("worker")
Expand Down
6 changes: 3 additions & 3 deletions src/cascade/gateway/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(
max_jobs: int | None,
):
self.poller = poller
self.jobs: dict[str, Job] = {}
self.jobs: dict[JobId, Job] = {}
self.active_jobs = 0
self.max_jobs = max_jobs
self.jobs_queue: OrderedDict[JobId, JobSpec] = OrderedDict()
self.procs: dict[str, subprocess.Popen] = {}
self.procs: dict[JobId, subprocess.Popen] = {}
self.loggingConfig = loggingConfig
self.troika_config = troika_config

Expand All @@ -85,7 +85,7 @@ def maybe_spawn(self) -> None:
def enqueue_job(self, job_spec: JobSpec) -> JobId:
job_id = next_uuid(
set(self.jobs.keys()).union(self.jobs_queue.keys()),
lambda: str(uuid.uuid4()),
lambda: JobId(str(uuid.uuid4())),
)
self.jobs_queue[job_id] = job_spec
self.maybe_spawn()
Expand Down
4 changes: 2 additions & 2 deletions src/cascade/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import zmq

import cascade.gateway.api as api
from cascade.controller.report import JobProgress, deserialize
from cascade.controller.report import JobId, JobProgress, deserialize
from cascade.deployment.logging import LoggingConfig, init_from_obj
from cascade.executor.comms import get_context
from cascade.gateway.client import parse_request, serialize_response
Expand All @@ -44,7 +44,7 @@ def handle_fe(socket: zmq.Socket, jobs: JobRouter) -> bool:
try:
progresses, datasets, queue_length = jobs.progress_of(m.job_ids)
rv = api.JobProgressResponse(
progresses=cast(dict[str, JobProgress | None], progresses),
progresses=cast(dict[JobId, JobProgress | None], progresses),
datasets=datasets,
error=None,
queue_length=queue_length,
Expand Down
9 changes: 5 additions & 4 deletions src/cascade/low/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
JobInstance,
Task2TaskEdge,
TaskDefinition,
TaskId,
TaskInstance,
)
from cascade.low.func import Either
Expand Down Expand Up @@ -100,12 +101,12 @@ def with_node(self, name: str, task: TaskInstance) -> Self:
return replace(self, nodes=self.nodes.set(name, task))

def with_output(self, task: str, output: str = DefaultTaskOutput) -> Self:
return replace(self, outputs=self.outputs.append(DatasetId(task, output)))
return replace(self, outputs=self.outputs.append(DatasetId(TaskId(task), output)))

def with_edge(self, source: str, sink: str, into: str | int, frum: str = DefaultTaskOutput) -> Self:
new_edge = Task2TaskEdge(
source=DatasetId(source, frum),
sink_task=sink,
source=DatasetId(TaskId(source), frum),
sink_task=TaskId(sink),
sink_input_kw=into if isinstance(into, str) else None,
sink_input_ps=into if isinstance(into, int) else None,
)
Expand Down Expand Up @@ -193,7 +194,7 @@ def get_edge_errors(edge: Task2TaskEdge) -> Iterator[str]:
else:
return Either.ok(
JobInstance(
tasks=cast(dict[str, TaskInstance], pyrsistent.thaw(self.nodes)),
tasks=cast(dict[TaskId, TaskInstance], pyrsistent.thaw(self.nodes)),
edges=pyrsistent.thaw(self.edges),
ext_outputs=pyrsistent.thaw(self.outputs),
)
Expand Down
12 changes: 6 additions & 6 deletions src/cascade/low/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from base64 import b64decode, b64encode
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Type, cast
from typing import Any, Callable, Literal, NewType, Optional, Type, cast

import cloudpickle
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -58,7 +58,7 @@ def func_enc(f: Callable) -> str:
return b64encode(cloudpickle.dumps(f)).decode("ascii")


TaskId = str
TaskId = NewType("TaskId", str)


@dataclass(frozen=True)
Expand All @@ -80,7 +80,7 @@ def ser(self) -> str:
def des(cls, value: str) -> Self:
pref, suf = value.split(".", 1)
task_len = int.from_bytes(bytes.fromhex(pref), byteorder="big")
return cls(task=suf[:task_len], output=suf[task_len:])
return cls(task=TaskId(suf[:task_len]), output=suf[task_len:])


class Task2TaskEdge(BaseModel):
Expand Down Expand Up @@ -138,7 +138,7 @@ def outputs_of(self, task_id: TaskId) -> set[DatasetId]:
return {DatasetId(task_id, output) for output, _ in self.tasks[task_id].definition.output_schema}


HostId = str
HostId = NewType("HostId", str)


@dataclass(frozen=True)
Expand All @@ -152,7 +152,7 @@ def __repr__(self) -> str:
@classmethod
def from_repr(cls, value: str) -> Self:
host, worker = value.split(".", 1)
return cls(host=host, worker=worker)
return cls(host=HostId(host), worker=worker)

def worker_num(self) -> int:
"""Used eg for gpu allocation"""
Expand Down Expand Up @@ -193,7 +193,7 @@ class JobExecutionRecord(BaseModel):


CheckpointStorageType = Literal["fs"]
StorageId = str
StorageId = NewType("StorageId", str)


class CheckpointSpec(BaseModel):
Expand Down
16 changes: 8 additions & 8 deletions src/cascade/low/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dask._task_spec import Alias, DataNode, Task, TaskRef

from cascade.low.builders import TaskBuilder
from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskInstance
from cascade.low.core import DatasetId, DefaultTaskOutput, JobInstance, Task2TaskEdge, TaskId, TaskInstance
from cascade.low.exceptions import CascadeUserError

logger = logging.getLogger(__name__)
Expand All @@ -38,8 +38,8 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]:
for i, v in enumerate(task.args):
if isinstance(v, Alias | TaskRef):
edge = Task2TaskEdge(
source=DatasetId(task=daskKeyRepr(v.key), output=DefaultTaskOutput),
sink_task=key,
source=DatasetId(task=TaskId(daskKeyRepr(v.key)), output=DefaultTaskOutput),
sink_task=TaskId(key),
sink_input_ps=i,
sink_input_kw=None,
)
Expand All @@ -52,8 +52,8 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]:
for k, v in task.kwargs.items():
if isinstance(v, Alias | TaskRef):
edge = Task2TaskEdge(
source=DatasetId(task=daskKeyRepr(v.key), output=DefaultTaskOutput),
sink_task=key,
source=DatasetId(task=TaskId(daskKeyRepr(v.key)), output=DefaultTaskOutput),
sink_task=TaskId(key),
sink_input_kw=k,
sink_input_ps=None,
)
Expand All @@ -68,7 +68,7 @@ def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]:


def graph2job(dask: dict) -> JobInstance:
task_nodes = {}
task_nodes: dict[TaskId, TaskInstance] = {}
edges = []

for node, value in dask.items():
Expand All @@ -78,10 +78,10 @@ def graph2job(dask: dict) -> JobInstance:
def provider() -> Any:
return value.value

task_nodes[key] = TaskBuilder.from_callable(provider)
task_nodes[TaskId(key)] = TaskBuilder.from_callable(provider)
elif isinstance(value, Task):
node, _edges = task2task(key, value)
task_nodes[key] = node
task_nodes[TaskId(key)] = node
edges.extend(_edges)
elif isinstance(value, list | tuple | set):
# TODO implement, consult further:
Expand Down
2 changes: 1 addition & 1 deletion src/cascade/low/execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TaskStatus(int, Enum):
failed = 3 # set by executor


VirtualCheckpointHost: HostId = "virtualCheckpointHost"
VirtualCheckpointHost: HostId = HostId("virtualCheckpointHost")


@dataclass
Expand Down
Loading
Loading