Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize transferring small objects for storage #3058

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions benchmarks/asv_bench/benchmarks/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import itertools
from typing import List

import cloudpickle
import numpy as np
Expand All @@ -24,20 +25,30 @@
from mars.utils import Timer, readable_size


def send_1_to_1(n: int = None):
def send_1_to_1(n: int = None, n_out: int = 1):
ctx = get_context()
workers = ctx.get_worker_addresses()

worker_to_gen_data = {
w: mr.spawn(_gen_data, kwargs=dict(n=n, worker=w), expect_worker=w)
w: mr.spawn(
_gen_data,
kwargs=dict(n=n, worker=w, n_out=n_out),
expect_worker=w,
n_output=n_out,
)
for i, w in enumerate(workers)
}
all_data = mars.execute(list(worker_to_gen_data.values()))
all_data = mars.execute(list(itertools.chain(*worker_to_gen_data.values())))
progress = 0.1
ctx.set_progress(progress)
infos = [d._fetch_infos(fields=["data_key", "store_size"]) for d in all_data]
data_size = infos[0]["store_size"][0]
worker_to_data_keys = dict(zip(workers, [info["data_key"][0] for info in infos]))
infos = np.array(
[d._fetch_infos(fields=["data_key", "store_size"]) for d in all_data],
dtype=object,
)
data_size = sum(info["store_size"][0] for info in infos[:n_out])
worker_to_data_keys = dict()
for worker, infos in zip(workers, np.split(infos, len(infos) // n_out)):
worker_to_data_keys[worker] = [info["data_key"][0] for info in infos]

workers_to_durations = dict()
size = len(workers) * (len(workers) - 1)
Expand All @@ -60,27 +71,31 @@ def send_1_to_1(n: int = None):


def _gen_data(
n: int = None, worker: str = None, check_addr: bool = True
) -> pd.DataFrame:
n: int = None, worker: str = None, check_addr: bool = True, n_out: int = 1
) -> List[pd.DataFrame]:
if check_addr:
ctx = get_context()
assert ctx.worker_address == worker
n = n if n is not None else 5_000_000
rs = np.random.RandomState(123)
data = {
"a": rs.rand(n),
"b": rs.randint(n * 10, size=n),
"c": [f"foo{i}" for i in range(n)],
}
return pd.DataFrame(data)

outs = []
for _ in range(n_out):
n = n if n is not None else 5_000_000
data = {
"a": rs.rand(n),
"b": rs.randint(n * 10, size=n),
"c": [f"foo{i}" for i in range(n)],
}
outs.append(pd.DataFrame(data))
return outs

def _fetch_data(data_key: str, worker: str = None):

def _fetch_data(data_keys: List[str], worker: str = None):
# do nothing actually
ctx = get_context()
assert ctx.worker_address == worker
with Timer() as timer:
ctx.get_chunks_result([data_key], fetch_only=True)
ctx.get_chunks_result(data_keys, fetch_only=True)
return timer.duration


Expand All @@ -107,9 +122,15 @@ def teardown(self):
def time_1_to_1(self):
return mr.spawn(send_1_to_1).execute().fetch()

def time_1_to_1_small_objects(self):
return mr.spawn(send_1_to_1, kwargs=dict(n=1_000, n_out=100)).execute().fetch()


if __name__ == "__main__":
suite = TransferPackageSuite()
suite.setup()
print("- Bench 1 to 1 -")
print(suite.time_1_to_1())
print("- Bench 1 to 1 with small objects -")
print(suite.time_1_to_1_small_objects())
suite.teardown()
3 changes: 1 addition & 2 deletions mars/services/storage/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
StorageManagerActor,
DataManagerActor,
DataInfo,
WrappedStorageFileObject,
)
from ..handler import StorageHandlerActor
from ..handler import StorageHandlerActor, WrappedStorageFileObject
from .core import AbstractStorageAPI

_is_windows = sys.platform.lower().startswith("win")
Expand Down
48 changes: 1 addition & 47 deletions mars/services/storage/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from typing import Dict, List, Optional, Union, Tuple

