diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c9525e0a1..c9303bb77 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -11,10 +11,7 @@ #include // std::setw, std::setfill #include #include // std::forward - -// Replace std::filesystem usage with Windows-specific headers -#include -#pragma comment(lib, "shlwapi.lib") +#include //------------------------------------------------------------------------------------------------- // Macro definitions @@ -188,14 +185,26 @@ ParamType* AllocateParamBuffer(std::vector>& paramBuffers, return static_cast(paramBuffers.back().get()); } +std::string DescribeChar(unsigned char ch) { + if (ch >= 32 && ch <= 126) { + return std::string("'") + static_cast(ch) + "'"; + } else { + char buffer[16]; + snprintf(buffer, sizeof(buffer), "U+%04X", ch); + return std::string(buffer); + } +} + // Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with // appropriate arguments SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, const std::vector& paramInfos, std::vector>& paramBuffers) { + LOG("Starting parameter binding. Number of parameters: {}", params.size()); for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) { const auto& param = params[paramIndex]; const ParamInfo& paramInfo = paramInfos[paramIndex]; + LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType); void* dataPtr = nullptr; SQLLEN bufferLength = 0; SQLLEN* strLenOrIndPtr = nullptr; @@ -233,8 +242,45 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, "Streaming parameters is not yet supported. Parameter size" " must be less than 8192 bytes"); } + + // Log detailed parameter information + LOG("SQL_C_WCHAR Parameter[{}]: Length={}, Content='{}'", + paramIndex, + strParam->size(), + (strParam->size() <= 100 + ? WideToUTF8(std::wstring(strParam->begin(), strParam->end())) + : WideToUTF8(std::wstring(strParam->begin(), strParam->begin() + 100)) + "...")); + + // Log each character's code point for debugging + if (strParam->size() <= 20) { + for (size_t i = 0; i < strParam->size(); i++) { + unsigned char ch = static_cast((*strParam)[i]); + LOG(" char[{}] = {} ({})", i, static_cast(ch), DescribeChar(ch)); + } + } +#if defined(__APPLE__) + // On macOS, we need special handling for wide characters + // Create a properly encoded SQLWCHAR buffer for the parameter + std::vector* sqlwcharBuffer = + AllocateParamBuffer>(paramBuffers); + + // Reserve space and convert from wstring to SQLWCHAR array + sqlwcharBuffer->resize(strParam->size() + 1, 0); // +1 for null terminator + + // Convert each wchar_t (4 bytes on macOS) to SQLWCHAR (2 bytes) + for (size_t i = 0; i < strParam->size(); i++) { + (*sqlwcharBuffer)[i] = static_cast((*strParam)[i]); + } + + // Use the SQLWCHAR buffer instead of the wstring directly + dataPtr = sqlwcharBuffer->data(); + bufferLength = (strParam->size() + 1) * sizeof(SQLWCHAR); + LOG("macOS: Created SQLWCHAR buffer for parameter with size: {} bytes", bufferLength); +#else + // On Windows, wchar_t and SQLWCHAR are the same size, so direct cast works dataPtr = const_cast(static_cast(strParam->c_str())); bufferLength = (strParam->size() + 1 /* null terminator */) * sizeof(wchar_t); +#endif strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_NTS; break; @@ -464,6 +510,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, } } } + LOG("Finished parameter binding. Number of parameters: {}", params.size()); return SQL_SUCCESS; } @@ -495,14 +542,6 @@ void LOG(const std::string& formatString, Args&&... args) { logging.attr("debug")(message); } -std::string WideToUTF8(const std::wstring& wstr) { - if (wstr.empty()) return {}; - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), nullptr, 0, nullptr, nullptr); - std::string result(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), result.data(), size_needed, nullptr, nullptr); - return result; -} - // TODO: Add more nuanced exception classes void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } @@ -511,130 +550,178 @@ std::string GetModuleDirectory() { py::object module_path = module.attr("__file__"); std::string module_file = module_path.cast(); +#ifdef _WIN32 + // Windows-specific path handling char path[MAX_PATH]; - strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); + errno_t err = strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); + if (err != 0) { + LOG("strncpy_s failed with error code: {}", err); + return {}; + } PathRemoveFileSpecA(path); return std::string(path); +#else + // macOS/Unix path handling without using std::filesystem + std::string::size_type pos = module_file.find_last_of('/'); + if (pos != std::string::npos) { + std::string dir = module_file.substr(0, pos); + return dir; + } + return module_file; +#endif +} + +// Platform-agnostic function to load the driver dynamic library +DriverHandle LoadDriverLibrary(const std::string& driverPath) { + LOG("Loading driver from path: {}", driverPath); +#ifdef _WIN32 + // Windows: Convert string to wide string for LoadLibraryW + std::wstring widePath(driverPath.begin(), driverPath.end()); + HMODULE handle = LoadLibraryW(widePath.c_str()); + if (!handle) { + LOG("LoadLibraryW failed."); + } + return handle; +#else + // macOS/Unix: Use dlopen + void* handle = dlopen(driverPath.c_str(), RTLD_LAZY); + if (!handle) { + LOG("dlopen failed."); + } + return handle; +#endif +} + +// Platform-agnostic function to get last error message +std::string GetLastErrorMessage() { +#ifdef _WIN32 + // Windows: Use FormatMessageA + DWORD error = GetLastError(); + char* messageBuffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + error, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&messageBuffer, + 0, + NULL + ); + std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; + LocalFree(messageBuffer); + return "Error code: " + std::to_string(error) + " - " + errorMessage; +#else + // macOS/Unix: Use dlerror + const char* error = dlerror(); + return error ? std::string(error) : "Unknown error"; +#endif } // Helper to load the driver // TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit // linking to load this DLL. It will simplify the code a lot. -std::wstring LoadDriverOrThrowException() { - const std::wstring& modulePath = L""; - std::wstring ddbcModulePath = modulePath; - if (ddbcModulePath.empty()) { - // Get the module path if not provided - std::string path = GetModuleDirectory(); - ddbcModulePath = std::wstring(path.begin(), path.end()); +DriverHandle LoadDriverOrThrowException() { + namespace fs = std::filesystem; + std::string moduleDir = GetModuleDirectory(); + LOG("Module directory: {}", moduleDir); + std::string archStr = ARCHITECTURE; + std::string archDir = + (archStr == "win64" || archStr == "amd64" || archStr == "x64") ? "x64" : + (archStr == "arm64") ? "arm64" : + "x86"; + + fs::path driverPath; +#ifdef _WIN32 + fs::path dllDir = fs::path(moduleDir) / "libs" / archDir; + + // Optionally load mssql-auth.dll if it exists + fs::path authDllPath = dllDir / "mssql-auth.dll"; + if (fs::exists(authDllPath)) { + HMODULE hAuth = LoadLibraryW(authDllPath.wstring().c_str()); + if (hAuth) { + LOG("Authentication DLL loaded: {}", authDllPath.string()); + } else { + LOG("Failed to load mssql-auth.dll: {}", GetLastErrorMessage()); + } + } else { + LOG("Note: mssql-auth.dll not found. This is OK if Entra ID is not in use."); } - std::wstring dllDir = ddbcModulePath; - dllDir += L"\\libs\\"; - - // Convert ARCHITECTURE macro to wstring - std::wstring archStr(ARCHITECTURE, ARCHITECTURE + strlen(ARCHITECTURE)); - - // Map architecture identifiers to correct subdirectory names - std::wstring archDir; - if (archStr == L"win64" || archStr == L"amd64" || archStr == L"x64") { - archDir = L"x64"; - } else if (archStr == L"arm64") { - archDir = L"arm64"; + driverPath = dllDir / "msodbcsql18.dll"; +#else // macOS + std::string runtimeArch = + #if defined(__arm64__) || defined(__aarch64__) + "arm64"; + #else + "x86_64"; + #endif + fs::path primaryPath = fs::path(moduleDir) / "libs" / "macos" / runtimeArch / "lib" / "libmsodbcsql.18.dylib"; + if (fs::exists(primaryPath)) { + driverPath = primaryPath; + LOG("macOS driver found at: {}", driverPath.string()); } else { - archDir = L"x86"; + driverPath = fs::path(moduleDir) / "libs" / archDir / "macos/libmsodbcsql.18.dylib"; + LOG("Using fallback macOS driver path: {}", driverPath.string()); } - dllDir += archDir; - std::wstring mssqlauthDllPath = dllDir + L"\\mssql-auth.dll"; - dllDir += L"\\msodbcsql18.dll"; - - // Preload mssql-auth.dll from the same path if available - HMODULE hAuthModule = LoadLibraryW(mssqlauthDllPath.c_str()); - if (hAuthModule) { - LOG("Authentication library loaded successfully from - {}", mssqlauthDllPath.c_str()); - } else { - LOG("Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication.", mssqlauthDllPath.c_str()); +#endif + if (!fs::exists(driverPath)) { + ThrowStdException("ODBC driver not found at: " + driverPath.string()); } - - // Convert wstring to string for logging - LOG("Attempting to load driver from - {}", WideToUTF8(dllDir)); - - HMODULE hModule = LoadLibraryW(dllDir.c_str()); - if (!hModule) { - // Failed to load the DLL, get the error message - DWORD error = GetLastError(); - char* messageBuffer = nullptr; - size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, - 0, - NULL - ); - std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; - LocalFree(messageBuffer); - - // Log the error message - LOG("Failed to load the driver with error code: {} - {}", error, errorMessage); - ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly."); + DriverHandle handle = LoadDriverLibrary(driverPath.string()); + if (!handle) { + LOG("Failed to load driver: {}", GetLastErrorMessage()); + ThrowStdException("Failed to load ODBC driver. Please check installation."); } - - // If we got here, we've successfully loaded the DLL. Now get the function pointers. - // Environment and handle function loading - SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress(hModule, "SQLAllocHandle"); - SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress(hModule, "SQLSetEnvAttr"); - SQLSetConnectAttr_ptr = (SQLSetConnectAttrFunc)GetProcAddress(hModule, "SQLSetConnectAttrW"); - SQLSetStmtAttr_ptr = (SQLSetStmtAttrFunc)GetProcAddress(hModule, "SQLSetStmtAttrW"); - SQLGetConnectAttr_ptr = (SQLGetConnectAttrFunc)GetProcAddress(hModule, "SQLGetConnectAttrW"); - - // Connection and statement function loading - SQLDriverConnect_ptr = (SQLDriverConnectFunc)GetProcAddress(hModule, "SQLDriverConnectW"); - SQLExecDirect_ptr = (SQLExecDirectFunc)GetProcAddress(hModule, "SQLExecDirectW"); - SQLPrepare_ptr = (SQLPrepareFunc)GetProcAddress(hModule, "SQLPrepareW"); - SQLBindParameter_ptr = (SQLBindParameterFunc)GetProcAddress(hModule, "SQLBindParameter"); - SQLExecute_ptr = (SQLExecuteFunc)GetProcAddress(hModule, "SQLExecute"); - SQLRowCount_ptr = (SQLRowCountFunc)GetProcAddress(hModule, "SQLRowCount"); - SQLGetStmtAttr_ptr = (SQLGetStmtAttrFunc)GetProcAddress(hModule, "SQLGetStmtAttrW"); - SQLSetDescField_ptr = (SQLSetDescFieldFunc)GetProcAddress(hModule, "SQLSetDescFieldW"); - - // Fetch and data retrieval function loading - SQLFetch_ptr = (SQLFetchFunc)GetProcAddress(hModule, "SQLFetch"); - SQLFetchScroll_ptr = (SQLFetchScrollFunc)GetProcAddress(hModule, "SQLFetchScroll"); - SQLGetData_ptr = (SQLGetDataFunc)GetProcAddress(hModule, "SQLGetData"); - SQLNumResultCols_ptr = (SQLNumResultColsFunc)GetProcAddress(hModule, "SQLNumResultCols"); - SQLBindCol_ptr = (SQLBindColFunc)GetProcAddress(hModule, "SQLBindCol"); - SQLDescribeCol_ptr = (SQLDescribeColFunc)GetProcAddress(hModule, "SQLDescribeColW"); - SQLMoreResults_ptr = (SQLMoreResultsFunc)GetProcAddress(hModule, "SQLMoreResults"); - SQLColAttribute_ptr = (SQLColAttributeFunc)GetProcAddress(hModule, "SQLColAttributeW"); - - // Transaction functions loading - SQLEndTran_ptr = (SQLEndTranFunc)GetProcAddress(hModule, "SQLEndTran"); - - // Disconnect and free functions loading - SQLFreeHandle_ptr = (SQLFreeHandleFunc)GetProcAddress(hModule, "SQLFreeHandle"); - SQLDisconnect_ptr = (SQLDisconnectFunc)GetProcAddress(hModule, "SQLDisconnect"); - SQLFreeStmt_ptr = (SQLFreeStmtFunc)GetProcAddress(hModule, "SQLFreeStmt"); - - // Diagnostic record function Loading - SQLGetDiagRec_ptr = (SQLGetDiagRecFunc)GetProcAddress(hModule, "SQLGetDiagRecW"); - - bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && - SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && - SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && - SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr && - SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr && - SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && - SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr && - SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + LOG("Driver library successfully loaded."); + + // Load function pointers using helper + SQLAllocHandle_ptr = GetFunctionPointer(handle, "SQLAllocHandle"); + SQLSetEnvAttr_ptr = GetFunctionPointer(handle, "SQLSetEnvAttr"); + SQLSetConnectAttr_ptr = GetFunctionPointer(handle, "SQLSetConnectAttrW"); + SQLSetStmtAttr_ptr = GetFunctionPointer(handle, "SQLSetStmtAttrW"); + SQLGetConnectAttr_ptr = GetFunctionPointer(handle, "SQLGetConnectAttrW"); + + SQLDriverConnect_ptr = GetFunctionPointer(handle, "SQLDriverConnectW"); + SQLExecDirect_ptr = GetFunctionPointer(handle, "SQLExecDirectW"); + SQLPrepare_ptr = GetFunctionPointer(handle, "SQLPrepareW"); + SQLBindParameter_ptr = GetFunctionPointer(handle, "SQLBindParameter"); + SQLExecute_ptr = GetFunctionPointer(handle, "SQLExecute"); + SQLRowCount_ptr = GetFunctionPointer(handle, "SQLRowCount"); + SQLGetStmtAttr_ptr = GetFunctionPointer(handle, "SQLGetStmtAttrW"); + SQLSetDescField_ptr = GetFunctionPointer(handle, "SQLSetDescFieldW"); + + SQLFetch_ptr = GetFunctionPointer(handle, "SQLFetch"); + SQLFetchScroll_ptr = GetFunctionPointer(handle, "SQLFetchScroll"); + SQLGetData_ptr = GetFunctionPointer(handle, "SQLGetData"); + SQLNumResultCols_ptr = GetFunctionPointer(handle, "SQLNumResultCols"); + SQLBindCol_ptr = GetFunctionPointer(handle, "SQLBindCol"); + SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); + SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); + SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); + + SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); + SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); + SQLFreeHandle_ptr = GetFunctionPointer(handle, "SQLFreeHandle"); + SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); + + SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + + bool success = + SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && + SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && + SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && + SQLExecute_ptr && SQLRowCount_ptr && SQLGetStmtAttr_ptr && + SQLSetDescField_ptr && SQLFetch_ptr && SQLFetchScroll_ptr && + SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && + SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && + SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && + SQLFreeStmt_ptr && SQLGetDiagRec_ptr; if (!success) { - ThrowStdException("Failed to load required function pointers from driver"); + ThrowStdException("Failed to load required function pointers from driver."); } - LOG("Successfully loaded function pointers from driver"); - - return dllDir; + LOG("All driver function pointers successfully loaded."); + return handle; } // DriverLoader definition @@ -714,8 +801,15 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET &nativeError, message, SQL_MAX_MESSAGE_LENGTH, &messageLen); if (SQL_SUCCEEDED(diagReturn)) { +#if defined(_WIN32) + // On Windows, SQLWCHAR and wchar_t are compatible errorInfo.sqlState = std::wstring(sqlState); errorInfo.ddbcErrorMsg = std::wstring(message); +#else + // On macOS/Linux, need to convert SQLWCHAR (usually unsigned short) to wchar_t + errorInfo.sqlState = SQLWCHARToWString(sqlState); + errorInfo.ddbcErrorMsg = SQLWCHARToWString(message, messageLen); +#endif } } return errorInfo; @@ -729,7 +823,14 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } - SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), const_cast(Query.c_str()), SQL_NTS); + SQLWCHAR* queryPtr; +#if defined(__APPLE__) + std::vector queryBuffer = WStringToSQLWCHAR(Query); + queryPtr = queryBuffer.data(); +#else + queryPtr = const_cast(Query.c_str()); +#endif + SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to execute query directly"); } @@ -761,7 +862,13 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, if (!statementHandle || !statementHandle->get()) { LOG("Statement handle is null or empty"); } - SQLWCHAR* queryPtr = const_cast(query.c_str()); + SQLWCHAR* queryPtr; +#if defined(__APPLE__) + std::vector queryBuffer = WStringToSQLWCHAR(query); + queryPtr = queryBuffer.data(); +#else + queryPtr = const_cast(query.c_str()); +#endif if (params.size() == 0) { // Execute statement directly if the statement is not parametrized. This is the // fastest way to submit a SQL statement for one-time execution according to @@ -861,7 +968,11 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta if (SQL_SUCCEEDED(retcode)) { // Append a named py::dict to ColumnMetadata // TODO: Should we define a struct for this task instead of dict? +#if defined(__APPLE__) + ColumnMetadata.append(py::dict("ColumnName"_a = SQLWCHARToWString(ColumnName, SQL_NTS), +#else ColumnMetadata.append(py::dict("ColumnName"_a = std::wstring(ColumnName), +#endif "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, "DecimalDigits"_a = DecimalDigits, "Nullable"_a = Nullable)); @@ -932,7 +1043,13 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // NOTE: dataBuffer.size() includes null-terminator, dataLen doesn't. Hence use '<'. if (numCharsInData < dataBuffer.size()) { // SQLGetData will null-terminate the data +#if defined(__APPLE__) + std::string fullStr(reinterpret_cast(dataBuffer.data())); + row.append(fullStr); + LOG("macOS: Appended CHAR string of length {} to result row", fullStr.length()); +#else row.append(std::string(reinterpret_cast(dataBuffer.data()))); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit @@ -975,7 +1092,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); if (numCharsInData < dataBuffer.size()) { // SQLGetData will null-terminate the data +#if defined(__APPLE__) + row.append(SQLWCHARToWString(dataBuffer.data(), SQL_NTS)); +#else row.append(std::wstring(dataBuffer.data())); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit @@ -1484,9 +1605,17 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (numCharsInData < fetchBufferSize) { // SQLFetch will nullterminate the data +#if defined(__APPLE__) + // Use macOS-specific conversion to handle the wchar_t/SQLWCHAR size difference + SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; + std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); + row.append(wstr); +#else + // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works row.append(std::wstring( reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), numCharsInData)); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index bb050eab8..2ccf28fda 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -7,14 +7,6 @@ #pragma once #include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions - -#include -#include -#include -#include -#include -#include - #include #include #include @@ -23,6 +15,58 @@ namespace py = pybind11; using namespace pybind11::literals; +#include +#include +#include + +#ifdef _WIN32 + // Windows-specific headers + #include // windows.h needs to be included before sql.h + #include + #pragma comment(lib, "shlwapi.lib") + #define IS_WINDOWS 1 +#else + #define IS_WINDOWS 0 +#endif + +#include +#include + +#if defined(__APPLE__) + // macOS-specific headers + #include + + inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { + if (!sqlwStr) return std::wstring(); + + if (length == SQL_NTS) { + size_t i = 0; + while (sqlwStr[i] != 0) ++i; + length = i; + } + + std::wstring result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + result.push_back(static_cast(sqlwStr[i])); + } + return result; + } + + inline std::vector WStringToSQLWCHAR(const std::wstring& str) { + std::vector result(str.size() + 1, 0); // +1 for null terminator + for (size_t i = 0; i < str.size(); ++i) { + result[i] = static_cast(str[i]); + } + return result; + } +#endif + +#if defined(__APPLE__) +#include "mac_utils.h" // For macOS-specific Unicode encoding fixes +#include "mac_buffers.h" // For macOS-specific buffer handling +#endif + //------------------------------------------------------------------------------------------------- // Function pointer typedefs //------------------------------------------------------------------------------------------------- @@ -116,20 +160,37 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr; // Diagnostic APIs extern SQLGetDiagRecFunc SQLGetDiagRec_ptr; - // Logging utility template void LOG(const std::string& formatString, Args&&... args); - // Throws a std::runtime_error with the given message void ThrowStdException(const std::string& message); +// Define a platform-agnostic type for the driver handle +#ifdef _WIN32 +typedef HMODULE DriverHandle; +#else +typedef void* DriverHandle; +#endif + +// Platform-agnostic function to get a function pointer from the loaded library +template +T GetFunctionPointer(DriverHandle handle, const char* functionName) { +#ifdef _WIN32 + // Windows: Use GetProcAddress + return reinterpret_cast(GetProcAddress(handle, functionName)); +#else + // macOS/Unix: Use dlsym + return reinterpret_cast(dlsym(handle, functionName)); +#endif +} + //------------------------------------------------------------------------------------------------- // Loads the ODBC driver and resolves function pointers. // Throws if loading or resolution fails. //------------------------------------------------------------------------------------------------- -std::wstring LoadDriverOrThrowException(); +DriverHandle LoadDriverOrThrowException(); //------------------------------------------------------------------------------------------------- // DriverLoader (Singleton) @@ -178,4 +239,35 @@ struct ErrorInfo { }; ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode); -std::string WideToUTF8(const std::wstring& wstr); \ No newline at end of file +inline std::string WideToUTF8(const std::wstring& wstr) { + if (wstr.empty()) return {}; +#if defined(_WIN32) + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), nullptr, 0, nullptr, nullptr); + if (size_needed == 0) return {}; + std::string result(size_needed, 0); + int converted = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), result.data(), size_needed, nullptr, nullptr); + if (converted == 0) return {}; + return result; +#else + std::wstring_convert> converter; + return converter.to_bytes(wstr); +#endif +} + +inline std::wstring Utf8ToWString(const std::string& str) { + if (str.empty()) return {}; +#if defined(_WIN32) + int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), nullptr, 0); + if (size_needed == 0) { + LOG("MultiByteToWideChar failed."); + return {}; + } + std::wstring result(size_needed, 0); + int converted = MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), result.data(), size_needed); + if (converted == 0) return {}; + return result; +#else + std::wstring_convert> converter; + return converter.from_bytes(str); +#endif +} diff --git a/mssql_python/pybind/ddbc_bindings_mac.cpp b/mssql_python/pybind/ddbc_bindings_mac.cpp index 6505efeea..249e79e2c 100644 --- a/mssql_python/pybind/ddbc_bindings_mac.cpp +++ b/mssql_python/pybind/ddbc_bindings_mac.cpp @@ -70,7 +70,7 @@ #include #if defined(__APPLE__) -#include "mac_fix.h" // For macOS-specific Unicode encoding fixes +#include "mac_utils.h" // For macOS-specific Unicode encoding fixes #include "mac_buffers.h" // For macOS-specific buffer handling #endif diff --git a/mssql_python/pybind/mac_fix.cpp b/mssql_python/pybind/mac_utils.cpp similarity index 94% rename from mssql_python/pybind/mac_fix.cpp rename to mssql_python/pybind/mac_utils.cpp index 06ebb2539..6dbb0ed76 100644 --- a/mssql_python/pybind/mac_fix.cpp +++ b/mssql_python/pybind/mac_utils.cpp @@ -1,5 +1,10 @@ -// Mac OS specific fixes for the C++ code -// This file contains patches to fix issues specific to macOS +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// This file defines utility functions for safely handling SQLWCHAR-based +// wide-character data in ODBC operations on macOS. It includes conversions +// between SQLWCHAR, std::wstring, and UTF-8 strings to bridge encoding +// differences specific to macOS. #if defined(__APPLE__) // Constants for character encoding diff --git a/mssql_python/pybind/mac_fix.h b/mssql_python/pybind/mac_utils.h similarity index 75% rename from mssql_python/pybind/mac_fix.h rename to mssql_python/pybind/mac_utils.h index 04592b048..776ab6447 100644 --- a/mssql_python/pybind/mac_fix.h +++ b/mssql_python/pybind/mac_utils.h @@ -1,3 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// This header defines utility functions for safely handling SQLWCHAR-based +// wide-character data in ODBC operations on macOS. It includes conversions +// between SQLWCHAR, std::wstring, and UTF-8 strings to bridge encoding +// differences specific to macOS. + #pragma once #include