Skip to content

Commit cd9f7f8

Browse files
UCT/GDA: Use wqe_idx for rc_gda progress (#10928)
1 parent 5959ce7 commit cd9f7f8

File tree

12 files changed

+159
-164
lines changed

12 files changed

+159
-164
lines changed

src/tools/perf/cuda/ucp_cuda_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ ucp_perf_cuda_send_sync(ucp_perf_cuda_params &params, ucx_perf_counter_t idx,
228228
ucp_device_request_t &req)
229229
{
230230
ucs_status_t status = ucp_perf_cuda_send_nbx<level, cmd>(params, idx, req);
231-
if (status != UCS_OK) {
231+
if (UCS_STATUS_IS_ERR(status)) {
232232
return status;
233233
}
234234

@@ -262,7 +262,7 @@ ucp_perf_cuda_put_multi_bw_kernel(ucx_perf_cuda_context &ctx,
262262

263263
ucp_device_request_t &req = request_mgr.get_request();
264264
status = ucp_perf_cuda_send_nbx<level, cmd>(params, idx, req);
265-
if (status != UCS_OK) {
265+
if (UCS_STATUS_IS_ERR(status)) {
266266
ucs_device_error("send failed: %d", status);
267267
goto out;
268268
}

src/ucp/api/device/ucp_device_impl.h

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525
typedef struct ucp_device_request {
2626
uct_device_completion_t comp;
27+
ucs_status_t status;
2728
uct_device_ep_h device_ep;
2829
} ucp_device_request_t;
2930

@@ -51,9 +52,6 @@ UCS_F_DEVICE void ucp_device_request_init(uct_device_ep_t *device_ep,
5152
if (req != nullptr) {
5253
comp = &req->comp;
5354
req->device_ep = device_ep;
54-
uct_device_completion_init(comp);
55-
/* TODO: Handle multiple device posts with same req? */
56-
++comp->count;
5755
} else {
5856
comp = nullptr;
5957
}
@@ -63,16 +61,20 @@ UCS_F_DEVICE void ucp_device_request_init(uct_device_ep_t *device_ep,
6361
/**
6462
* Macro for device put operations with retry logic
6563
*/
66-
#define UCP_DEVICE_SEND_BLOCKING(_level, _uct_device_ep_send, _device_ep, ...) \
64+
#define UCP_DEVICE_SEND_BLOCKING(_level, _uct_device_ep_send, _device_ep, \
65+
_req, ...) \
6766
({ \
6867
ucs_status_t _status; \
6968
do { \
7069
_status = _uct_device_ep_send<_level>(_device_ep, __VA_ARGS__); \
7170
if (_status != UCS_ERR_NO_RESOURCE) { \
7271
break; \
7372
} \
74-
_status = uct_device_ep_progress<_level>(_device_ep); \
75-
} while (!UCS_STATUS_IS_ERR(_status)); \
73+
uct_device_ep_progress<_level>(_device_ep); \
74+
} while (1); \
75+
if (_req != nullptr) { \
76+
_req->status = _status; \
77+
} \
7678
_status; \
7779
})
7880

@@ -148,8 +150,8 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_single(
148150
}
149151

150152
return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_put_single, device_ep,
151-
uct_elem, address, remote_address, length,
152-
flags, comp);
153+
req, uct_elem, address, remote_address,
154+
length, flags, comp);
153155
}
154156

155157

@@ -199,8 +201,8 @@ UCS_F_DEVICE ucs_status_t ucp_device_counter_inc(
199201
}
200202

201203
return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_atomic_add, device_ep,
202-
uct_elem, inc_value, remote_address, flags,
203-
comp);
204+
req, uct_elem, inc_value, remote_address,
205+
flags, comp);
204206
}
205207

206208

@@ -263,8 +265,9 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_multi(
263265
}
264266

265267
return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_put_multi, device_ep,
266-
uct_mem_list, mem_list_h->mem_list_length,
267-
addresses, remote_addresses, lengths,
268+
req, uct_mem_list,
269+
mem_list_h->mem_list_length, addresses,
270+
remote_addresses, lengths,
268271
counter_inc_value, counter_remote_address,
269272
flags, comp);
270273
}
@@ -338,10 +341,11 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_multi_partial(
338341
}
339342

