Skip to content

Commit

Permalink
Make push mapper data work
Browse files Browse the repository at this point in the history
  • Loading branch information
xuye.qin committed Jun 8, 2022
1 parent b767644 commit c52c107
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 31 deletions.
80 changes: 62 additions & 18 deletions mars/services/subtask/worker/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
import sys
import time
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Type
from typing import Any, Dict, List, Optional, Set, Type, Tuple

from .... import oscar as mo
from ....core import ChunkGraph, OperandType, enter_mode, ExecutionError
from ....core.context import get_context, set_context
from ....core.operand import Fetch, FetchShuffle, execute
from ....core.operand import (
Fetch,
FetchShuffle,
execute,
)
from ....lib.aio import alru_cache
from ....metrics import Metrics
from ....optimization.physical import optimize
from ....typing import BandType, ChunkType
Expand Down Expand Up @@ -420,26 +425,56 @@ async def set_chunks_meta():
# set result data size
self.result.data_size = result_data_size

async def _push_mapper_data(self, chunk_graph):
# TODO: use task api to get reducer bands
reducer_idx_to_band = dict()
if not reducer_idx_to_band:
return
@classmethod
@alru_cache(cache_exceptions=False)
async def _gen_reducer_index_to_bands(
cls, session_id: str, supervisor_address: str, task_id: str, map_reduce_id: int
) -> Dict[Tuple[int], BandType]:
task_api = await TaskAPI.create(session_id, supervisor_address)
map_reduce_info = await task_api.get_map_reduce_info(task_id, map_reduce_id)
assert len(map_reduce_info.reducer_indexes) == len(
map_reduce_info.reducer_bands
)
return {
reducer_index: band
for reducer_index, band in zip(
map_reduce_info.reducer_indexes, map_reduce_info.reducer_bands
)
}

async def _push_mapper_data(self):
storage_api_to_fetch_tasks = defaultdict(list)
for result_chunk in chunk_graph.result_chunks:
key = result_chunk.key
reducer_idx = key[1]
if isinstance(key, tuple):
skip = True
for result_chunk in self._chunk_graph.result_chunks:
map_reduce_id = getattr(result_chunk.op, "extra_params", dict()).get(
"analyzer_map_reduce_id"
)
if map_reduce_id is None:
continue
skip = False
reducer_index_to_bands = await self._gen_reducer_index_to_bands(
self._session_id,
self._supervisor_address,
self.subtask.task_id,
map_reduce_id,
)
for reducer_index, band in reducer_index_to_bands.items():
# mapper key is a tuple
address, band_name = reducer_idx_to_band[reducer_idx]
storage_api = StorageAPI(address, self._session_id, band_name)
address, band_name = band
storage_api = await StorageAPI.create(
self._session_id, address, band_name
)
fetch_task = storage_api.fetch.delay(
key, band_name=self._band[1], remote_address=self._band[0]
(result_chunk.key, reducer_index),
band_name=self._band[1],
remote_address=self._band[0],
)
storage_api_to_fetch_tasks[storage_api].append(fetch_task)
if skip:
return
batch_tasks = []
for storage_api, tasks in storage_api_to_fetch_tasks.items():
batch_tasks.append(asyncio.create_task(storage_api.fetch.batch(*tasks)))
batch_tasks.append(storage_api.fetch.batch(*tasks))
await asyncio.gather(*batch_tasks)

async def done(self):
Expand Down Expand Up @@ -513,8 +548,6 @@ async def run(self):
await self._unpin_data(input_keys)

