diff --git a/CMakeLists.txt b/CMakeLists.txt index 922b04b89..88c946f9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,19 @@ # - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler cmake_minimum_required(VERSION 3.22.1) +# On Windows with HIP backend, auto-detect compilers from ROCM_PATH before project() +if(WIN32 AND COMPUTE_BACKEND STREQUAL "hip") + if(DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + if(ROCM_PATH AND NOT DEFINED CMAKE_CXX_COMPILER) + set(CMAKE_CXX_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe") + endif() + if(ROCM_PATH AND NOT DEFINED CMAKE_HIP_COMPILER) + set(CMAKE_HIP_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe") + endif() +endif() + project(bitsandbytes LANGUAGES CXX) # If run without specifying a build type, default to using the Release configuration: @@ -200,6 +213,20 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) elseif(BUILD_HIP) + # Auto-detect GPU architecture on Windows using hipinfo.exe + if(WIN32 AND NOT DEFINED BNB_ROCM_ARCH AND NOT DEFINED AMDGPU_TARGETS AND NOT DEFINED CMAKE_HIP_ARCHITECTURES) + execute_process( + COMMAND hipinfo + OUTPUT_VARIABLE HIPINFO_OUTPUT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(HIPINFO_OUTPUT MATCHES "gcnArchName:[ \t]*([a-z0-9]+)") + set(CMAKE_HIP_ARCHITECTURES "${CMAKE_MATCH_1}") + message(STATUS "Auto-detected HIP architecture: ${CMAKE_HIP_ARCHITECTURES}") + endif() + endif() + enable_language(HIP) message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") if(DEFINED BNB_ROCM_ARCH) @@ -263,6 +290,8 @@ endif() if(WIN32) # Export all symbols set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + # Prevent Windows SDK min/max macros from conflicting with std::min/std::max + add_compile_definitions(NOMINMAX) endif() if(MSVC) @@ -315,10 +344,11 @@ if(BUILD_CUDA) ) endif() if(BUILD_HIP) - if(NOT DEFINED ENV{ROCM_PATH}) - set(ROCM_PATH /opt/rocm) - else() + # Determine ROCM_PATH from environment variable, fallback to /opt/rocm on Linux + if(DEFINED ENV{ROCM_PATH}) set(ROCM_PATH $ENV{ROCM_PATH}) + else() + set(ROCM_PATH /opt/rocm) endif() list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) macro(find_package_and_print_version PACKAGE_NAME) @@ -330,14 +360,24 @@ if(BUILD_HIP) find_package_and_print_version(hipsparse REQUIRED) ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) - set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") - set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") - set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + ## On Windows, we need to link amdhip64 explicitly + if(NOT WIN32) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + endif() target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + # On Windows, link the HIP runtime and rocblas directly using full paths + if(WIN32) + target_link_libraries(bitsandbytes PUBLIC + "${ROCM_PATH}/lib/amdhip64.lib" + "${ROCM_PATH}/lib/rocblas.lib") + endif() + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 71e7568a9..66be5cc93 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,7 @@ import dataclasses from functools import lru_cache import logging +import platform import re import subprocess from typing import Optional @@ -83,10 +84,21 @@ def get_rocm_gpu_arch() -> str: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + # On Windows, use hipinfo.exe; on Linux, use rocminfo + if platform.system() == "Windows": + cmd = ["hipinfo.exe"] + arch_pattern = r"gcnArchName:\s+(gfx[a-zA-Z\d]+)" + else: + cmd = ["rocminfo"] + arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + + result = subprocess.run(cmd, capture_output=True, text=True) + match = re.search(arch_pattern, result.stdout) if match: - return "gfx" + match.group(1) + if platform.system() == "Windows": + return match.group(1) + else: + return "gfx" + match.group(1) else: return "unknown" else: @@ -107,8 +119,17 @@ def get_rocm_warpsize() -> int: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + # On Windows, use hipinfo.exe; on Linux, use rocminfo + if platform.system() == "Windows": + cmd = ["hipinfo.exe"] + # hipinfo.exe output format: "warpSize: 32" or "warpSize: 64" + warp_pattern = r"warpSize:\s+(\d+)" + else: + cmd = ["rocminfo"] + warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)" + + result = subprocess.run(cmd, capture_output=True, text=True) + match = re.search(warp_pattern, result.stdout) if match: return int(match.group(1)) else: diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 709432dcb..7d22cf9a7 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -10,6 +10,13 @@ #include #include #include +#ifdef _WIN32 +#include +#include +#include +#else +#include +#endif #include #include diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 4eb446206..86392ae5e 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -11,7 +11,17 @@ #include #include #include + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#else #include +#endif #include #include