diff --git a/mars/services/storage/api/oscar.py b/mars/services/storage/api/oscar.py index 81277abaa0..a71ed0f6f5 100644 --- a/mars/services/storage/api/oscar.py +++ b/mars/services/storage/api/oscar.py @@ -23,9 +23,8 @@ StorageManagerActor, DataManagerActor, DataInfo, - WrappedStorageFileObject, ) -from ..handler import StorageHandlerActor +from ..handler import StorageHandlerActor, WrappedStorageFileObject from .core import AbstractStorageAPI _is_windows = sys.platform.lower().startswith("win") diff --git a/mars/services/storage/core.py b/mars/services/storage/core.py index eedc375389..6e4d33daf2 100644 --- a/mars/services/storage/core.py +++ b/mars/services/storage/core.py @@ -19,12 +19,10 @@ from typing import Dict, List, Optional, Union, Tuple from ... import oscar as mo -from ...lib.aio import AioFileObject from ...oscar.backends.allocate_strategy import IdleLabel, NoIdleSlot from ...resource import cuda_card_stats from ...storage import StorageLevel, get_storage_backend -from ...storage.base import ObjectInfo, StorageBackend -from ...storage.core import StorageFileObject +from ...storage.base import ObjectInfo from ...utils import dataslots from .errors import DataNotExist, StorageFull @@ -44,50 +42,6 @@ def build_data_info(storage_info: ObjectInfo, level, size, band_name=None): return DataInfo(storage_info.object_id, level, size, store_size, band_name) -class WrappedStorageFileObject(AioFileObject): - """ - Wrap to hold ref after write close - """ - - def __init__( - self, - file: StorageFileObject, - level: StorageLevel, - size: int, - session_id: str, - data_key: str, - data_manager: mo.ActorRefType["DataManagerActor"], - storage_handler: StorageBackend, - ): - self._object_id = file.object_id - super().__init__(file) - self._size = size - self._level = level - self._session_id = session_id - self._data_key = data_key - self._data_manager = data_manager - self._storage_handler = storage_handler - - def __getattr__(self, item): - return getattr(self._file, item) - - async def clean_up(self): - self._file.close() - - async def close(self): - self._file.close() - if self._object_id is None: - # for some backends like vineyard, - # object id is generated after write close - self._object_id = self._file.object_id - if "w" in self._file.mode: - object_info = await self._storage_handler.object_info(self._object_id) - data_info = build_data_info(object_info, self._level, self._size) - await self._data_manager.put_data_info( - self._session_id, self._data_key, data_info, object_info - ) - - class StorageQuotaActor(mo.Actor): def __init__( self, diff --git a/mars/services/storage/handler.py b/mars/services/storage/handler.py index d134ea8775..6e45872e5a 100644 --- a/mars/services/storage/handler.py +++ b/mars/services/storage/handler.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List, Union from ... import oscar as mo +from ...lib.aio import AioFileObject from ...storage import StorageLevel, get_storage_backend from ...storage.core import StorageFileObject from ...typing import BandType @@ -29,7 +30,6 @@ DataManagerActor, DataInfo, build_data_info, - WrappedStorageFileObject, ) from .errors import DataNotExist, NoDataToSpill @@ -39,6 +39,54 @@ logger = logging.getLogger(__name__) +class WrappedStorageFileObject(AioFileObject): + """ + Wrap to hold ref after write close + """ + + def __init__( + self, + file: StorageFileObject, + level: StorageLevel, + size: int, + session_id: str, + data_key: str, + storage_handler: mo.ActorRefType["StorageHandlerActor"], + ): + self._object_id = file.object_id + super().__init__(file) + self._size = size + self._level = level + self._session_id = session_id + self._data_key = data_key + self._storage_handler = storage_handler + + def __getattr__(self, item): + return getattr(self._file, item) + + @property + def file(self): + return self._file + + @property + def object_id(self): + return self._object_id + + @property + def level(self): + return self._level + + @property + def size(self): + return self._size + + async def clean_up(self): + self._file.close() + + async def close(self): + await self._storage_handler.close_writer(self) + + class StorageHandlerActor(mo.Actor): """ Storage handler actor, provide methods like `get`, `put`, etc. @@ -360,8 +408,7 @@ async def open_writer( size, session_id, data_key, - self._data_manager_ref, - self._clients[level], + self, ) @open_writer.batch @@ -392,12 +439,48 @@ async def batch_open_writers(self, args_list, kwargs_list): size, session_id, data_key, - self._data_manager_ref, - self._clients[level], + self, ) ) return wrapped_writers + @mo.extensible + async def close_writer(self, writer: WrappedStorageFileObject): + writer.file.close() + if writer.object_id is None: + # for some backends like vineyard, + # object id is generated after write close + writer._object_id = writer.file.object_id + if "w" in writer.file.mode: + client = self._clients[writer.level] + object_info = await client.object_info(writer.object_id) + data_info = build_data_info(object_info, writer.level, writer.size) + await self._data_manager_ref.put_data_info( + writer._session_id, writer._data_key, data_info, object_info + ) + + @close_writer.batch + async def batch_close_writers(self, args_list, kwargs_list): + put_info_tasks = [] + for args, kwargs in zip(args_list, kwargs_list): + (writer,) = self.close_writer.bind(*args, **kwargs) + writer.file.close() + if writer.object_id is None: + # for some backends like vineyard, + # object id is generated after write close + writer._object_id = writer.file.object_id + if "w" in writer.file.mode: + client = self._clients[writer.level] + object_info = await client.object_info(writer.object_id) + data_info = build_data_info(object_info, writer.level, writer.size) + put_info_tasks.append( + self._data_manager_ref.put_data_info.delay( + writer._session_id, writer._data_key, data_info, object_info + ) + ) + if put_info_tasks: + await self._data_manager_ref.put_data_info.batch(*put_info_tasks) + async def _get_meta_api(self, session_id: str): if self._supervisor_address is None: cluster_api = await ClusterAPI.create(self.address) diff --git a/mars/services/storage/tests/test_transfer.py b/mars/services/storage/tests/test_transfer.py index ec801220e7..2a37f56141 100644 --- a/mars/services/storage/tests/test_transfer.py +++ b/mars/services/storage/tests/test_transfer.py @@ -266,7 +266,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver): send_task = asyncio.create_task( sender_actor.send_batch_data( - "mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False + "mock", + ["data_key1"], + worker_address_2, + StorageLevel.MEMORY, + is_small_objects=False, ) ) @@ -284,7 +288,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver): send_task = asyncio.create_task( sender_actor.send_batch_data( - "mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False + "mock", + ["data_key1"], + worker_address_2, + StorageLevel.MEMORY, + is_small_objects=False, ) ) await send_task @@ -295,12 +303,20 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver): if mock_sender is MockSenderManagerActor: send_task1 = asyncio.create_task( sender_actor.send_batch_data( - "mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False + "mock", + ["data_key2"], + worker_address_2, + StorageLevel.MEMORY, + is_small_objects=False, ) ) send_task2 = asyncio.create_task( sender_actor.send_batch_data( - "mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False + "mock", + ["data_key2"], + worker_address_2, + StorageLevel.MEMORY, + is_small_objects=False, ) ) await asyncio.sleep(0.5) diff --git a/mars/services/storage/transfer.py b/mars/services/storage/transfer.py index 30b27386b2..b095db016c 100644 --- a/mars/services/storage/transfer.py +++ b/mars/services/storage/transfer.py @@ -21,8 +21,8 @@ from ...lib.aio import alru_cache from ...storage import StorageLevel from ...utils import dataslots -from .core import DataManagerActor, WrappedStorageFileObject, DataInfo -from .handler import StorageHandlerActor +from .core import DataManagerActor, DataInfo +from .handler import StorageHandlerActor, WrappedStorageFileObject DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024**2 @@ -96,9 +96,7 @@ async def send(self, buffer, eof_mark, key): open_reader_tasks = [] storage_client = await self._storage_handler.get_client(level) for info in data_infos: - open_reader_tasks.append( - storage_client.open_reader(info.object_id) - ) + open_reader_tasks.append(storage_client.open_reader(info.object_id)) readers = await asyncio.gather(*open_reader_tasks) for data_key, reader in zip(data_keys, readers): @@ -129,7 +127,9 @@ async def _send( band_name: str, level: StorageLevel, ): - receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name) + receiver_ref: mo.ActorRefType[ + ReceiverManagerActor + ] = await self.get_receiver_ref(address, band_name) is_transferring_list = await receiver_ref.open_writers( session_id, data_keys, data_sizes, level ) @@ -163,11 +163,11 @@ async def _send_small_objects( ): # simple get all objects and send them all to receiver storage_client = await self._storage_handler.get_client(level) - get_tasks = [ - storage_client.get(info.object_id) for info in data_infos - ] + get_tasks = [storage_client.get(info.object_id) for info in data_infos] data_list = list(await asyncio.gather(*get_tasks)) - receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name) + receiver_ref: mo.ActorRefType[ + ReceiverManagerActor + ] = await self.get_receiver_ref(address, band_name) await receiver_ref.put_small_objects(session_id, data_keys, data_list, level) async def send_batch_data( @@ -358,9 +358,9 @@ async def do_write( if data: await writer.write(data) if is_eof: - close_tasks.append(writer.close()) + close_tasks.append(self._storage_handler.close_writer.delay(writer)) finished_keys.append(data_key) - await asyncio.gather(*close_tasks) + await self._storage_handler.close_writer.batch(*close_tasks) async with self._lock: for data_key in finished_keys: event = self._writing_infos[(session_id, data_key)].event