Skip to content

Commit c00f22d

Browse files
committed
Call zeInitDrivers in L0 provider
According to the L0 spec, zeInitDrivers must be called (by every library) before calling any other APIs. Not calling zeInitDrivers causes crash when using statically linked L0 loader in UR.
1 parent c80d2b2 commit c00f22d

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

src/provider/provider_level_zero.c

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <umf/memory_provider_ops.h>
1515
#include <umf/providers/provider_level_zero.h>
1616

17+
#include "base_alloc_global.h"
1718
#include "provider_level_zero_internal.h"
1819
#include "utils_load_library.h"
1920
#include "utils_log.h"
@@ -111,7 +112,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) {
111112

112113
#else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER)
113114

114-
#include "base_alloc_global.h"
115115
#include "libumf.h"
116116
#include "utils_assert.h"
117117
#include "utils_common.h"
@@ -211,6 +211,49 @@ static umf_result_t ze2umf_result(ze_result_t result) {
211211
}
212212
}
213213

214+
static umf_result_t ze_init_drivers(void *lib_handle, const char *lib_name) {
215+
ze_result_t (*zeInitDriversFunc)(uint32_t *, ze_driver_handle_t *,
216+
ze_init_driver_type_desc_t *);
217+
*(void **)&zeInitDriversFunc =
218+
utils_get_symbol_addr(lib_handle, "zeInitDrivers", lib_name);
219+
if (!zeInitDriversFunc) {
220+
return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE;
221+
}
222+
223+
ze_init_driver_type_desc_t desc = {
224+
.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC,
225+
.pNext = NULL,
226+
.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU};
227+
uint32_t driverCount = 0;
228+
ze_result_t result = zeInitDriversFunc(&driverCount, NULL, &desc);
229+
if (result != ZE_RESULT_SUCCESS) {
230+
return ze2umf_result(result);
231+
}
232+
233+
ze_driver_handle_t *zeAllDrivers =
234+
umf_ba_global_alloc(sizeof(ze_driver_handle_t) * driverCount);
235+
result = zeInitDriversFunc(&driverCount, zeAllDrivers, &desc);
236+
umf_ba_global_free(zeAllDrivers);
237+
if (result != ZE_RESULT_SUCCESS) {
238+
return ze2umf_result(result);
239+
}
240+
241+
return UMF_RESULT_SUCCESS;
242+
}
243+
244+
static umf_result_t ze_init(void *lib_handle, const char *lib_name) {
245+
ze_result_t (*zeInitFunc)(ze_init_flag_t);
246+
*(void **)&zeInitFunc =
247+
utils_get_symbol_addr(lib_handle, "zeInit", lib_name);
248+
249+
if (!zeInitFunc) {
250+
return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE;
251+
}
252+
253+
ze_result_t result = zeInitFunc(ZE_INIT_FLAG_GPU_ONLY);
254+
return ze2umf_result(result);
255+
}
256+
214257
static void init_ze_global_state(void) {
215258
#ifdef _WIN32
216259
const char *lib_name = "ze_loader.dll";
@@ -266,6 +309,19 @@ static void init_ze_global_state(void) {
266309
utils_close_library(lib_handle);
267310
return;
268311
}
312+
313+
if (ze_init_drivers(lib_handle, lib_name) != UMF_RESULT_SUCCESS) {
314+
LOG_INFO("Initializing Level Zero through zeInitDrivers failed. "
315+
"Falling back to zeInit.");
316+
317+
if (ze_init(lib_handle, lib_name) != UMF_RESULT_SUCCESS) {
318+
LOG_FATAL("Failed to initialize Level Zero");
319+
Init_ze_global_state_failed = true;
320+
utils_close_library(lib_handle);
321+
return;
322+
}
323+
}
324+
269325
ze_lib_handle = lib_handle;
270326
}
271327

0 commit comments

Comments
 (0)