Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add session context support for Ray DAG mode #3358

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mars/deploy/oscar/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@
client = new_cluster_in_ray(backend=backend, **new_cluster_kwargs)
session_id = session_id or client.session.session_id
address = client.address
logger.warning("CLIENT ADDRESS: %s", address)

Check warning on line 392 in mars/deploy/oscar/ray.py

View check run for this annotation

Codecov / codecov/patch

mars/deploy/oscar/ray.py#L392

Added line #L392 was not covered by tests
session = new_session(
address=address, session_id=session_id, backend=backend, default=default
)
Expand Down
2 changes: 2 additions & 0 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ async def fetch(self, *tileables, **kwargs) -> list:
chunks, chunk_metas, itertools.chain(*fetch_infos_list)
):
await fetcher.append(chunk.key, meta, fetch_info.indexes)
logger.warning("FETCH!! %r", fetcher)
fetched_data = await fetcher.get()
logger.warning("FETCH2!!")
for fetch_info, data in zip(
itertools.chain(*fetch_infos_list), fetched_data
):
Expand Down
31 changes: 31 additions & 0 deletions mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import copy
import logging
import os
import time

import pytest

from .... import get_context
from .... import remote as mr
from .... import tensor as mt
from ....session import new_session, get_default_async_session
from ....tests import test_session
Expand Down Expand Up @@ -125,6 +127,35 @@ def test_sync_execute(ray_start_regular_shared2, config):
test_local.test_sync_execute(config)


@require_ray
@pytest.mark.parametrize("config", [{"backend": "ray"}])
def test_spawn_execution(ray_start_regular_shared2, config):
session = new_session(
backend=config["backend"],
n_cpu=2,
web=False,
use_uvloop=False,
config={"task.execution_config.ray.monitor_interval_seconds": 0},
)

assert session._session.client.web_address is None
assert session.get_web_endpoint() is None

def f1(c=0):
if c:
executed = mr.spawn(f1).execute()
logging.warning("EXECUTE DONE!")
executed.fetch()
logging.warning("FETCH DONE!")
return c

with session:
assert 10 == mr.spawn(f1, 10).execute().fetch()

session.stop_server()
assert get_default_async_session() is None