340343
return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_put_multi_partial,
341-
device_ep, uct_mem_list, mem_list_indices,
342-
mem_list_count, addresses, remote_addresses,
343-
lengths, counter_index, counter_inc_value,
344-
counter_remote_address, flags, comp);
344+
device_ep, req, uct_mem_list,
345+
mem_list_indices, mem_list_count, addresses,
346+
remote_addresses, lengths, counter_index,
347+
counter_inc_value, counter_remote_address,
348+
flags, comp);
345349
}
346350

347351

@@ -409,19 +413,14 @@ UCS_F_DEVICE void ucp_device_counter_write(void *counter_ptr, uint64_t value)
409413
template<ucs_device_level_t level = UCS_DEVICE_LEVEL_THREAD>
410414
UCS_F_DEVICE ucs_status_t ucp_device_progress_req(ucp_device_request_t *req)
411415
{
412-
ucs_status_t status;
413-
414-
if (ucs_likely(req->comp.count == 0)) {
415-
return req->comp.status;
416-
}
417-
418-
status = uct_device_ep_progress<level>(req->device_ep);
419-
if (status != UCS_OK) {
420-
return status;
416+
if (ucs_likely(req->status != UCS_INPROGRESS)) {
417+
return req->status;
421418
}
422419

423-
return (ucs_likely(req->comp.count == 0)) ? req->comp.status :
424-
UCS_INPROGRESS;
420+
uct_device_ep_progress<level>(req->device_ep);
421+
req->status = uct_device_ep_check_completion<level>(req->device_ep,
422+
&req->comp);
423+
return req->status;
425424
}
426425

427426
#endif /* UCP_DEVICE_IMPL_H */

src/uct/api/device/uct_device_impl.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
#include <uct/ib/mlx5/gdaki/gdaki.cuh>
1717

18+
union uct_device_completion {
19+
uct_rc_gda_completion_t rc_gda;
20+
uct_cuda_ipc_completion_t cuda_ipc;
21+
};
22+
1823

1924
/**
2025
* @ingroup UCT_DEVICE
@@ -242,34 +247,37 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_multi_partial(
242247
* @brief Progress all operations on device endpoint @a device_ep.
243248
*
244249
* @param [in] device_ep Device endpoint to be used for the operation.
245-
*
246-
* @return UCS_OK - Some operation was completed.
247-
* @return UCS_INPROGRESS - No progress on the endpoint.
248-
* @return Error code as defined by @ref ucs_status_t
249250
*/
250251
template<ucs_device_level_t level>
251-
UCS_F_DEVICE ucs_status_t uct_device_ep_progress(uct_device_ep_h device_ep)
252+
UCS_F_DEVICE void uct_device_ep_progress(uct_device_ep_h device_ep)
252253
{
253254
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
254-
return uct_rc_mlx5_gda_ep_progress<level>(device_ep);
255-
} else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
256-
return UCS_OK;
255+
uct_rc_mlx5_gda_ep_progress<level>(device_ep);
257256
}
258-
259-
return UCS_ERR_UNSUPPORTED;
260257
}
261258

262259

263260
/**
264261
* @ingroup UCT_DEVICE
265-
* @brief Initialize a device completion object.
262+
* @brief Check whether opetation executed on device endpoint @a device_ep was
263+
* completed.
264+
*
265+
* @param [in] device_ep Device endpoint to be used for the operation.
266+
* @param [in] comp Completion object tracking operation progress.
266267
*
267-
* @param [out] comp Device completion object to initialize.
268+
* @return UCS_OK - Some operation was completed.
269+
* @return UCS_INPROGRESS - No progress on the endpoint.
270+
* @return Error code as defined by @ref ucs_status_t
268271
*/
269-
UCS_F_DEVICE void uct_device_completion_init(uct_device_completion_t *comp)
272+
template<ucs_device_level_t level>
273+
UCS_F_DEVICE ucs_status_t uct_device_ep_check_completion(
274+
uct_device_ep_h device_ep, uct_device_completion_t *comp)
270275
{
271-
comp->count = 0;
272-
comp->status = UCS_OK;
276+
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
277+
return uct_rc_mlx5_gda_ep_check_completion<level>(device_ep, comp);
278+
}
279+
280+
return UCS_ERR_UNSUPPORTED;
273281
}
274282