from ... import oscar as mo
from ...lib.aio import AioFileObject
from ...oscar.backends.allocate_strategy import IdleLabel, NoIdleSlot
from ...resource import cuda_card_stats
from ...storage import StorageLevel, get_storage_backend
from ...storage.base import ObjectInfo, StorageBackend
from ...storage.core import StorageFileObject
from ...storage.base import ObjectInfo
from ...utils import dataslots
from .errors import DataNotExist, StorageFull

Expand All @@ -44,50 +42,6 @@ def build_data_info(storage_info: ObjectInfo, level, size, band_name=None):
return DataInfo(storage_info.object_id, level, size, store_size, band_name)


class WrappedStorageFileObject(AioFileObject):
"""
Wrap to hold ref after write close
"""

def __init__(
self,
file: StorageFileObject,
level: StorageLevel,
size: int,
session_id: str,
data_key: str,
data_manager: mo.ActorRefType["DataManagerActor"],
storage_handler: StorageBackend,
):
self._object_id = file.object_id
super().__init__(file)
self._size = size
self._level = level
self._session_id = session_id
self._data_key = data_key
self._data_manager = data_manager
self._storage_handler = storage_handler

def __getattr__(self, item):
return getattr(self._file, item)

async def clean_up(self):
self._file.close()

async def close(self):
self._file.close()
if self._object_id is None:
# for some backends like vineyard,
# object id is generated after write close
self._object_id = self._file.object_id
if "w" in self._file.mode:
object_info = await self._storage_handler.object_info(self._object_id)
data_info = build_data_info(object_info, self._level, self._size)
await self._data_manager.put_data_info(
self._session_id, self._data_key, data_info, object_info
)


class StorageQuotaActor(mo.Actor):
def __init__(
self,
Expand Down
122 changes: 105 additions & 17 deletions mars/services/storage/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Dict, List, Union

from ... import oscar as mo
from ...lib.aio import AioFileObject
from ...storage import StorageLevel, get_storage_backend
from ...storage.core import StorageFileObject
from ...typing import BandType
Expand All @@ -29,7 +30,6 @@
DataManagerActor,
DataInfo,
build_data_info,
WrappedStorageFileObject,
)
from .errors import DataNotExist, NoDataToSpill

Expand All @@ -39,6 +39,54 @@
logger = logging.getLogger(__name__)


class WrappedStorageFileObject(AioFileObject):
"""
Wrap to hold ref after write close
"""

def __init__(
self,
file: StorageFileObject,
level: StorageLevel,
size: int,
session_id: str,
data_key: str,
storage_handler: mo.ActorRefType["StorageHandlerActor"],
):
self._object_id = file.object_id
super().__init__(file)
self._size = size
self._level = level
self._session_id = session_id
self._data_key = data_key
self._storage_handler = storage_handler

def __getattr__(self, item):
return getattr(self._file, item)

@property
def file(self):
return self._file

@property
def object_id(self):
return self._object_id

@property
def level(self):
return self._level

@property
def size(self):
return self._size

async def clean_up(self):
self._file.close()

async def close(self):
await self._storage_handler.close_writer(self)