await self.done()
# after done, we push mapper data to reducers in advance.
await self.ref()._push_mapper_data.tell(chunk_graph)
if self.result.status == SubtaskStatus.succeeded:
cost_time_secs = (
self.result.execution_end_time - self.result.execution_start_time
Expand All @@ -536,6 +569,9 @@ async def run(self):
pass
return self.result

async def post_run(self):
await self._push_mapper_data()

async def report_progress_periodically(self, interval=0.5, eps=0.001):
last_progress = self.result.progress
while not self.result.status.is_done:
Expand Down Expand Up @@ -618,7 +654,7 @@ async def _init_context(self, session_id: str):
await context.init()
set_context(context)

async def run(self, subtask: Subtask):
async def run(self, subtask: Subtask, wait_post_run: bool = False):
logger.info(
"Start to run subtask: %r on %s. chunk graph contains %s",
subtask,
Expand All @@ -644,10 +680,18 @@ async def run(self, subtask: Subtask):
try:
result = yield self._running_aio_task
logger.info("Finished subtask: %s", subtask.subtask_id)
# post run with actor tell which will not block
if not wait_post_run:
await self.ref().post_run.tell(processor)
else:
await self.post_run(processor)
raise mo.Return(result)
finally:
self._processor = self._running_aio_task = None

async def post_run(self, processor: SubtaskProcessor):
await processor.post_run()

async def wait(self):
return self._processor.is_done.wait()

Expand Down
6 changes: 4 additions & 2 deletions mars/services/subtask/worker/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def _get_supervisor_address(self, session_id: str):
[address] = await self._cluster_api.get_supervisors_by_keys([session_id])
return address

async def run_subtask(self, subtask: Subtask):
async def run_subtask(self, subtask: Subtask, wait_post_run: bool = False):
if self._running_processor is not None: # pragma: no cover
running_subtask_id = await self._running_processor.get_running_subtask_id()
# current subtask is still running
Expand Down Expand Up @@ -122,7 +122,9 @@ async def run_subtask(self, subtask: Subtask):
processor = self._session_id_to_processors[session_id]
try:
self._running_processor = self._last_processor = processor
result = yield self._running_processor.run(subtask)
result = yield self._running_processor.run(
subtask, wait_post_run=wait_post_run
)
finally:
self._running_processor = None
raise mo.Return(result)
Expand Down
47 changes: 45 additions & 2 deletions mars/services/subtask/worker/tests/test_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import time

import numpy as np
import pandas as pd
import pytest

from ..... import oscar as mo
from ..... import dataframe as md
from ..... import tensor as mt
from ..... import remote as mr
from .....core import ExecutionError
from .....core import ExecutionError, ChunkGraph
from .....core.context import get_context
from .....core.graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
from .....core.operand import OperandStage
from .....resource import Resource
from .....utils import Timer
from ....cluster import MockClusterAPI
Expand All @@ -34,7 +37,7 @@
from ....scheduling import MockSchedulingAPI
from ....session import MockSessionAPI
from ....storage import MockStorageAPI
from ....task import new_task_id
from ....task import new_task_id, MapReduceInfo
from ....task.supervisor.manager import TaskManagerActor, TaskConfigurationActor
from ....mutable import MockMutableAPI
from ... import Subtask, SubtaskStatus, SubtaskResult
Expand All @@ -46,6 +49,13 @@ class FakeTaskManager(TaskManagerActor):
def set_subtask_result(self, subtask_result: SubtaskResult):
return

def get_map_reduce_info(self, task_id: str, map_reduce_id: int) -> MapReduceInfo:
return MapReduceInfo(
map_reduce_id=0,
reducer_indexes=[(0, 0)],
reducer_bands=[(self.address, "numa-0")],
)


@pytest.fixture
async def actor_pool():
Expand Down Expand Up @@ -142,6 +152,39 @@ async def test_subtask_success(actor_pool):
assert await subtask_runner.is_runner_free() is True


@pytest.mark.asyncio
async def test_shuffle_subtask(actor_pool):
pool, session_id, meta_api, storage_api, manager = actor_pool

pdf = pd.DataFrame({"f1": ["a", "b", "a"], "f2": [1, 2, 3]})
df = md.DataFrame(pdf)
result = df.groupby("f1").sum(method="shuffle")

graph = TileableGraph([result.data])
next(TileableGraphBuilder(graph).build())
chunk_graph = next(ChunkGraphBuilder(graph, fuse_enabled=False).build())
result_chunks = []
new_chunk_graph = ChunkGraph(result_chunks)
chunk_graph_iter = chunk_graph.topological_iter()
curr = None
for _ in range(3):
prev = curr
curr = next(chunk_graph_iter)
new_chunk_graph.add_node(curr)
if prev is not None:
new_chunk_graph.add_edge(prev, curr)
assert curr.op.stage == OperandStage.map
curr.op.extra_params = {"analyzer_map_reduce_id": 0}
result_chunks.append(curr)
subtask = Subtask(new_task_id(), session_id, new_task_id(), new_chunk_graph)
subtask_runner: SubtaskRunnerRef = await mo.actor_ref(
SubtaskRunnerActor.gen_uid("numa-0", 0), address=pool.external_address
)
await subtask_runner.run_subtask(subtask, wait_post_run=True)
result = await subtask_runner.get_subtask_result()
assert result.status == SubtaskStatus.succeeded


@pytest.mark.asyncio
async def test_subtask_failure(actor_pool):
pool, session_id, meta_api, storage_api, manager = actor_pool
Expand Down
2 changes: 1 addition & 1 deletion mars/services/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

from .api import AbstractTaskAPI, TaskAPI, WebTaskAPI
from .config import task_options
from .core import Task, TaskStatus, TaskResult, new_task_id
from .core import Task, TaskStatus, TaskResult, new_task_id, MapReduceInfo
from .errors import TaskNotExist
48 changes: 47 additions & 1 deletion mars/services/task/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
LogicKeyGenerator,
MapReduceOperand,
OperandStage,
ShuffleProxy,
)
from ....lib.ordered_set import OrderedSet
from ....resource import Resource
from ....typing import BandType, OperandType
from ....utils import build_fetch, tokenize
from ...subtask import SubtaskGraph, Subtask
from ..core import Task, new_task_id
from ..core import Task, new_task_id, MapReduceInfo
from .assigner import AbstractGraphAssigner, GraphAssigner
from .fusion import Coloring

Expand All @@ -50,6 +52,8 @@ def need_reassign_worker(op: OperandType) -> bool:


class GraphAnalyzer:
_map_reduce_id = itertools.count()

def __init__(
self,
chunk_graph: ChunkGraph,
Expand All @@ -59,6 +63,7 @@ def __init__(
chunk_to_subtasks: Dict[ChunkType, Subtask],
graph_assigner_cls: Type[AbstractGraphAssigner] = None,
stage_id: str = None,
map_reduce_id_to_infos: Dict[int, MapReduceInfo] = None,
):
self._chunk_graph = chunk_graph
self._band_resource = band_resource
Expand All @@ -68,12 +73,17 @@ def __init__(
self._fuse_enabled = task.fuse_enabled
self._extra_config = task.extra_config
self._chunk_to_subtasks = chunk_to_subtasks
self._map_reduce_id_to_infos = map_reduce_id_to_infos
if graph_assigner_cls is None:
graph_assigner_cls = GraphAssigner
self._graph_assigner_cls = graph_assigner_cls
self._chunk_to_copied = dict()
self._logic_key_generator = LogicKeyGenerator()

@classmethod
def next_map_reduce_id(cls) -> int:
return next(cls._map_reduce_id)

@classmethod
def _iter_start_ops(cls, chunk_graph: ChunkGraph):
visited = set()
Expand Down Expand Up @@ -300,6 +310,38 @@ def _gen_logic_key(self, chunks: List[ChunkType]):
*[self._logic_key_generator.get_logic_key(chunk.op) for chunk in chunks]
)

def _gen_map_reduce_info(
self, chunk: ChunkType, assign_results: Dict[ChunkType, BandType]
):
reducer_ops = OrderedSet(
[
c.op
for c in self._chunk_graph.successors(chunk)
if c.op.stage == OperandStage.reduce
]
)
map_chunks = [
c
for c in self._chunk_graph.predecessors(chunk)
if c.op.stage == OperandStage.map
]
map_reduce_id = self.next_map_reduce_id()
for map_chunk in map_chunks:
# record analyzer map reduce id for mapper op
# copied chunk exists because map chunk must have
# been processed before shuffle proxy
copied_map_chunk_op = self._chunk_to_copied[map_chunk].op
if not hasattr(copied_map_chunk_op, "extra_params"):
copied_map_chunk_op.extra_params = dict()
copied_map_chunk_op.extra_params["analyzer_map_reduce_id"] = map_reduce_id
reducer_bands = [assign_results[r.outputs[0]] for r in reducer_ops]
map_reduce_info = MapReduceInfo(
map_reduce_id=map_reduce_id,
reducer_indexes=[reducer_op.reducer_index for reducer_op in reducer_ops],
reducer_bands=reducer_bands,
)
self._map_reduce_id_to_infos[map_reduce_id] = map_reduce_info

@enter_mode(build=True)
def gen_subtask_graph(
self, op_to_bands: Dict[str, BandType] = None
Expand Down Expand Up @@ -420,6 +462,10 @@ def gen_subtask_graph(

for c in same_color_chunks:
chunk_to_subtask[c] = subtask
if self._map_reduce_id_to_infos is not None and isinstance(
chunk.op, ShuffleProxy
):
self._gen_map_reduce_info(chunk, chunk_to_bands)
visited.update(same_color_chunks)

for subtasks in logic_key_to_subtasks.values():
Expand Down
7 changes: 6 additions & 1 deletion mars/services/task/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ....core import Tileable
from ....lib.aio import alru_cache
from ...subtask import SubtaskResult
from ..core import TileableGraph, TaskResult
from ..core import TileableGraph, TaskResult, MapReduceInfo
from ..supervisor.manager import TaskManagerActor
from .core import AbstractTaskAPI

Expand Down Expand Up @@ -104,3 +104,8 @@ async def set_subtask_result(self, subtask_result: SubtaskResult):

async def get_last_idle_time(self) -> Union[float, None]:
return await self._task_manager_ref.get_last_idle_time()

async def get_map_reduce_info(
self, task_id: str, map_reduce_id: int
) -> MapReduceInfo:
return await self._task_manager_ref.get_map_reduce_info(task_id, map_reduce_id)
Loading

0 comments on commit c52c107

Please sign in to comment.