275283
#endif

src/uct/api/device/uct_device_types.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ typedef struct uct_device_ep {
4242

4343

4444
/* Completion object for device operations */
45-
typedef struct uct_device_completion {
46-
uint32_t count; /* How many operations are pending */
47-
ucs_status_t status; /* Status of the operation */
48-
} uct_device_completion_t;
45+
typedef union uct_device_completion uct_device_completion_t;
4946

5047

5148
/* Base structure for all device memory elements */

src/uct/cuda/cuda_ipc/cuda_ipc.cuh

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ uct_cuda_ipc_ep_put_single(uct_device_ep_h device_ep,
304304
mapped_rem_addr = uct_cuda_ipc_map_remote(cuda_ipc_mem_element, remote_address);
305305
uct_cuda_ipc_copy_level<level>(mapped_rem_addr, address, length);
306306
uct_cuda_ipc_level_sync<level>();
307-
--comp->count;
308-
309307
return UCS_OK;
310308
}
311309

@@ -339,10 +337,6 @@ uct_cuda_ipc_ep_put_multi(uct_device_ep_h device_ep,
339337
}
340338

341339
uct_cuda_ipc_level_sync<level>();
342-
if (lane_id == 0) {
343-
--comp->count;
344-
}
345-
346340
return UCS_OK;
347341
}
348342

@@ -376,10 +370,6 @@ uct_cuda_ipc_ep_put_multi_partial(uct_device_ep_h device_ep,
376370
}
377371

378372
uct_cuda_ipc_level_sync<level>();
379-
if (lane_id == 0) {
380-
--comp->count;
381-
}
382-
383373
return UCS_OK;
384374
}
385375

@@ -403,10 +393,6 @@ uct_cuda_ipc_ep_atomic_add(uct_device_ep_h device_ep,
403393
}
404394

405395
uct_cuda_ipc_level_sync<level>();
406-
if (lane_id == 0) {
407-
--comp->count;
408-
}
409-
410396
return UCS_OK;
411397
}
412398

src/uct/cuda/cuda_ipc/cuda_ipc_device.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,8 @@ typedef struct {
1212
ptrdiff_t mapped_offset;
1313
} uct_cuda_ipc_device_mem_element_t;
1414

15+
16+
typedef struct {
17+
} uct_cuda_ipc_completion_t;
18+
1519
#endif

src/uct/ib/mlx5/gdaki/gdaki.c

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,8 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
9292
uct_ib_mlx5_wq_calc_sizes(&qp_attr);
9393

9494
cq_attr.flags |= UCT_IB_MLX5_CQ_IGNORE_OVERRUN;
95-
cq_attr.umem_offset = ucs_align_up_pow2(
96-
sizeof(uct_rc_gdaki_dev_ep_t) +
97-
qp_attr.max_tx * sizeof(uct_rc_gdaki_op_t),
98-
ucs_get_page_size());
95+
cq_attr.umem_offset = ucs_align_up_pow2(sizeof(uct_rc_gdaki_dev_ep_t),
96+
ucs_get_page_size());
9997

10098
qp_attr.mmio_mode = UCT_IB_MLX5_MMIO_MODE_DB;
10199
qp_attr.super.srq_num = 0;
@@ -109,9 +107,9 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
109107
dev_ep_size = qp_attr.umem_offset + qp_attr.len;
110108
/*
111109
* dev_ep layout:
112-
* +---------------------+-------+---------+---------+
113-
* | counters, dbr | ops | cq buff | wq buff |
114-
* +---------------------+-------+---------+---------+
110+
* +---------------------+---------+---------+
111+
* | counters, dbr | cq buff | wq buff |
112+
* +---------------------+---------+---------+
115113
*/
116114
status = uct_rc_gdaki_alloc(dev_ep_size, ucs_get_page_size(),
117115
(void**)&self->ep_gpu, &self->ep_raw);

0 commit comments

Comments
 (0)