diff --git a/fastsafetensors/cpp/ext.hpp b/fastsafetensors/cpp/ext.hpp index 961011f..770d3a1 100644 --- a/fastsafetensors/cpp/ext.hpp +++ b/fastsafetensors/cpp/ext.hpp @@ -39,7 +39,12 @@ typedef struct CUfileError { CUfileOpError err; } CUfileError_t; // Define minimal CUDA/HIP types for both platforms to avoid compile-time dependencies // We load all GPU functions dynamically at runtime via dlopen() typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t; +// Platform-specific enum values - CUDA and HIP have different values for HostToDevice +#ifdef USE_ROCM +enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4 }; +#else enum cudaMemcpyKind { cudaMemcpyHostToDevice=2, cudaMemcpyDefault = 4 }; +#endif typedef enum CUfileFeatureFlags {