|
14 | 14 | #include <umf/memory_provider_ops.h> |
15 | 15 | #include <umf/providers/provider_level_zero.h> |
16 | 16 |
|
| 17 | +#include "base_alloc_global.h" |
17 | 18 | #include "provider_level_zero_internal.h" |
18 | 19 | #include "utils_load_library.h" |
19 | 20 | #include "utils_log.h" |
@@ -111,7 +112,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) { |
111 | 112 |
|
112 | 113 | #else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER) |
113 | 114 |
|
114 | | -#include "base_alloc_global.h" |
115 | 115 | #include "libumf.h" |
116 | 116 | #include "utils_assert.h" |
117 | 117 | #include "utils_common.h" |
@@ -211,6 +211,49 @@ static umf_result_t ze2umf_result(ze_result_t result) { |
211 | 211 | } |
212 | 212 | } |
213 | 213 |
|
| 214 | +static umf_result_t ze_init_drivers(void *lib_handle, const char *lib_name) { |
| 215 | + ze_result_t (*zeInitDriversFunc)(uint32_t *, ze_driver_handle_t *, |
| 216 | + ze_init_driver_type_desc_t *); |
| 217 | + *(void **)&zeInitDriversFunc = |
| 218 | + utils_get_symbol_addr(lib_handle, "zeInitDrivers", lib_name); |
| 219 | + if (!zeInitDriversFunc) { |
| 220 | + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; |
| 221 | + } |
| 222 | + |
| 223 | + ze_init_driver_type_desc_t desc = { |
| 224 | + .stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC, |
| 225 | + .pNext = NULL, |
| 226 | + .flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU}; |
| 227 | + uint32_t driverCount = 0; |
| 228 | + ze_result_t result = zeInitDriversFunc(&driverCount, NULL, &desc); |
| 229 | + if (result != ZE_RESULT_SUCCESS) { |
| 230 | + return ze2umf_result(result); |
| 231 | + } |
| 232 | + |
| 233 | + ze_driver_handle_t *zeAllDrivers = |
| 234 | + umf_ba_global_alloc(sizeof(ze_driver_handle_t) * driverCount); |
| 235 | + result = zeInitDriversFunc(&driverCount, zeAllDrivers, &desc); |
| 236 | + umf_ba_global_free(zeAllDrivers); |
| 237 | + if (result != ZE_RESULT_SUCCESS) { |
| 238 | + return ze2umf_result(result); |
| 239 | + } |
| 240 | + |
| 241 | + return UMF_RESULT_SUCCESS; |
| 242 | +} |
| 243 | + |
| 244 | +static umf_result_t ze_init(void *lib_handle, const char *lib_name) { |
| 245 | + ze_result_t (*zeInitFunc)(ze_init_flag_t); |
| 246 | + *(void **)&zeInitFunc = |
| 247 | + utils_get_symbol_addr(lib_handle, "zeInit", lib_name); |
| 248 | + |
| 249 | + if (!zeInitFunc) { |
| 250 | + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; |
| 251 | + } |
| 252 | + |
| 253 | + ze_result_t result = zeInitFunc(ZE_INIT_FLAG_GPU_ONLY); |
| 254 | + return ze2umf_result(result); |
| 255 | +} |
| 256 | + |
214 | 257 | static void init_ze_global_state(void) { |
215 | 258 | #ifdef _WIN32 |
216 | 259 | const char *lib_name = "ze_loader.dll"; |
@@ -266,6 +309,19 @@ static void init_ze_global_state(void) { |
266 | 309 | utils_close_library(lib_handle); |
267 | 310 | return; |
268 | 311 | } |
| 312 | + |
| 313 | + if (ze_init_drivers(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { |
| 314 | + LOG_INFO("Initializing Level Zero through zeInitDrivers failed. " |
| 315 | + "Falling back to zeInit."); |
| 316 | + |
| 317 | + if (ze_init(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { |
| 318 | + LOG_FATAL("Failed to initialize Level Zero"); |
| 319 | + Init_ze_global_state_failed = true; |
| 320 | + utils_close_library(lib_handle); |
| 321 | + return; |
| 322 | + } |
| 323 | + } |
| 324 | + |
269 | 325 | ze_lib_handle = lib_handle; |
270 | 326 | } |
271 | 327 |
|
|
0 commit comments