From 2966c8e9b9a95fc1cf02279a2d60490c0c9ba37b Mon Sep 17 00:00:00 2001 From: "Xuye (Chris) Qin" Date: Wed, 27 Apr 2022 11:20:25 +0800 Subject: [PATCH] Support reporting tile progress (#2954) --- mars/core/__init__.py | 2 + mars/core/base.py | 11 +- mars/core/entity/utils.py | 9 +- mars/core/graph/__init__.py | 2 +- mars/core/graph/builder/__init__.py | 2 +- mars/core/graph/builder/chunk.py | 116 ++++- mars/dataframe/merge/merge.py | 14 +- .../merge/tests/test_merge_execution.py | 6 +- .../chunk/tests/test_column_pruning.py | 10 +- .../logical/chunk/tests/test_head.py | 10 +- mars/optimization/physical/tests/test_cupy.py | 10 +- mars/oscar/backends/pool.py | 16 +- mars/services/task/analyzer/analyzer.py | 6 +- .../task/analyzer/tests/test_assigner.py | 2 +- mars/services/task/execution/api.py | 8 +- mars/services/task/execution/mars/executor.py | 32 +- mars/services/task/execution/ray/executor.py | 6 +- mars/services/task/supervisor/manager.py | 11 +- mars/services/task/supervisor/preprocessor.py | 18 +- mars/services/task/supervisor/processor.py | 421 ++--------------- mars/services/task/supervisor/task.py | 424 ++++++++++++++++++ .../supervisor/tests/task_preprocessor.py | 9 +- mars/services/task/tests/test_service.py | 84 +++- 23 files changed, 774 insertions(+), 455 deletions(-) create mode 100644 mars/services/task/supervisor/task.py diff --git a/mars/core/__init__.py b/mars/core/__init__.py index d2cf2ce335..abc0385114 100644 --- a/mars/core/__init__.py +++ b/mars/core/__init__.py @@ -61,5 +61,7 @@ ChunkGraph, TileableGraphBuilder, ChunkGraphBuilder, + TileContext, + TileStatus, ) from .mode import enter_mode, is_build_mode, is_eager_mode, is_kernel_mode diff --git a/mars/core/base.py b/mars/core/base.py index 4af63d0fbf..fe9ac60346 100644 --- a/mars/core/base.py +++ b/mars/core/base.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import wraps -from typing import Dict +from typing import Dict, Tuple, Type from ..serialization.core import Placeholder, fast_id from ..serialization.serializables import Serializable, StringField @@ -117,6 +117,15 @@ def key(self): def id(self): return self._id + def to_kv(self, exclude_fields: Tuple[str], accept_value_types: Tuple[Type]): + fields = self._FIELDS + field_values = self._FIELD_VALUES + return { + fields[attr_name].tag: value + for attr_name, value in field_values.items() + if attr_name not in exclude_fields and isinstance(value, accept_value_types) + } + def buffered_base(func): @wraps(func) diff --git a/mars/core/entity/utils.py b/mars/core/entity/utils.py index 45303770b2..ea3198d8e8 100644 --- a/mars/core/entity/utils.py +++ b/mars/core/entity/utils.py @@ -28,7 +28,12 @@ def refresh_tileable_shape(tileable): def tile(tileable, *tileables: TileableType): - from ..graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder + from ..graph import ( + TileableGraph, + TileableGraphBuilder, + ChunkGraphBuilder, + TileContext, + ) raw_tileables = target_tileables = [tileable] + list(tileables) target_tileables = [t.data if hasattr(t, "data") else t for t in target_tileables] @@ -38,7 +43,7 @@ def tile(tileable, *tileables: TileableType): next(tileable_graph_builder.build()) # tile - tile_context = dict() + tile_context = TileContext() chunk_graph_builder = ChunkGraphBuilder( tileable_graph, fuse_enabled=False, tile_context=tile_context ) diff --git a/mars/core/graph/__init__.py b/mars/core/graph/__init__.py index cee2235ad2..b62d5c463f 100644 --- a/mars/core/graph/__init__.py +++ b/mars/core/graph/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .builder import TileableGraphBuilder, ChunkGraphBuilder +from .builder import TileableGraphBuilder, ChunkGraphBuilder, TileContext, TileStatus from .core import DirectedGraph, DAG, GraphContainsCycleError from .entity import TileableGraph, ChunkGraph, EntityGraph diff --git a/mars/core/graph/builder/__init__.py b/mars/core/graph/builder/__init__.py index f6ea328d4a..e73bb0b792 100644 --- a/mars/core/graph/builder/__init__.py +++ b/mars/core/graph/builder/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .chunk import ChunkGraphBuilder +from .chunk import ChunkGraphBuilder, TileContext, TileStatus from .tileable import TileableGraphBuilder diff --git a/mars/core/graph/builder/chunk.py b/mars/core/graph/builder/chunk.py index 2807d497b4..f58d6cc70a 100644 --- a/mars/core/graph/builder/chunk.py +++ b/mars/core/graph/builder/chunk.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import functools from typing import ( Callable, Dict, @@ -35,6 +36,7 @@ tile_gen_type = Generator[List[ChunkType], List[ChunkType], List[TileableType]] +DEFAULT_UPDATED_PROGRESS = 0.4 @dataclasses.dataclass @@ -44,14 +46,84 @@ class _TileableHandler: last_need_processes: List[EntityType] = None +@dataclasses.dataclass +class _TileableTileInfo: + curr_iter: int + # incremental progress for this iteration + tile_progress: float + # newly generated chunks by a tileable in this iteration + generated_chunks: List[ChunkType] = dataclasses.field(default_factory=list) + + +class TileContext(Dict[TileableType, TileableType]): + _tileables = Set[TileableType] + _tileable_to_progress: Dict[TileableType, float] + _tileable_to_tile_infos: Dict[TileableType, List[_TileableTileInfo]] + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._tileables = None + self._tileable_to_progress = dict() + self._tileable_to_tile_infos = dict() + + def set_tileables(self, tileables: Set[TileableType]): + self._tileables = tileables + + def __setitem__(self, key, value): + self._tileable_to_progress.pop(key, None) + return super().__setitem__(key, value) + + def set_progress(self, tileable: TileableType, progress: float): + assert 0.0 <= progress <= 1.0 + last_progress = self._tileable_to_progress.get(tileable, 0.0) + self._tileable_to_progress[tileable] = max(progress, last_progress) + + def get_progress(self, tileable: TileableType) -> float: + if tileable in self: + return 1.0 + else: + return self._tileable_to_progress.get(tileable, 0.0) + + def get_all_progress(self) -> float: + return sum(self.get_progress(t) for t in self._tileables) / len(self._tileables) + + def record_tileable_tile_info( + self, tileable: TileableType, curr_iter: int, generated_chunks: List[ChunkType] + ): + if tileable not in self._tileable_to_tile_infos: + self._tileable_to_tile_infos[tileable] = [] + prev_progress = sum( + info.tile_progress for info in self._tileable_to_tile_infos[tileable] + ) + curr_progress = self.get_progress(tileable) + infos = self._tileable_to_tile_infos[tileable] + infos.append( + _TileableTileInfo( + curr_iter=curr_iter, + tile_progress=curr_progress - prev_progress, + generated_chunks=generated_chunks, + ) + ) + + def get_tileable_tile_infos(self) -> Dict[TileableType, List[_TileableTileInfo]]: + return {t: self._tileable_to_tile_infos.get(t, list()) for t in self._tileables} + + +@dataclasses.dataclass +class TileStatus: + entities: List[EntityType] = None + progress: float = None + + class Tiler: + _cur_iter: int _cur_chunk_graph: Optional[ChunkGraph] _tileable_handlers: Iterable[_TileableHandler] def __init__( self, tileable_graph: TileableGraph, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, processed_chunks: Set[ChunkType], chunk_to_fetch: Dict[ChunkType, ChunkType], add_nodes: Callable, @@ -60,13 +132,31 @@ def __init__( self._tile_context = tile_context self._processed_chunks = processed_chunks self._chunk_to_fetch = chunk_to_fetch - self._add_nodes = add_nodes + self._add_nodes = self._wrap_add_nodes(add_nodes) + self._curr_iter = 0 self._cur_chunk_graph = None self._tileable_handlers = ( _TileableHandler(tileable, self._tile_handler(tileable)) for tileable in tileable_graph.topological_iter() ) + def _wrap_add_nodes(self, add_nodes: Callable): + @functools.wraps(add_nodes) + def inner( + chunk_graph: ChunkGraph, + chunks: List[ChunkType], + visited: Set[ChunkType], + tileable: TileableType, + ): + prev_chunks = set(chunk_graph) + add_nodes(chunk_graph, chunks, visited) + new_chunks = set(chunk_graph) + self._tile_context.record_tileable_tile_info( + tileable, self._curr_iter, list(new_chunks - prev_chunks) + ) + + return inner + @staticmethod def _get_data(entity: EntityType): return entity.data if hasattr(entity, "data") else entity @@ -119,6 +209,17 @@ def _tile( ): try: need_process = next(tile_handler) + + if isinstance(need_process, TileStatus): + # process tile that returns progress + self._tile_context.set_progress(tileable, need_process.progress) + need_process = need_process.entities + else: + # if progress not specified, we just update 0.4 * rest progress + progress = self._tile_context.get_progress(tileable) + new_progress = progress + (1.0 - progress) * DEFAULT_UPDATED_PROGRESS + self._tile_context.set_progress(tileable, new_progress) + chunks = [] if need_process is not None: for t in need_process: @@ -127,7 +228,7 @@ def _tile( elif isinstance(t, TILEABLE_TYPE): to_update_tileables.append(self._get_data(t)) # not finished yet - self._add_nodes(chunk_graph, chunks.copy(), visited) + self._add_nodes(chunk_graph, chunks.copy(), visited, tileable) next_tileable_handlers.append( _TileableHandler(tileable, tile_handler, need_process) ) @@ -145,8 +246,8 @@ def _tile( if chunks is None: # pragma: no cover raise ValueError(f"tileable({out}) is still coarse after tile") chunks = [self._get_data(c) for c in chunks] - self._add_nodes(chunk_graph, chunks, visited) self._tile_context[out] = tiled_tileable + self._add_nodes(chunk_graph, chunks, visited, tileable) def _gen_result_chunks( self, @@ -227,6 +328,8 @@ def _iter(self): # prune unused chunks prune_chunk_graph(chunk_graph) + self._curr_iter += 1 + return to_update_tileables def __iter__(self): @@ -278,12 +381,13 @@ def __init__( self, graph: TileableGraph, fuse_enabled: bool = True, - tile_context: Dict[TileableType, TileableType] = None, + tile_context: TileContext = None, tiler_cls: Union[Type[Tiler], Callable] = None, ): super().__init__(graph) self.fuse_enabled = fuse_enabled - self.tile_context = dict() if tile_context is None else tile_context + self.tile_context = TileContext() if tile_context is None else tile_context + self.tile_context.set_tileables(set(graph)) self._processed_chunks: Set[ChunkType] = set() self._chunk_to_fetch: Dict[ChunkType, ChunkType] = dict() diff --git a/mars/dataframe/merge/merge.py b/mars/dataframe/merge/merge.py index 2dacd99a10..8d7dc52a92 100644 --- a/mars/dataframe/merge/merge.py +++ b/mars/dataframe/merge/merge.py @@ -21,7 +21,7 @@ import pandas as pd from ... import opcodes as OperandDef -from ...core import OutputType, recursive_tile +from ...core import OutputType, recursive_tile, TileStatus from ...core.context import get_context from ...core.operand import OperandStage, MapReduceOperand from ...serialization.serializables import ( @@ -609,7 +609,7 @@ def tile(cls, op: "DataFrameMerge"): auto_merge_before and len(left.chunks) + len(right.chunks) > auto_merge_threshold ): - yield [left, right] + left.chunks + right.chunks + yield TileStatus([left, right] + left.chunks + right.chunks, progress=0.2) left = auto_merge_chunks(ctx, left) right = auto_merge_chunks(ctx, right) @@ -626,7 +626,7 @@ def tile(cls, op: "DataFrameMerge"): right_on = _prepare_shuffle_on(op.right_index, op.right_on, op.on) if op.how == "inner" and op.bloom_filter: if has_unknown_shape(left, right): - yield left.chunks + right.chunks + yield TileStatus(left.chunks + right.chunks, progress=0.3) small_one = right if len(left.chunks) > len(right.chunks) else left logger.debug( "Apply bloom filter for operand %s, use DataFrame %s to build bloom filter.", @@ -637,7 +637,9 @@ def tile(cls, op: "DataFrameMerge"): *cls._apply_bloom_filter(left, right, left_on, right_on, op) ) # auto merge after bloom filter - yield [left, right] + left.chunks + right.chunks + yield TileStatus( + [left, right] + left.chunks + right.chunks, progress=0.5 + ) left = auto_merge_chunks(ctx, left) right = auto_merge_chunks(ctx, right) @@ -660,7 +662,9 @@ def tile(cls, op: "DataFrameMerge"): ): # if how=="inner", output data size will reduce greatly with high probability, # use auto_merge_chunks to combine small chunks. - yield ret[0].chunks # trigger execution for chunks + yield TileStatus( + ret[0].chunks, progress=0.8 + ) # trigger execution for chunks return [auto_merge_chunks(get_context(), ret[0])] else: return ret diff --git a/mars/dataframe/merge/tests/test_merge_execution.py b/mars/dataframe/merge/tests/test_merge_execution.py index 9063bac6c5..50a0f0d217 100644 --- a/mars/dataframe/merge/tests/test_merge_execution.py +++ b/mars/dataframe/merge/tests/test_merge_execution.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd +import pytest from ....core.graph.builder.utils import build_graph from ...datasource.dataframe import from_pandas @@ -597,7 +598,8 @@ def test_merge_with_bloom_filter(setup): ) -def test_merge_on_duplicate_columns(setup): +@pytest.mark.parametrize("auto_merge", ["none", "both", "before", "after"]) +def test_merge_on_duplicate_columns(setup, auto_merge): raw1 = pd.DataFrame( [["foo", 1, "bar"], ["bar", 2, "foo"], ["baz", 3, "foo"]], columns=["lkey", "value", "value"], @@ -611,7 +613,7 @@ def test_merge_on_duplicate_columns(setup): df1 = from_pandas(raw1, chunk_size=2) df2 = from_pandas(raw2, chunk_size=3) - r = df1.merge(df2, left_on="lkey", right_on="rkey", auto_merge="none") + r = df1.merge(df2, left_on="lkey", right_on="rkey", auto_merge=auto_merge) result = r.execute().fetch() expected = raw1.merge(raw2, left_on="lkey", right_on="rkey") pd.testing.assert_frame_equal(expected, result) diff --git a/mars/optimization/logical/chunk/tests/test_column_pruning.py b/mars/optimization/logical/chunk/tests/test_column_pruning.py index a65267cd90..953c3b2f0c 100644 --- a/mars/optimization/logical/chunk/tests/test_column_pruning.py +++ b/mars/optimization/logical/chunk/tests/test_column_pruning.py @@ -19,7 +19,13 @@ import pytest from ..... import dataframe as md -from .....core import enter_mode, TileableGraph, TileableGraphBuilder, ChunkGraphBuilder +from .....core import ( + enter_mode, + TileableGraph, + TileableGraphBuilder, + ChunkGraphBuilder, + TileContext, +) from .. import optimize @@ -47,7 +53,7 @@ def test_groupby_read_csv(gen_data1): df2 = df1[["a", "b"]] graph = TileableGraph([df2.data]) next(TileableGraphBuilder(graph).build()) - context = dict() + context = TileContext() chunk_graph_builder = ChunkGraphBuilder( graph, fuse_enabled=False, tile_context=context ) diff --git a/mars/optimization/logical/chunk/tests/test_head.py b/mars/optimization/logical/chunk/tests/test_head.py index 1ec37fe9c9..42db361ced 100644 --- a/mars/optimization/logical/chunk/tests/test_head.py +++ b/mars/optimization/logical/chunk/tests/test_head.py @@ -19,7 +19,13 @@ import pytest from ..... import dataframe as md -from .....core import enter_mode, TileableGraph, TileableGraphBuilder, ChunkGraphBuilder +from .....core import ( + enter_mode, + TileableGraph, + TileableGraphBuilder, + ChunkGraphBuilder, + TileContext, +) from .. import optimize @@ -47,7 +53,7 @@ def test_read_csv_head(gen_data1): df2 = df1.head(5) graph = TileableGraph([df2.data]) next(TileableGraphBuilder(graph).build()) - context = dict() + context = TileContext() chunk_graph_builder = ChunkGraphBuilder( graph, fuse_enabled=False, tile_context=context ) diff --git a/mars/optimization/physical/tests/test_cupy.py b/mars/optimization/physical/tests/test_cupy.py index c863bb47e0..646891c492 100644 --- a/mars/optimization/physical/tests/test_cupy.py +++ b/mars/optimization/physical/tests/test_cupy.py @@ -13,7 +13,13 @@ # limitations under the License. from .... import tensor as mt -from ....core import enter_mode, TileableGraph, TileableGraphBuilder, ChunkGraphBuilder +from ....core import ( + enter_mode, + TileableGraph, + TileableGraphBuilder, + ChunkGraphBuilder, + TileContext, +) from ..cupy import CupyRuntimeOptimizer @@ -25,7 +31,7 @@ def test_cupy(): graph = TileableGraph([t.data]) next(TileableGraphBuilder(graph).build()) - context = dict() + context = TileContext() chunk_graph_builder = ChunkGraphBuilder( graph, fuse_enabled=False, tile_context=context ) diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index 3f62ffcc74..dce0f68bfb 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -19,6 +19,7 @@ import multiprocessing import os import threading +import traceback from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional @@ -361,10 +362,19 @@ async def _send_channel( with _ErrorProcessor( self.external_address, result.message_id, result.protocol ) as processor: - raise SendMessageFailed( + error_msg = ( f"Error when sending message {result.message_id.hex()}. " - f"Caused by {ex!r}. See server logs for more details" - ) from None + f"Caused by {ex!r}. " + ) + if isinstance(result, ErrorMessage): + format_tb = "\n".join(traceback.format_tb(result.traceback)) + error_msg += ( + f"\nOriginal error: {result.error!r}" + f"Traceback: \n{format_tb}" + ) + else: + error_msg += "See server logs for more details" + raise SendMessageFailed(error_msg) from None await self._send_channel(processor.result, channel, resend_failure=False) async def process_message(self, message: _MessageBase, channel: Channel): diff --git a/mars/services/task/analyzer/analyzer.py b/mars/services/task/analyzer/analyzer.py index 714172fc13..afcfbdb994 100644 --- a/mars/services/task/analyzer/analyzer.py +++ b/mars/services/task/analyzer/analyzer.py @@ -38,6 +38,7 @@ def __init__( band_resource: Dict[BandType, Resource], task: Task, config: Config, + chunk_to_subtasks: Dict[ChunkType, Subtask], graph_assigner_cls: Type[AbstractGraphAssigner] = None, stage_id: str = None, ): @@ -48,6 +49,7 @@ def __init__( self._config = config self._fuse_enabled = task.fuse_enabled self._extra_config = task.extra_config + self._chunk_to_subtasks = chunk_to_subtasks if graph_assigner_cls is None: graph_assigner_cls = GraphAssigner self._graph_assigner_cls = graph_assigner_cls @@ -291,8 +293,6 @@ def gen_subtask_graph( ------- subtask_graph: SubtaskGraph Subtask graph. - op_to_bands: Dict - Assigned operand's band, usually for fetch operands. """ reassign_worker_ops = [ chunk.op for chunk in self._chunk_graph if chunk.op.reassign_worker @@ -372,7 +372,7 @@ def gen_subtask_graph( # gen subtask graph subtask_graph = SubtaskGraph() chunk_to_fetch_chunk = dict() - chunk_to_subtask = dict() + chunk_to_subtask = self._chunk_to_subtasks # states visited = set() logic_key_to_subtasks = defaultdict(list) diff --git a/mars/services/task/analyzer/tests/test_assigner.py b/mars/services/task/analyzer/tests/test_assigner.py index 57dc64cc26..ea93453059 100644 --- a/mars/services/task/analyzer/tests/test_assigner.py +++ b/mars/services/task/analyzer/tests/test_assigner.py @@ -52,7 +52,7 @@ def test_assigner_with_fetch_inputs(): band_resource = dict((band, Resource(num_cpus=1)) for band in all_bands) task = Task("mock_task", "mock_session") - analyzer = GraphAnalyzer(chunk_graph, band_resource, task, Config()) + analyzer = GraphAnalyzer(chunk_graph, band_resource, task, Config(), dict()) subtask_graph = analyzer.gen_subtask_graph(cur_assigns) assigner = GraphAssigner( diff --git a/mars/services/task/execution/api.py b/mars/services/task/execution/api.py index 184a26ac66..264d024976 100644 --- a/mars/services/task/execution/api.py +++ b/mars/services/task/execution/api.py @@ -16,9 +16,9 @@ from dataclasses import dataclass from typing import List, Dict, Any, Type -from ....core import ChunkGraph, Chunk +from ....core import ChunkGraph, Chunk, TileContext from ....resource import Resource -from ....typing import BandType, TileableType +from ....typing import BandType from ...subtask import SubtaskGraph, SubtaskResult @@ -40,7 +40,7 @@ async def create( session_id: str, address: str, task, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, **kwargs, ) -> "TaskExecutor": name = config.get("backend", "mars") @@ -68,7 +68,7 @@ async def execute_subtask_graph( stage_id: str, subtask_graph: SubtaskGraph, chunk_graph: ChunkGraph, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, context: Any = None, ) -> Dict[Chunk, ExecutionChunkResult]: """Execute a subtask graph and returns result.""" diff --git a/mars/services/task/execution/mars/executor.py b/mars/services/task/execution/mars/executor.py index 7609a8291d..f77b693c24 100644 --- a/mars/services/task/execution/mars/executor.py +++ b/mars/services/task/execution/mars/executor.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Set from ..... import oscar as mo -from .....core import ChunkGraph +from .....core import ChunkGraph, TileContext from .....core.operand import ( Fetch, MapReduceOperand, @@ -60,6 +60,7 @@ def _get_n_reducer(subtask: Subtask) -> int: class MarsTaskExecutor(TaskExecutor): name = "mars" _stage_processors: List[TaskStageProcessor] + _stage_tile_progresses: List[float] _cur_stage_processor: Optional[TaskStageProcessor] _meta_updated_tileables: Set[TileableType] @@ -67,7 +68,7 @@ def __init__( self, config: Dict, task: Task, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, cluster_api: ClusterAPI, lifecycle_api: LifecycleAPI, scheduling_api: SchedulingAPI, @@ -87,6 +88,7 @@ def __init__( self._meta_api = meta_api self._stage_processors = [] + self._stage_tile_progresses = [] self._cur_stage_processor = None self._lifecycle_processed_tileables = set() self._subtask_decref_events = dict() @@ -100,7 +102,7 @@ async def create( session_id: str, address: str, task: Task, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, **kwargs, ) -> "TaskExecutor": assert ( @@ -141,7 +143,7 @@ async def execute_subtask_graph( stage_id: str, subtask_graph: SubtaskGraph, chunk_graph: ChunkGraph, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, context=None, ): available_bands = await self.get_available_band_resources() @@ -162,6 +164,10 @@ async def execute_subtask_graph( resource_evaluator.evaluate() self._stage_processors.append(stage_processor) self._cur_stage_processor = stage_processor + # get the tiled progress for current stage + prev_progress = sum(self._stage_tile_progresses) + curr_tile_progress = self._tile_context.get_all_progress() - prev_progress + self._stage_tile_progresses.append(curr_tile_progress) return await stage_processor.run() async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -186,10 +192,11 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]: async def get_progress(self) -> float: # get progress of stages - subtask_progress = 0.0 - n_stage = 0 - - for stage_processor in self._stage_processors: + executor_progress = 0.0 + assert len(self._stage_tile_progresses) == len(self._stage_processors) + for stage_processor, stage_tile_progress in zip( + self._stage_processors, self._stage_tile_progresses + ): if stage_processor.subtask_graph is None: # pragma: no cover # generating subtask continue @@ -204,12 +211,9 @@ async def get_progress(self) -> float: for subtask_key, result in stage_processor.subtask_snapshots.items() if subtask_key not in stage_processor.subtask_results ) - subtask_progress += progress / n_subtask - n_stage += 1 - if n_stage > 0: - subtask_progress /= n_stage - - return subtask_progress + subtask_progress = progress / n_subtask + executor_progress += subtask_progress * stage_tile_progress + return executor_progress async def cancel(self): if self._cur_stage_processor is not None: diff --git a/mars/services/task/execution/ray/executor.py b/mars/services/task/execution/ray/executor.py index 37a1252b86..fc00467cdf 100644 --- a/mars/services/task/execution/ray/executor.py +++ b/mars/services/task/execution/ray/executor.py @@ -15,7 +15,7 @@ import asyncio import logging from typing import List, Dict, Any, Set -from .....core import ChunkGraph, Chunk +from .....core import ChunkGraph, Chunk, TileContext from .....core.operand import ( Fuse, VirtualOperand, @@ -26,7 +26,7 @@ from .....optimization.physical import optimize from .....resource import Resource from .....serialization import serialize, deserialize -from .....typing import BandType, TileableType +from .....typing import BandType from .....utils import ( lazy_import, get_chunk_params, @@ -147,7 +147,7 @@ async def execute_subtask_graph( stage_id: str, subtask_graph: SubtaskGraph, chunk_graph: ChunkGraph, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, context: Any = None, ) -> Dict[Chunk, ExecutionChunkResult]: logger.info("Stage %s start.", stage_id) diff --git a/mars/services/task/supervisor/manager.py b/mars/services/task/supervisor/manager.py index 5ef5a9b6c3..8571098b16 100644 --- a/mars/services/task/supervisor/manager.py +++ b/mars/services/task/supervisor/manager.py @@ -21,7 +21,7 @@ from typing import Any, Dict, List, Type, Union from .... import oscar as mo -from ....core import TileableGraph, TileableType, enter_mode +from ....core import TileableGraph, TileableType, enter_mode, TileContext from ....core.context import set_context from ....core.operand import Fetch from ...context import ThreadedServiceContext @@ -30,7 +30,8 @@ from ..core import Task, new_task_id, TaskStatus from ..errors import TaskNotExist from .preprocessor import TaskPreprocessor -from .processor import TaskProcessorActor, TaskProcessor +from .processor import TaskProcessor +from .task import TaskProcessorActor logger = logging.getLogger(__name__) @@ -225,11 +226,9 @@ async def get_tileable_subtasks(self, task_id, tileable_id, with_input_output): return await processor_ref.get_tileable_subtasks(tileable_id, with_input_output) - async def _gen_tiled_context( - self, graph: TileableGraph - ) -> Dict[TileableType, TileableType]: + async def _gen_tiled_context(self, graph: TileableGraph) -> TileContext: # process graph, add fetch node to tiled context - tiled_context = dict() + tiled_context = TileContext() for tileable in graph: if isinstance(tileable.op, Fetch) and tileable.is_coarse(): info = self._tileable_key_to_info[tileable.key][-1] diff --git a/mars/services/task/supervisor/preprocessor.py b/mars/services/task/supervisor/preprocessor.py index 95858b52bb..f441639462 100644 --- a/mars/services/task/supervisor/preprocessor.py +++ b/mars/services/task/supervisor/preprocessor.py @@ -19,7 +19,7 @@ from typing import Callable, Dict, List, Iterable, Set from ....config import Config -from ....core import TileableGraph, ChunkGraph, ChunkGraphBuilder +from ....core import TileableGraph, ChunkGraph, ChunkGraphBuilder, TileContext from ....core.graph.builder.chunk import Tiler, _TileableHandler from ....core.operand import Fetch from ....resource import Resource @@ -27,7 +27,7 @@ from ....optimization.logical.chunk import optimize as optimize_chunk_graph from ....optimization.logical.tileable import optimize as optimize_tileable_graph from ....typing import BandType -from ...subtask import SubtaskGraph +from ...subtask import Subtask, SubtaskGraph from ..analyzer import GraphAnalyzer from ..core import Task @@ -38,7 +38,7 @@ class CancellableTiler(Tiler): def __init__( self, tileable_graph: TileableGraph, - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, processed_chunks: Set[ChunkType], chunk_to_fetch: Dict[ChunkType, ChunkType], add_nodes: Callable, @@ -117,12 +117,12 @@ class TaskPreprocessor: "_done", ) - tile_context: Dict[TileableType, TileableType] + tile_context: TileContext def __init__( self, task: Task, - tiled_context: Dict[TileableType, TileableType] = None, + tiled_context: TileContext = None, config: Config = None, ): self._task = task @@ -204,6 +204,7 @@ def post_chunk_graph_execution(self): # pylint: disable=no-self-use def analyze( self, chunk_graph: ChunkGraph, + chunk_to_subtasks: Dict[ChunkType, Subtask], available_bands: Dict[BandType, Resource], stage_id: str = None, op_to_bands: Dict[str, BandType] = None, @@ -211,7 +212,12 @@ def analyze( logger.debug("Start to gen subtask graph for task %s", self._task.task_id) task = self._task analyzer = GraphAnalyzer( - chunk_graph, available_bands, task, self._config, stage_id=stage_id + chunk_graph, + available_bands, + task, + self._config, + chunk_to_subtasks, + stage_id=stage_id, ) graph = analyzer.gen_subtask_graph(op_to_bands) logger.debug( diff --git a/mars/services/task/supervisor/processor.py b/mars/services/task/supervisor/processor.py index ddbb418c16..bb1d7dea14 100644 --- a/mars/services/task/supervisor/processor.py +++ b/mars/services/task/supervisor/processor.py @@ -13,20 +13,14 @@ # limitations under the License. import asyncio -import importlib import logging -import operator import os import tempfile import time -from collections import defaultdict -from functools import reduce -from typing import Dict, Iterator, Optional, Type, List, Set - -from .... import oscar as mo -from ....config import Config -from ....core import ChunkGraph, TileableGraph, Chunk -from ....core.operand import Fetch, FetchShuffle +from typing import Dict, Iterator, Optional, List, Set + +from ....core import ChunkGraph, TileableGraph, Chunk, TileContext +from ....core.operand import Fetch from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE from ....metrics import Metrics from ....optimization.logical import OptimizationRecords @@ -35,9 +29,9 @@ MARS_ENABLE_PROFILING, ) from ....tensor.core import TENSOR_TYPE -from ....typing import TileableType -from ....utils import build_fetch, Timer -from ...subtask import SubtaskResult, SubtaskStatus, SubtaskGraph, Subtask +from ....typing import TileableType, ChunkType +from ....utils import Timer +from ...subtask import SubtaskResult, Subtask from ..core import Task, TaskResult, TaskStatus, new_task_id from ..execution.api import TaskExecutor, ExecutionChunkResult from .preprocessor import TaskPreprocessor @@ -50,6 +44,7 @@ class TaskProcessor: _tileable_to_subtasks: Dict[TileableType, List[Subtask]] _tileable_id_to_tileable: Dict[str, TileableType] + _chunk_to_subtasks: Dict[ChunkType, Subtask] _stage_tileables: Set[TileableType] def __init__( @@ -62,8 +57,8 @@ def __init__( self._preprocessor = preprocessor self._executor = executor - self._tileable_to_subtasks = dict() self._tileable_id_to_tileable = dict() + self._chunk_to_subtasks = dict() self._stage_tileables = set() if MARS_ENABLE_PROFILING: @@ -110,14 +105,14 @@ def task_id(self): def tileable_graph(self): return self._preprocessor.tileable_graph - @property - def tileable_to_subtasks(self): - return self._tileable_to_subtasks - @property def tileable_id_to_tileable(self): return self._tileable_id_to_tileable + @property + def tile_context(self) -> TileContext: + return self._preprocessor.tile_context + @property def stage_processors(self): # TODO(fyrestone): Remove it. @@ -126,6 +121,22 @@ def stage_processors(self): def get_tiled(self, tileable: TileableType): return self._preprocessor.get_tiled(tileable) + def get_subtasks(self, chunks: List[ChunkType]) -> List[Subtask]: + return [self._chunk_to_subtasks[chunk] for chunk in chunks] + + def get_tileable_to_subtasks(self) -> Dict[TileableType, List[Subtask]]: + tile_context = self.tile_context + result = dict() + for tileable, infos in tile_context.get_tileable_tile_infos().items(): + subtasks = [] + for info in infos: + chunks = [ + c for c in info.generated_chunks if not isinstance(c.op, Fetch) + ] + subtasks.extend(self.get_subtasks(chunks)) + result[tileable] = subtasks + return result + @staticmethod async def _get_next_chunk_graph( chunk_graph_iter: Iterator[ChunkGraph], @@ -202,6 +213,7 @@ async def _process_stage_chunk_graph( subtask_graph = await asyncio.to_thread( self._preprocessor.analyze, chunk_graph, + self._chunk_to_subtasks, available_bands, stage_id=stage_id, op_to_bands=fetch_op_to_bands, @@ -223,14 +235,6 @@ async def _process_stage_chunk_graph( }, ) - tileable_to_subtasks = await asyncio.to_thread( - self._get_tileable_to_subtasks, - self._preprocessor.tileable_graph, - self._preprocessor.tile_context, - subtask_graph, - ) - self._tileable_to_subtasks.update(tileable_to_subtasks) - tile_context = await asyncio.to_thread( self._get_stage_tile_context, {c for c in chunk_graph.result_chunks if not isinstance(c.op, Fetch)}, @@ -251,9 +255,9 @@ async def _process_stage_chunk_graph( optimization_records = None self._update_stage_meta(chunk_to_result, tile_context, optimization_records) - def _get_stage_tile_context(self, result_chunks: Set[Chunk]): + def _get_stage_tile_context(self, result_chunks: Set[Chunk]) -> TileContext: collected = self._stage_tileables - tile_context = {} + tile_context = TileContext() for tileable in self.tileable_graph: if tileable in collected: continue @@ -270,7 +274,7 @@ def _get_stage_tile_context(self, result_chunks: Set[Chunk]): def _update_stage_meta( cls, chunk_to_result: Dict[Chunk, ExecutionChunkResult], - tile_context: Dict[TileableType, TileableType], + tile_context: TileContext, optimization_records: OptimizationRecords, ): for tiled_tileable in tile_context.values(): @@ -364,12 +368,9 @@ async def run(self): self._gen_result() self._finish() - async def get_progress(self): + async def get_progress(self) -> float: # get tileable proportion that is tiled - tileable_graph = self._preprocessor.tileable_graph - tileable_context = self._preprocessor.tile_context - tiled_percentage = len(tileable_context) / len(tileable_graph) - return tiled_percentage * await self._executor.get_progress() + return await self._executor.get_progress() async def cancel(self): self._preprocessor.cancel() @@ -378,45 +379,6 @@ async def cancel(self): async def set_subtask_result(self, subtask_result: SubtaskResult): await self._executor.set_subtask_result(subtask_result) - @staticmethod - def _get_tileable_to_subtasks( - tileable_graph: TileableGraph, - tile_context: Dict[TileableType, TileableType], - subtask_graph: SubtaskGraph, - ) -> Dict[TileableType, List[Subtask]]: - tileable_to_chunks = defaultdict(set) - chunk_to_subtasks = dict() - - for tileable in tileable_graph: - if tileable not in tile_context: - continue - for chunk in tile_context[tileable].chunks: - tileable_to_chunks[tileable].add(chunk.key) - # register chunk mapping for tiled terminals - chunk_to_subtasks[chunk.key] = set() - - for subtask in subtask_graph: - for chunk in subtask.chunk_graph: - # for every non-fuse chunks (including fused), - # register subtasks if needed - if ( - isinstance(chunk.op, (FetchShuffle, Fetch)) - or chunk.key not in chunk_to_subtasks - ): - continue - chunk_to_subtasks[chunk.key].add(subtask) - - tileable_to_subtasks = dict() - # collect subtasks for tileables - for tileable, chunk_keys in tileable_to_chunks.items(): - tileable_to_subtasks[tileable] = list( - reduce( - operator.or_, - [chunk_to_subtasks[chunk_key] for chunk_key in chunk_keys], - ) - ) - return tileable_to_subtasks - @staticmethod def _get_tileable_id_to_tileable( tileable_graph: TileableGraph, @@ -487,316 +449,3 @@ def _finish(self): def is_done(self) -> bool: return self.done.is_set() - - -class TaskProcessorActor(mo.Actor): - _task_id_to_processor: Dict[str, TaskProcessor] - _cur_processor: Optional[TaskProcessor] - - def __init__( - self, - session_id: str, - task_id: str, - task_name: str = None, - task_processor_cls: Type[TaskPreprocessor] = None, - ): - self.session_id = session_id - self.task_id = task_id - self.task_name = task_name - - self._task_processor_cls = self._get_task_processor_cls(task_processor_cls) - self._task_id_to_processor = dict() - self._cur_processor = None - - @classmethod - def gen_uid(cls, session_id: str, task_id: str): - return f"task_processor_{session_id}_{task_id}" - - async def add_task( - self, - task: Task, - tiled_context: Dict[TileableType, TileableType], - config: Config, - task_executor_config: Dict, - task_preprocessor_cls: Type[TaskPreprocessor], - ): - task_preprocessor = task_preprocessor_cls( - task, tiled_context=tiled_context, config=config - ) - task_executor = await TaskExecutor.create( - task_executor_config, - task=task, - session_id=self.session_id, - address=self.address, - tile_context=task_preprocessor.tile_context, - ) - processor = self._task_processor_cls( - task, - task_preprocessor, - task_executor, - ) - self._task_id_to_processor[task.task_id] = processor - - # tell self to start running - await self.ref().start.tell() - - @classmethod - def _get_task_processor_cls(cls, task_processor_cls): - if task_processor_cls is not None: - assert isinstance(task_processor_cls, str) - module, name = task_processor_cls.rsplit(".", 1) - return getattr(importlib.import_module(module), name) - else: - return TaskProcessor - - def _get_unprocessed_task_processor(self): - for processor in self._task_id_to_processor.values(): - if processor.result.status == TaskStatus.pending: - return processor - - async def start(self): - if self._cur_processor is not None: # pragma: no cover - # some processor is running - return - - processor = self._get_unprocessed_task_processor() - if processor is None: # pragma: no cover - return - self._cur_processor = processor - try: - yield processor.run() - finally: - self._cur_processor = None - - async def wait(self, timeout: int = None): - fs = [ - asyncio.ensure_future(processor.done.wait()) - for processor in self._task_id_to_processor.values() - ] - - _, pending = yield asyncio.wait(fs, timeout=timeout) - if not pending: - raise mo.Return(self.result()) - else: - [fut.cancel() for fut in pending] - - async def cancel(self): - if self._cur_processor: - await self._cur_processor.cancel() - - def result(self): - terminated_result = None - for processor in self._task_id_to_processor.values(): - if processor.result.status != TaskStatus.terminated: - return processor.result - else: - terminated_result = processor.result - return terminated_result - - async def progress(self): - processor_progresses = [ - await processor.get_progress() - for processor in self._task_id_to_processor.values() - ] - return sum(processor_progresses) / len(processor_progresses) - - def get_result_tileables(self): - processor = list(self._task_id_to_processor.values())[-1] - tileable_graph = processor.tileable_graph - result = [] - for result_tileable in tileable_graph.result_tileables: - tiled = processor.get_tiled(result_tileable) - result.append(build_fetch(tiled)) - return result - - def get_subtask_graphs(self, task_id: str) -> List[SubtaskGraph]: - return [ - stage_processor.subtask_graph - for stage_processor in self._task_id_to_processor[task_id].stage_processors - ] - - def get_tileable_graph_as_dict(self): - processor = list(self._task_id_to_processor.values())[-1] - tileable_graph = processor.tileable_graph - - node_list = [] - edge_list = [] - - visited = set() - - for chunk in tileable_graph: - if chunk.key in visited: - continue - visited.add(chunk.key) - - node_name = str(chunk.op) - - node_list.append({"tileableId": chunk.key, "tileableName": node_name}) - for inp, is_pure_dep in zip(chunk.inputs, chunk.op.pure_depends): - if inp not in tileable_graph: # pragma: no cover - continue - edge_list.append( - { - "fromTileableId": inp.key, - "toTileableId": chunk.key, - "linkType": 1 if is_pure_dep else 0, - } - ) - - graph_dict = {"tileables": node_list, "dependencies": edge_list} - return graph_dict - - def get_tileable_details(self): - tileable_to_subtasks = dict() - subtask_results = dict() - - for processor in self._task_id_to_processor.values(): - tileable_to_subtasks.update(processor.tileable_to_subtasks) - for stage in processor.stage_processors: - for subtask, result in stage.subtask_results.items(): - subtask_results[subtask.subtask_id] = result - for subtask, result in stage.subtask_snapshots.items(): - if subtask.subtask_id in subtask_results: - continue - subtask_results[subtask.subtask_id] = result - - tileable_infos = dict() - for tileable, subtasks in tileable_to_subtasks.items(): - results = [ - subtask_results.get( - subtask.subtask_id, - SubtaskResult( - progress=0.0, - status=SubtaskStatus.pending, - stage_id=subtask.stage_id, - ), - ) - for subtask in subtasks - ] - - # calc progress - if not results: # pragma: no cover - progress = 1.0 - else: - progress = ( - 1.0 * sum(result.progress for result in results) / len(results) - ) - - # calc status - statuses = set(result.status for result in results) - if not results or statuses == {SubtaskStatus.succeeded}: - status = SubtaskStatus.succeeded - elif statuses == {SubtaskStatus.cancelled}: - status = SubtaskStatus.cancelled - elif statuses == {SubtaskStatus.pending}: - status = SubtaskStatus.pending - elif SubtaskStatus.errored in statuses: - status = SubtaskStatus.errored - else: - status = SubtaskStatus.running - - fields = tileable.op._FIELDS - field_values = tileable.op._FIELD_VALUES - props = { - fields[attr_name].tag: value - for attr_name, value in field_values.items() - if attr_name not in ("_key", "_id") - and isinstance(value, (int, float, str)) - } - - tileable_infos[tileable.key] = { - "progress": progress, - "subtaskCount": len(results), - "status": status.value, - "properties": props, - } - - return tileable_infos - - def get_tileable_subtasks(self, tileable_id: str, with_input_output: bool): - returned_subtasks = dict() - subtask_id_to_types = dict() - - subtask_details = dict() - subtask_graph = subtask_results = subtask_snapshots = None - for processor in self._task_id_to_processor.values(): - tileable_to_subtasks = processor.tileable_to_subtasks - tileable_id_to_tileable = processor.tileable_id_to_tileable - for stage in processor.stage_processors: - if tileable_id in tileable_id_to_tileable: - tileable = tileable_id_to_tileable[tileable_id] - returned_subtasks = { - subtask.subtask_id: subtask - for subtask in tileable_to_subtasks[tileable] - } - subtask_graph = stage.subtask_graph - subtask_results = stage.subtask_results - subtask_snapshots = stage.subtask_snapshots - break - if returned_subtasks: - break - - if subtask_graph is None: # pragma: no cover - return {} - - if with_input_output: - for subtask in list(returned_subtasks.values()): - for pred in subtask_graph.iter_predecessors(subtask): - if pred.subtask_id in returned_subtasks: # pragma: no cover - continue - returned_subtasks[pred.subtask_id] = pred - subtask_id_to_types[pred.subtask_id] = "Input" - for succ in subtask_graph.iter_successors(subtask): - if succ.subtask_id in returned_subtasks: # pragma: no cover - continue - returned_subtasks[succ.subtask_id] = succ - subtask_id_to_types[succ.subtask_id] = "Output" - - for subtask in returned_subtasks.values(): - subtask_result = subtask_results.get( - subtask, - subtask_snapshots.get( - subtask, - SubtaskResult( - progress=0.0, - status=SubtaskStatus.pending, - stage_id=subtask.stage_id, - ), - ), - ) - subtask_details[subtask.subtask_id] = { - "name": subtask.subtask_name, - "status": subtask_result.status.value, - "progress": subtask_result.progress, - "nodeType": subtask_id_to_types.get(subtask.subtask_id, "Calculation"), - } - - for subtask in returned_subtasks.values(): - pred_ids = [] - for pred in subtask_graph.iter_predecessors(subtask): - if pred.subtask_id in returned_subtasks: - pred_ids.append(pred.subtask_id) - subtask_details[subtask.subtask_id]["fromSubtaskIds"] = pred_ids - return subtask_details - - def get_result_tileable(self, tileable_key: str): - processor = list(self._task_id_to_processor.values())[-1] - tileable_graph = processor.tileable_graph - for result_tileable in tileable_graph.result_tileables: - if result_tileable.key == tileable_key: - tiled = processor.get_tiled(result_tileable) - return build_fetch(tiled) - raise KeyError(f"Tileable {tileable_key} does not exist") # pragma: no cover - - async def set_subtask_result(self, subtask_result: SubtaskResult): - logger.debug( - "Set subtask %s with result %s.", subtask_result.subtask_id, subtask_result - ) - if self._cur_processor is not None: - await self._cur_processor.set_subtask_result(subtask_result) - - def is_done(self) -> bool: - for processor in self._task_id_to_processor.values(): - if not processor.is_done(): - return False - return True diff --git a/mars/services/task/supervisor/task.py b/mars/services/task/supervisor/task.py new file mode 100644 index 0000000000..dc741addc7 --- /dev/null +++ b/mars/services/task/supervisor/task.py @@ -0,0 +1,424 @@ +# Copyright 1999-2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import dataclasses +import importlib +import logging +from typing import Any, Dict, Optional, Set, Type, List + +from .... import oscar as mo +from ....config import Config +from ....core import TileContext +from ....core.operand import Fetch +from ....typing import TileableType +from ....utils import build_fetch +from ...subtask import SubtaskResult, SubtaskStatus, SubtaskGraph +from ..core import Task, TaskStatus +from ..execution.api import TaskExecutor +from .preprocessor import TaskPreprocessor +from .processor import TaskProcessor + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class _TileableStageInfo: + progress: float + subtask_ids: Set[str] + + +@dataclasses.dataclass +class _TileableDetailInfo: + progress: float + subtask_count: int + status: int + properties: Dict[str, Any] + + +class _TaskInfoProcessorMixin: + _task_id_to_processor: Dict[str, TaskProcessor] + _tileable_to_details_cache: Dict[TileableType, _TileableDetailInfo] + + def _init_cache(self): + try: + return self._tileable_to_details_cache + except AttributeError: + cache = self._tileable_to_details_cache = dict() + return cache + + def _get_all_subtask_results(self) -> Dict[str, SubtaskResult]: + subtask_results = dict() + for processor in self._task_id_to_processor.values(): + for stage in processor.stage_processors: + for subtask, result in stage.subtask_results.items(): + subtask_results[subtask.subtask_id] = result + for subtask, result in stage.subtask_snapshots.items(): + if subtask.subtask_id in subtask_results: + continue + subtask_results[subtask.subtask_id] = result + return subtask_results + + def _get_tileable_infos(self) -> Dict[TileableType, _TileableDetailInfo]: + cache = self._init_cache() + + tileable_to_stage_infos: Dict[TileableType, List[_TileableStageInfo]] = dict() + for processor in self._task_id_to_processor.values(): + tile_context = processor.tile_context + for tileable, infos in tile_context.get_tileable_tile_infos().items(): + tileable_to_stage_infos[tileable] = [] + if tileable in cache: + # cached + continue + for info in infos: + chunks = [ + c for c in info.generated_chunks if not isinstance(c.op, Fetch) + ] + try: + subtask_ids = { + st.subtask_id for st in processor.get_subtasks(chunks) + } + except KeyError: # pragma: no cover + subtask_ids = None + stage_info = _TileableStageInfo( + progress=info.tile_progress, subtask_ids=subtask_ids + ) + tileable_to_stage_infos[tileable].append(stage_info) + + tileable_to_defails = dict() + subtask_id_to_results = self._get_all_subtask_results() + for tileable, infos in tileable_to_stage_infos.items(): + if tileable in cache: + # cached + tileable_to_defails[tileable] = cache[tileable] + continue + + statuses = set() + progress = 0.0 if not isinstance(tileable.op, Fetch) else 1.0 + n_subtask = 0 + for stage_info in infos: + tile_progress = stage_info.progress + stage_progress = 0.0 + if stage_info.subtask_ids is None: + continue + for subtask_id in stage_info.subtask_ids: + try: + result = subtask_id_to_results[subtask_id] + stage_progress += result.progress * tile_progress + statuses.add(result.status) + except KeyError: + # pending + statuses.add(SubtaskStatus.pending) + n_subtask += len(stage_info.subtask_ids) + if stage_info.subtask_ids: + progress += stage_progress / len(stage_info.subtask_ids) + else: + progress += tile_progress + + # calc status + if (not statuses or statuses == {SubtaskStatus.succeeded}) and abs( + progress - 1.0 + ) < 1e-3: + status = SubtaskStatus.succeeded + elif statuses == {SubtaskStatus.cancelled}: + status = SubtaskStatus.cancelled + elif statuses == {SubtaskStatus.pending}: + status = SubtaskStatus.pending + elif SubtaskStatus.errored in statuses: + status = SubtaskStatus.errored + else: + status = SubtaskStatus.running + + props = tileable.op.to_kv( + exclude_fields=("_key", "_id"), accept_value_types=(int, float, str) + ) + info = _TileableDetailInfo( + progress=progress, + subtask_count=n_subtask, + status=status.value, + properties=props, + ) + tileable_to_defails[tileable] = info + if status.is_done and tileable not in cache: + cache[tileable] = info + + return tileable_to_defails + + async def get_tileable_details(self): + tileable_to_details = yield asyncio.to_thread(self._get_tileable_infos) + raise mo.Return( + { + t.key: { + "progress": info.progress, + "subtaskCount": info.subtask_count, + "status": info.status, + "properties": info.properties, + } + for t, info in tileable_to_details.items() + } + ) + + def _get_tileable_graph_as_dict(self): + processor = list(self._task_id_to_processor.values())[-1] + tileable_graph = processor.tileable_graph + + node_list = [] + edge_list = [] + + visited = set() + + for chunk in tileable_graph: + if chunk.key in visited: # pragma: no cover + continue + visited.add(chunk.key) + + node_name = str(chunk.op) + + node_list.append({"tileableId": chunk.key, "tileableName": node_name}) + for inp, is_pure_dep in zip(chunk.inputs, chunk.op.pure_depends): + if inp not in tileable_graph: # pragma: no cover + continue + edge_list.append( + { + "fromTileableId": inp.key, + "toTileableId": chunk.key, + "linkType": 1 if is_pure_dep else 0, + } + ) + + graph_dict = {"tileables": node_list, "dependencies": edge_list} + return graph_dict + + async def get_tileable_graph_as_dict(self): + return await asyncio.to_thread(self._get_tileable_graph_as_dict) + + def _get_tileable_subtasks(self, tileable_id: str, with_input_output: bool): + returned_subtasks = dict() + subtask_id_to_types = dict() + + subtask_details = dict() + subtask_graph = subtask_results = subtask_snapshots = None + for processor in self._task_id_to_processor.values(): + tileable_to_subtasks = processor.get_tileable_to_subtasks() + tileable_id_to_tileable = processor.tileable_id_to_tileable + for stage in processor.stage_processors: + if tileable_id in tileable_id_to_tileable: + tileable = tileable_id_to_tileable[tileable_id] + returned_subtasks = { + subtask.subtask_id: subtask + for subtask in tileable_to_subtasks[tileable] + } + subtask_graph = stage.subtask_graph + subtask_results = stage.subtask_results + subtask_snapshots = stage.subtask_snapshots + break + if returned_subtasks: + break + + if subtask_graph is None: # pragma: no cover + return {} + + if with_input_output: + for subtask in list(returned_subtasks.values()): + for pred in subtask_graph.iter_predecessors(subtask): + if pred.subtask_id in returned_subtasks: # pragma: no cover + continue + returned_subtasks[pred.subtask_id] = pred + subtask_id_to_types[pred.subtask_id] = "Input" + for succ in subtask_graph.iter_successors(subtask): + if succ.subtask_id in returned_subtasks: # pragma: no cover + continue + returned_subtasks[succ.subtask_id] = succ + subtask_id_to_types[succ.subtask_id] = "Output" + + for subtask in returned_subtasks.values(): + subtask_result = subtask_results.get( + subtask, + subtask_snapshots.get( + subtask, + SubtaskResult( + progress=0.0, + status=SubtaskStatus.pending, + stage_id=subtask.stage_id, + ), + ), + ) + subtask_details[subtask.subtask_id] = { + "name": subtask.subtask_name, + "status": subtask_result.status.value, + "progress": subtask_result.progress, + "nodeType": subtask_id_to_types.get(subtask.subtask_id, "Calculation"), + } + + for subtask in returned_subtasks.values(): + pred_ids = [] + for pred in subtask_graph.iter_predecessors(subtask): + if pred.subtask_id in returned_subtasks: + pred_ids.append(pred.subtask_id) + subtask_details[subtask.subtask_id]["fromSubtaskIds"] = pred_ids + return subtask_details + + async def get_tileable_subtasks(self, tileable_id: str, with_input_output: bool): + return await asyncio.to_thread( + self._get_tileable_subtasks, tileable_id, with_input_output + ) + + +class TaskProcessorActor(mo.Actor, _TaskInfoProcessorMixin): + _task_id_to_processor: Dict[str, TaskProcessor] + _cur_processor: Optional[TaskProcessor] + + def __init__( + self, + session_id: str, + task_id: str, + task_name: str = None, + task_processor_cls: Type[TaskPreprocessor] = None, + ): + self.session_id = session_id + self.task_id = task_id + self.task_name = task_name + + self._task_processor_cls = self._get_task_processor_cls(task_processor_cls) + self._task_id_to_processor = dict() + self._cur_processor = None + + @classmethod + def gen_uid(cls, session_id: str, task_id: str): + return f"task_processor_{session_id}_{task_id}" + + async def add_task( + self, + task: Task, + tiled_context: TileContext, + config: Config, + task_executor_config: Dict, + task_preprocessor_cls: Type[TaskPreprocessor], + ): + task_preprocessor = task_preprocessor_cls( + task, tiled_context=tiled_context, config=config + ) + task_executor = await TaskExecutor.create( + task_executor_config, + task=task, + session_id=self.session_id, + address=self.address, + tile_context=task_preprocessor.tile_context, + ) + processor = self._task_processor_cls( + task, + task_preprocessor, + task_executor, + ) + self._task_id_to_processor[task.task_id] = processor + + # tell self to start running + await self.ref().start.tell() + + @classmethod + def _get_task_processor_cls(cls, task_processor_cls): + if task_processor_cls is not None: # pragma: no cover + assert isinstance(task_processor_cls, str) + module, name = task_processor_cls.rsplit(".", 1) + return getattr(importlib.import_module(module), name) + else: + return TaskProcessor + + def _get_unprocessed_task_processor(self): + for processor in self._task_id_to_processor.values(): + if processor.result.status == TaskStatus.pending: + return processor + + async def start(self): + if self._cur_processor is not None: # pragma: no cover + # some processor is running + return + + processor = self._get_unprocessed_task_processor() + if processor is None: # pragma: no cover + return + self._cur_processor = processor + try: + yield processor.run() + finally: + self._cur_processor = None + + async def wait(self, timeout: int = None): + fs = [ + asyncio.ensure_future(processor.done.wait()) + for processor in self._task_id_to_processor.values() + ] + + _, pending = yield asyncio.wait(fs, timeout=timeout) + if not pending: + raise mo.Return(self.result()) + else: + _ = [fut.cancel() for fut in pending] + + async def cancel(self): + if self._cur_processor: + await self._cur_processor.cancel() + + def result(self): + terminated_result = None + for processor in self._task_id_to_processor.values(): + if processor.result.status != TaskStatus.terminated: + return processor.result + else: + terminated_result = processor.result + return terminated_result + + async def progress(self): + processor_progresses = [ + await processor.get_progress() + for processor in self._task_id_to_processor.values() + ] + return sum(processor_progresses) / len(processor_progresses) + + def get_result_tileables(self): + processor = list(self._task_id_to_processor.values())[-1] + tileable_graph = processor.tileable_graph + result = [] + for result_tileable in tileable_graph.result_tileables: + tiled = processor.get_tiled(result_tileable) + result.append(build_fetch(tiled)) + return result + + def get_subtask_graphs(self, task_id: str) -> List[SubtaskGraph]: + return [ + stage_processor.subtask_graph + for stage_processor in self._task_id_to_processor[task_id].stage_processors + ] + + def get_result_tileable(self, tileable_key: str): + processor = list(self._task_id_to_processor.values())[-1] + tileable_graph = processor.tileable_graph + for result_tileable in tileable_graph.result_tileables: + if result_tileable.key == tileable_key: + tiled = processor.get_tiled(result_tileable) + return build_fetch(tiled) + raise KeyError(f"Tileable {tileable_key} does not exist") # pragma: no cover + + async def set_subtask_result(self, subtask_result: SubtaskResult): + logger.debug( + "Set subtask %s with result %s.", subtask_result.subtask_id, subtask_result + ) + if self._cur_processor is not None: + await self._cur_processor.set_subtask_result(subtask_result) + + def is_done(self) -> bool: + for processor in self._task_id_to_processor.values(): + if not processor.is_done(): + return False + return True diff --git a/mars/services/task/supervisor/tests/task_preprocessor.py b/mars/services/task/supervisor/tests/task_preprocessor.py index b9b0a8011e..d2de0bce70 100644 --- a/mars/services/task/supervisor/tests/task_preprocessor.py +++ b/mars/services/task/supervisor/tests/task_preprocessor.py @@ -29,8 +29,8 @@ from .....core.operand import Fetch from .....resource import Resource from .....tests.core import _check_args, ObjectCheckMixin -from .....typing import BandType -from ....subtask import SubtaskGraph +from .....typing import BandType, ChunkType +from ....subtask import Subtask, SubtaskGraph from ...analyzer import GraphAnalyzer from ..preprocessor import CancellableTiler, TaskPreprocessor @@ -140,6 +140,7 @@ def _get_tiler_cls(self) -> Callable: def analyze( self, chunk_graph: ChunkGraph, + chunk_to_subtasks: Dict[ChunkType, Subtask], available_bands: Dict[BandType, Resource], stage_id: str, op_to_bands: Dict[str, BandType] = None, @@ -148,7 +149,9 @@ def analyze( for n in chunk_graph: self._raw_chunk_shapes[n.key] = getattr(n, "shape", None) task = self._task - analyzer = GraphAnalyzer(chunk_graph, available_bands, task, self._config) + analyzer = GraphAnalyzer( + chunk_graph, available_bands, task, self._config, chunk_to_subtasks + ) subtask_graph = analyzer.gen_subtask_graph() results = set( analyzer._chunk_to_copied[c] diff --git a/mars/services/task/tests/test_service.py b/mars/services/task/tests/test_service.py index b7162484af..5eed76977f 100644 --- a/mars/services/task/tests/test_service.py +++ b/mars/services/task/tests/test_service.py @@ -16,15 +16,19 @@ import time import numpy as np +import pandas as pd import pytest from .... import dataframe as md from .... import oscar as mo from .... import remote as mr -from ....core import TileableGraph, TileableGraphBuilder +from .... import tensor as mt +from ....core import TileableGraph, TileableGraphBuilder, TileStatus, recursive_tile from ....core.context import get_context from ....resource import Resource -from ....utils import Timer +from ....tensor.core import TensorOrder +from ....tensor.operands import TensorOperand, TensorOperandMixin +from ....utils import Timer, build_fetch from ... import start_services, stop_services, NodeRole from ...session import SessionAPI from ...storage import MockStorageAPI @@ -319,6 +323,50 @@ def f1(count: int): assert results[0].progress == 1.0 +class _TileProgressOperand(TensorOperand, TensorOperandMixin): + @classmethod + def tile(cls, op: "_TileProgressOperand"): + progress_controller = get_context().get_remote_object("progress_controller") + + t = yield from recursive_tile(mt.random.rand(10, 10, chunk_size=5)) + yield TileStatus(t.chunks, progress=0.25) + progress_controller.wait() + + new_op = op.copy() + params = op.outputs[0].params.copy() + params["chunks"] = t.chunks + params["nsplits"] = t.nsplits + return new_op.new_tileables(t.inputs, kws=[params]) + + +@pytest.mark.asyncio +async def test_task_tile_progress(start_test_service): + sv_pool_address, task_api, storage_api = start_test_service + + session_api = await SessionAPI.create(address=sv_pool_address) + ref = await session_api.create_remote_object( + task_api._session_id, "progress_controller", _ProgressController + ) + + t = _TileProgressOperand(dtype=np.dtype(np.float64)).new_tensor( + None, (10, 10), order=TensorOrder.C_ORDER + ) + + graph = TileableGraph([t.data]) + next(TileableGraphBuilder(graph).build()) + + await task_api.submit_tileable_graph(graph, fuse_enabled=False) + + await asyncio.sleep(1) + results = await task_api.get_task_results(progress=True) + assert results[0].progress == 0.25 + + await ref.set() + await asyncio.sleep(1) + results = await task_api.get_task_results(progress=True) + assert results[0].progress == 1.0 + + @pytest.mark.asyncio async def test_get_tileable_graph(start_test_service): _sv_pool_address, task_api, storage_api = start_test_service @@ -475,6 +523,38 @@ def _get_fields(details, field, wrapper=None): assert property_key != "id" assert isinstance(property_value, (int, float, str)) + # test merge + d1 = pd.DataFrame({"a": np.random.rand(100), "b": np.random.randint(3, size=100)}) + d2 = pd.DataFrame({"c": np.random.rand(100), "b": np.random.randint(3, size=100)}) + df1 = md.DataFrame(d1, chunk_size=10) + df2 = md.DataFrame(d2, chunk_size=10) + + graph = TileableGraph([df1.data, df2.data]) + next(TileableGraphBuilder(graph).build()) + + task_id = await task_api.submit_tileable_graph(graph, fuse_enabled=True) + await task_api.wait_task(task_id) + details = await task_api.get_tileable_details(task_id) + assert details[df1.key]["progress"] == details[df2.key]["progress"] == 1.0 + + f1 = build_fetch(df1) + f2 = build_fetch(df2) + df3 = f1.merge(f2, auto_merge="none", bloom_filter=False) + graph = TileableGraph([df3.data]) + next(TileableGraphBuilder(graph).build()) + + task_id = await task_api.submit_tileable_graph(graph, fuse_enabled=True) + await task_api.wait_task(task_id) + for _ in range(2): + # get twice to ensure cache work + details = await task_api.get_tileable_details(task_id) + assert ( + details[df3.key]["progress"] + == details[f1.key]["progress"] + == details[f2.key]["progress"] + == 1.0 + ) + @pytest.mark.asyncio @pytest.mark.parametrize("with_input_output", [False, True])