@@ -319,6 +319,21 @@ static void ucp_proto_common_tl_perf_reset(ucp_proto_common_tl_perf_t *tl_perf)
319
319
tl_perf -> max_frag = SIZE_MAX ;
320
320
}
321
321
322
+ static void ucp_proto_common_perf_attr_set_mem_type (
323
+ const ucp_proto_common_init_params_t * params ,
324
+ uct_perf_attr_t * perf_attr )
325
+ {
326
+ const ucp_rkey_config_key_t * rkey_config_key = params -> super .rkey_config_key ;
327
+
328
+ perf_attr -> field_mask |= UCT_PERF_ATTR_FIELD_LOCAL_MEMORY_TYPE ;
329
+ perf_attr -> local_memory_type = params -> reg_mem_info .type ;
330
+
331
+ if (rkey_config_key != NULL ) {
332
+ perf_attr -> field_mask |= UCT_PERF_ATTR_FIELD_REMOTE_MEMORY_TYPE ;
333
+ perf_attr -> remote_memory_type = rkey_config_key -> mem_type ;
334
+ }
335
+ }
336
+
322
337
ucs_status_t
323
338
ucp_proto_common_get_lane_perf (const ucp_proto_common_init_params_t * params ,
324
339
ucp_lane_index_t lane ,
@@ -367,6 +382,7 @@ ucp_proto_common_get_lane_perf(const ucp_proto_common_init_params_t *params,
367
382
UCT_PERF_ATTR_FIELD_PATH_BANDWIDTH |
368
383
UCT_PERF_ATTR_FIELD_LATENCY ;
369
384
perf_attr .operation = params -> send_op ;
385
+ ucp_proto_common_perf_attr_set_mem_type (params , & perf_attr );
370
386
371
387
status = ucp_worker_iface_estimate_perf (wiface , & perf_attr );
372
388
if (status != UCS_OK ) {
@@ -647,6 +663,22 @@ ucp_proto_common_reg_md_map(const ucp_proto_common_init_params_t *params,
647
663
return reg_md_map ;
648
664
}
649
665
666
+ static int ucp_proto_common_find_lanes_check_mem_type (
667
+ const ucp_proto_common_init_params_t * params , ucp_lane_index_t lane )
668
+ {
669
+ uct_perf_attr_t perf_attr = {0 };
670
+ ucp_rsc_index_t rsc_index ;
671
+ ucp_worker_iface_t * wiface ;
672
+
673
+ ucp_proto_common_perf_attr_set_mem_type (params , & perf_attr );
674
+
675
+ rsc_index = ucp_proto_common_get_rsc_index (& params -> super , lane );
676
+ wiface = ucp_worker_iface (params -> super .worker , rsc_index );
677
+ /* TODO: Use memory reachability UCT API, when available, to check memory
678
+ type support */
679
+ return uct_iface_estimate_perf (wiface -> iface , & perf_attr ) == UCS_OK ;
680
+ }
681
+
650
682
ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag (
651
683
const ucp_proto_common_init_params_t * params , ucp_lane_type_t lane_type ,
652
684
uint64_t tl_cap_flags , ucp_lane_index_t max_lanes ,
@@ -688,6 +720,10 @@ ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag(
688
720
continue ;
689
721
}
690
722
723
+ if (!ucp_proto_common_find_lanes_check_mem_type (params , lane )) {
724
+ continue ;
725
+ }
726
+
691
727
lanes [num_valid_lanes ++ ] = lane ;
692
728
if (num_valid_lanes >= max_lanes ) {
693
729
break ;
0 commit comments