diff --git a/src/host/proxy/proxy.cpp b/src/host/proxy/proxy.cpp index 2fcaa7fc..a1504c7d 100644 --- a/src/host/proxy/proxy.cpp +++ b/src/host/proxy/proxy.cpp @@ -687,51 +687,23 @@ int process_channel_amo(proxy_state_t *state, proxy_channel_t *ch, int *is_proce } void enforce_cst(proxy_state_t *proxy_state) { -#if defined(NVSHMEM_X86_64) - nvshmemi_state_t *state = proxy_state->nvshmemi_state; -#endif - int status = 0; if (nvshmemi_options.BYPASS_FLUSH) return; - if (proxy_state->is_consistency_api_supported) { - if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && - CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { - status = - CUPFN(nvshmemi_cuda_syms, - cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, - CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); - /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable - consistent access of data on any GPU (and not just self GPU) with - wait_until, quiet, barrier, etc. **/ - if (status != CUDA_SUCCESS) { - NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); - } - } - return; - } -#if defined(NVSHMEM_PPC64LE) - status = cudaEventRecord(proxy_state->cuev, proxy_state->stream); - if (unlikely(status != CUDA_SUCCESS)) { - NVSHMEMI_ERROR_EXIT("cuEventRecord() failed in the proxy thread \n"); - } -#elif defined(NVSHMEM_X86_64) - for (int i = 0; i < state->num_initialized_transports; i++) { - if (!((state->transport_bitmap) & (1 << i))) continue; - struct nvshmem_transport *tcurr = state->transports[i]; - if (!tcurr->host_ops.enforce_cst) continue; - - // assuming the transport is connected - IB RC - if (tcurr->attr & NVSHMEM_TRANSPORT_ATTR_CONNECTED) { - status = tcurr->host_ops.enforce_cst(tcurr); - if (status) { - NVSHMEMI_ERROR_PRINT("aborting due to error in progress_cst \n"); - exit(-1); - } + if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && + CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { + status = + CUPFN(nvshmemi_cuda_syms, + cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); + /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable + consistent access of data on any GPU (and not just self GPU) with + wait_until, quiet, barrier, etc. **/ + if (status != CUDA_SUCCESS) { + NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); } } -#endif } inline void quiet_ack_channels(proxy_state_t *proxy_state) { diff --git a/src/host/topo/topo.cpp b/src/host/topo/topo.cpp index ac86a9eb..753bbe9d 100644 --- a/src/host/topo/topo.cpp +++ b/src/host/topo/topo.cpp @@ -9,6 +9,7 @@ #include // for CUDA_SUCCESS #include // for cudaDevice... #include // for cudaDevice... +#include // for opendir, readdir #include // for PATH_MAX #include // for NULL, fclose #include // for free, calloc @@ -49,6 +50,8 @@ enum pci_distance { static const int pci_distance_perf[PATH_COUNT] = {4, 4, 3, 2, 1}; static const char *pci_distance_string[PATH_COUNT] = {"PIX", "PXB", "PHB", "NODE", "SYS"}; +#define NVIDIA_DRIVER_PATH "/sys/bus/pci/drivers/nvidia" + static int get_cuda_bus_id(int cuda_dev, char *bus_id) { int status = NVSHMEMX_SUCCESS; cudaError_t err; @@ -106,6 +109,71 @@ static int get_device_path(char *bus_id, char **path) { return status; } +static int is_pci_addr(const char *name) { + // Match XXXX:XX:XX.X pattern + return strlen(name) == 12 && name[4] == ':' && name[7] == ':' && name[10] == '.'; +} + +int get_nvidia_gpu_count(void) { + DIR *dir = opendir(NVIDIA_DRIVER_PATH); + if (!dir) return 0; + int count = 0; + struct dirent *ent; + while ((ent = readdir(dir)) != NULL) { + if (is_pci_addr(ent->d_name)) count++; + } + closedir(dir); + return count; +} + +static int get_gpu_paths_and_index(int cuda_device_id, char **cuda_device_paths, + int *out_mygpu_index) { + int status = NVSHMEMX_SUCCESS; + char my_bus_id[MAX_BUSID_SIZE]; + DIR *nvidia_dir = NULL; + + status = get_cuda_bus_id(cuda_device_id, my_bus_id); + if (status != NVSHMEMX_SUCCESS) return status; + for (int k = 0; k < MAX_BUSID_SIZE; k++) + my_bus_id[k] = tolower(my_bus_id[k]); + + nvidia_dir = opendir(NVIDIA_DRIVER_PATH); + if (!nvidia_dir) { + NVSHMEMI_ERROR_PRINT("Failed to open " NVIDIA_DRIVER_PATH "\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + int gpu_id = 0; + *out_mygpu_index = -1; + struct dirent *ent; + while ((ent = readdir(nvidia_dir)) != NULL) { + if (!is_pci_addr(ent->d_name)) continue; + char bus_id[MAX_BUSID_SIZE]; + strncpy(bus_id, ent->d_name, MAX_BUSID_SIZE - 1); + bus_id[MAX_BUSID_SIZE - 1] = '\0'; + + status = get_device_path(bus_id, &cuda_device_paths[gpu_id]); + if (status != NVSHMEMX_SUCCESS) { + NVSHMEMI_ERROR_PRINT("get cuda path failed\n"); + closedir(nvidia_dir); + return status; + } + + if (strncmp(my_bus_id, bus_id, MAX_BUSID_SIZE) == 0) + *out_mygpu_index = gpu_id; + + gpu_id++; + } + closedir(nvidia_dir); + + if (*out_mygpu_index < 0) { + NVSHMEMI_ERROR_PRINT("Could not find current GPU in sysfs\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + return NVSHMEMX_SUCCESS; +} + static enum pci_distance get_pci_distance(char *cuda_path, char *mlx_path) { int score = 0; int depth = 0; @@ -130,7 +198,7 @@ static enum pci_distance get_pci_distance(char *cuda_path, char *mlx_path) { } typedef struct nvshmemi_path_pair_info { - int pe_idx; + int gpu_idx; int dev_idx; enum pci_distance pcie_distance; } nvshmemi_path_pair_info_t; @@ -146,37 +214,32 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, char gpu_bus_id[MAX_BUSID_SIZE]; } gpu_info, *gpu_info_all = NULL; - std::list pe_dev_pairs; + std::list gpu_dev_pairs; std::list::iterator pairs_iter; int ndev = tcurr->n_devices; int mype = nvshmemi_state->mype; int n_pes = nvshmemi_state->npes; int n_pes_node = nvshmemi_state->npes_node; - CUdevice gpu_device_id; char **cuda_device_paths = NULL; - int *pe_selected_devices = NULL; - enum pci_distance *pe_device_distance = NULL; + int *gpu_selected_devices = NULL; + enum pci_distance *gpu_device_distance = NULL; int *used_devs = NULL; + int n_gpus_node = 0; - int mype_array_index = -1, mydev_index = -1; - int i, dev_id, pe_id, pe_pair_index; + int mygpu_index = -1, mydev_index = -1; + int i, dev_id, gpu_id, gpu_pair_index; int devices_assigned = 0; - int mype_device_count = 0; + int mygpu_device_count = 0; int status = NVSHMEMX_ERROR_INTERNAL; + int mygpu_array_index; if (ndev <= 0) { NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "transport devices (setup_connections) failed \n"); } - status = CUPFN(nvshmemi_cuda_syms, cuCtxGetDevice(&gpu_device_id)); - if (status != CUDA_SUCCESS) { - status = NVSHMEMX_ERROR_INTERNAL; - goto out; - } - /* Allocate data structures start */ /* Array of dev_info structures of size # of local NICs */ dev_info_all = (struct dev_info *)calloc(ndev, sizeof(struct dev_info)); @@ -188,66 +251,45 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, NVSHMEMI_NULL_ERROR_JMP(gpu_info_all, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "gpu_info_all allocation failed \n"); - /* array linking each GPU on our node with it's pcie path */ - cuda_device_paths = (char **)calloc(n_pes_node, sizeof(char *)); - NVSHMEMI_NULL_ERROR_JMP(cuda_device_paths, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - - /* Array of size n_pes_node * max_dev_per_pe storing the accepted mappings of PE to Dev(s) */ - pe_selected_devices = (int *)calloc(n_pes_node * max_dev_per_pe, sizeof(int)); - NVSHMEMI_NULL_ERROR_JMP(pe_selected_devices, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { - for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - pe_selected_devices[pe_id * max_dev_per_pe + dev_id] = -1; - } - } - - pe_device_distance = - (enum pci_distance *)calloc(n_pes_node * max_dev_per_pe, sizeof(enum pci_distance)); - NVSHMEMI_NULL_ERROR_JMP(pe_device_distance, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { - for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - pe_device_distance[pe_id * max_dev_per_pe + dev_id] = PATH_SYS; - } - } - used_devs = (int *)calloc(ndev, sizeof(int)); NVSHMEMI_NULL_ERROR_JMP(used_devs, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate memory for PE/NIC Mapping.\n"); /* Allocate data structures end */ /* Gather GPU and NIC paths start */ - status = get_cuda_bus_id(gpu_device_id, gpu_info.gpu_bus_id); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "get cuda busid failed \n"); - - status = nvshmemi_boot_handle.allgather((void *)&gpu_info, (void *)gpu_info_all, - sizeof(struct gpu_info), &nvshmemi_boot_handle); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "allgather of gpu_info failed \n"); - - pe_id = 0; - for (i = 0; i < n_pes; i++) { - if (nvshmemi_state->pe_info[i].hostHash != nvshmemi_state->pe_info[mype].hostHash) { - continue; - } + n_gpus_node = get_nvidia_gpu_count(); + if (n_gpus_node <= 0) { + NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No NVIDIA GPUs found in " NVIDIA_DRIVER_PATH "\n"); + } - status = get_device_path(gpu_info_all[i].gpu_bus_id, &cuda_device_paths[pe_id]); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "get cuda path failed \n"); - /* to get back to our PE after the algorithm finishes. */ - if (i == mype) { - mype_array_index = pe_id * max_dev_per_pe; - } + cuda_device_paths = (char **)calloc(n_gpus_node, sizeof(char *)); + NVSHMEMI_NULL_ERROR_JMP(cuda_device_paths, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); - pe_id++; - if (pe_id == n_pes_node) { - break; + status = get_gpu_paths_and_index(nvshmemi_state->device_id, cuda_device_paths, &mygpu_index); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "get_gpu_paths_and_index failed\n"); + mygpu_array_index = mygpu_index * max_dev_per_pe; + + /* Allocate GPU-based arrays */ + gpu_selected_devices = (int *)calloc(n_gpus_node * max_dev_per_pe, sizeof(int)); + NVSHMEMI_NULL_ERROR_JMP(gpu_selected_devices, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { + for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { + gpu_selected_devices[gpu_id * max_dev_per_pe + dev_id] = -1; } } - if (pe_id != n_pes_node || mype_array_index == -1) { - NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Number of PEs found doesn't match the PE node count.\n"); + gpu_device_distance = + (enum pci_distance *)calloc(n_gpus_node * max_dev_per_pe, sizeof(enum pci_distance)); + NVSHMEMI_NULL_ERROR_JMP(gpu_device_distance, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { + for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { + gpu_device_distance[gpu_id * max_dev_per_pe + dev_id] = PATH_SYS; + } } for (i = 0; i < ndev; i++) { @@ -257,37 +299,37 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, /* Gather GPU and NIC paths end */ /* Get path distances start */ - /* construct a n_pes_node * ndev array of distance measurements */ - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { + /* construct a n_gpus_node * ndev array of distance measurements */ + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { for (dev_id = 0; dev_id < ndev; dev_id++) { enum pci_distance distance_compare; distance_compare = - get_pci_distance(cuda_device_paths[pe_id], dev_info_all[dev_id].dev_path); - if (unlikely(pe_dev_pairs.empty())) { - pe_dev_pairs.push_front({pe_id, dev_id, distance_compare}); + get_pci_distance(cuda_device_paths[gpu_id], dev_info_all[dev_id].dev_path); + if (unlikely(gpu_dev_pairs.empty())) { + gpu_dev_pairs.push_front({gpu_id, dev_id, distance_compare}); } else { - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { if (distance_compare < (*pairs_iter).pcie_distance) { break; } } - INFO(NVSHMEM_TOPO, "PE %d: %s dev %d: %s distance: %d\n", pe_id, - cuda_device_paths[pe_id], dev_id, dev_info_all[dev_id].dev_path, + INFO(NVSHMEM_TOPO, "GPU %d: %s dev %d: %s distance: %d\n", gpu_id, + cuda_device_paths[gpu_id], dev_id, dev_info_all[dev_id].dev_path, distance_compare); - pe_dev_pairs.insert(pairs_iter, {pe_id, dev_id, distance_compare}); + gpu_dev_pairs.insert(pairs_iter, {gpu_id, dev_id, distance_compare}); } } } /* Get path distances end */ /* loop one, do initial assignments of NIC(s) to each GPU */ - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); pairs_iter++) { + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { bool need_more_assignments = 0; - int pe_base_index = (*pairs_iter).pe_idx * max_dev_per_pe; + int gpu_base_index = (*pairs_iter).gpu_idx * max_dev_per_pe; /* skip pairs where the GPU already has a partner in the first loop */ - for (pe_pair_index = 0; pe_pair_index < max_dev_per_pe; pe_pair_index++) - if (pe_selected_devices[pe_base_index + pe_pair_index] == PE_DEVICE_NOT_ASSIGNED) { + for (gpu_pair_index = 0; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) + if (gpu_selected_devices[gpu_base_index + gpu_pair_index] == PE_DEVICE_NOT_ASSIGNED) { need_more_assignments = 1; break; } @@ -297,13 +339,13 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } if (pci_distance_perf[(*pairs_iter).pcie_distance] < - pci_distance_perf[pe_device_distance[pe_base_index]]) { + pci_distance_perf[gpu_device_distance[gpu_base_index]]) { /* This NIC and all subsequent ones are less optimal than the already selected NICs * They can be safely ignored and we assign -2 to indicate that there are no more * optimal NICs for this GPU. */ - for (; pe_pair_index < max_dev_per_pe; pe_pair_index++) { - pe_selected_devices[pe_base_index + pe_pair_index] = + for (; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) { + gpu_selected_devices[gpu_base_index + gpu_pair_index] = PE_DEVICE_NO_OPTIMAL_ASSIGNMENT; /* While not technically assigned, we need to account for these NICs to make * forward progress. @@ -312,61 +354,61 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } } else { /* This NIC is optimal for this GPU. */ - INFO(NVSHMEM_TOPO, "Pairing PE %d with device %d at distance %d\n", - (*pairs_iter).pe_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); - pe_selected_devices[pe_base_index + pe_pair_index] = (*pairs_iter).dev_idx; - pe_device_distance[pe_base_index + pe_pair_index] = (*pairs_iter).pcie_distance; + INFO(NVSHMEM_TOPO, "Pairing GPU %d with device %d at distance %d\n", + (*pairs_iter).gpu_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); + gpu_selected_devices[gpu_base_index + gpu_pair_index] = (*pairs_iter).dev_idx; + gpu_device_distance[gpu_base_index + gpu_pair_index] = (*pairs_iter).pcie_distance; used_devs[(*pairs_iter).dev_idx]++; devices_assigned++; } - if (devices_assigned == n_pes_node * max_dev_per_pe) { + if (devices_assigned == n_gpus_node * max_dev_per_pe) { break; } } /* loop two, load balance the NICs. */ - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - int pe_pair_idx = pe_id * max_dev_per_pe + dev_id; + int gpu_pair_idx = gpu_id * max_dev_per_pe + dev_id; int nic_density; - if (pe_selected_devices[pe_pair_idx] < 0) { + if (gpu_selected_devices[gpu_pair_idx] < 0) { continue; } - nic_density = used_devs[pe_selected_devices[pe_pair_idx]]; + nic_density = used_devs[gpu_selected_devices[gpu_pair_idx]]; /* Can't find a less populated NIC if ours is only assigned to 1 gpu. */ if (nic_density < 2) { continue; } - /* Calculate PE Index from nic_id. Each PE gets max_dev_per_pe assigned to them. - * If there are 8 NIC's and 4 PE's, the nic -> PE mapping looks like + /* Calculate GPU Index from nic_id. Each GPU gets max_dev_per_pe assigned to them. + * If there are 8 NIC's and 4 GPU's, the nic -> GPU mapping looks like * nic_id: 0 1 2 3 4 5 6 7 - * pe_idx: 0 0 1 1 2 2 3 3 + * gpu_idx: 0 0 1 1 2 2 3 3 */ - int pe_idx = (pe_pair_idx - (pe_pair_idx % max_dev_per_pe)) / max_dev_per_pe; - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); + int gpu_idx = (gpu_pair_idx - (gpu_pair_idx % max_dev_per_pe)) / max_dev_per_pe; + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { /* Never change for a less optimal NIC. */ - if ((*pairs_iter).pe_idx != pe_idx) { + if ((*pairs_iter).gpu_idx != gpu_idx) { continue; } if (pci_distance_perf[(*pairs_iter).pcie_distance] < - pci_distance_perf[pe_device_distance[pe_pair_idx]]) { + pci_distance_perf[gpu_device_distance[gpu_pair_idx]]) { break; } if ((nic_density - used_devs[(*pairs_iter).dev_idx]) >= 2) { - INFO(NVSHMEM_TOPO, "Re-Pairing PE %d with device %d at distance %d\n", - (*pairs_iter).pe_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); - used_devs[pe_selected_devices[pe_pair_idx]]--; + INFO(NVSHMEM_TOPO, "Re-Pairing GPU %d with device %d at distance %d\n", + (*pairs_iter).gpu_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); + used_devs[gpu_selected_devices[gpu_pair_idx]]--; used_devs[(*pairs_iter).dev_idx]++; nic_density = used_devs[(*pairs_iter).dev_idx]; - pe_selected_devices[pe_pair_idx] = (*pairs_iter).dev_idx; - pe_device_distance[pe_pair_idx] = (*pairs_iter).pcie_distance; + gpu_selected_devices[gpu_pair_idx] = (*pairs_iter).dev_idx; + gpu_device_distance[gpu_pair_idx] = (*pairs_iter).pcie_distance; if (nic_density < 2) { break; } @@ -374,32 +416,34 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } } - for (pe_pair_index = 0; pe_pair_index < max_dev_per_pe; pe_pair_index++) { - if (pe_selected_devices[mype_array_index + pe_pair_index] >= 0) { - mydev_index = pe_selected_devices[mype_array_index + pe_pair_index]; - device_arr[pe_pair_index] = mydev_index; - mype_device_count++; - INFO(NVSHMEM_TOPO, "Our PE is sharing its NIC at index %d with %d other PEs.\n", - used_devs[mydev_index], mype_device_count); + for (gpu_pair_index = 0; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) { + if (gpu_selected_devices[mygpu_array_index + gpu_pair_index] >= 0) { + mydev_index = gpu_selected_devices[mygpu_array_index + gpu_pair_index]; + device_arr[gpu_pair_index] = mydev_index; + mygpu_device_count++; + INFO(NVSHMEM_TOPO, "Our GPU is sharing its NIC at index %d with %d other GPUs.\n", + used_devs[mydev_index], mygpu_device_count); } } - if (mype_device_count == 0) { + if (mygpu_device_count == 0) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No NICs were assigned to our PE.\n"); + "No NICs were assigned to our GPU.\n"); } /* No need to report this in a loop - All Devices will have the same perf characteristics. */ - if (pci_distance_perf[pe_device_distance[mype_array_index]] < pci_distance_perf[PATH_PIX]) { + if (pci_distance_perf[gpu_device_distance[mygpu_array_index]] < pci_distance_perf[PATH_PIX]) { nvshmemi_state->are_nics_ll128_compliant = false; INFO(NVSHMEM_TOPO, - "Our PE is connected to a NIC with pci distance %s." + "Our GPU is connected to a NIC with pci distance %s." "this will provide less than optimal performance.\n", - pci_distance_string[pe_device_distance[mype_array_index]]); + pci_distance_string[gpu_device_distance[mygpu_array_index]]); } } + status = NVSHMEMX_SUCCESS; + out: if (dev_info_all) { free(dev_info_all); @@ -410,7 +454,7 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } if (cuda_device_paths) { - for (i = 0; i < n_pes_node; i++) { + for (i = 0; i < n_gpus_node; i++) { if (cuda_device_paths[i]) { free(cuda_device_paths[i]); } @@ -418,18 +462,18 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, free(cuda_device_paths); } - pe_dev_pairs.clear(); + gpu_dev_pairs.clear(); - if (pe_selected_devices) { - free(pe_selected_devices); + if (gpu_selected_devices) { + free(gpu_selected_devices); } if (used_devs) { free(used_devs); } - if (pe_device_distance) { - free(pe_device_distance); + if (gpu_device_distance) { + free(gpu_device_distance); } return status; diff --git a/src/host/topo/topo.h b/src/host/topo/topo.h index cdf07cd4..475c6250 100644 --- a/src/host/topo/topo.h +++ b/src/host/topo/topo.h @@ -10,6 +10,7 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, struct nvshmem_transport *tcurr); +int get_nvidia_gpu_count(void); int nvshmemi_detect_same_device(nvshmemi_state_t *state); int nvshmemi_build_transport_map(nvshmemi_state_t *state); diff --git a/src/host/transport/transport.cpp b/src/host/transport/transport.cpp index 67077e7e..26e9b499 100644 --- a/src/host/transport/transport.cpp +++ b/src/host/transport/transport.cpp @@ -384,7 +384,12 @@ int nvshmemi_setup_connections(nvshmemi_state_t *state) { continue; } - int devices_temp = tcurr->n_devices / state->npes_node; + int n_gpus_node = get_nvidia_gpu_count(); + if (n_gpus_node <= 0) { + n_gpus_node = state->npes_node; + } + + int devices_temp = tcurr->n_devices / n_gpus_node; if (devices_temp == 0) devices_temp = 1; const int max_devices_per_pe = devices_temp; int selected_devices[max_devices_per_pe]; diff --git a/src/include/internal/host_transport/transport.h b/src/include/internal/host_transport/transport.h index f3fc7c14..f36b9959 100644 --- a/src/include/internal/host_transport/transport.h +++ b/src/include/internal/host_transport/transport.h @@ -148,7 +148,6 @@ struct nvshmem_transport_host_ops { fence_handle fence; quiet_handle quiet; put_signal_handle put_signal; - int (*enforce_cst)(struct nvshmem_transport *transport); int (*enforce_cst_at_target)(struct nvshmem_transport *transport); int (*add_device_remote_mem_handles)(struct nvshmem_transport *transport, int transport_stride, nvshmem_mem_handle_t *mem_handles, uint64_t heap_offset, diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index 086dc016..654d8dfc 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -98,6 +98,13 @@ NVSHMEMI_ENV_DEF(DISABLE_LOCAL_ONLY_PROXY, bool, false, NVSHMEMI_ENV_CAT_TRANSPO NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, "Set the feature set provider for the libfabric transport: cxi, efa, verbs") +NVSHMEMI_ENV_DEF(LIBFABRIC_MAX_NIC_PER_PE, int, 16, NVSHMEMI_ENV_CAT_TRANSPORT, + "Set the maximum number of NIC's per PE to use for libfabric provider") + +NVSHMEMI_ENV_DEF(LIBFABRIC_PROXY_REQUEST_BATCH_MAX, int, 32, NVSHMEMI_ENV_CAT_TRANSPORT, + "Maximum number of requests that the libfabric transport processes per queue " + "in a single iteration of the progress loop.") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/ibdevx/ibdevx.cpp b/src/modules/transport/ibdevx/ibdevx.cpp index edc11086..b65d4740 100644 --- a/src/modules/transport/ibdevx/ibdevx.cpp +++ b/src/modules/transport/ibdevx/ibdevx.cpp @@ -1440,46 +1440,6 @@ int nvshmemt_ibdevx_amo(struct nvshmem_transport *tcurr, int pe, void *curetptr, return status; } -int nvshmemt_ibdevx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - nvshmemt_ib_common_state_t ibdevx_state = (nvshmemt_ib_common_state_t)tcurr->state; - struct ibdevx_ep *ep = (struct ibdevx_ep *)ibdevx_state->cst_ep; - struct ibdevx_rw_wqe *wqe; - - int status = 0; - - uintptr_t wqe_bb_idx_64 = ep->wqe_bb_idx; - uint32_t wqe_bb_idx_32 = ep->wqe_bb_idx; - size_t wqe_size; - - wqe = (struct ibdevx_rw_wqe *)((char *)ep->wq_buf + - ((wqe_bb_idx_64 % get_ibdevx_qp_depth(ibdevx_state)) - << NVSHMEMT_IBDEVX_WQE_BB_SHIFT)); - wqe_size = sizeof(struct ibdevx_rw_wqe); - memset(wqe, 0, sizeof(struct ibdevx_rw_wqe)); - - wqe->ctrl.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; - wqe->ctrl.qpn_ds = - htobe32((uint32_t)(wqe_size / NVSHMEMT_IBDEVX_MLX5_SEND_WQE_DS) | ep->qpid << 8); - wqe->ctrl.opmod_idx_opcode = htobe32(MLX5_OPCODE_RDMA_READ | (wqe_bb_idx_32 << 8)); - - wqe->raddr.raddr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - wqe->raddr.rkey = htobe32(local_dummy_mr.rkey); - - wqe->data.data_seg.byte_count = htobe32((uint32_t)4); - wqe->data.data_seg.lkey = htobe32(local_dummy_mr.lkey); - wqe->data.data_seg.addr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - - assert(wqe_size <= MLX5_SEND_WQE_BB); - ep->wqe_bb_idx++; - nvshmemt_ibdevx_post_send(ep, (void *)wqe, 1); - - status = nvshmemt_ib_common_check_poll_avail(tcurr, ep, NVSHMEMT_IB_COMMON_WAIT_ALL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "check_poll failed \n"); - -out: - return status; -} - // Using common fence and quiet functions from transport_ib_common int nvshmemt_ibdevx_ep_create(struct ibdevx_ep **ep, int devid, nvshmem_transport_t t, @@ -1922,7 +1882,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ibdevx_finalize; transport->host_ops.show_info = nvshmemt_ibdevx_show_info; transport->host_ops.progress = nvshmemt_ibdevx_progress; - transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 115f6b66..d73232ed 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -4903,7 +4903,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.amo = NULL; transport->host_ops.fence = NULL; transport->host_ops.quiet = NULL; - transport->host_ops.enforce_cst = NULL; transport->host_ops.add_device_remote_mem_handles = nvshmemt_ibgda_add_device_remote_mem_handles; transport->host_ops.put_signal = NULL; diff --git a/src/modules/transport/ibrc/ibrc.cpp b/src/modules/transport/ibrc/ibrc.cpp index f7c9ce06..b0fdddf1 100644 --- a/src/modules/transport/ibrc/ibrc.cpp +++ b/src/modules/transport/ibrc/ibrc.cpp @@ -1810,7 +1810,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.progress = nvshmemt_ibrc_progress; transport->host_ops.put_signal = nvshmemt_put_signal; - transport->host_ops.enforce_cst = nvshmemt_ibrc_enforce_cst_at_target; #if !defined(NVSHMEM_PPC64LE) && !defined(NVSHMEM_AARCH64) if (!use_gdrcopy) #endif diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 70bce5a5..80b362b0 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -62,11 +62,15 @@ static bool use_gdrcopy = false; #endif #define MAX_COMPLETIONS_PER_CQ_POLL 300 +#define MAX_COMPLETIONS_PER_CQ_POLL_EFA 32 #define NVSHMEM_STAGED_AMO_WIREDATA_SIZE \ sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + static bool use_staged_atomics = false; -threadSafeOpQueue nvshmemtLibfabricOpQueue; +static bool use_auto_progress = false; + std::recursive_mutex gdrRecvMutex; typedef enum { @@ -88,8 +92,7 @@ typedef enum { NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_ENFORCE_CST } nvshmemt_libfabric_try_again_call_site_t; -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport); -int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport); +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index); int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr); @@ -99,33 +102,117 @@ static nvshmemt_libfabric_imm_cq_data_hdr_t nvshmemt_get_write_with_imm_hdr(uint NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); } -static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_endpoint_t *ep, - struct fi_cq_data_entry *entry) { +static inline nvshmemt_libfabric_endpoint_t *nvshmemt_libfabric_get_next_ep( + nvshmemt_libfabric_state_t *state, int qp_index) { + int selected_ep; + + if (qp_index == NVSHMEMX_QP_HOST) { + selected_ep = 0; + } else { + /* + * Return the current EP, and increment the next EP in round robin fashion + * between 1 and state->num_selected_domains - 1. state->cur_proxy_ep_index + * is initialized to 1. This round-robin goes through the proxy EP's and + * ignores the host EP. + */ + selected_ep = state->cur_proxy_ep_index; + state->cur_proxy_ep_index = (state->cur_proxy_ep_index + 1) % state->num_selected_domains; + if (!state->cur_proxy_ep_index) state->cur_proxy_ep_index = 1; + } + + return state->eps[selected_ep]; +} + +static inline int convert_addr_to_pe(nvshmemt_libfabric_state_t *state, + nvshmemt_libfabric_endpoint_t *ep, + fi_addr_t addr) +{ + // addr = pe * libfabric_state->num_selected_domains + ep->domain_index + // so + // pe = (addr - ep->domain_index) / (libfabric_state->num_selected_domains) + int base_ep_index = addr - ep->domain_index; + assert((base_ep_index % state->num_selected_domains) == 0); + + return base_ep_index / state->num_selected_domains; +} + +static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_state_t *state, + nvshmemt_libfabric_endpoint_t *ep, + struct fi_cq_data_entry *entry, + fi_addr_t addr) { uint32_t seq_num = entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; if (seq_num != NVSHMEM_STAGED_AMO_SEQ_NUM) { - ep->put_signal_seq_counter.return_acked_seq_num(seq_num); + /* Use host_signal_state for eps[0], proxy_signal_state for eps[1+] */ + nvshmemt_libfabric_signal_state_t *signal_state = + (ep->domain_index == 0) ? &state->host_signal_state : &state->proxy_signal_state; + + int pe = convert_addr_to_pe(state, ep, addr); + nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = + nvshmemt_get_write_with_imm_hdr(entry->data); + + if (imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK) { + (*signal_state->put_signal_seq_counter_per_pe)[pe] + .return_acked_seq_num_range_for_put(seq_num); + } else { + (*signal_state->put_signal_seq_counter_per_pe)[pe] + .return_acked_seq_num(seq_num); + } } ep->completed_staged_atomics++; } +static inline bool is_signal_only_op(nvshmemi_amo_t op) { + return (op == NVSHMEMI_AMO_SIGNAL || op == NVSHMEMI_AMO_SIGNAL_SET || + op == NVSHMEMI_AMO_SIGNAL_ADD); +} + +inline int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, + nvshmemt_libfabric_gdr_op_ctx_t *op) { + nvshmemt_libfabric_gdr_ret_amo_op_t *ret = &op->ret_amo; + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + nvshmemt_libfabric_memhandle_info_t *handle_info; + g_elem_t *elem; + void *valid_cpu_ptr; + + handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get( + transport, libfabric_state->cache, ret->ret_addr); + if (!handle_info) { + NVSHMEMI_ERROR_PRINT("Unable to get handle info for atomic response.\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + valid_cpu_ptr = + (void *)((char *)handle_info->cpu_ptr + ((char *)ret->ret_addr - (char *)handle_info->ptr)); + assert(valid_cpu_ptr); + elem = (g_elem_t *)valid_cpu_ptr; + elem->data = ret->elem.data; + elem->flag = ret->elem.flag; + + return 0; +} + static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr) { int status = 0; nvshmemt_libfabric_gdr_op_ctx_t *op; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; /* Write w/imm doesn't have op->op_context, must be checked first */ if (entry->flags & FI_REMOTE_CQ_DATA) { nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = nvshmemt_get_write_with_imm_hdr(entry->data); - if (NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ == imm_header) { + if (NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ == imm_header) { status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); goto out; - } else if (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK == imm_header) { - nvshmemt_libfabric_put_signal_ack_completion(ep, entry); + } else if (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK == imm_header) { + nvshmemt_libfabric_put_signal_ack_completion(state, ep, entry, *addr); goto out; } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -140,19 +227,28 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo op->src_addr = *addr; if (entry->flags & FI_SEND) { - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); + ep->completed_ops++; } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ - nvshmemtLibfabricOpQueue.putToSend(op); - } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { + state->op_queue[ep->domain_index]->putToSend(op); + ep->completed_ops++; + } else if ((op->type == NVSHMEMT_LIBFABRIC_MATCH) && (entry->flags & FI_RECV)) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); } else if (entry->flags & FI_RECV) { op->ep = ep; if (op->type == NVSHMEMT_LIBFABRIC_ACK) { - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); } else { - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); } } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -163,91 +259,126 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo return status; } -static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; +static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_t *ep) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int max_per_poll = (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) + ? MAX_COMPLETIONS_PER_CQ_POLL_EFA + : MAX_COMPLETIONS_PER_CQ_POLL; + char buf[max_per_poll * sizeof(struct fi_cq_data_entry)]; + fi_addr_t src_addr[max_per_poll]; + fi_addr_t *addr; + ssize_t qstatus; + struct fi_cq_data_entry *entry; int status = 0; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - uint64_t cnt = fi_cntr_readerr(libfabric_state->eps[i].counter); - - if (cnt > 0) { - NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP %d (%" PRIu64 ")\n", i, cnt); + int ret = 0; - struct fi_cq_err_entry err; - memset(&err, 0, sizeof(struct fi_cq_err_entry)); - ssize_t nerr = fi_cq_readerr(libfabric_state->eps[i].cq, &err, 0); - - if (nerr > 0) { - char str[100] = "\0"; - const char *err_str = fi_cq_strerror(libfabric_state->eps[i].cq, err.prov_errno, - err.err_data, str, 100); - NVSHMEMI_WARN_PRINT( - "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " - "info: %s\n", - err.err, fi_strerror(err.err), err_str ? err_str : "none", - strlen(str) ? str : "none"); - } else if (nerr == -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); - } else { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, - fi_strerror(-1 * nerr)); + qstatus = fi_cq_readfrom(ep->cq, buf, max_per_poll, src_addr); + /* Note - EFA provider does not support selective completions */ + if (qstatus > 0) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + entry = (struct fi_cq_data_entry *)buf; + addr = src_addr; + for (int i = 0; i < qstatus; i++, entry++, addr++) { + status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); + if (status) return NVSHMEMX_ERROR_INTERNAL; } - return err.err; + } else { + NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } + } else if (qstatus < 0 && qstatus != -FI_EAGAIN) { + /* On call to fi_cq_readerr, Libfabric requires some members of + * err_entry to be zero-initialized or point to valid data. For + * simplicity, just zero out the whole struct. + */ + struct fi_cq_err_entry err_entry = {}; - { - char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; - fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; - ssize_t qstatus; - nvshmemt_libfabric_endpoint_t *ep = &libfabric_state->eps[i]; - do { - qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); - /* Note - EFA provider does not support selective completions */ - if (qstatus > 0) { - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - struct fi_cq_data_entry *entry = (struct fi_cq_data_entry *)buf; - fi_addr_t *addr = src_addr; - for (int i = 0; i < qstatus; i++, entry++, addr++) { - status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, - addr); - if (status) return NVSHMEMX_ERROR_INTERNAL; - } - } else { - NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP %d\n", qstatus, i); - } - } - } while (qstatus > 0); - if (qstatus < 0 && qstatus != -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, - fi_strerror(qstatus * -1)); - return NVSHMEMX_ERROR_INTERNAL; - } - } - } - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - status = nvshmemt_libfabric_gdr_process_amos_ack(transport); - if (status) { + ret = fi_cq_readerr(ep->cq, &err_entry, 0); + if (ret == -FI_EAGAIN) { + return 0; + } else if (ret < 0) { + NVSHMEMI_WARN_PRINT("Unable to read from fi_cq_readerr. RC: %d. Error: %s\n", ret, fi_strerror(-ret)); return NVSHMEMX_ERROR_INTERNAL; } + + NVSHMEMI_WARN_PRINT("Received a CQE with error. RC: %d. Error: %d (%s)", err_entry.err, err_entry.prov_errno, + fi_cq_strerror(ep->cq, err_entry.prov_errno, err_entry.err_data, NULL, 0)); + return NVSHMEMX_ERROR_INTERNAL; } + return 0; } -static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { +static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { + int status; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int end_iter; + + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } + + for (int i = qp_index; i < end_iter; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + +out: + return status; +} + +static int nvshmemt_libfabric_progress(nvshmem_transport_t transport, int qp_index); +static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) { + return nvshmemt_libfabric_progress(transport, NVSHMEMT_LIBFABRIC_PROXY_EP_IDX); +} + +static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int status; + for (size_t i = 0; i < state->eps.size(); i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + +out: + return status; +} + +static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport, int qp_index) { + int status = 0; + + if (use_auto_progress) + status = nvshmemt_libfabric_auto_progress(transport, qp_index); + else + status = nvshmemt_libfabric_manual_progress(transport); + + return status; +} + + +static int nvshmemt_libfabric_progress(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; int status; - status = nvshmemt_libfabric_process_completions(transport); + status = nvshmemt_libfabric_process_completions(transport, qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } + int progress_qp_index = (use_auto_progress ? qp_index : NVSHMEMX_QP_ALL); + if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport); - gdrRecvMutex.unlock(); + status = nvshmemt_libfabric_gdr_process_amos(transport, progress_qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } + gdrRecvMutex.unlock(); } } @@ -255,6 +386,7 @@ static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { } static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries, + int qp_index, nvshmemt_libfabric_try_again_call_site_t call_site, bool completions_only = false) { if (likely(*status == 0)) { @@ -270,9 +402,9 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t } (*num_retries)++; if (completions_only) { - *status = nvshmemt_libfabric_process_completions(transport); + *status = nvshmemt_libfabric_process_completions(transport, qp_index); } else { - *status = nvshmemt_libfabric_progress(transport); + *status = nvshmemt_libfabric_progress(transport, qp_index); } } @@ -286,24 +418,45 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t return 1; } +static inline int get_next_seq_num_with_retry(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_seq_counter_t &seq_counter, + uint32_t *sequence_count, + int qp_index, + nvshmemt_libfabric_try_again_call_site_t call_site) { + uint64_t num_retries = 0; + int status; + do { + int32_t seq_num = seq_counter.next_seq_num(); + if (seq_num < 0) { + status = -EAGAIN; + } else { + *sequence_count = seq_num; + status = 0; + } + } while (try_again(transport, &status, &num_retries, qp_index, call_site)); + + return status; +} + int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, fi_addr_t dest_addr, uint32_t sequence_count, int pe, - nvshmemt_libfabric_gdr_op_ctx_t **send_elems) { + nvshmemt_libfabric_gdr_op_ctx_t **send_elems, + nvshmemt_libfabric_imm_cq_data_hdr_t ack_header) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *resp_op = NULL; uint64_t num_retries = 0; int status; uint64_t imm_data = 0; + uint64_t rkey_index = pe * libfabric_state->num_selected_domains + ep->domain_index; resp_op = send_elems[0]; - imm_data = (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK - << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | - sequence_count; + imm_data = (ack_header << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; do { - status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, - dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], - libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, + status = fi_writedata( + ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr[ep->domain_index]), imm_data, + dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], + libfabric_state->rkey_staged_amo_ack[rkey_index], &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, ep->domain_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDRCOPY_AMO_ACK, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); @@ -400,7 +553,7 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op /* Post recv before posting TX operations to avoid deadlocks */ status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to re-post recv.\n"); if (is_fetch_amo) { @@ -412,9 +565,10 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op resp_op->type = NVSHMEMT_LIBFABRIC_ACK; do { - status = fi_send(ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, + status = fi_send(op->ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), op->src_addr, + &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PERFORM_GDRCOPY_AMO_SEND, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); @@ -423,7 +577,8 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op } status = gdrcopy_amo_ack(transport, ep, src_addr, sequence_count, src_pe, - &send_elems[send_elems_index]); + &send_elems[send_elems_index], + NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK); out: return status; } @@ -453,95 +608,108 @@ int nvshmemt_libfabric_gdr_process_amo(nvshmem_transport_t transport, return status; } -int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, - nvshmemt_libfabric_gdr_op_ctx_t *op) { - nvshmemt_libfabric_gdr_ret_amo_op_t *ret = &op->ret_amo; +int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_memhandle_info_t *handle_info; - g_elem_t *elem; - void *valid_cpu_ptr; + nvshmemt_libfabric_gdr_op_ctx_t *op; + nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; + int end_iter; + int status = 0; - handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get( - transport, libfabric_state->cache, ret->ret_addr); - if (!handle_info) { - NVSHMEMI_ERROR_PRINT("Unable to get handle info for atomic response.\n"); - return NVSHMEMX_ERROR_INTERNAL; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; } - valid_cpu_ptr = - (void *)((char *)handle_info->cpu_ptr + ((char *)ret->ret_addr - (char *)handle_info->ptr)); - assert(valid_cpu_ptr); - elem = (g_elem_t *)valid_cpu_ptr; - elem->data = ret->elem.data; - elem->flag = ret->elem.flag; - return 0; -} + for (int i = qp_index; i < end_iter; i++) { + int ops_processed = 0; -int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_gdr_op_ctx_t *op; - nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; - size_t num_retries = 0; - int status = 0; - do { + size_t num_retries = 0; do { - status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, - NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); - } while (try_again(transport, &status, &num_retries, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_ACK, - true)); - num_retries = 0; - - if (op) { - status = nvshmemt_libfabric_gdr_process_ack(transport, op); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to re-post recv.\n"); - } - } while (op); + do { + status = libfabric_state->op_queue[i]->getNextAmoOps(send_elems, &op, + NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); + } while (try_again(transport, &status, &num_retries, i, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_ACK, + true)); + num_retries = 0; + + if (op) { + ops_processed++; + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); + } + } while (op && ops_processed < libfabric_state->proxy_request_batch_max); + + } out: return status; } -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport) { +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index) { + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; size_t num_retries = 0; int status = 0; + int end_iter; + + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } + + for (int i = qp_index; i < end_iter; i++) { + int ops_processed = 0; - do { do { - status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, - NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); - } while (try_again(transport, &status, &num_retries, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_NOT_ACK, - true)); - num_retries = 0; - - if (op) { - if (op->type == NVSHMEMT_LIBFABRIC_SEND) { - assert(send_elems[0] != NULL); - assert(send_elems[1] != NULL); - status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, - NVSHMEM_STAGED_AMO_SEQ_NUM); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - /* Reposts recv in perform_gdrcopy_amo() */ - } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { - assert(send_elems[0] != NULL); - assert(send_elems[1] != NULL); - status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, - op->send_amo.sequence_count); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - /* Reposts recv in perform_gdrcopy_amo() */ + do { + status = libfabric_state->op_queue[i]->getNextAmoOps(send_elems, &op, + NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + } while (try_again(transport, &status, &num_retries, i, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_NOT_ACK, + true)); + num_retries = 0; + + if (op) { + ops_processed++; + if (op->type == NVSHMEMT_LIBFABRIC_SEND) { + assert(send_elems[0] != NULL); + assert(send_elems[1] != NULL); + status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, + NVSHMEM_STAGED_AMO_SEQ_NUM); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + /* Reposts recv in perform_gdrcopy_amo() */ + } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { + assert(send_elems[0] != NULL); + assert(send_elems[1] != NULL); + status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, + op->send_amo.sequence_count); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + /* Reposts recv in perform_gdrcopy_amo() */ + } } - } - } while (op); + } while (op && ops_processed < libfabric_state->proxy_request_batch_max); + + } + out: return status; } @@ -571,12 +739,19 @@ nvshmemt_libfabric_gdr_op_ctx_t *nvshmemt_inplace_copy_sig_op_to_gdr_op( int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr) { + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_signal_op *sig_op = NULL; nvshmemt_libfabric_gdr_op_ctx_t *op = NULL; bool is_write_comp = entry->flags & FI_REMOTE_CQ_DATA; - int status = 0, progress_count; + int status = 0, progress_count, pe; uint64_t map_key; - std::unordered_map>::iterator iter; + bool is_standalone_put = false; + std::unordered_map::iterator iter; + + /* Use host_signal_state for eps[0], proxy_signal_state for eps[1+] */ + nvshmemt_libfabric_signal_state_t *signal_state = + (ep->domain_index == 0) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; if (unlikely(*addr == FI_ADDR_NOTAVAIL)) { status = -1; @@ -584,13 +759,19 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, "Write w/imm returned with invalid src address.\n"); } + pe = convert_addr_to_pe(libfabric_state, ep, *addr); + if (is_write_comp) { - map_key = *addr << 32 | (uint32_t)entry->data; + nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = + nvshmemt_get_write_with_imm_hdr(entry->data); + is_standalone_put = (imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT || + imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ); + map_key = (((uint64_t)pe) << 32) | ((uint32_t)entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK); progress_count = -1; } else { sig_op = (nvshmemt_libfabric_gdr_signal_op *)container_of( entry->op_context, nvshmemt_libfabric_gdr_op_ctx_t, ofi_context); - map_key = *addr << 32 | sig_op->sequence_count; + map_key = (((uint64_t)pe) << 32) | sig_op->sequence_count; progress_count = (int)sig_op->num_writes; /* The EFA provider has an inline send size of 32 bytes. @@ -603,23 +784,81 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, op = nvshmemt_inplace_copy_sig_op_to_gdr_op(sig_op, ep); } - iter = ep->proxy_put_signal_comp_map->find(map_key); - if (iter != ep->proxy_put_signal_comp_map->end()) { - if (!is_write_comp) iter->second.first = op; - iter->second.second += progress_count; + if (is_write_comp && nvshmemt_get_write_with_imm_hdr(entry->data) == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ) { + nvshmemt_libfabric_comp_entry_t ack_comp_entry; + ack_comp_entry.type = NVSHMEMT_LIBFABRIC_COMP_ENTRY_PUT_ACK; + ack_comp_entry.ack_entry.src_addr = *addr; + ack_comp_entry.ack_entry.ep = ep; + signal_state->proxy_put_signal_comp_map->insert(std::make_pair(map_key, ack_comp_entry)); } else { - iter = ep->proxy_put_signal_comp_map - ->insert(std::make_pair(map_key, std::make_pair(op, progress_count))) - .first; - } + iter = signal_state->proxy_put_signal_comp_map->find(map_key); + if (iter != signal_state->proxy_put_signal_comp_map->end()) { + if (!is_write_comp) iter->second.signal_entry.op = op; + iter->second.signal_entry.progress_count += progress_count; + } else { + nvshmemt_libfabric_comp_entry_t sig_comp_entry; + sig_comp_entry.type = NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL; + if (is_standalone_put) { + sig_comp_entry.signal_entry.op = nullptr; + sig_comp_entry.signal_entry.progress_count = 0; + } else { + sig_comp_entry.signal_entry.op = op; + sig_comp_entry.signal_entry.progress_count = progress_count; + } + signal_state->proxy_put_signal_comp_map->insert(std::make_pair(map_key, sig_comp_entry)); + iter = signal_state->proxy_put_signal_comp_map->find(map_key); + } - if (!iter->second.second) { - if (is_write_comp) { - op = iter->second.first; + if (iter->second.signal_entry.progress_count != 0) { + goto out; } + } + + { + fi_addr_t src_addr = *addr; + // operator[] will default-construct (initialize to 0) if src_addr doesn't exist + uint32_t &next_seq = (*signal_state->next_expected_seq)[pe]; + + while (true) { + // Skip reserved sequence number + if (next_seq == NVSHMEM_STAGED_AMO_SEQ_NUM) { + next_seq = (next_seq + 1) & nvshmemt_libfabric_endpoint_seq_counter_t::sequence_mask; + continue; + } - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); - ep->proxy_put_signal_comp_map->erase(iter); + uint64_t key = (((uint64_t)pe) << 32) | next_seq; + auto it = signal_state->proxy_put_signal_comp_map->find(key); + + if (it == signal_state->proxy_put_signal_comp_map->end()) break; + + if (it->second.type == NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL) { + if (it->second.signal_entry.progress_count != 0) break; + + if (it->second.signal_entry.op != NULL) { + auto op_ep = it->second.signal_entry.op->ep; + libfabric_state->op_queue[op_ep->domain_index]->putToRecv( + it->second.signal_entry.op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + } + } else { + nvshmemt_libfabric_endpoint_t *ack_ep = it->second.ack_entry.ep; + nvshmemt_libfabric_gdr_op_ctx_t *send_elem; + uint64_t num_retries = 0; + do { + status = libfabric_state->op_queue[ack_ep->domain_index]->getNextSends( + (void **)(&send_elem), 1); + } while (try_again(transport, &status, &num_retries, ack_ep->domain_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDRCOPY_AMO_ACK, true)); + + if (status == 0) { + status = gdrcopy_amo_ack(transport, ack_ep, it->second.ack_entry.src_addr, next_seq, pe, + &send_elem, NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK); + } + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "gdrcopy_amo_ack failed\n"); + } + + signal_state->proxy_put_signal_comp_map->erase(it); + next_seq = (next_seq + 1) & nvshmemt_libfabric_endpoint_seq_counter_t::sequence_mask; + } } out: @@ -627,49 +866,33 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, } static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - nvshmemt_libfabric_endpoint_t *ep; - int is_proxy = qp_index != NVSHMEMX_QP_HOST; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; + bool all_nics_quieted; int status = 0; + int end_iter; - if (is_proxy) { - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; } else { - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; } - if (likely(libfabric_state->prov_info->domain_attr->control_progress == FI_PROGRESS_MANUAL) || - (libfabric_state->prov_info->domain_attr->data_progress == FI_PROGRESS_MANUAL) || - (use_staged_atomics == true) -#ifdef NVSHMEM_USE_GDRCOPY - || (use_gdrcopy == true) -#endif - ) { - uint64_t submitted, completed; - for (;;) { - completed = fi_cntr_read(ep->counter); - submitted = ep->submitted_ops; - if (completed + ep->completed_staged_atomics == submitted) - break; - else { - if (nvshmemt_libfabric_progress(tcurr)) { + for (;;) { + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + if (state->eps[i]->submitted_ops != state->eps[i]->completed_ops) { + all_nics_quieted = false; + if (nvshmemt_libfabric_progress(tcurr, qp_index)) { status = NVSHMEMX_ERROR_INTERNAL; break; } } } - } else { - status = fi_cntr_wait(ep->counter, ep->submitted_ops, NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); - if (status) { - /* note - Status is negative for this function in error cases but - * fi_strerror only accepts positive values. - */ - NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, - fi_strerror(status * -1)); - status = NVSHMEMX_ERROR_INTERNAL; - } + if (status || all_nics_quieted) break; } + return status; } @@ -687,40 +910,34 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy, - uint32_t *imm_data) { + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL; void *local_mr_desc = NULL; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; struct iovec p_op_l_iov; struct fi_msg_rma p_op_msg; struct fi_rma_iov p_op_r_iov; - nvshmemt_libfabric_endpoint_t *ep; size_t op_size; uint64_t num_retries = 0; int status = 0; int target_ep; - int ep_idx = 0; void *context = NULL; memset(&p_op_l_iov, 0, sizeof(struct iovec)); memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } + /* put_signal passes in EP to ensure that both operations go through same EP */ + if (!ep) ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&gdr_ctx), 1); - } while (try_again(tcurr, &status, &num_retries, + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&gdr_ctx), 1); + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get context buffer for put request.\n"); @@ -728,26 +945,26 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, /* local->handle may be NULL for small operations (P ops) sent by value/inline */ if (likely(local->handle != NULL)) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep_idx]; + local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep->domain_index]; local_mr_desc = local_handle->local_desc; } } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep_idx]; + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep->domain_index]; op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { - assert(!imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_P on Libfabric transport if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *p_buf = container_of(context, nvshmemt_libfabric_gdr_op_ctx_t, ofi_context); num_retries = 0; + p_buf->p_op.value = *(uint64_t *)local->ptr; + assert(imm_data); // EFA provider requires immediate data for p/put do { - p_buf->p_op.value = *(uint64_t *)local->ptr; - status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, - fi_mr_desc(libfabric_state->mr), target_ep, - (uintptr_t)remote->ptr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + status = fi_writedata(ep->endpoint, &p_buf->p_op.value, op_size, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), *imm_data, target_ep, + (uintptr_t)remote->ptr, remote_handle->key, context); + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_EFA)); } else { p_op_msg.msg_iov = &p_op_l_iov; @@ -760,7 +977,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, p_op_l_iov.iov_base = local->ptr; p_op_l_iov.iov_len = op_size; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & + FI_MR_VIRT_ADDR) p_op_r_iov.addr = (uintptr_t)remote->ptr; else p_op_r_iov.addr = (uintptr_t)remote->offset; @@ -772,12 +990,12 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, */ do { status = fi_writemsg(ep->endpoint, &p_op_msg, FI_INJECT); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_NON_EFA)); } } else if (verb.desc == NVSHMEMI_OP_PUT) { uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -790,13 +1008,13 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, else status = fi_write(ep->endpoint, local->ptr, op_size, local_mr_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_PUT)); } else if (verb.desc == NVSHMEMI_OP_G || verb.desc == NVSHMEMI_OP_GET) { assert( !imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_G/GET on Libfabric transport uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -804,7 +1022,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { status = fi_read(ep->endpoint, local->ptr, op_size, local_mr_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_GET)); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -824,32 +1042,89 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, is_proxy, NULL); + rma_bytesdesc_t bytesdesc, int qp_index) { + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + uint32_t imm_data_val = 0; + uint32_t *imm_data = NULL; + int status; + nvshmemt_libfabric_endpoint_t *ep = nullptr; + + // Generate sequence number for P and PUT operations when ordering is needed + if (use_staged_atomics && + (verb.desc == NVSHMEMI_OP_P || verb.desc == NVSHMEMI_OP_PUT)) { + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; + uint32_t sequence_count; + + status = get_next_seq_num_with_retry(tcurr, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_PUT); + if (status) return status; + + seq_counter.put_count++; + + nvshmemt_libfabric_imm_cq_data_hdr_t header; + if (seq_counter.put_count >= NVSHMEM_STAGED_AMO_PUT_ACK_FREQ) { + header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ; + seq_counter.put_count = 0; + } else { + header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT; + } + + imm_data_val = (header << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; + imm_data = &imm_data_val; + } + + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, imm_data, + ep); } +static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, + void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, + amo_bytesdesc_t bytesdesc, int qp_index, + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep); + static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy) { + amo_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; - nvshmemt_libfabric_gdr_op_ctx_t *amo; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; + + /* Signal-only operations use gdr_signal path with num_writes=0 */ + if (is_signal_only_op(verb.desc)) { + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; + uint32_t sequence_count; + status = get_next_seq_num_with_retry(transport, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_GET_NEXT_SENDS); + if (status) goto out; + + seq_counter.put_count = 0; + + status = nvshmemt_libfabric_gdr_signal(transport, pe, curetptr, verb, remote, + bytesdesc, qp_index, sequence_count, 0, ep); + goto out; } - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; - + /* Fetch operations use full gdr_op_ctx_t */ + nvshmemt_libfabric_gdr_op_ctx_t *amo; do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&amo), 1); - } while (try_again(transport, &status, &num_retries, + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&amo), 1); + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(amo, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to retrieve AMO operation."); @@ -861,21 +1136,22 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p amo->send_amo.swap_add = remote->val; amo->send_amo.size = bytesdesc.elembytes; amo->send_amo.src_pe = transport->my_pe; - amo->type = NVSHMEMT_LIBFABRIC_SEND; amo->send_amo.comp = remote->cmp; + amo->type = NVSHMEMT_LIBFABRIC_SEND; num_retries = 0; do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); - } while (try_again(transport, &status, &num_retries, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &amo->ofi_context); + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_SEND)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post an AMO operation.\n"); status = NVSHMEMX_ERROR_INTERNAL; } else { - ep->submitted_ops += 2; + ep->submitted_ops++; } out: @@ -884,7 +1160,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, - int is_proxy) { + int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_mem_handle_ep_t *remote_handle = NULL, *local_handle = NULL; nvshmemt_libfabric_endpoint_t *ep; @@ -898,7 +1174,6 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v uint64_t num_retries = 0; int target_ep; int status = 0; - int ep_idx; memset(&amo_msg, 0, sizeof(struct fi_msg_atomic)); memset(&fi_local_iov, 0, sizeof(struct fi_ioc)); @@ -906,19 +1181,14 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; remote_handle = - &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep_idx]; + &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep->domain_index]; if (verb.desc > NVSHMEMI_AMO_END_OF_NONFETCH) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep_idx]; + local_handle = + &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep->domain_index]; } if (bytesdesc.elembytes == 8) { @@ -985,7 +1255,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v amo_msg.addr = target_ep; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.ptr; else fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.offset; @@ -1020,7 +1290,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v status = fi_fetch_atomicmsg(ep->endpoint, &amo_msg, &fi_ret_iov, &local_handle->local_desc, 1, FI_INJECT); } - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_AMO_ATOMICMSG)); if (status) goto out; // Status set by try_again @@ -1037,30 +1307,24 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy, - uint32_t sequence_count, uint16_t num_writes) { + amo_bytesdesc_t bytesdesc, int qp_index, + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&context), 1); - } while (try_again(transport, &status, &num_retries, + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&context), 1); + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(context, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to retrieve signal operation buffer."); @@ -1077,15 +1341,15 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in num_retries = 0; do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), - fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); - } while (try_again(transport, &status, &num_retries, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, &context->ofi_context); + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_SEND)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post a signal operation.\n"); status = NVSHMEMX_ERROR_INTERNAL; } else { - ep->submitted_ops += 2; + ep->submitted_ops++; } out: @@ -1097,44 +1361,35 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_local, std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, - amo_bytesdesc_t sig_bytes_desc, int is_proxy) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + amo_bytesdesc_t sig_bytes_desc, int qp_index) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx; + nvshmemt_libfabric_endpoint_t *ep = nvshmemt_libfabric_get_next_ep(state, qp_index); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; + /* Get or create sequence counter for this destination fi_addr_t */ + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &state->host_signal_state + : &state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; /* Get sequence number for this put-signal, with retry */ - uint64_t num_retries = 0; - do { - int32_t seq_num = ep.put_signal_seq_counter.next_seq_num(); - if (seq_num < 0) { - status = -EAGAIN; - } else { - sequence_count = seq_num; - status = 0; - } - } while (try_again(tcurr, &status, &num_retries, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PUT_SIGNAL_UNORDERED_SEQ)); - + status = get_next_seq_num_with_retry(tcurr, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PUT_SIGNAL_UNORDERED_SEQ); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT("Error in nvshmemt_put_signal_unordered while waiting for category\n"); goto out; } + seq_counter.put_count = 0; + assert(write_remote.size() == write_local.size() && write_local.size() == write_bytes_desc.size()); for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], is_proxy, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count, ep); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1143,8 +1398,9 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v } assert(use_staged_atomics == true); - status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - is_proxy, sequence_count, (uint16_t)write_remote.size()); + status = + nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, + qp_index, sequence_count, (uint16_t)write_remote.size(), ep); out: if (status) { NVSHMEMI_ERROR_PRINT( @@ -1155,77 +1411,12 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v return status; } -static int nvshmemt_libfabric_enforce_cst(struct nvshmem_transport *tcurr) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - uint64_t num_retries = 0; - int status; - int target_ep; - int mype = tcurr->my_pe; - -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - if (libfabric_state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - int temp; - nvshmemt_libfabric_memhandle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, 0); - if (!mem_handle_info) { - goto skip; - } - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - } - } - -skip: -#endif - - target_ep = mype * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - do { - struct fi_msg_rma msg; - struct iovec l_iov; - struct fi_rma_iov r_iov; - void *desc = libfabric_state->local_mr_desc[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - uint64_t flags = 0; - - memset(&msg, 0, sizeof(struct fi_msg_rma)); - memset(&l_iov, 0, sizeof(struct iovec)); - memset(&r_iov, 0, sizeof(struct fi_rma_iov)); - - l_iov.iov_base = libfabric_state->local_mem_ptr; - l_iov.iov_len = 8; - - r_iov.addr = 0; // Zero offset - r_iov.len = 8; - r_iov.key = libfabric_state->local_mr_key[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - - msg.msg_iov = &l_iov; - msg.desc = &desc; - msg.iov_count = 1; - msg.rma_iov = &r_iov; - msg.rma_iov_count = 1; - msg.context = NULL; - msg.data = 0; - - if (libfabric_state->prov_info->caps & FI_FENCE) flags |= FI_FENCE; - - status = - fi_readmsg(libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX].endpoint, &msg, flags); - } while (try_again(tcurr, &status, &num_retries, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_ENFORCE_CST)); - - libfabric_state->eps[target_ep].submitted_ops++; - return status; -} - static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handle, nvshmem_transport_t t) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_mem_handle_t *fabric_handle; void *curr_ptr; - int max_reg, status = 0; + int status = 0; assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; @@ -1255,18 +1446,10 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl } } - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) - max_reg = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - else - max_reg = 1; - - for (int i = 0; i < max_reg; i++) { - if (libfabric_state->local_mr[i] == fabric_handle->hdls[i].mr) - libfabric_state->local_mr[i] = NULL; - + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { - NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, + NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %zu (%d): %s\n", i, status, fi_strerror(status * -1)); } } @@ -1275,6 +1458,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl return status; } +static_assert(sizeof(nvshmemt_libfabric_mem_handle_t) < sizeof(nvshmem_mem_handle_t)); static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, void *buf, size_t length, nvshmem_transport_t t, bool local_only) { @@ -1310,6 +1494,7 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; + fabric_handle->buf = buf; status = cudaPointerGetAttributes(&attr, buf); if (status != cudaSuccess) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, @@ -1340,40 +1525,15 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v mr_attr.iface = FI_HMEM_SYSTEM; } - fabric_handle->buf = buf; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = - fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error registering memory region: %s\n", - fi_strerror(status * -1)); - - status = - fi_mr_bind(fabric_handle->hdls[i].mr, &libfabric_state->eps[i].endpoint->fid, 0); - - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error binding MR to EP %d: %s\n", i, fi_strerror(status * -1)); - - status = fi_mr_enable(fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error enabling MR: %s\n", - fi_strerror(status * -1)); - - fabric_handle->hdls[i].key = fi_mr_key(fabric_handle->hdls[i].mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(fabric_handle->hdls[i].mr); - } - } else { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { struct fid_mr *mr; - - status = fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &mr); + status = fi_mr_regattr(libfabric_state->domains[i], &mr_attr, 0, &mr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error registering memory region: %s\n", fi_strerror(status * -1)); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - fabric_handle->hdls[i].mr = mr; - fabric_handle->hdls[i].key = fi_mr_key(mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); - } + fabric_handle->hdls[i].mr = mr; + fabric_handle->hdls[i].key = fi_mr_key(mr); + fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); } if (!local_only && libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { @@ -1443,15 +1603,6 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v } while (curr_ptr < (char *)buf + length); } - if (libfabric_state->local_mr[0] == NULL && !local_only) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - libfabric_state->local_mr[i] = fabric_handle->hdls[i].mr; - libfabric_state->local_mr_key[i] = fabric_handle->hdls[i].key; - libfabric_state->local_mr_desc[i] = fabric_handle->hdls[i].local_desc; - } - libfabric_state->local_mem_ptr = buf; - } - out: if (status) { if (handle_info) { @@ -1516,179 +1667,224 @@ static int get_pci_path(int dev, char **pci_path, nvshmem_transport_t t) { return status; } +static void nvshmemt_libfabric_cleanup_signal_ordering_state(nvshmemt_libfabric_state_t *state) +{ + if (state->host_signal_state.put_signal_seq_counter_per_pe) { + delete state->host_signal_state.put_signal_seq_counter_per_pe; + state->host_signal_state.put_signal_seq_counter_per_pe = nullptr; + } + if (state->host_signal_state.proxy_put_signal_comp_map) { + delete state->host_signal_state.proxy_put_signal_comp_map; + state->host_signal_state.proxy_put_signal_comp_map = nullptr; + } + if (state->host_signal_state.next_expected_seq) { + delete state->host_signal_state.next_expected_seq; + state->host_signal_state.next_expected_seq = nullptr; + } + if (state->proxy_signal_state.put_signal_seq_counter_per_pe) { + delete state->proxy_signal_state.put_signal_seq_counter_per_pe; + state->proxy_signal_state.put_signal_seq_counter_per_pe = nullptr; + } + if (state->proxy_signal_state.proxy_put_signal_comp_map) { + delete state->proxy_signal_state.proxy_put_signal_comp_map; + state->proxy_signal_state.proxy_put_signal_comp_map = nullptr; + } + if (state->proxy_signal_state.next_expected_seq) { + delete state->proxy_signal_state.next_expected_seq; + state->proxy_signal_state.next_expected_seq = nullptr; + } +} + static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *selected_dev_ids, int num_selected_devs, int *out_qp_indices, int num_qps) { nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_ep_name_t *all_ep_names = NULL; nvshmemt_libfabric_ep_name_t *local_ep_names = NULL; - struct fi_info *current_fabric; + struct fi_info *current_info; + struct fid_fabric *fabric; + struct fid_domain *domain; + struct fid_av *address; + struct fid_mr *mr; struct fi_av_attr av_attr; struct fi_cq_attr cq_attr; - struct fi_cntr_attr cntr_attr; size_t ep_namelen = NVSHMEMT_LIBFABRIC_EP_LEN; int status = 0; int total_num_eps; - size_t num_recvs_per_pe = 0; + size_t num_recvs_per_ep = 0; int n_pes = t->n_pes; - - if (state->eps) { - NVSHMEMI_WARN_PRINT( - "Device already selected. libfabric only supports one NIC per PE and doesn't support " - "additional QPs.\n"); - goto out_already_connected; + size_t num_sends; + size_t num_recvs; + size_t elem_size; + uint64_t flags; + state->num_selected_devs = MIN(num_selected_devs, state->max_nic_per_pe); + + if (state->eps.size()) { + NVSHMEMI_ERROR_PRINT("PE has previously called connect_endpoints()\n"); + return NVSHMEMX_ERROR_INTERNAL; } - state->eps = (nvshmemt_libfabric_endpoint_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, - sizeof(nvshmemt_libfabric_endpoint_t)); - NVSHMEMI_NULL_ERROR_JMP(state->eps, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EPs."); - - current_fabric = state->all_prov_info; - do { - if (!strncmp(current_fabric->nic->device_attr->name, - state->domain_names[selected_dev_ids[0]].name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { - break; - } - current_fabric = current_fabric->next; - } while (current_fabric != NULL); - NVSHMEMI_NULL_ERROR_JMP(current_fabric, status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to find the selected fabric.\n"); - - state->prov_info = fi_dupinfo(current_fabric); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && - strcmp(state->prov_info->fabric_attr->name, "efa-direct")) + if (state->num_selected_devs > NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE) { + state->num_selected_devs = NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE; NVSHMEMI_WARN_PRINT( - "Libfabric transport is using efa fabric instead of efa-direct, " - "use libfabric v2.1.0 or newer for improved performance\n"); - - status = fi_fabric(state->prov_info->fabric_attr, &state->fabric, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate fabric: %d: %s\n", status, fi_strerror(status * -1)); - - status = fi_domain(state->fabric, state->prov_info, &state->domain, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate domain: %d: %s\n", status, fi_strerror(status * -1)); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->num_sends = current_fabric->tx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - state->num_recvs = current_fabric->rx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - size_t elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); - - num_recvs_per_pe = state->num_recvs / NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - - state->recv_buf = calloc(state->num_sends + state->num_recvs, elem_size); - NVSHMEMI_NULL_ERROR_JMP(state->recv_buf, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EFA msg buffer.\n"); - state->send_buf = (char *)state->recv_buf + (elem_size * state->num_recvs); - - status = fi_mr_reg(state->domain, state->recv_buf, - (state->num_sends + state->num_recvs) * elem_size, FI_SEND | FI_RECV, 0, - 0, 0, &state->mr, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - - nvshmemtLibfabricOpQueue.putToSendBulk((char *)state->send_buf, elem_size, - state->num_sends); - } - - t->max_op_len = state->prov_info->ep_attr->max_msg_size; - av_attr.type = FI_AV_TABLE; - av_attr.rx_ctx_bits = 0; - av_attr.count = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; - av_attr.ep_per_node = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - av_attr.name = NULL; - av_attr.map_addr = NULL; - av_attr.flags = 0; - - /* Note - This is needed because EFA will only bind AVs to EPs on a 1:1 basis. - * If EFA ever lifts this requirement, we can reduce the number of AVs required. - */ - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_av_open(state->domain, &av_attr, &state->addresses[i], NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate address vector: %d: %s\n", status, - fi_strerror(status * -1)); + "PE selected %d devices, but the libfabric transport only supports a max of %d " + "devices. Continuing using %d devices.\n", + state->num_selected_devs, NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE, + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE); } + state->num_selected_domains = state->num_selected_devs + 1; - INFO(state->log_level, "Selected provider %s, fabric %s, nic %s, hmem %s", - state->prov_info->fabric_attr->prov_name, state->prov_info->fabric_attr->name, - state->prov_info->nic->device_attr->name, state->prov_info->caps & FI_HMEM ? "yes" : "no"); - - assert(state->eps); + /* Initialize configuration which only need to be set once */ + t->max_op_len = UINT64_MAX; /* Set as sential value */ + state->cur_proxy_ep_index = 1; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); - memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); - - state->prov_info->ep_attr->tx_ctx_cnt = 0; - state->prov_info->caps = FI_RMA; - if ((state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) || - (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS)) { - state->prov_info->caps |= FI_ATOMIC; - } else { - state->prov_info->caps |= FI_MSG; - state->prov_info->caps |= FI_SOURCE; - } - state->prov_info->tx_attr->op_flags = 0; - state->prov_info->tx_attr->mode = 0; - state->prov_info->rx_attr->mode = 0; - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->prov_info->mode = FI_CONTEXT2; - } else { - state->prov_info->mode = 0; - } - - state->prov_info->tx_attr->op_flags = FI_DELIVERY_COMPLETE; - - cntr_attr.events = FI_CNTR_EVENTS_COMP; - cntr_attr.wait_obj = FI_WAIT_UNSPEC; - cntr_attr.wait_set = NULL; - cntr_attr.flags = 0; - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { cq_attr.size = 16; /* CQ is only used to capture error events */ cq_attr.format = FI_CQ_FORMAT_UNSPEC; cq_attr.wait_obj = FI_WAIT_NONE; - } - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { cq_attr.format = FI_CQ_FORMAT_DATA; cq_attr.wait_obj = FI_WAIT_NONE; cq_attr.size = 32768; } - local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, + memset(&av_attr, 0, sizeof(struct fi_av_attr)); + av_attr.type = FI_AV_TABLE; + av_attr.count = state->num_selected_domains * n_pes; + + /* Find fabric info for each selected device */ + for (int dev_idx = 0; dev_idx < state->num_selected_devs; dev_idx++) { + current_info = state->all_prov_info; + do { + if (!strncmp(current_info->nic->device_attr->name, + state->domain_names[selected_dev_ids[dev_idx]].name, + NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { + break; + } + current_info = current_info->next; + } while (current_info != NULL); + NVSHMEMI_NULL_ERROR_JMP(current_info, status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to find fabric for device %d.\n", dev_idx); + + /* + * Create two domains (host/proxy domain) for the first NIC. + */ + if (state->prov_infos.size() == 0) state->prov_infos.push_back(current_info); + + state->prov_infos.push_back(current_info); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && + strcmp(current_info->fabric_attr->name, "efa-direct")) + NVSHMEMI_WARN_PRINT( + "Libfabric transport is using efa fabric instead of efa-direct, " + "use libfabric v2.1.0 or newer for improved performance\n"); + } + + /* Allocate out of band AV name exchange buffers */ + local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(state->num_selected_domains, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(local_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - total_num_eps = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; + total_num_eps = n_pes * state->num_selected_domains; all_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(total_num_eps, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(all_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); + /* Initialize state-level signal ordering state */ + state->host_signal_state.put_signal_seq_counter_per_pe = + new std::unordered_map(); + state->host_signal_state.proxy_put_signal_comp_map = + new std::unordered_map(); + state->host_signal_state.next_expected_seq = + new std::unordered_map(); + + state->proxy_signal_state.put_signal_seq_counter_per_pe = + new std::unordered_map(); + state->proxy_signal_state.proxy_put_signal_comp_map = + new std::unordered_map(); + state->proxy_signal_state.next_expected_seq = + new std::unordered_map(); + + /* Create Resources For Each Selected Device */ + for (size_t i = 0; i < state->prov_infos.size(); i++) { + INFO(state->log_level, + "Selected provider %s, fabric %s, nic %s, hmem %s multi-rail %zu/%d\n", + state->prov_infos[i]->fabric_attr->prov_name, state->prov_infos[i]->fabric_attr->name, + state->prov_infos[i]->nic->device_attr->name, + state->prov_infos[i]->caps & FI_HMEM ? "yes" : "no", i + 1, num_selected_devs); + + if (state->prov_infos[i]->ep_attr->max_msg_size < t->max_op_len) + t->max_op_len = state->prov_infos[i]->ep_attr->max_msg_size; + + status = fi_fabric(state->prov_infos[i]->fabric_attr, &fabric, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to allocate endpoint: %d: %s\n", status, + "Failed to allocate fabric: %d: %s\n", status, + fi_strerror(status * -1)); + state->fabrics.push_back(fabric); + + status = fi_domain(fabric, state->prov_infos[i], &domain, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate domain: %d: %s\n", status, + fi_strerror(status * -1)); + state->domains.push_back(domain); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + num_sends = state->prov_infos[i]->tx_attr->size; + num_recvs = state->prov_infos[i]->rx_attr->size; + elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); + num_recvs_per_ep = num_recvs; + + state->recv_buf.push_back(calloc(num_sends + num_recvs, elem_size)); + NVSHMEMI_NULL_ERROR_JMP(state->recv_buf[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate EFA msg buffer.\n"); + state->send_buf.push_back((char *)state->recv_buf[i] + (elem_size * num_recvs)); + + status = fi_mr_reg(domain, state->recv_buf[i], (num_sends + num_recvs) * elem_size, + FI_SEND | FI_RECV | FI_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->mr.push_back(mr); + + state->op_queue.push_back(new threadSafeOpQueue); + state->op_queue[i]->putToSendBulk((char *)state->send_buf[i], elem_size, num_sends); + state->op_queue[i]->set_auto_progress(use_auto_progress); + } + + status = fi_av_open(domain, &av_attr, &address, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate address vector: %d: %s\n", status, fi_strerror(status * -1)); + state->addresses.push_back(address); + + /* Create nvshmemt_libfabric_endpoint_t resources */ + state->eps.push_back( + (nvshmemt_libfabric_endpoint_t *)calloc(1, sizeof(nvshmemt_libfabric_endpoint_t))); + NVSHMEMI_NULL_ERROR_JMP(state->eps[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to alloc libfabric_tx_progress_group struct.\n"); + state->eps[i]->domain_index = i; - /* Initialize per-endpoint proxy_put_signal_comp_map */ - state->eps[i].proxy_put_signal_comp_map = - new std::unordered_map>(); + state->eps[i]->completed_staged_atomics = 0; + state->eps[i]->submitted_ops = 0; + state->eps[i]->completed_ops = 0; - state->eps[i].put_signal_seq_counter.reset(); - state->eps[i].completed_staged_atomics = 0; + status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open completion queue for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + + status = fi_endpoint(domain, state->prov_infos[i], &state->eps[i]->endpoint, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to allocate endpoint: %d: %s\n", status, + fi_strerror(status * -1)); /* FI_OPT_CUDA_API_PERMITTED was introduced in libfabric 1.18.0 */ if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { bool prohibit_cuda_api = false; - status = fi_setopt(&state->eps[i].endpoint->fid, FI_OPT_ENDPOINT, + status = fi_setopt(&state->eps[i]->endpoint->fid, FI_OPT_ENDPOINT, FI_OPT_CUDA_API_PERMITTED, &prohibit_cuda_api, sizeof(bool)); if (status == -FI_ENOPROTOOPT) { NVSHMEMI_WARN_PRINT( @@ -1702,112 +1898,83 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } } - status = fi_cq_open(state->domain, &cq_attr, &state->eps[i].cq, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open completion queue for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_cntr_open(state->domain, &cntr_attr, &state->eps[i].counter, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open counter for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_ep_bind(state->eps[i].endpoint, &state->addresses[i]->fid, 0); + /* Bind Resources To EP */ + status = fi_ep_bind(state->eps[i]->endpoint, &state->addresses[i]->fid, 0); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to address vector: %d: %s\n", status, fi_strerror(status * -1)); - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) /* EFA is documented as not supporting FI_SELECTIVE_COMPLETION */ - status = - fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, FI_TRANSMIT | FI_RECV); - } else { + flags = FI_TRANSMIT | FI_RECV; + else { NVSHMEMI_ERROR_PRINT( "Invalid provider identified. This should be impossible. " "Possible memory corruption in the state pointer?"); status = NVSHMEMX_ERROR_INTERNAL; goto out; } + + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->cq->fid, flags); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to completion queue: %d: %s\n", status, fi_strerror(status * -1)); -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, - FI_READ | FI_WRITE | FI_SEND); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } else -#endif - { - int flags = FI_READ | FI_WRITE; - if (use_staged_atomics) { - flags |= FI_SEND; - } - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, flags); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } - - status = fi_enable(state->eps[i].endpoint); + status = fi_enable(state->eps[i]->endpoint); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to enable endpoint: %d: %s\n", status, fi_strerror(status * -1)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - for (size_t j = 0; j < num_recvs_per_pe; j++) { - nvshmemt_libfabric_gdr_op_ctx_t *op; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf; - op = op + ((num_recvs_per_pe * i) + j); + nvshmemt_libfabric_gdr_op_ctx_t *op; + op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf[i]; + for (size_t j = 0; j < num_recvs_per_ep; j++, op++) { assert(op != NULL); - status = fi_recv(state->eps[i].endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + status = fi_recv(state->eps[i]->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[i]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to post recv to ep. Error: %d: %s\n", status, + fi_strerror(status * -1)); } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to post recv to ep. Error: %d: %s\n", status, - fi_strerror(status * -1)); } - status = fi_getname(&state->eps[i].endpoint->fid, local_ep_names[i].name, &ep_namelen); + status = fi_getname(&state->eps[i]->endpoint->fid, local_ep_names[i].name, &ep_namelen); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get name for endpoint: %d: %s\n", status, fi_strerror(status * -1)); - if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) { + if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Name of EP is too long."); - } } + /* Perform out of band address exchange */ status = t->boot_handle->allgather( local_ep_names, all_ep_names, - NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); + state->num_selected_domains * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather endpoint names.\n"); /* We need to insert one at a time since each buffer is larger than the address. */ - for (int j = 0; j < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; j++) { - for (int i = 0; i < total_num_eps; i++) { - status = fi_av_insert(state->addresses[j], &all_ep_names[i], 1, NULL, 0, NULL); + for (int i = 0; i < state->num_selected_domains; i++) { + for (int j = 0; j < total_num_eps; j++) { + status = fi_av_insert(state->addresses[i], &all_ep_names[j], 1, NULL, 0, NULL); if (status < 1) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to insert ep names in address vector: %d: %s\n", status, fi_strerror(status * -1)); } - status = NVSHMEMX_SUCCESS; } } + /* Out of bounds exchange a pre-registered write w/imm target for staged_amo acks */ if (use_staged_atomics) { state->remote_addr_staged_amo_ack = (void **)calloc(sizeof(void *), t->n_pes); + state->rkey_staged_amo_ack = + (uint64_t *)calloc(sizeof(uint64_t), t->n_pes * state->num_selected_domains); NVSHMEMI_NULL_ERROR_JMP(state->remote_addr_staged_amo_ack, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate remote address array for staged atomic ack.\n"); @@ -1816,13 +1983,15 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate CUDA memory for staged atomic ack.\n"); - status = fi_mr_reg(state->domain, state->remote_addr_staged_amo_ack[t->my_pe], sizeof(int), - FI_REMOTE_WRITE, 0, 0, 0, &state->mr_staged_amo_ack, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - state->rkey_staged_amo_ack = (uint64_t *)calloc(sizeof(uint64_t), t->n_pes); - state->rkey_staged_amo_ack[t->my_pe] = fi_mr_key(state->mr_staged_amo_ack); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_mr_reg(state->domains[i], state->remote_addr_staged_amo_ack[t->my_pe], + sizeof(int), FI_REMOTE_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains + i] = fi_mr_key(mr); + state->mr_staged_amo_ack.push_back(mr); + } status = t->boot_handle->allgather(&state->remote_addr_staged_amo_ack[t->my_pe], state->remote_addr_staged_amo_ack, sizeof(void *), @@ -1830,9 +1999,10 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote addresses.\n"); - status = - t->boot_handle->allgather(&state->rkey_staged_amo_ack[t->my_pe], - state->rkey_staged_amo_ack, sizeof(uint64_t), t->boot_handle); + status = t->boot_handle->allgather( + &state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains], + state->rkey_staged_amo_ack, sizeof(uint64_t) * state->num_selected_domains, + t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote keys.\n"); } @@ -1845,30 +2015,30 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele free(state->remote_addr_staged_amo_ack); } if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); - if (state->mr_staged_amo_ack) fi_close(&state->mr_staged_amo_ack->fid); - if (state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (state->eps[i].proxy_put_signal_comp_map) - delete state->eps[i].proxy_put_signal_comp_map; - if (state->eps[i].endpoint) { - fi_close(&state->eps[i].endpoint->fid); - state->eps[i].endpoint = NULL; - } - if (state->eps[i].cq) { - fi_close(&state->eps[i].cq->fid); - state->eps[i].cq = NULL; - } - if (state->eps[i].counter) { - fi_close(&state->eps[i].counter->fid); - state->eps[i].counter = NULL; - } + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) + fi_close(&state->mr_staged_amo_ack[i]->fid); + + /* Cleanup state-level signal ordering state */ + nvshmemt_libfabric_cleanup_signal_ordering_state(state); + + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->endpoint) { + fi_close(&state->eps[i]->endpoint->fid); + state->eps[i]->endpoint = NULL; + } + if (state->eps[i]->cq) { + fi_close(&state->eps[i]->cq->fid); + state->eps[i]->cq = NULL; } - free(state->eps); - state->eps = NULL; + if (state->eps[i]->counter) { + fi_close(&state->eps[i]->counter->fid); + state->eps[i]->counter = NULL; + } + free(state->eps[i]); + state->eps[i] = NULL; } } -out_already_connected: free(local_ep_names); free(all_ep_names); @@ -1876,12 +2046,12 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state; + nvshmemt_libfabric_state_t *state; int status; assert(transport); - libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + state = (nvshmemt_libfabric_state_t *)transport->state; if (transport->device_pci_paths) { for (int i = 0; i < transport->n_devices; i++) { @@ -1893,19 +2063,19 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { size_t mem_handle_cache_size; nvshmemt_libfabric_memhandle_info_t *handle_info = NULL, *previous_handle_info = NULL; - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(libfabric_state->cache); + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(state->cache); for (size_t i = 0; i < mem_handle_cache_size; i++) { handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, i); + state->cache, i); if (handle_info && handle_info != previous_handle_info) { free(handle_info); } previous_handle_info = handle_info; } - nvshmemt_mem_handle_cache_fini(libfabric_state->cache); + nvshmemt_mem_handle_cache_fini(state->cache); #ifdef NVSHMEM_USE_GDRCOPY if (use_gdrcopy) { nvshmemt_gdrcopy_ftable_fini(&gdrcopy_ftable, &gdr_desc, &gdrcopy_handle); @@ -1913,95 +2083,97 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { #endif } - if (libfabric_state->prov_info) { - fi_freeinfo(libfabric_state->prov_info); - } + /* + * Since fi_dupinfo() is not called, we don't need to clean + * we do not need to clean prov_infos + */ + if (state->all_prov_info) fi_freeinfo(state->all_prov_info); - if (libfabric_state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (libfabric_state->eps[i].proxy_put_signal_comp_map) - delete libfabric_state->eps[i].proxy_put_signal_comp_map; - if (libfabric_state->eps[i].endpoint) { - status = fi_close(&libfabric_state->eps[i].endpoint->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, - fi_strerror(status * -1)); - } + /* Cleanup state-level signal ordering state */ + nvshmemt_libfabric_cleanup_signal_ordering_state(state); + + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->endpoint) { + status = fi_close(&state->eps[i]->endpoint->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].cq) { - status = fi_close(&libfabric_state->eps[i].cq->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->cq) { + status = fi_close(&state->eps[i]->cq->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].counter) { - status = fi_close(&libfabric_state->eps[i].counter->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->counter) { + status = fi_close(&state->eps[i]->counter->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, + fi_strerror(status * -1)); } } - free(libfabric_state->eps); + free(state->eps[i]); } - if (libfabric_state->remote_addr_staged_amo_ack) { - if (libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]) - cudaFree(libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]); - free(libfabric_state->remote_addr_staged_amo_ack); + if (state->remote_addr_staged_amo_ack) { + if (state->remote_addr_staged_amo_ack[transport->my_pe]) + cudaFree(state->remote_addr_staged_amo_ack[transport->my_pe]); + free(state->remote_addr_staged_amo_ack); } - if (libfabric_state->rkey_staged_amo_ack) free(libfabric_state->rkey_staged_amo_ack); - if (libfabric_state->mr_staged_amo_ack) { - status = fi_close(&libfabric_state->mr_staged_amo_ack->fid); + if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) { + status = fi_close(&state->mr_staged_amo_ack[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close staged atomic ack MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->mr) { - status = fi_close(&libfabric_state->mr->fid); + for (size_t i = 0; i < state->mr.size(); i++) { + status = fi_close(&state->mr[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->recv_buf) free(libfabric_state->recv_buf); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_close(&libfabric_state->addresses[i]->fid); + for (size_t i = 0; i < state->recv_buf.size(); i++) free(state->recv_buf[i]); + + for (size_t i = 0; i < state->addresses.size(); i++) { + status = fi_close(&state->addresses[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric address vector: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->domain) { - status = fi_close(&libfabric_state->domain->fid); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_close(&state->domains[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric domain: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->fabric) { - status = fi_close(&libfabric_state->fabric->fid); + for (size_t i = 0; i < state->fabrics.size(); i++) { + status = fi_close(&state->fabrics[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric: %d: %s\n", status, fi_strerror(status * -1)); } } - free(libfabric_state); - + free(state); free(transport); return 0; } -static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state) { - struct fi_info info; +static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, + struct nvshmemi_options_s *options) { + struct fi_info hints; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; struct fi_ep_attr ep_attr; @@ -2009,65 +2181,88 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr struct fi_fabric_attr fabric_attr; struct fid_nic nic; struct fi_av_attr av_attr; - struct fi_info *returned_fabrics, *current_fabric; + struct fi_info *all_infos, *current_info; int num_fabrics_returned = 0; int status = 0; memset(&ep_attr, 0, sizeof(struct fi_ep_attr)); memset(&av_attr, 0, sizeof(struct fi_av_attr)); - memset(&info, 0, sizeof(struct fi_info)); + memset(&hints, 0, sizeof(struct fi_info)); memset(&tx_attr, 0, sizeof(struct fi_tx_attr)); memset(&rx_attr, 0, sizeof(struct fi_rx_attr)); memset(&domain_attr, 0, sizeof(struct fi_domain_attr)); memset(&fabric_attr, 0, sizeof(struct fi_fabric_attr)); memset(&nic, 0, sizeof(struct fid_nic)); - info.tx_attr = &tx_attr; - info.rx_attr = &rx_attr; - info.ep_attr = &ep_attr; - info.domain_attr = &domain_attr; - info.fabric_attr = &fabric_attr; - info.nic = &nic; + hints.tx_attr = &tx_attr; + hints.rx_attr = &rx_attr; + hints.ep_attr = &ep_attr; + hints.domain_attr = &domain_attr; + hints.fabric_attr = &fabric_attr; + hints.nic = &nic; - info.addr_format = FI_FORMAT_UNSPEC; - info.caps = FI_RMA | FI_HMEM; + hints.addr_format = FI_FORMAT_UNSPEC; + hints.caps = FI_RMA | FI_HMEM; if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - info.caps |= FI_ATOMIC; + hints.caps |= FI_ATOMIC; domain_attr.mr_mode = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { /* TODO: Use FI_FENCE to optimize put_with_signal */ - info.caps |= FI_FENCE | FI_ATOMIC; + hints.caps |= FI_FENCE | FI_ATOMIC; domain_attr.mr_mode = FI_MR_ENDPOINT | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { domain_attr.mr_mode = FI_MR_LOCAL | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_HMEM; - info.caps |= FI_MSG; - info.caps |= FI_SOURCE; + hints.caps |= FI_MSG; + hints.caps |= FI_SOURCE; } if (use_staged_atomics) { - info.mode |= FI_CONTEXT2; + hints.mode |= FI_CONTEXT2; } - /* Be thread safe at the level of the endpoint completion context. */ - domain_attr.threading = FI_THREAD_SAFE; - + ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ - info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; + hints.tx_attr->op_flags = FI_DELIVERY_COMPLETE; - ep_attr.type = FI_EP_RDM; // Reliable datagrams + /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ + domain_attr.threading = FI_THREAD_COMPLETION; + hints.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, - NULL, 0, &info, &returned_fabrics); + NULL, 0, &hints, &all_infos); + + /* + * 1. Ensure that at least one fabric was returned + * 2. Make sure returned fabric matches the name of selected provider + * + * This has an assumption that the provided fabric option + * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric + * name + */ + if (!status && strstr(all_infos->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + use_auto_progress = true; + } else { + fi_freeinfo(all_infos); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No providers matched fi_getinfo query: %d: %s\n", status, - fi_strerror(status * -1)); - state->all_prov_info = returned_fabrics; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { + /* + * Fallback to FI_PROGRESS_MANUAL path + * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE + */ + domain_attr.threading = FI_THREAD_SAFE; + hints.domain_attr->data_progress = FI_PROGRESS_MANUAL; + status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), + NULL, NULL, 0, &hints, &all_infos); + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No providers matched fi_getinfo query: %d: %s\n", status, + fi_strerror(status * -1)); + } + + state->all_prov_info = all_infos; + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { num_fabrics_returned++; } @@ -2078,53 +2273,51 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr /* Only select unique devices. */ state->num_domains = 0; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { - if (!current_fabric->nic) { + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { + if (!current_info->nic) { INFO(state->log_level, "Interface did not return NIC structure to fi_getinfo. Skipping.\n"); continue; } - if (!current_fabric->tx_attr) { + if (!current_info->tx_attr) { INFO(state->log_level, "Interface did not return TX_ATTR structure to fi_getinfo. Skipping.\n"); continue; } TRACE(state->log_level, "fi_getinfo returned provider %s, fabric %s, nic %s", - current_fabric->fabric_attr->prov_name, current_fabric->fabric_attr->name, - current_fabric->nic->device_attr->name); + current_info->fabric_attr->prov_name, current_info->fabric_attr->name, + current_info->nic->device_attr->name); if (state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - if (current_fabric->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { + if (current_info->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { INFO(state->log_level, "Disabling interface due to insufficient inject data size. reported %lu, " "expected " "%u", - current_fabric->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); + current_info->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); continue; } } - if ((current_fabric->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { + if ((current_info->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { INFO(state->log_level, "Disabling interface due to FI_MR_PROV_KEY support"); continue; } for (int i = 0; i <= state->num_domains; i++) { - if (!strncmp(current_fabric->nic->device_attr->name, state->domain_names[i].name, + if (!strncmp(current_info->nic->device_attr->name, state->domain_names[i].name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { break; } else if (i == state->num_domains) { - size_t name_len = strlen(current_fabric->nic->device_attr->name); + size_t name_len = strlen(current_info->nic->device_attr->name); if (name_len >= NVSHMEMT_LIBFABRIC_DOMAIN_LEN) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to copy domain name for libfabric transport."); } (void)strncpy(state->domain_names[state->num_domains].name, - current_fabric->nic->device_attr->name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN); + current_info->nic->device_attr->name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN); state->num_domains++; break; } @@ -2146,8 +2339,6 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr nvshmemt_libfabric_finalize(t); } - free(info.fabric_attr->name); - return status; } @@ -2184,9 +2375,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.quiet = nvshmemt_libfabric_quiet; transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; - transport->host_ops.progress = nvshmemt_libfabric_progress; - transport->host_ops.enforce_cst = nvshmemt_libfabric_enforce_cst; - transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; @@ -2199,6 +2387,7 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, "Unable to initialize env options."); libfabric_state->log_level = nvshmemt_common_get_log_level(&options); + libfabric_state->max_nic_per_pe = options.LIBFABRIC_MAX_NIC_PER_PE; if (strcmp(options.LIBFABRIC_PROVIDER, "verbs") == 0) { libfabric_state->provider = NVSHMEMT_LIBFABRIC_PROVIDER_VERBS; @@ -2302,12 +2491,19 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, #undef NVSHMEMI_SET_ENV_VAR /* Prepare fabric state information. */ - status = nvshmemi_libfabric_init_state(transport, libfabric_state); + status = nvshmemi_libfabric_init_state(transport, libfabric_state, &options); if (status) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out_clean, "Failed to initialize the libfabric state.\n"); } + libfabric_state->proxy_request_batch_max = options.LIBFABRIC_PROXY_REQUEST_BATCH_MAX; + + if (use_auto_progress) + transport->host_ops.progress = nvshmemt_libfabric_auto_proxy_progress; + else + transport->host_ops.progress = nvshmemt_libfabric_manual_progress; + *t = transport; out: if (status) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 8a889a69..392516a5 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -31,12 +31,9 @@ #define NVSHMEMT_LIBFABRIC_DOMAIN_LEN 32 #define NVSHMEMT_LIBFABRIC_PROVIDER_LEN 32 #define NVSHMEMT_LIBFABRIC_EP_LEN 128 - -/* one EP for all proxy ops, one for host ops */ -#define NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS 2 +/* Constrainted by memhandle size */ +#define NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE 16 #define NVSHMEMT_LIBFABRIC_PROXY_EP_IDX 1 -#define NVSHMEMT_LIBFABRIC_HOST_EP_IDX 0 - #define NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS 20 /* Maximum size of inject data. Currently @@ -70,6 +67,13 @@ typedef struct nvshmemt_libfabric_gdr_op_ctx nvshmemt_libfabric_gdr_op_ctx_t; #define NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK \ ((1U << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) - 1) +/** + * Frequency at which we send an ack for puts (without signal). For puts-only, + * we don't need the ack for every message for semantic reasons, we only need + * an occasional ack to handle sequence number overflow correctly. + */ +#define NVSHMEM_STAGED_AMO_PUT_ACK_FREQ 64 + /** * The last sequence number is reserved for atomic-only operations. * This will not be returned by the sequence counter. @@ -98,6 +102,11 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { constexpr static uint32_t num_index_bits = (num_sequence_bits - num_category_bits); constexpr static uint32_t index_mask = ((1U << num_index_bits) - 1); + + /* Assert that index_mask is large enough to simplify some ranged ack return + logic. */ + static_assert((index_mask + 1) >= (2 * NVSHMEM_STAGED_AMO_PUT_ACK_FREQ), + "Number of indexes should be >= 2 * put_ack_freq"); constexpr static uint32_t category_mask = (1U << num_index_bits); constexpr static uint32_t sequence_mask = NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; @@ -120,6 +129,14 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { uint32_t sequence_counter; uint32_t pending_acks[num_categories]; + uint32_t put_count; + + /** + * Default constructor - initializes counter to zero + */ + nvshmemt_libfabric_endpoint_seq_counter_t() { + reset(); + } /** * Reset counter and pending acks to zero @@ -127,6 +144,7 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { void reset() { sequence_counter = 0; memset(pending_acks, 0, sizeof(pending_acks)); + put_count = 0; } /** @@ -174,6 +192,55 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { assert(pending_acks[category] > 0); --pending_acks[category]; } + + /** + * Mark a range of sequence numbers as complete, resulting from reciving a + * put ack. The sequence range ends with end_seq. + * + * We send an ack for every NVSHMEM_STAGED_AMO_PUT_ACK_FREQ puts. Therefore, + * a put ack for is an acknowledgement sequence numbers (end_seq - + * NVSHMEM_STAGED_AMO_PUT_ACK_FREQ + 1) to end_seq, inclusive. The + * wraparound case is also handled. + * + * This code assumes the sequence range spans at most two categories. This + * will be true as long as the index space is sufficiently larger than the + * put ack frequency, as static asserted above. + */ + void return_acked_seq_num_range_for_put(uint32_t end_seq) { + assert(end_seq != NVSHMEM_STAGED_AMO_SEQ_NUM); + + uint32_t start_seq = (end_seq - NVSHMEM_STAGED_AMO_PUT_ACK_FREQ + 1) & sequence_mask; + + /* Note: in the wraparound case, the (start_seq, end_seq) range will + include NVSHMEM_STAGED_AMO_SEQ_NUM, which is not used. The logic + below handles this correctly, as long as `start_category` is correct + (which is true as long as the index space is sufficiently large that + we can only span two categories, as static-asserted above.) */ + + uint32_t start_category = get_category(start_seq); + uint32_t end_category = get_category(end_seq); + + uint32_t num_indexes; + if (end_seq >= start_seq) { + num_indexes = end_seq - start_seq + 1; + } else { + num_indexes = (NVSHMEM_STAGED_AMO_SEQ_NUM - start_seq + 1) + (end_seq + 1); + } + + if (start_category == end_category) { + assert(pending_acks[start_category] >= num_indexes); + pending_acks[start_category] -= num_indexes; + } else { + uint32_t count_in_start_cat = (index_mask + 1) - get_index(start_seq); + uint32_t count_in_end_cat = get_index(end_seq) + 1; + + assert(pending_acks[start_category] >= count_in_start_cat); + assert(pending_acks[end_category] >= count_in_end_cat); + + pending_acks[start_category] -= count_in_start_cat; + pending_acks[end_category] -= count_in_end_cat; + } + } }; typedef enum { @@ -192,12 +259,38 @@ typedef struct { struct fid_cq *cq; struct fid_cntr *counter; uint64_t submitted_ops; + uint64_t completed_ops; uint64_t completed_staged_atomics; - nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; - std::unordered_map> - *proxy_put_signal_comp_map; + int domain_index; } nvshmemt_libfabric_endpoint_t; +// Entry types for completion map +enum nvshmemt_libfabric_comp_entry_type { + NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL, + NVSHMEMT_LIBFABRIC_COMP_ENTRY_PUT_ACK +}; + +// Entry for signal operations (put-signal, atomic) +struct nvshmemt_libfabric_signal_comp_entry { + nvshmemt_libfabric_gdr_op_ctx_t *op; + int progress_count; +}; + +// Entry for puts that need acknowledgment +struct nvshmemt_libfabric_put_ack_entry { + fi_addr_t src_addr; + nvshmemt_libfabric_endpoint_t *ep; +}; + +// Tagged union for completion entries +struct nvshmemt_libfabric_comp_entry_t { + nvshmemt_libfabric_comp_entry_type type; + union { + nvshmemt_libfabric_signal_comp_entry signal_entry; + nvshmemt_libfabric_put_ack_entry ack_entry; + }; +}; + typedef struct nvshmemt_libfabric_gdr_send_p_op { uint64_t value; } nvshmemt_libfabric_gdr_send_p_op_t; @@ -248,18 +341,52 @@ typedef enum { typedef enum { NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ = 0, NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK, } nvshmemt_libfabric_imm_cq_data_hdr_t; +/* + * Conditional lock: skips locking when FI_THREAD_COMPLETION is active. + * + * With FI_PROGRESS_AUTO, we request FI_THREAD_COMPLETION from the provider, + * meaning the host thread and proxy thread each operate on separate endpoints + * (eps[0] vs eps[1+]), so their op_queues are disjoint and no synchronization + * is needed. With FI_PROGRESS_MANUAL, we use FI_THREAD_SAFE because + * manual_progress() iterates all EPs from both threads, requiring locking. + */ +class conditional_mutex { + std::mutex mtx; + bool needs_lock; + + public: + conditional_mutex() : needs_lock(true) {} + void set_needs_lock(bool v) { needs_lock = v; } + void lock() { if (needs_lock) mtx.lock(); } + void unlock() { if (needs_lock) mtx.unlock(); } +}; + class threadSafeOpQueue { private: - std::mutex send_mutex; - std::mutex ack_recv_mutex; - std::mutex other_recv_mutex; + conditional_mutex send_mutex; + conditional_mutex ack_recv_mutex; + conditional_mutex other_recv_mutex; std::vector send; std::deque ack_recv; std::deque other_recv; public: + threadSafeOpQueue() = default; + threadSafeOpQueue(const threadSafeOpQueue &) = delete; + threadSafeOpQueue &operator=(const threadSafeOpQueue &) = delete; + + /* Disable locking when FI_THREAD_COMPLETION keeps host/proxy EPs disjoint. */ + void set_auto_progress(bool auto_progress) { + send_mutex.set_needs_lock(!auto_progress); + ack_recv_mutex.set_needs_lock(!auto_progress); + other_recv_mutex.set_needs_lock(!auto_progress); + } + int getNextSends(void **elems, size_t num_elems = 1) { send_mutex.lock(); if (send.size() < num_elems) { @@ -390,31 +517,47 @@ class threadSafeOpQueue { }; typedef struct { - struct fi_info *prov_info; + std::unordered_map *put_signal_seq_counter_per_pe; + std::unordered_map *proxy_put_signal_comp_map; + std::unordered_map *next_expected_seq; +} nvshmemt_libfabric_signal_state_t; + +typedef struct { struct fi_info *all_prov_info; - struct fid_fabric *fabric; - struct fid_domain *domain; - struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; - nvshmemt_libfabric_endpoint_t *eps; - /* local_mr is used only for consistency ops. */ - struct fid_mr *local_mr[2]; - uint64_t local_mr_key[2]; - void *local_mr_desc[2]; - void *local_mem_ptr; + std::vector prov_infos; + std::vector fabrics; + std::vector domains; + std::vector addresses; + std::vector eps; + nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; int log_level; struct nvshmemi_cuda_fn_table *table; - size_t num_sends; - void *send_buf; - size_t num_recvs; - void *recv_buf; - struct fid_mr *mr; struct transport_mem_handle_info_cache *cache; + + /* Required for multi-rail */ + int max_nic_per_pe; + int num_selected_devs; + int num_selected_domains; + int cur_proxy_ep_index; + + /* Required for staged_amo */ + std::vector op_queue; + std::vector mr; + std::vector send_buf; + std::vector recv_buf; + std::vector mr_staged_amo_ack; void **remote_addr_staged_amo_ack; uint64_t *rkey_staged_amo_ack; - struct fid_mr *mr_staged_amo_ack; + + /* Signal ordering state */ + nvshmemt_libfabric_signal_state_t host_signal_state; + nvshmemt_libfabric_signal_state_t proxy_signal_state; + + /* Max ops per progress iteration */ + int proxy_request_batch_max; } nvshmemt_libfabric_state_t; typedef struct { @@ -435,7 +578,7 @@ typedef struct { typedef struct { void *buf; - nvshmemt_libfabric_mem_handle_ep_t hdls[2]; + nvshmemt_libfabric_mem_handle_ep_t hdls[1 + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE]; } nvshmemt_libfabric_mem_handle_t; /* Wire data for put-signal gdr staged atomics diff --git a/src/modules/transport/ucx/ucx.cpp b/src/modules/transport/ucx/ucx.cpp index 271ed69d..4959d0b4 100644 --- a/src/modules/transport/ucx/ucx.cpp +++ b/src/modules/transport/ucx/ucx.cpp @@ -1180,67 +1180,6 @@ int nvshmemt_ucx_finalize(nvshmem_transport_t transport) { return 0; } -int nvshmemt_ucx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - transport_ucx_state_t *ucx_state = (transport_ucx_state_t *)tcurr->state; - nvshmemt_ucx_mem_handle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_ucx_mem_handle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(ucx_state->cache, 0); - - if (!mem_handle_info) return 0; -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - int temp; - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - return 0; - } -#endif - int mype = tcurr->my_pe; - int ep_index = (ucx_state->ep_count * mype + ucx_state->proxy_ep_idx); - ucp_ep_h ep = ucx_state->endpoints[ep_index]; - ucp_request_param_t param; - ucs_status_ptr_t ucs_ptr_rc = NULL; - ucs_status_t ucs_rc; - nvshmemt_ucx_mem_handle_t *mem_handle; - ucp_rkey_h rkey; - int local_int; - - mem_handle = mem_handle_info->mem_handle; - if (unlikely(mem_handle->ep_rkey_host == NULL)) { - ucs_rc = ucp_ep_rkey_unpack(ep, mem_handle->rkey_packed_buf, &mem_handle->ep_rkey_host); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_EXIT("Unable to unpack rkey in UCS transport! Exiting.\n"); - } - } - rkey = mem_handle->ep_rkey_host; - - param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; - param.cb.send = nvshmemt_ucx_send_request_cb; - - ucs_ptr_rc = - ucp_get_nbx(ep, &local_int, sizeof(int), (uint64_t)mem_handle_info->ptr, rkey, ¶m); - - /* Wait for completion of get. */ - if (ucs_ptr_rc != NULL) { - if (UCS_PTR_IS_ERR(ucs_ptr_rc)) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } else { - do { - ucs_rc = ucp_request_check_status(ucs_ptr_rc); - ucp_worker_progress(ucx_state->worker_context); - } while (ucs_rc == UCS_INPROGRESS); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } - } - } - - return 0; -} - int nvshmemt_ucx_show_info(struct nvshmem_transport *transport, int style) { NVSHMEMI_ERROR_PRINT("UCX show info not implemented"); return 0; @@ -1446,7 +1385,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ucx_finalize; transport->host_ops.show_info = nvshmemt_ucx_show_info; transport->host_ops.progress = nvshmemt_ucx_progress; - transport->host_ops.enforce_cst = nvshmemt_ucx_enforce_cst_at_target; transport->host_ops.enforce_cst_at_target = NULL; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/test/unit/mem/transport/remote_unit_tests.cpp b/test/unit/mem/transport/remote_unit_tests.cpp index 2d057e05..eb518858 100644 --- a/test/unit/mem/transport/remote_unit_tests.cpp +++ b/test/unit/mem/transport/remote_unit_tests.cpp @@ -258,7 +258,6 @@ nvshmem_transport_host_ops initialize_nvshmem_transport_host_ops() { .fence = NULL, .quiet = NULL, .put_signal = NULL, - .enforce_cst = NULL, .enforce_cst_at_target = NULL, .add_device_remote_mem_handles = &add_device_remote_mem_handles};