Skip to content

Commit

Permalink
Add batch method for closing writers
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng committed Jun 7, 2022
1 parent 324d598 commit 99d086e
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 70 deletions.
3 changes: 1 addition & 2 deletions mars/services/storage/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
StorageManagerActor,
DataManagerActor,
DataInfo,
WrappedStorageFileObject,
)
from ..handler import StorageHandlerActor
from ..handler import StorageHandlerActor, WrappedStorageFileObject
from .core import AbstractStorageAPI

_is_windows = sys.platform.lower().startswith("win")
Expand Down
48 changes: 1 addition & 47 deletions mars/services/storage/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from typing import Dict, List, Optional, Union, Tuple

from ... import oscar as mo
from ...lib.aio import AioFileObject
from ...oscar.backends.allocate_strategy import IdleLabel, NoIdleSlot
from ...resource import cuda_card_stats
from ...storage import StorageLevel, get_storage_backend
from ...storage.base import ObjectInfo, StorageBackend
from ...storage.core import StorageFileObject
from ...storage.base import ObjectInfo
from ...utils import dataslots
from .errors import DataNotExist, StorageFull

Expand All @@ -44,50 +42,6 @@ def build_data_info(storage_info: ObjectInfo, level, size, band_name=None):
return DataInfo(storage_info.object_id, level, size, store_size, band_name)


class WrappedStorageFileObject(AioFileObject):
"""
Wrap to hold ref after write close
"""

def __init__(
self,
file: StorageFileObject,
level: StorageLevel,
size: int,
session_id: str,
data_key: str,
data_manager: mo.ActorRefType["DataManagerActor"],
storage_handler: StorageBackend,
):
self._object_id = file.object_id
super().__init__(file)
self._size = size
self._level = level
self._session_id = session_id
self._data_key = data_key
self._data_manager = data_manager
self._storage_handler = storage_handler

def __getattr__(self, item):
return getattr(self._file, item)

async def clean_up(self):
self._file.close()

async def close(self):
self._file.close()
if self._object_id is None:
# for some backends like vineyard,
# object id is generated after write close
self._object_id = self._file.object_id
if "w" in self._file.mode:
object_info = await self._storage_handler.object_info(self._object_id)
data_info = build_data_info(object_info, self._level, self._size)
await self._data_manager.put_data_info(
self._session_id, self._data_key, data_info, object_info
)


class StorageQuotaActor(mo.Actor):
def __init__(
self,
Expand Down
93 changes: 88 additions & 5 deletions mars/services/storage/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Dict, List, Union

from ... import oscar as mo
from ...lib.aio import AioFileObject
from ...storage import StorageLevel, get_storage_backend
from ...storage.core import StorageFileObject
from ...typing import BandType
Expand All @@ -29,7 +30,6 @@
DataManagerActor,
DataInfo,
build_data_info,
WrappedStorageFileObject,
)
from .errors import DataNotExist, NoDataToSpill

Expand All @@ -39,6 +39,54 @@
logger = logging.getLogger(__name__)


class WrappedStorageFileObject(AioFileObject):
"""
Wrap to hold ref after write close
"""

def __init__(
self,
file: StorageFileObject,
level: StorageLevel,
size: int,
session_id: str,
data_key: str,
storage_handler: mo.ActorRefType["StorageHandlerActor"],
):
self._object_id = file.object_id
super().__init__(file)
self._size = size
self._level = level
self._session_id = session_id
self._data_key = data_key
self._storage_handler = storage_handler

def __getattr__(self, item):
return getattr(self._file, item)

@property
def file(self):
return self._file

@property
def object_id(self):
return self._object_id

@property
def level(self):
return self._level

@property
def size(self):
return self._size

async def clean_up(self):
self._file.close()

async def close(self):
await self._storage_handler.close_writer(self)


