Skip to content

Commit

Permalink
Fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng committed May 23, 2022
1 parent 2376990 commit e62540e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
8 changes: 4 additions & 4 deletions mars/services/storage/tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):

send_task = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
)
)

Expand All @@ -283,7 +283,7 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):

send_task = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
)
)
await send_task
Expand All @@ -294,12 +294,12 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
if mock_sender is MockSenderManagerActor:
send_task1 = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
)
)
send_task2 = asyncio.create_task(
sender_actor.send_batch_data(
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
)
)
await asyncio.sleep(0.5)
Expand Down
23 changes: 13 additions & 10 deletions mars/services/storage/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def get_receiver_ref(address: str, band_name: str):

async def _send_data(
self,
receiver_ref: Union[mo.ActorRef],
receiver_ref: Union[mo.ActorRef, "ReceiverManagerActor"],
session_id: str,
data_keys: List[str],
data_infos: List[DataInfo],
Expand Down Expand Up @@ -218,6 +218,7 @@ async def send_batch_data(
level: StorageLevel,
band_name: str = "numa-0",
block_size: int = None,
is_small_objects=None,
error: str = "raise",
):
logger.debug(
Expand Down Expand Up @@ -253,7 +254,17 @@ async def send_batch_data(
if level is None:
level = infos[0].level
total_size = sum(data_sizes)
if total_size > block_size:
if is_small_objects is None:
is_small_objects = total_size <= block_size
if is_small_objects:
logger.debug(
"Choose send_small_objects method for sending data of %s bytes",
total_size,
)
await self._send_small_objects(
session_id, data_keys, infos, address, band_name, level
)
else:
logger.debug("Choose block method for sending data of %s bytes", total_size)
await self._send(
session_id,
Expand All @@ -265,14 +276,6 @@ async def send_batch_data(
band_name,
level,
)
else:
logger.debug(
"Choose send_small_objects method for sending data of %s bytes",
total_size,
)
await self._send_small_objects(
session_id, data_keys, infos, address, band_name, level
)
unpin_tasks = []
for data_key in data_keys:
unpin_tasks.append(
Expand Down

0 comments on commit e62540e

Please sign in to comment.