Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/python/remote_storage_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,21 @@ The system automatically selects the best available storage backend:
1. Initiator sends memory descriptors to target
2. Target performs storage-to-memory or memory-to-storage operations
3. Data is transferred between initiator and target memory

Remote reads are implemented as a read from storage followed by a network write.

Remote writes are implemented as a read from network following by a storage write.

### Pipelining

To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tstamler please can you add some more description on how this is implemented with NIXL in this example?

![Remote Operation Pipelines](storage_pipelines.png)

### Performance Tips

For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide suggested list of env - vars, (like providing example configuration with CX-7, SPX and providing what UCX tuning/GDS tuning that is seen to be beneficial?


First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also highlight the difference between the GDS backend and that GDS plugins can work in compat mode by default, and what is required for true GDS support.


On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
266 changes: 220 additions & 46 deletions examples/python/remote_storage_example/nixl_p2p_storage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes.
"""

import concurrent.futures
import time

import nixl_storage_utils as nsu
Expand All @@ -27,14 +28,20 @@
logger = get_logger(__name__)


def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation):
handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name)
def execute_transfer(
my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[]
):
handle = my_agent.initialize_xfer(
operation, local_descs, remote_descs, remote_name, backends=use_backends
)
my_agent.transfer(handle)
nsu.wait_for_transfer(my_agent, handle)
my_agent.release_xfer_handle(handle)


def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name):
def remote_storage_transfer(
my_agent, my_mem_descs, operation, remote_agent_name, iterations
):
"""Initiate remote memory transfer."""
if operation != "READ" and operation != "WRITE":
logger.error("Invalid operation, exiting")
Expand All @@ -45,14 +52,22 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name
else:
operation = b"READ"

iterations_str = bytes(f"{iterations:04d}", "utf-8")
# Send the descriptors that you want to read into or write from
logger.info(f"Sending {operation} request to {remote_agent_name}")
test_descs_str = my_agent.get_serialized_descs(my_mem_descs)
my_agent.send_notif(remote_agent_name, operation + test_descs_str)

start_time = time.time()

my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str)

while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"):
continue

end_time = time.time()

logger.info(f"Time for {iterations} iterations: {end_time - start_time} seconds")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-string -> %s logger formatting



def connect_to_agents(my_agent, agents_file):
target_agents = []
Expand All @@ -79,13 +94,145 @@ def connect_to_agents(my_agent, agents_file):
return target_agents


def pipeline_reads(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
n = 0
s = 0
futures = []

while n < iterations and s < iterations:
if s == 0:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"READ",
)
)
s += 1
continue
Comment on lines +106 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can move this line before the loop, initiating s to 1?

I think it would simplify the loop and help avoid a branch


if s == iterations:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can this flow happen with the s < iterations while condition?

futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"WRITE",
)
)
n += 1
continue

# Do two storage and network in parallel
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"READ",
)
)
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"WRITE",
)
)
s += 1
n += 1

_, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED
)
assert not not_done


def pipeline_writes(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
n = 0
s = 0
futures = []

while n < iterations and s < iterations:
if s == 0:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"READ",
)
)
s += 1
continue

if s == iterations:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"WRITE",
)
)
n += 1
continue

# Do two storage and network in parallel
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"READ",
)
)
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"WRITE",
)
)
s += 1
n += 1

_, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED
)
assert not not_done


def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
"""Handle remote memory and storage transfers as target."""
# Wait for initiator to send list of memory descriptors
notifs = my_agent.get_new_notifs()

logger.info("Waiting for a remote transfer request...")

while len(notifs) == 0:
notifs = my_agent.get_new_notifs()

Expand All @@ -101,57 +248,69 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
logger.error("Invalid operation, exiting")
exit(-1)

sent_descs = my_agent.deserialize_descs(recv_msg[4:])
iterations = int(recv_msg[4:8])

logger.info("Checking to ensure metadata is loaded...")
while my_agent.check_remote_metadata(req_agent, sent_descs) is False:
continue
logger.info(f"Performing {operation} with {iterations} iterations")

if operation == "READ":
logger.info("Starting READ operation")
sent_descs = my_agent.deserialize_descs(recv_msg[8:])

# Read from file first
execute_transfer(
my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ"
if operation == "READ":
pipeline_reads(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
)
# Send to client
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE")

elif operation == "WRITE":
logger.info("Starting WRITE operation")

# Read from client first
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ")
# Write to storage
execute_transfer(
my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE"
pipeline_writes(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
)

# Send completion notification to initiator
my_agent.send_notif(req_agent, b"COMPLETE")

logger.info("One transfer test complete.")


def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
def run_client(
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations
):
logger.info("Client initialized, ready for local transfer test...")

# For sample purposes, write to and then read from local storage
logger.info("Starting local transfer test...")
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"WRITE",
)
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"READ",
)

start_time = time.time()

for i in range(1, iterations):
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"WRITE",
["GDS_MT"],
)

end_time = time.time()

elapsed = end_time - start_time
Comment on lines +290 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end_time = time.time()
elapsed = end_time - start_time
elapsed = time.time() - start_time


logger.info(f"Time for {iterations} WRITE iterations: {elapsed} seconds")

start_time = time.time()

for i in range(1, iterations):
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"READ",
["GDS_MT"],
)

end_time = time.time()

elapsed = end_time - start_time
Comment on lines +308 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end_time = time.time()
elapsed = end_time - start_time
elapsed = time.time() - start_time


logger.info(f"Time for {iterations} READ iterations: {elapsed} seconds")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(f"Time for {iterations} READ iterations: {elapsed} seconds")
logger.info("Time for %s READ iterations: %s seconds", iterations, elapsed)

We shouldn't use f-strings in loggers, it's not optimized, pylint would warn about it


logger.info("Local transfer test complete")

logger.info("Starting remote transfer test...")
Expand All @@ -161,10 +320,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
# For sample purposes, write to and then read from each target agent
for target_agent in target_agents:
remote_storage_transfer(
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent, iterations
)
remote_storage_transfer(
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent, iterations
)

logger.info("Remote transfer test complete")
Expand Down Expand Up @@ -199,8 +358,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
type=str,
help="File containing list of target agents (only needed for client)",
)
parser.add_argument(
"--iterations",
type=int,
default=100,
help="Number of iterations for each transfer",
)
args = parser.parse_args()

mem = "DRAM"

if args.role == "client":
mem = "VRAM"
Comment on lines +369 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have some constants for "DRAM", "VRAM", "GDS_MT" and so on?


my_agent = nsu.create_agent_with_plugins(args.name, args.port)

(
Expand All @@ -209,15 +379,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
nixl_mem_reg_descs,
nixl_file_reg_descs,
) = nsu.setup_memory_and_files(
my_agent, args.batch_size, args.buf_size, args.fileprefix
my_agent, args.batch_size, args.buf_size, args.fileprefix, mem
)

if args.role == "client":
if not args.agents_file:
parser.error("--agents_file is required when role is client")
try:
run_client(
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, args.agents_file
my_agent,
nixl_mem_reg_descs,
nixl_file_reg_descs,
args.agents_file,
args.iterations,
)
finally:
nsu.cleanup_resources(
Expand Down
Loading