class StorageHandlerActor(mo.Actor):
"""
Storage handler actor, provide methods like `get`, `put`, etc.
Expand Down Expand Up @@ -360,8 +408,7 @@ async def open_writer(
size,
session_id,
data_key,
self._data_manager_ref,
self._clients[level],
self,
)

@open_writer.batch
Expand Down Expand Up @@ -392,12 +439,48 @@ async def batch_open_writers(self, args_list, kwargs_list):
size,
session_id,
data_key,
self._data_manager_ref,
self._clients[level],
self,
)
)
return wrapped_writers

@mo.extensible
async def close_writer(self, writer: WrappedStorageFileObject):
writer.file.close()
if writer.object_id is None:
# for some backends like vineyard,
# object id is generated after write close
writer._object_id = writer.file.object_id
if "w" in writer.file.mode:
client = self._clients[writer.level]
object_info = await client.object_info(writer.object_id)
data_info = build_data_info(object_info, writer.level, writer.size)
await self._data_manager_ref.put_data_info(
writer._session_id, writer._data_key, data_info, object_info
)

@close_writer.batch
async def batch_close_writers(self, args_list, kwargs_list):
put_info_tasks = []
for args, kwargs in zip(args_list, kwargs_list):
(writer,) = self.close_writer.bind(*args, **kwargs)
writer.file.close()
if writer.object_id is None:
# for some backends like vineyard,
# object id is generated after write close
writer._object_id = writer.file.object_id
if "w" in writer.file.mode:
client = self._clients[writer.level]
object_info = await client.object_info(writer.object_id)
data_info = build_data_info(object_info, writer.level, writer.size)
put_info_tasks.append(
self._data_manager_ref.put_data_info.delay(
writer._session_id, writer._data_key, data_info, object_info
)
)
if put_info_tasks:
await self._data_manager_ref.put_data_info.batch(*put_info_tasks)

async def _get_meta_api(self, session_id: str):
if self._supervisor_address is None:
cluster_api = await ClusterAPI.create(self.address)
Expand Down
24 changes: 20 additions & 4 deletions mars/services/storage/tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):

send_task = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
"mock",
["data_key1"],
worker_address_2,
StorageLevel.MEMORY,
is_small_objects=False,
)
)

Expand All @@ -284,7 +288,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):

send_task = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
"mock",
["data_key1"],
worker_address_2,
StorageLevel.MEMORY,
is_small_objects=False,
)
)
await send_task
Expand All @@ -295,12 +303,20 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
if mock_sender is MockSenderManagerActor:
send_task1 = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
"mock",
["data_key2"],
worker_address_2,
StorageLevel.MEMORY,
is_small_objects=False,
)
)
send_task2 = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
"mock",
["data_key2"],
worker_address_2,
StorageLevel.MEMORY,
is_small_objects=False,
)
)
await asyncio.sleep(0.5)
Expand Down
24 changes: 12 additions & 12 deletions mars/services/storage/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from ...lib.aio import alru_cache
from ...storage import StorageLevel
from ...utils import dataslots
from .core import DataManagerActor, WrappedStorageFileObject, DataInfo
from .handler import StorageHandlerActor
from .core import DataManagerActor, DataInfo
from .handler import StorageHandlerActor, WrappedStorageFileObject

DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024**2

Expand Down Expand Up @@ -96,9 +96,7 @@ async def send(self, buffer, eof_mark, key):
open_reader_tasks = []
storage_client = await self._storage_handler.get_client(level)
for info in data_infos:
open_reader_tasks.append(
storage_client.open_reader(info.object_id)
)
open_reader_tasks.append(storage_client.open_reader(info.object_id))
readers = await asyncio.gather(*open_reader_tasks)

for data_key, reader in zip(data_keys, readers):
Expand Down Expand Up @@ -129,7 +127,9 @@ async def _send(
band_name: str,
level: StorageLevel,
):
receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name)
receiver_ref: mo.ActorRefType[
ReceiverManagerActor
] = await self.get_receiver_ref(address, band_name)
is_transferring_list = await receiver_ref.open_writers(
session_id, data_keys, data_sizes, level
)
Expand Down Expand Up @@ -163,11 +163,11 @@ async def _send_small_objects(
):
# simple get all objects and send them all to receiver
storage_client = await self._storage_handler.get_client(level)
get_tasks = [
storage_client.get(info.object_id) for info in data_infos
]
get_tasks = [storage_client.get(info.object_id) for info in data_infos]
data_list = list(await asyncio.gather(*get_tasks))
receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name)
receiver_ref: mo.ActorRefType[
ReceiverManagerActor
] = await self.get_receiver_ref(address, band_name)
await receiver_ref.put_small_objects(session_id, data_keys, data_list, level)

async def send_batch_data(
Expand Down Expand Up @@ -358,9 +358,9 @@ async def do_write(
if data:
await writer.write(data)
if is_eof:
close_tasks.append(writer.close())
close_tasks.append(self._storage_handler.close_writer.delay(writer))
finished_keys.append(data_key)
await asyncio.gather(*close_tasks)
await self._storage_handler.close_writer.batch(*close_tasks)
async with self._lock:
for data_key in finished_keys:
event = self._writing_infos[(session_id, data_key)].event
Expand Down

0 comments on commit 99d086e

Please sign in to comment.