@require_ray
@pytest.mark.parametrize(
"create_cluster",
Expand Down
8 changes: 8 additions & 0 deletions mars/lib/aio/isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import asyncio
import atexit
import logging
import threading
from typing import Dict, Optional

logger = logging.getLogger(__name__)


class Isolation:
loop: asyncio.AbstractEventLoop
Expand All @@ -31,6 +34,9 @@
self._thread = None
self._thread_ident = None

def __repr__(self):
return f"<Isolation loop={id(self.loop)}{self.loop!r} threaded={self._threaded} thread_ident={self._thread_ident}>"

def _run(self):
asyncio.set_event_loop(self.loop)
self._stopped = asyncio.Event()
Expand Down Expand Up @@ -72,9 +78,11 @@

if loop is None:
loop = asyncio.new_event_loop()
logger.warning("NEW_LOOP %d", id(loop))

Check warning on line 81 in mars/lib/aio/isolation.py

View check run for this annotation

Codecov / codecov/patch

mars/lib/aio/isolation.py#L81

Added line #L81 was not covered by tests

isolation = Isolation(loop, threaded=threaded)
isolation.start()
logger.warning("NEW_ISOLATION! loop: %r", loop)
_name_to_isolation[name] = isolation
return isolation

Expand Down
4 changes: 3 additions & 1 deletion mars/services/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
local_address: str,
loop: asyncio.AbstractEventLoop,
band: BandType = None,
isolation_threaded: bool = False,
):
super().__init__(
session_id=session_id,
Expand All @@ -59,7 +60,8 @@ def __init__(
# new isolation with current loop,
# so that session created in tile and execute
# can get the right isolation
new_isolation(loop=self._loop, threaded=False)
logger.warning("NEW_ISOLATION in ThreadedServiceContext.__init__")
new_isolation(loop=self._loop, threaded=isolation_threaded)

self._running_session_id = None
self._running_op_key = None
Expand Down
23 changes: 19 additions & 4 deletions mars/services/task/execution/ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from typing import Dict, List, Callable

from .....core.context import Context
from .....session import ensure_isolation_created
from .....storage.base import StorageLevel
from .....typing import ChunkType
from .....typing import ChunkType, SessionType
from .....utils import implements, lazy_import, sync_to_async
from ....context import ThreadedServiceContext
from .config import RayExecutionConfig
Expand Down Expand Up @@ -187,13 +188,27 @@


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
class RayExecutionWorkerContext(_RayRemoteObjectContext, ThreadedServiceContext, dict):
"""The context for executing operands."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
get_or_create_actor: Callable[[], "ray.actor.ActorHandle"],
*args,
**kwargs,
):
_RayRemoteObjectContext.__init__(self, get_or_create_actor, *args, loop=None, isolation_threaded=True, **kwargs)
dict.__init__(self)

Check warning on line 201 in mars/services/task/execution/ray/context.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/context.py#L200-L201

Added lines #L200 - L201 were not covered by tests
self._current_chunk = None

@implements(Context.get_current_session)
def get_current_session(self) -> SessionType:
from .....session import new_session

Check warning on line 206 in mars/services/task/execution/ray/context.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/context.py#L206

Added line #L206 was not covered by tests

return new_session(

Check warning on line 208 in mars/services/task/execution/ray/context.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/context.py#L208

Added line #L208 was not covered by tests
self.supervisor_address, self.session_id, backend="ray", new=False, default=False
)

@classmethod
@implements(Context.new_custom_log_dir)
def new_custom_log_dir(cls):
Expand Down
64 changes: 49 additions & 15 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import itertools
import logging
import operator
import os
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Callable

import numpy as np

from .....core import ChunkGraph, Chunk, TileContext
from .....core.context import set_context
from .....core.context import set_context, get_context
from .....core.operand import (
Fetch,
Fuse,
Expand All @@ -38,6 +39,7 @@
from .....metrics.api import init_metrics, Metrics
from .....resource import Resource
from .....serialization import serialize, deserialize
from .....session import AbstractSession, get_default_session
from .....typing import BandType
from .....utils import (
aiotask_wrapper,
Expand Down Expand Up @@ -149,10 +151,12 @@


def execute_subtask(
session_id: str,
subtask_id: str,
subtask_chunk_graph: ChunkGraph,
output_meta_n_keys: int,
is_mapper,
address: str,
*inputs,
):
"""
Expand All @@ -176,6 +180,9 @@
-------
subtask outputs and meta for outputs if `output_meta_keys` is provided.
"""
logging.basicConfig(level=logging.INFO)
logger.setLevel(logging.INFO)

Check warning on line 184 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L183-L184

Added lines #L183 - L184 were not covered by tests

init_metrics("ray")
started_subtask_number.record(1)
ray_task_id = ray.get_runtime_context().get_task_id()
Expand All @@ -184,7 +191,16 @@
# Optimize chunk graph.
subtask_chunk_graph = _optimize_subtask_graph(subtask_chunk_graph)
fetch_chunks, shuffle_fetch_chunk = _get_fetch_chunks(subtask_chunk_graph)
context = RayExecutionWorkerContext(RayTaskState.get_handle)

context = RayExecutionWorkerContext(

Check warning on line 195 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L195

Added line #L195 was not covered by tests
RayTaskState.get_handle,
session_id,
address,
address,
address,
)
set_context(context)

Check warning on line 202 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L202

Added line #L202 was not covered by tests

if shuffle_fetch_chunk is not None:
# The subtask is a reducer subtask.
n_mappers = shuffle_fetch_chunk.op.n_mappers
Expand All @@ -209,19 +225,28 @@
# Update non shuffle inputs to context.
context.update(zip((start_chunk.key for start_chunk in fetch_chunks), inputs))

for chunk in subtask_chunk_graph.topological_iter():
if chunk.key not in context:
try:
context.set_current_chunk(chunk)
execute(context, chunk.op)
except Exception:
logger.exception(
"Execute operand %s of graph %s failed.",
chunk.op,
subtask_chunk_graph.to_dot(),
)
raise
subtask_gc.gc_inputs(chunk)
default_session = get_default_session()
try:
context.get_current_session().as_default()

Check warning on line 230 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L228-L230

Added lines #L228 - L230 were not covered by tests

for chunk in subtask_chunk_graph.topological_iter():
if chunk.key not in context:
try:
context.set_current_chunk(chunk)
execute(context, chunk.op)
except Exception:
logger.exception(

Check warning on line 238 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L234-L238

Added lines #L234 - L238 were not covered by tests
"Execute operand %s of graph %s failed.",
chunk.op,
subtask_chunk_graph.to_dot(),
)
raise
subtask_gc.gc_inputs(chunk)

Check warning on line 244 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L243-L244

Added lines #L243 - L244 were not covered by tests
finally:
if default_session is not None:
default_session.as_default()
else:
AbstractSession.reset_default()

# For non-mapper subtask, output context is chunk key to results.
# For mapper subtasks, output context is data key to results.
Expand Down Expand Up @@ -455,6 +480,7 @@
task_chunks_meta: Dict[str, _RayChunkMeta],
lifecycle_api: LifecycleAPI,
meta_api: MetaAPI,
address: str,
):
logger.info(
"Start task %s with GC method %s.",
Expand All @@ -475,6 +501,8 @@
self._available_band_resources = None
self._result_tileables_lifecycle = None

self._address = address

Check warning on line 504 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L504

Added line #L504 was not covered by tests

# For progress and task cancel
self._stage_index = 0
self._pre_all_stages_progress = 0.0
Expand Down Expand Up @@ -507,6 +535,7 @@
task_chunks_meta,
lifecycle_api,
meta_api,
address,
)
available_band_resources = await executor.get_available_band_resources()
worker_addresses = list(
Expand Down Expand Up @@ -710,10 +739,12 @@
memory=subtask_memory,
scheduling_strategy="DEFAULT" if len(input_object_refs) else "SPREAD",
).remote(
subtask.session_id,
subtask.subtask_id,
serialize(subtask_chunk_graph, context={"serializer": "ray"}),
subtask.stage_n_outputs,
is_mapper,
self._address,
*input_object_refs,
)
await asyncio.sleep(0)
Expand All @@ -739,6 +770,7 @@
task_context[chunk_key] = object_ref
logger.info("Submitted %s subtasks of stage %s.", len(subtask_graph), stage_id)

logger.warning("SUBTASK_RUN_1")

Check warning on line 773 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L773

Added line #L773 was not covered by tests
monitor_context.stage = _RayExecutionStage.WAITING
key_to_meta = {}
if len(output_meta_object_refs) > 0:
Expand All @@ -752,6 +784,7 @@
self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size)
logger.info("Got %s metas of stage %s.", meta_count, stage_id)

logger.warning("SUBTASK_RUN_2")

Check warning on line 787 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L787

Added line #L787 was not covered by tests
chunk_to_meta = {}
# ray.wait requires the object ref list is unique.
output_object_refs = set()
Expand All @@ -773,6 +806,7 @@
await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False)

logger.info("Complete stage %s.", stage_id)
logger.warning("%d: SUBTASK_RUN_3: %r", os.getpid(), output_object_refs)

Check warning on line 809 in mars/services/task/execution/ray/executor.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/executor.py#L809

Added line #L809 was not covered by tests
return chunk_to_meta

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
14 changes: 11 additions & 3 deletions mars/services/task/execution/ray/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

import asyncio
import functools
import logging
from collections import namedtuple
from typing import Dict, List

from .....utils import lazy_import
from ..api import Fetcher, register_fetcher_cls

logger = logging.getLogger(__name__)

ray = lazy_import("ray")
_FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"])

Expand All @@ -36,9 +39,10 @@
name = "ray"
required_meta_keys = ("object_refs",)

def __init__(self, **kwargs):
def __init__(self, loop=None, **kwargs):
self._fetch_info_list = []
self._no_conditions = True
self._loop = loop

Check warning on line 45 in mars/services/task/execution/ray/fetcher.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/fetcher.py#L45

Added line #L45 was not covered by tests

@staticmethod
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
Expand All @@ -55,9 +59,12 @@

async def get(self):
if self._no_conditions:
logger.warning(f"FETCHER_0 {self._fetch_info_list}")

Check warning on line 62 in mars/services/task/execution/ray/fetcher.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/fetcher.py#L62

Added line #L62 was not covered by tests
return await asyncio.gather(
*(info.object_ref for info in self._fetch_info_list)
*(info.object_ref for info in self._fetch_info_list),
loop=self._loop,
)
logger.warning("FETCHER_1")

Check warning on line 67 in mars/services/task/execution/ray/fetcher.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/fetcher.py#L67

Added line #L67 was not covered by tests
refs = [None] * len(self._fetch_info_list)
for index, fetch_info in enumerate(self._fetch_info_list):
if fetch_info.conditions is None:
Expand All @@ -66,4 +73,5 @@
refs[index] = self._remote_query_object_with_condition().remote(
fetch_info.object_ref, tuple(fetch_info.conditions)
)
return await asyncio.gather(*refs)
logger.warning("FETCHER_2")
return await asyncio.gather(*refs, loop=self._loop)

Check warning on line 77 in mars/services/task/execution/ray/fetcher.py

View check run for this annotation

Codecov / codecov/patch

mars/services/task/execution/ray/fetcher.py#L76-L77

Added lines #L76 - L77 were not covered by tests
Loading
Loading