diff --git a/mars/services/storage/api/oscar.py b/mars/services/storage/api/oscar.py index 81277abaa0..a2dadbc66b 100644 --- a/mars/services/storage/api/oscar.py +++ b/mars/services/storage/api/oscar.py @@ -13,7 +13,7 @@ # limitations under the License. import sys -from typing import Any, List, Type, TypeVar +from typing import Any, List, Tuple, Type, TypeVar from .... import oscar as mo from ....lib.aio import alru_cache @@ -163,7 +163,7 @@ async def batch_delete(self, args_list, kwargs_list): @mo.extensible async def fetch( self, - data_key: str, + data_key: Union[str, Tuple], level: StorageLevel = None, band_name: str = None, remote_address: str = None, diff --git a/mars/services/subtask/worker/processor.py b/mars/services/subtask/worker/processor.py index f57f838c98..78b916eee9 100644 --- a/mars/services/subtask/worker/processor.py +++ b/mars/services/subtask/worker/processor.py @@ -26,6 +26,8 @@ Fetch, FetchShuffle, execute, + MapReduceOperand, + OperandStage, ) from ....metrics import Metrics from ....optimization.physical import optimize @@ -424,6 +426,28 @@ async def set_chunks_meta(): # set result data size self.result.data_size = result_data_size + async def push_mapper_data(self, chunk_graph): + # TODO: use task api to get reducer bands + reducer_idx_to_band = dict() + if not reducer_idx_to_band: + return + storage_api_to_fetch_tasks = defaultdict(list) + for result_chunk in chunk_graph.result_chunks: + key = result_chunk.key + reducer_idx = key[1] + if isinstance(key, tuple): + # mapper key is a tuple + address, band_name = reducer_idx_to_band[reducer_idx] + storage_api = StorageAPI(address, self._session_id, band_name) + fetch_task = storage_api.fetch.delay( + key, band_name=self._band[1], remote_address=self._band[0] + ) + storage_api_to_fetch_tasks[storage_api].append(fetch_task) + batch_tasks = [] + for storage_api, tasks in storage_api_to_fetch_tasks.items(): + batch_tasks.append(asyncio.create_task(storage_api.fetch.batch(*tasks))) + await asyncio.gather(*batch_tasks) + async def done(self): if self.result.status == SubtaskStatus.running: self.result.status = SubtaskStatus.succeeded @@ -495,6 +519,8 @@ async def run(self): await self._unpin_data(input_keys) await self.done() + # after done, we push mapper data to reducers in advance. + await self.push_mapper_data(chunk_graph) if self.result.status == SubtaskStatus.succeeded: cost_time_secs = ( self.result.execution_end_time - self.result.execution_start_time