class StorageHandlerActor(mo.Actor):
"""
Storage handler actor, provide methods like `get`, `put`, etc.
Expand Down Expand Up @@ -82,22 +130,26 @@ async def __post_create__(self):
if client.level & level:
clients[level] = client

async def _get_data(self, data_info, conditions):
@mo.extensible
async def get_data_by_info(self, data_info: DataInfo, conditions: List = None):
if conditions is None:
res = yield self._clients[data_info.level].get(data_info.object_id)
res = await self._clients[data_info.level].get(data_info.object_id)
else:
try:
res = yield self._clients[data_info.level].get(
res = await self._clients[data_info.level].get(
data_info.object_id, conditions=conditions
)
except NotImplementedError:
data = yield self._clients[data_info.level].get(data_info.object_id)
data = await self._clients[data_info.level].get(data_info.object_id)
try:
sliced_value = data.iloc[tuple(conditions)]
except AttributeError:
sliced_value = data[tuple(conditions)]
res = sliced_value
raise mo.Return(res)
return res

def get_client(self, level: StorageLevel):
return self._clients[level]

@mo.extensible
async def get(
Expand All @@ -111,7 +163,7 @@ async def get(
data_info = await self._data_manager_ref.get_data_info(
session_id, data_key, self._band_name
)
data = yield self._get_data(data_info, conditions)
data = yield self.get_data_by_info(data_info, conditions)
raise mo.Return(data)
except DataNotExist:
if error == "raise":
Expand Down Expand Up @@ -143,7 +195,7 @@ async def batch_get(self, args_list, kwargs_list):
if data_info is None:
results.append(None)
else:
result = yield self._get_data(data_info, conditions)
result = yield self.get_data_by_info(data_info, conditions)
results.append(result)
raise mo.Return(results)

Expand Down Expand Up @@ -314,12 +366,16 @@ async def batch_delete(self, args_list, kwargs_list):
for level, size in level_sizes.items():
await self._quota_refs[level].release_quota(size)

@mo.extensible
async def open_reader_by_info(self, data_info: DataInfo) -> StorageFileObject:
return await self._clients[data_info.level].open_reader(data_info.object_id)

@mo.extensible
async def open_reader(self, session_id: str, data_key: str) -> StorageFileObject:
data_info = await self._data_manager_ref.get_data_info(
session_id, data_key, self._band_name
)
reader = await self._clients[data_info.level].open_reader(data_info.object_id)
reader = await self.open_reader_by_info(data_info)
return reader

@open_reader.batch
Expand All @@ -333,10 +389,7 @@ async def batch_open_readers(self, args_list, kwargs_list):
)
data_infos = await self._data_manager_ref.get_data_info.batch(*get_data_infos)
return await asyncio.gather(
*[
self._clients[data_info.level].open_reader(data_info.object_id)
for data_info in data_infos
]
*[self.open_reader_by_info(data_info) for data_info in data_infos]
)

@mo.extensible
Expand All @@ -357,8 +410,7 @@ async def open_writer(
size,
session_id,
data_key,
self._data_manager_ref,
self._clients[level],
self,
)

@open_writer.batch
Expand Down Expand Up @@ -389,12 +441,48 @@ async def batch_open_writers(self, args_list, kwargs_list):
size,
session_id,
data_key,
self._data_manager_ref,
self._clients[level],
self,
)
)
return wrapped_writers

@mo.extensible
async def close_writer(self, writer: WrappedStorageFileObject):
writer.file.close()
if writer.object_id is None:
# for some backends like vineyard,
# object id is generated after write close
writer._object_id = writer.file.object_id
if "w" in writer.file.mode:
client = self._clients[writer.level]
object_info = await client.object_info(writer.object_id)
data_info = build_data_info(object_info, writer.level, writer.size)
await self._data_manager_ref.put_data_info(
writer._session_id, writer._data_key, data_info, object_info
)

@close_writer.batch
async def batch_close_writers(self, args_list, kwargs_list):
put_info_tasks = []
for args, kwargs in zip(args_list, kwargs_list):
(writer,) = self.close_writer.bind(*args, **kwargs)
writer.file.close()
if writer.object_id is None:
# for some backends like vineyard,
# object id is generated after write close
writer._object_id = writer.file.object_id
if "w" in writer.file.mode:
client = self._clients[writer.level]
object_info = await client.object_info(writer.object_id)
data_info = build_data_info(object_info, writer.level, writer.size)
put_info_tasks.append(
self._data_manager_ref.put_data_info.delay(
writer._session_id, writer._data_key, data_info, object_info
)
)
if put_info_tasks:
await self._data_manager_ref.put_data_info.batch(*put_info_tasks)

async def _get_meta_api(self, session_id: str):
if self._supervisor_address is None:
cluster_api = await ClusterAPI.create(self.address)
Expand Down
Loading