-
Notifications
You must be signed in to change notification settings - Fork 170
Remote storage pipeline #899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,3 +108,21 @@ The system automatically selects the best available storage backend: | |
| 1. Initiator sends memory descriptors to target | ||
| 2. Target performs storage-to-memory or memory-to-storage operations | ||
| 3. Data is transferred between initiator and target memory | ||
|
|
||
| Remote reads are implemented as a read from storage followed by a network write. | ||
|
|
||
| Remote writes are implemented as a read from network following by a storage write. | ||
|
|
||
| ### Pipelining | ||
|
|
||
| To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme. | ||
|
|
||
|  | ||
|
|
||
| ### Performance Tips | ||
|
|
||
| For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide suggested list of env - vars, (like providing example configuration with CX-7, SPX and providing what UCX tuning/GDS tuning that is seen to be beneficial? |
||
|
|
||
| First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md). | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also highlight the difference between the GDS backend and that GDS plugins can work in compat mode by default, and what is required for true GDS support. |
||
|
|
||
| On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed. | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |||||||||
| Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| import concurrent.futures | ||||||||||
| import time | ||||||||||
|
|
||||||||||
| import nixl_storage_utils as nsu | ||||||||||
|
|
@@ -27,14 +28,20 @@ | |||||||||
| logger = get_logger(__name__) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation): | ||||||||||
| handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name) | ||||||||||
| def execute_transfer( | ||||||||||
| my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[] | ||||||||||
| ): | ||||||||||
| handle = my_agent.initialize_xfer( | ||||||||||
| operation, local_descs, remote_descs, remote_name, backends=use_backends | ||||||||||
| ) | ||||||||||
| my_agent.transfer(handle) | ||||||||||
| nsu.wait_for_transfer(my_agent, handle) | ||||||||||
| my_agent.release_xfer_handle(handle) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name): | ||||||||||
| def remote_storage_transfer( | ||||||||||
| my_agent, my_mem_descs, operation, remote_agent_name, iterations | ||||||||||
| ): | ||||||||||
| """Initiate remote memory transfer.""" | ||||||||||
| if operation != "READ" and operation != "WRITE": | ||||||||||
| logger.error("Invalid operation, exiting") | ||||||||||
|
|
@@ -45,14 +52,22 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name | |||||||||
| else: | ||||||||||
| operation = b"READ" | ||||||||||
|
|
||||||||||
| iterations_str = bytes(f"{iterations:04d}", "utf-8") | ||||||||||
| # Send the descriptors that you want to read into or write from | ||||||||||
| logger.info(f"Sending {operation} request to {remote_agent_name}") | ||||||||||
| test_descs_str = my_agent.get_serialized_descs(my_mem_descs) | ||||||||||
| my_agent.send_notif(remote_agent_name, operation + test_descs_str) | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str) | ||||||||||
|
|
||||||||||
| while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"): | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| logger.info(f"Time for {iterations} iterations: {end_time - start_time} seconds") | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. f-string -> %s logger formatting |
||||||||||
|
|
||||||||||
|
|
||||||||||
| def connect_to_agents(my_agent, agents_file): | ||||||||||
| target_agents = [] | ||||||||||
|
|
@@ -79,13 +94,145 @@ def connect_to_agents(my_agent, agents_file): | |||||||||
| return target_agents | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def pipeline_reads( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ): | ||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | ||||||||||
| n = 0 | ||||||||||
| s = 0 | ||||||||||
| futures = [] | ||||||||||
|
|
||||||||||
| while n < iterations and s < iterations: | ||||||||||
| if s == 0: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| continue | ||||||||||
|
Comment on lines
+106
to
+118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can move this line before the loop, initiating I think it would simplify the loop and help avoid a branch |
||||||||||
|
|
||||||||||
| if s == iterations: | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can this flow happen with the |
||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| n += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| # Do two storage and network in parallel | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| n += 1 | ||||||||||
|
|
||||||||||
| _, not_done = concurrent.futures.wait( | ||||||||||
| futures, return_when=concurrent.futures.ALL_COMPLETED | ||||||||||
| ) | ||||||||||
| assert not not_done | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def pipeline_writes( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ): | ||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | ||||||||||
| n = 0 | ||||||||||
| s = 0 | ||||||||||
| futures = [] | ||||||||||
|
|
||||||||||
| while n < iterations and s < iterations: | ||||||||||
| if s == 0: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| if s == iterations: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| n += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| # Do two storage and network in parallel | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| n += 1 | ||||||||||
|
|
||||||||||
| _, not_done = concurrent.futures.wait( | ||||||||||
| futures, return_when=concurrent.futures.ALL_COMPLETED | ||||||||||
| ) | ||||||||||
| assert not not_done | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): | ||||||||||
| """Handle remote memory and storage transfers as target.""" | ||||||||||
| # Wait for initiator to send list of memory descriptors | ||||||||||
| notifs = my_agent.get_new_notifs() | ||||||||||
|
|
||||||||||
| logger.info("Waiting for a remote transfer request...") | ||||||||||
|
|
||||||||||
| while len(notifs) == 0: | ||||||||||
| notifs = my_agent.get_new_notifs() | ||||||||||
|
|
||||||||||
|
|
@@ -101,57 +248,69 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): | |||||||||
| logger.error("Invalid operation, exiting") | ||||||||||
| exit(-1) | ||||||||||
|
|
||||||||||
| sent_descs = my_agent.deserialize_descs(recv_msg[4:]) | ||||||||||
| iterations = int(recv_msg[4:8]) | ||||||||||
|
|
||||||||||
| logger.info("Checking to ensure metadata is loaded...") | ||||||||||
| while my_agent.check_remote_metadata(req_agent, sent_descs) is False: | ||||||||||
| continue | ||||||||||
| logger.info(f"Performing {operation} with {iterations} iterations") | ||||||||||
|
|
||||||||||
| if operation == "READ": | ||||||||||
| logger.info("Starting READ operation") | ||||||||||
| sent_descs = my_agent.deserialize_descs(recv_msg[8:]) | ||||||||||
|
|
||||||||||
| # Read from file first | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ" | ||||||||||
| if operation == "READ": | ||||||||||
| pipeline_reads( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ) | ||||||||||
| # Send to client | ||||||||||
| execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE") | ||||||||||
|
|
||||||||||
| elif operation == "WRITE": | ||||||||||
| logger.info("Starting WRITE operation") | ||||||||||
|
|
||||||||||
| # Read from client first | ||||||||||
| execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ") | ||||||||||
| # Write to storage | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE" | ||||||||||
| pipeline_writes( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Send completion notification to initiator | ||||||||||
| my_agent.send_notif(req_agent, b"COMPLETE") | ||||||||||
|
|
||||||||||
| logger.info("One transfer test complete.") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file): | ||||||||||
| def run_client( | ||||||||||
| my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations | ||||||||||
| ): | ||||||||||
| logger.info("Client initialized, ready for local transfer test...") | ||||||||||
|
|
||||||||||
| # For sample purposes, write to and then read from local storage | ||||||||||
| logger.info("Starting local transfer test...") | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| for i in range(1, iterations): | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ["GDS_MT"], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| elapsed = end_time - start_time | ||||||||||
|
Comment on lines
+290
to
+292
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| logger.info(f"Time for {iterations} WRITE iterations: {elapsed} seconds") | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| for i in range(1, iterations): | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ["GDS_MT"], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| elapsed = end_time - start_time | ||||||||||
|
Comment on lines
+308
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| logger.info(f"Time for {iterations} READ iterations: {elapsed} seconds") | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We shouldn't use f-strings in loggers, it's not optimized, pylint would warn about it |
||||||||||
|
|
||||||||||
| logger.info("Local transfer test complete") | ||||||||||
|
|
||||||||||
| logger.info("Starting remote transfer test...") | ||||||||||
|
|
@@ -161,10 +320,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file): | |||||||||
| # For sample purposes, write to and then read from each target agent | ||||||||||
| for target_agent in target_agents: | ||||||||||
| remote_storage_transfer( | ||||||||||
| my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent | ||||||||||
| my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent, iterations | ||||||||||
| ) | ||||||||||
| remote_storage_transfer( | ||||||||||
| my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent | ||||||||||
| my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent, iterations | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| logger.info("Remote transfer test complete") | ||||||||||
|
|
@@ -199,8 +358,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs): | |||||||||
| type=str, | ||||||||||
| help="File containing list of target agents (only needed for client)", | ||||||||||
| ) | ||||||||||
| parser.add_argument( | ||||||||||
| "--iterations", | ||||||||||
| type=int, | ||||||||||
| default=100, | ||||||||||
| help="Number of iterations for each transfer", | ||||||||||
| ) | ||||||||||
| args = parser.parse_args() | ||||||||||
|
|
||||||||||
| mem = "DRAM" | ||||||||||
|
|
||||||||||
| if args.role == "client": | ||||||||||
| mem = "VRAM" | ||||||||||
|
Comment on lines
+369
to
+372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we have some constants for "DRAM", "VRAM", "GDS_MT" and so on? |
||||||||||
|
|
||||||||||
| my_agent = nsu.create_agent_with_plugins(args.name, args.port) | ||||||||||
|
|
||||||||||
| ( | ||||||||||
|
|
@@ -209,15 +379,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs): | |||||||||
| nixl_mem_reg_descs, | ||||||||||
| nixl_file_reg_descs, | ||||||||||
| ) = nsu.setup_memory_and_files( | ||||||||||
| my_agent, args.batch_size, args.buf_size, args.fileprefix | ||||||||||
| my_agent, args.batch_size, args.buf_size, args.fileprefix, mem | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if args.role == "client": | ||||||||||
| if not args.agents_file: | ||||||||||
| parser.error("--agents_file is required when role is client") | ||||||||||
| try: | ||||||||||
| run_client( | ||||||||||
| my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, args.agents_file | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs, | ||||||||||
| nixl_file_reg_descs, | ||||||||||
| args.agents_file, | ||||||||||
| args.iterations, | ||||||||||
| ) | ||||||||||
| finally: | ||||||||||
| nsu.cleanup_resources( | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tstamler please can you add some more description on how this is implemented with NIXL in this example?