diff --git a/eng/pipelines/pr-validation-pipeline.yml b/eng/pipelines/pr-validation-pipeline.yml index 384fe1c0e..e688fad31 100644 --- a/eng/pipelines/pr-validation-pipeline.yml +++ b/eng/pipelines/pr-validation-pipeline.yml @@ -7,163 +7,48 @@ trigger: - main jobs: -- job: PytestOnWindows - pool: - vmImage: 'windows-latest' - - steps: - - task: UsePythonVersion@0 - inputs: - # TODO: Remove this once Python 3.13 is available in ADO - # ADO indexing will take some time to reflect 3.13 as 3.13.5, right now it is pointing to 3.13.4. - # We're specifying to use Python 3.13.5 since 3.13.4 has issues with compilation - # See https://github.com/python/cpython/issues/135151, next release should fix it. - versionSpec: '3.13.5' - addToPath: true - githubToken: $(GITHUB_TOKEN) - displayName: 'Use Python 3.13' - - - script: | - python -m pip install --upgrade pip - pip install -r requirements.txt - displayName: 'Install dependencies' - - # Start LocalDB instance - - powershell: | - sqllocaldb create MSSQLLocalDB - sqllocaldb start MSSQLLocalDB - displayName: 'Start LocalDB instance' - - # Create database and user - - powershell: | - sqlcmd -S "(localdb)\MSSQLLocalDB" -Q "CREATE DATABASE TestDB" - sqlcmd -S "(localdb)\MSSQLLocalDB" -Q "CREATE LOGIN testuser WITH PASSWORD = '$(DB_PASSWORD)'" - sqlcmd -S "(localdb)\MSSQLLocalDB" -d TestDB -Q "CREATE USER testuser FOR LOGIN testuser" - sqlcmd -S "(localdb)\MSSQLLocalDB" -d TestDB -Q "ALTER ROLE db_owner ADD MEMBER testuser" - displayName: 'Setup database and user' - env: - DB_PASSWORD: $(DB_PASSWORD) - - - script: | - cd mssql_python\pybind - build.bat x64 - displayName: 'Build .pyd file' - - - script: | - python -m pytest -v --junitxml=test-results.xml --cov=. --cov-report=xml --capture=tee-sys --cache-clear - displayName: 'Run tests with coverage' - env: - DB_CONNECTION_STRING: 'Server=(localdb)\MSSQLLocalDB;Database=TestDB;Uid=testuser;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' - - - task: PublishBuildArtifacts@1 - inputs: - PathtoPublish: 'mssql_python/ddbc_bindings.cp313-amd64.pyd' - ArtifactName: 'ddbc_bindings' - publishLocation: 'Container' - displayName: 'Publish pyd file as artifact' - - - task: PublishBuildArtifacts@1 - inputs: - PathtoPublish: 'mssql_python/ddbc_bindings.cp313-amd64.pdb' - ArtifactName: 'ddbc_bindings' - publishLocation: 'Container' - displayName: 'Publish pdb file as artifact' - - - task: PublishTestResults@2 - condition: succeededOrFailed() - inputs: - testResultsFiles: '**/test-results.xml' - testRunTitle: 'Publish test results' - - - task: PublishCodeCoverageResults@1 - inputs: - codeCoverageTool: 'Cobertura' - summaryFileLocation: 'coverage.xml' - displayName: 'Publish code coverage results' - -- job: PytestOnMacOS - pool: - vmImage: 'macos-latest' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.13.5' - addToPath: true - displayName: 'Use Python 3.13 on macOS' - - - script: | - brew update - brew install cmake - displayName: 'Install CMake' - - # - script: | - # brew update - # brew install docker colima - - # # Start Colima with extra resources - # colima start --cpu 4 --memory 8 --disk 50 - - # # Optional: set Docker context (usually automatic) - # docker context use colima >/dev/null || true - - # # Confirm Docker is operational - # docker version - # docker ps - # displayName: 'Install and start Colima-based Docker' - - # - script: | - # # Pull and run SQL Server container - # docker pull mcr.microsoft.com/mssql/server:2022-latest - # docker run \ - # --name sqlserver \ - # -e ACCEPT_EULA=Y \ - # -e MSSQL_SA_PASSWORD="${DB_PASSWORD}" \ - # -p 1433:1433 \ - # -d mcr.microsoft.com/mssql/server:2022-latest - - # # Starting SQL Server container… - # for i in {1..30}; do - # docker exec sqlserver \ - # /opt/mssql-tools18/bin/sqlcmd \ - # -S localhost \ - # -U SA \ - # -P "$DB_PASSWORD" \ - # -C -Q "SELECT 1" && break - # sleep 2 - # done - # displayName: 'Pull & start SQL Server (Docker)' - # env: - # DB_PASSWORD: $(DB_PASSWORD) - - - script: | - python -m pip install --upgrade pip - pip install -r requirements.txt - displayName: 'Install Python dependencies' - - - script: | - cd mssql_python/pybind - ./build.sh - displayName: 'Build pybind bindings (.so)' - - - script: | - echo "Build successful, running tests now" - python -m pytest -v --junitxml=test-results.xml --cov=. --cov-report=xml --capture=tee-sys --cache-clear - displayName: 'Run pytest with coverage' - env: - # Temporarily Use Azure SQL Database connection string for testing purposes since Docker takes too long to install & start in MacOS - DB_CONNECTION_STRING: $(AZURE_CONNECTION_STRING) - # DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' - DB_PASSWORD: $(DB_PASSWORD) - - - task: PublishTestResults@2 - condition: succeededOrFailed() - inputs: - testResultsFiles: '**/test-results.xml' - testRunTitle: 'Publish pytest results on macOS' - - - task: PublishCodeCoverageResults@1 - inputs: - codeCoverageTool: 'Cobertura' - summaryFileLocation: 'coverage.xml' - displayName: 'Publish code coverage results' + - job: TestMacOS + pool: + vmImage: 'macos-latest' + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.13.5' + addToPath: true + displayName: 'Use Python 3.13 on macOS' + + - script: | + brew update + brew install cmake + displayName: 'Install CMake' + + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + displayName: 'Install Python dependencies' + + - script: | + cd mssql_python/pybind + ./build.sh + displayName: 'Build pybind bindings (.so)' + + - script: | + echo "Running main.py on macOS" + python main.py + displayName: 'Run main.py to ensure Python is working' + env: + # Temporarily Use Azure SQL Database connection string for testing purposes since Docker takes too long to install & start in MacOS + DB_CONNECTION_STRING: $(AZURE_CONNECTION_STRING) + # DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_PASSWORD: $(DB_PASSWORD) + + - script: | + echo "Build successful, running tests now" + python -m pytest -v --junitxml=test-results.xml --cov=. --cov-report=xml --capture=tee-sys --cache-clear + displayName: 'Run pytest with coverage' + env: + # Temporarily Use Azure SQL Database connection string for testing purposes since Docker takes too long to install & start in MacOS + DB_CONNECTION_STRING: $(AZURE_CONNECTION_STRING) + # DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_PASSWORD: $(DB_PASSWORD) \ No newline at end of file diff --git a/mssql_python/connection_mac.py b/mssql_python/connection_mac.py deleted file mode 100644 index ab2cb167e..000000000 --- a/mssql_python/connection_mac.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -This module defines the Connection class, which is used to manage a connection to a database. -The class provides methods to establish a connection, create cursors, commit transactions, -roll back transactions, and close the connection. -""" -import ctypes -from mssql_python.cursor_mac import Cursor -from mssql_python.logging_config import get_logger, ENABLE_LOGGING -from mssql_python.constants import ConstantsDDBC as ddbc_sql_const -from mssql_python.helpers import add_driver_to_connection_str, check_error -from mssql_python import ddbc_bindings - -logger = get_logger() - - -class Connection: - """ - A class to manage a connection to a database, compliant with DB-API 2.0 specifications. - - This class provides methods to establish a connection to a database, create cursors, - commit transactions, roll back transactions, and close the connection. It is designed - to be used in a context where database operations are required, such as executing queries - and fetching results. - - Methods: - __init__(database: str) -> None: - connect_to_db() -> None: - cursor() -> Cursor: - commit() -> None: - rollback() -> None: - close() -> None: - """ - - def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> None: - """ - Initialize the connection object with the specified connection string and parameters. - - Args: - - connection_str (str): The connection string to connect to. - - autocommit (bool): If True, causes a commit to be performed after each SQL statement. - **kwargs: Additional key/value pairs for the connection string. - Not including below properties since we are driver doesn't support this: - - Returns: - None - - Raises: - ValueError: If the connection string is invalid or connection fails. - - This method sets up the initial state for the connection object, - preparing it for further operations such as connecting to the - database, executing queries, etc. - """ - self.henv = ctypes.c_void_p() - self.hdbc = ctypes.c_void_p() - self.connection_str = self._construct_connection_string( - connection_str, **kwargs - ) - self._initializer() - self._autocommit = autocommit - self.setautocommit(autocommit) - - def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str: - """ - Construct the connection string by concatenating the connection string - with key/value pairs from kwargs. - - Args: - connection_str (str): The base connection string. - **kwargs: Additional key/value pairs for the connection string. - - Returns: - str: The constructed connection string. - """ - # Add the driver attribute to the connection string - conn_str = add_driver_to_connection_str(connection_str) - - # Add additional key-value pairs to the connection string - for key, value in kwargs.items(): - if key.lower() == "host" or key.lower() == "server": - key = "Server" - elif key.lower() == "user" or key.lower() == "uid": - key = "Uid" - elif key.lower() == "password" or key.lower() == "pwd": - key = "Pwd" - elif key.lower() == "database": - key = "Database" - elif key.lower() == "encrypt": - key = "Encrypt" - elif key.lower() == "trust_server_certificate": - key = "TrustServerCertificate" - else: - continue - conn_str += f"{key}={value};" - - return conn_str - - def _initializer(self) -> None: - """ - Initialize the environment and connection handles. - - This method is responsible for setting up the environment and connection - handles, allocating memory for them, and setting the necessary attributes. - It should be called before establishing a connection to the database. - """ - self._allocate_environment_handle() - self._set_environment_attributes() - self._allocate_connection_handle() - self._set_connection_attributes() - self._connect_to_db() - - def _allocate_environment_handle(self): - """ - Allocate the environment handle. - """ - ret = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_ENV.value, # SQL environment handle type - 0, # SQL input handle - ctypes.cast( - ctypes.pointer(self.henv), ctypes.c_void_p - ).value, # SQL output handle pointer - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, self.henv.value, ret) - - def _set_environment_attributes(self): - """ - Set the environment attributes. - """ - ret = ddbc_bindings.DDBCSQLSetEnvAttr( - self.henv.value, # Environment handle - ddbc_sql_const.SQL_ATTR_DDBC_VERSION.value, # Attribute - ddbc_sql_const.SQL_OV_DDBC3_80.value, # String Length - 0, # Null-terminated string - ) - check_error(ddbc_sql_const.SQL_HANDLE_ENV.value, self.henv.value, ret) - - def _allocate_connection_handle(self): - """ - Allocate the connection handle. - """ - ret = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_DBC.value, # SQL connection handle type - self.henv.value, # SQL environment handle - ctypes.cast( - ctypes.pointer(self.hdbc), ctypes.c_void_p - ).value, # SQL output handle pointer - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - - def _set_connection_attributes(self): - """ - Set the connection attributes before connecting. - """ - if self.autocommit: - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc.value, - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value, - 0, - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - - def _connect_to_db(self) -> None: - """ - Establish a connection to the database. - - This method is responsible for creating a connection to the specified database. - It does not take any arguments and does not return any value. The connection - details such as database name, user credentials, host, and port should be - configured within the class or passed during the class instantiation. - - Raises: - DatabaseError: If there is an error while trying to connect to the database. - InterfaceError: If there is an error related to the database interface. - """ - if ENABLE_LOGGING: - logger.info("Connecting to the database") - ret = ddbc_bindings.DDBCSQLDriverConnect( - self.hdbc.value, # Connection handle - 0, # Window handle - self.connection_str, # Connection string - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - if ENABLE_LOGGING: - logger.info("Connection established successfully.") - - @property - def autocommit(self) -> bool: - """ - Return the current autocommit mode of the connection. - Returns: - bool: True if autocommit is enabled, False otherwise. - """ - autocommit_mode = ddbc_bindings.DDBCSQLGetConnectionAttr( - self.hdbc.value, # Connection handle - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ) - check_error( - ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, autocommit_mode - ) - return autocommit_mode == ddbc_sql_const.SQL_AUTOCOMMIT_ON.value - - @autocommit.setter - def autocommit(self, value: bool) -> None: - """ - Set the autocommit mode of the connection. - Args: - value (bool): True to enable autocommit, False to disable it. - Returns: - None - Raises: - DatabaseError: If there is an error while setting the autocommit mode. - """ - ret = ddbc_bindings.DDBCSQLSetConnectAttr( - self.hdbc.value, # Connection handle - ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value, # Attribute - ( - ddbc_sql_const.SQL_AUTOCOMMIT_ON.value - if value - else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value - ), # Value - 0, # String length - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - self._autocommit = value - if ENABLE_LOGGING: - logger.info("Autocommit mode set to %s.", value) - - def setautocommit(self, value: bool = True) -> None: - """ - Set the autocommit mode of the connection. - Args: - value (bool): True to enable autocommit, False to disable it. - Returns: - None - Raises: - DatabaseError: If there is an error while setting the autocommit mode. - """ - self.autocommit = value - - def cursor(self) -> Cursor: - """ - Return a new Cursor object using the connection. - - This method creates and returns a new cursor object that can be used to - execute SQL queries and fetch results. The cursor is associated with the - current connection and allows interaction with the database. - - Returns: - Cursor: A new cursor object for executing SQL queries. - - Raises: - DatabaseError: If there is an error while creating the cursor. - InterfaceError: If there is an error related to the database interface. - """ - return Cursor(self) - - def commit(self) -> None: - """ - Commit the current transaction. - - This method commits the current transaction to the database, making all - changes made during the transaction permanent. It should be called after - executing a series of SQL statements that modify the database to ensure - that the changes are saved. - - Raises: - DatabaseError: If there is an error while committing the transaction. - """ - # Commit the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc.value, # Connection handle - ddbc_sql_const.SQL_COMMIT.value, # Commit the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - if ENABLE_LOGGING: - logger.info("Transaction committed successfully.") - - def rollback(self) -> None: - """ - Roll back the current transaction. - - This method rolls back the current transaction, undoing all changes made - during the transaction. It should be called if an error occurs during the - transaction or if the changes should not be saved. - - Raises: - DatabaseError: If there is an error while rolling back the transaction. - """ - # Roll back the current transaction - ret = ddbc_bindings.DDBCSQLEndTran( - ddbc_sql_const.SQL_HANDLE_DBC.value, # Handle type - self.hdbc.value, # Connection handle - ddbc_sql_const.SQL_ROLLBACK.value, # Roll back the transaction - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - if ENABLE_LOGGING: - logger.info("Transaction rolled back successfully.") - - def close(self) -> None: - """ - Close the connection now (rather than whenever .__del__() is called). - - This method closes the connection to the database, releasing any resources - associated with it. After calling this method, the connection object should - not be used for any further operations. The same applies to all cursor objects - trying to use the connection. Note that closing a connection without committing - the changes first will cause an implicit rollback to be performed. - - Raises: - DatabaseError: If there is an error while closing the connection. - """ - # Disconnect from the database - ret = ddbc_bindings.DDBCSQLDisconnect(self.hdbc.value) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - - # Free the connection handle - ret = ddbc_bindings.DDBCSQLFreeHandle( - ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value - ) - check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret) - - if ENABLE_LOGGING: - logger.info("Connection closed successfully.") diff --git a/mssql_python/cursor_mac.py b/mssql_python/cursor_mac.py deleted file mode 100644 index 5a75ce6eb..000000000 --- a/mssql_python/cursor_mac.py +++ /dev/null @@ -1,741 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -This module contains the Cursor class, which represents a database cursor. -""" -import ctypes -import decimal -import uuid -import datetime -from typing import List, Union -from mssql_python.constants import ConstantsDDBC as ddbc_sql_const -from mssql_python.helpers import check_error -from mssql_python.logging_config import get_logger, ENABLE_LOGGING -from mssql_python import ddbc_bindings -from .row import Row - -logger = get_logger() - - -class Cursor: - """ - Represents a database cursor, which is used to manage the context of a fetch operation. - - Attributes: - connection: Database connection object. - description: Sequence of 7-item sequences describing one result column. - rowcount: Number of rows produced or affected by the last execute operation. - arraysize: Number of rows to fetch at a time with fetchmany(). - - Methods: - __init__(connection_str) -> None. - callproc(procname, parameters=None) -> - Modified copy of the input sequence with output parameters. - close() -> None. - execute(operation, parameters=None) -> None. - executemany(operation, seq_of_parameters) -> None. - fetchone() -> Single Row object or None if no more data is available. - fetchmany(size=None) -> List of Row objects. - fetchall() -> List of Row objects. - nextset() -> True if there is another result set, None otherwise. - setinputsizes(sizes) -> None. - setoutputsize(size, column=None) -> None. - """ - - def __init__(self, connection) -> None: - """ - Initialize the cursor with a database connection. - - Args: - connection: Database connection object. - """ - self.connection = connection - # self.connection.autocommit = False - self.hstmt = ctypes.c_void_p() - self._initialize_cursor() - self.description = None - self.rowcount = -1 - self.arraysize = ( - 1 # Default number of rows to fetch at a time is 1, user can change it - ) - self.buffer_length = 1024 # Default buffer length for string data - self.closed = False # Flag to indicate if the cursor is closed - self.last_executed_stmt = ( - "" # Stores the last statement executed by this cursor - ) - self.is_stmt_prepared = [ - False - ] # Indicates if last_executed_stmt was prepared by ddbc shim. - # Is a list instead of a bool coz bools in Python are immutable. - # Hence, we can't pass around bools by reference & modify them. - # Therefore, it must be a list with exactly one bool element. - - def _is_unicode_string(self, param): - """ - Check if a string contains non-ASCII characters. - - Args: - param: The string to check. - - Returns: - True if the string contains non-ASCII characters, False otherwise. - """ - try: - param.encode("ascii") - return False # Can be encoded to ASCII, so not Unicode - except UnicodeEncodeError: - return True # Contains non-ASCII characters, so treat as Unicode - - def _parse_date(self, param): - """ - Attempt to parse a string as a date. - - Args: - param: The string to parse. - - Returns: - A datetime.date object if parsing is successful, else None. - """ - formats = ["%Y-%m-%d"] - for fmt in formats: - try: - return datetime.datetime.strptime(param, fmt).date() - except ValueError: - continue - return None - - def _parse_datetime(self, param): - """ - Attempt to parse a string as a datetime, smalldatetime, datetime2, timestamp. - - Args: - param: The string to parse. - - Returns: - A datetime.datetime object if parsing is successful, else None. - """ - formats = [ - "%Y-%m-%dT%H:%M:%S.%f", # ISO 8601 datetime with fractional seconds - "%Y-%m-%dT%H:%M:%S", # ISO 8601 datetime - "%Y-%m-%d %H:%M:%S.%f", # Datetime with fractional seconds - "%Y-%m-%d %H:%M:%S", # Datetime without fractional seconds - ] - for fmt in formats: - try: - return datetime.datetime.strptime(param, fmt) # Valid datetime - except ValueError: - continue # Try next format - - return None # If all formats fail, return None - - def _parse_time(self, param): - """ - Attempt to parse a string as a time. - - Args: - param: The string to parse. - - Returns: - A datetime.time object if parsing is successful, else None. - """ - formats = [ - "%H:%M:%S", # Time only - "%H:%M:%S.%f", # Time with fractional seconds - ] - for fmt in formats: - try: - return datetime.datetime.strptime(param, fmt).time() - except ValueError: - continue - return None - - def _get_numeric_data(self, param): - """ - Get the data for a numeric parameter. - - Args: - param: The numeric parameter. - - Returns: - numeric_data: A NumericData struct containing - the numeric data. - """ - decimal_as_tuple = param.as_tuple() - num_digits = len(decimal_as_tuple.digits) - exponent = decimal_as_tuple.exponent - - # Calculate the SQL precision & scale - # precision = no. of significant digits - # scale = no. digits after decimal point - if exponent >= 0: - # digits=314, exp=2 ---> '31400' --> precision=5, scale=0 - precision = num_digits + exponent - scale = 0 - elif (-1 * exponent) <= num_digits: - # digits=3140, exp=-3 ---> '3.140' --> precision=4, scale=3 - precision = num_digits - scale = exponent * -1 - else: - # digits=3140, exp=-5 ---> '0.03140' --> precision=5, scale=5 - # TODO: double check the precision calculation here with SQL documentation - precision = exponent * -1 - scale = exponent * -1 - - # TODO: Revisit this check, do we want this restriction? - if precision > 15: - raise ValueError( - "Precision of the numeric value is too high - " - + str(param) - + ". Should be less than or equal to 15" - ) - Numeric_Data = ddbc_bindings.NumericData - numeric_data = Numeric_Data() - numeric_data.scale = scale - numeric_data.precision = precision - numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0 - # strip decimal point from param & convert the significant digits to integer - # Ex: 12.34 ---> 1234 - val = str(param) - if "." in val or "-" in val: - val = val.replace(".", "") - val = val.replace("-", "") - val = int(val) - numeric_data.val = val - return numeric_data - - def _map_sql_type(self, param, parameters_list, i): - """ - Map a Python data type to the corresponding SQL type, - C type, Column size, and Decimal digits. - Takes: - - param: The parameter to map. - - parameters_list: The list of parameters to bind. - - i: The index of the parameter in the list. - Returns: - - A tuple containing the SQL type, C type, column size, and decimal digits. - """ - if param is None: - return ( - ddbc_sql_const.SQL_VARCHAR.value, # TODO: Add SQLDescribeParam to get correct type - ddbc_sql_const.SQL_C_DEFAULT.value, - 1, - 0, - ) - - if isinstance(param, bool): - return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0 - - if isinstance(param, int): - if 0 <= param <= 255: - return ( - ddbc_sql_const.SQL_TINYINT.value, - ddbc_sql_const.SQL_C_TINYINT.value, - 3, - 0, - ) - if -32768 <= param <= 32767: - return ( - ddbc_sql_const.SQL_SMALLINT.value, - ddbc_sql_const.SQL_C_SHORT.value, - 5, - 0, - ) - if -2147483648 <= param <= 2147483647: - return ( - ddbc_sql_const.SQL_INTEGER.value, - ddbc_sql_const.SQL_C_LONG.value, - 10, - 0, - ) - return ( - ddbc_sql_const.SQL_BIGINT.value, - ddbc_sql_const.SQL_C_SBIGINT.value, - 19, - 0, - ) - - if isinstance(param, float): - return ( - ddbc_sql_const.SQL_DOUBLE.value, - ddbc_sql_const.SQL_C_DOUBLE.value, - 15, - 0, - ) - - if isinstance(param, decimal.Decimal): - parameters_list[i] = self._get_numeric_data( - param - ) # Replace the parameter with the dictionary - return ( - ddbc_sql_const.SQL_NUMERIC.value, - ddbc_sql_const.SQL_C_NUMERIC.value, - parameters_list[i].precision, - parameters_list[i].scale, - ) - - if isinstance(param, str): - if ( - param.startswith("POINT") - or param.startswith("LINESTRING") - or param.startswith("POLYGON") - ): - return ( - ddbc_sql_const.SQL_WVARCHAR.value, - ddbc_sql_const.SQL_C_WCHAR.value, - len(param), - 0, - ) - - # Attempt to parse as date, datetime, datetime2, timestamp, smalldatetime or time - if self._parse_date(param): - parameters_list[i] = self._parse_date( - param - ) # Replace the parameter with the date object - return ( - ddbc_sql_const.SQL_DATE.value, - ddbc_sql_const.SQL_C_TYPE_DATE.value, - 10, - 0, - ) - if self._parse_datetime(param): - parameters_list[i] = self._parse_datetime(param) - return ( - ddbc_sql_const.SQL_TIMESTAMP.value, - ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, - 26, - 6, - ) - if self._parse_time(param): - parameters_list[i] = self._parse_time(param) - return ( - ddbc_sql_const.SQL_TIME.value, - ddbc_sql_const.SQL_C_TYPE_TIME.value, - 8, - 0, - ) - - # String mapping logic here - is_unicode = self._is_unicode_string(param) - # TODO: revisit - if len(param) > 4000: # Long strings - if is_unicode: - return ( - ddbc_sql_const.SQL_WLONGVARCHAR.value, - ddbc_sql_const.SQL_C_WCHAR.value, - len(param), - 0, - ) - return ( - ddbc_sql_const.SQL_LONGVARCHAR.value, - ddbc_sql_const.SQL_C_CHAR.value, - len(param), - 0, - ) - if is_unicode: # Short Unicode strings - return ( - ddbc_sql_const.SQL_WVARCHAR.value, - ddbc_sql_const.SQL_C_WCHAR.value, - len(param), - 0, - ) - return ( - ddbc_sql_const.SQL_VARCHAR.value, - ddbc_sql_const.SQL_C_CHAR.value, - len(param), - 0, - ) - - if isinstance(param, bytes): - if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays - return ( - ddbc_sql_const.SQL_VARBINARY.value, - ddbc_sql_const.SQL_C_BINARY.value, - len(param), - 0, - ) - return ( - ddbc_sql_const.SQL_BINARY.value, - ddbc_sql_const.SQL_C_BINARY.value, - len(param), - 0, - ) - - if isinstance(param, bytearray): - if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays - return ( - ddbc_sql_const.SQL_VARBINARY.value, - ddbc_sql_const.SQL_C_BINARY.value, - len(param), - 0, - ) - return ( - ddbc_sql_const.SQL_BINARY.value, - ddbc_sql_const.SQL_C_BINARY.value, - len(param), - 0, - ) - - if isinstance(param, datetime.datetime): - return ( - ddbc_sql_const.SQL_TIMESTAMP.value, - ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, - 26, - 6, - ) - - if isinstance(param, datetime.date): - return ( - ddbc_sql_const.SQL_DATE.value, - ddbc_sql_const.SQL_C_TYPE_DATE.value, - 10, - 0, - ) - - if isinstance(param, datetime.time): - return ( - ddbc_sql_const.SQL_TIME.value, - ddbc_sql_const.SQL_C_TYPE_TIME.value, - 8, - 0, - ) - - return ( - ddbc_sql_const.SQL_VARCHAR.value, - ddbc_sql_const.SQL_C_CHAR.value, - len(str(param)), - 0, - ) - - def _initialize_cursor(self) -> None: - """ - Initialize the DDBC statement handle. - """ - self._allocate_statement_handle() - - def _allocate_statement_handle(self): - """ - Allocate the DDBC statement handle. - """ - ret = ddbc_bindings.DDBCSQLAllocHandle( - ddbc_sql_const.SQL_HANDLE_STMT.value, - self.connection.hdbc.value, - ctypes.cast(ctypes.pointer(self.hstmt), ctypes.c_void_p).value, - ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - - def _reset_cursor(self) -> None: - """ - Reset the DDBC statement handle. - """ - # Free the existing statement handle - if self.hstmt.value: - ddbc_bindings.DDBCSQLFreeHandle( - ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value - ) - # Reinitialize the statement handle - self._initialize_cursor() - - def close(self) -> None: - """ - Close the cursor now (rather than whenever __del__ is called). - - Raises: - Error: If any operation is attempted with the cursor after it is closed. - """ - if self.closed: - raise RuntimeError("Cursor is already closed.") - - if self.hstmt.value: - ret = ddbc_bindings.DDBCSQLFreeHandle( - ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value - ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - self.hstmt.value = None - - self.closed = True - - def _check_closed(self): - """ - Check if the cursor is closed and raise an exception if it is. - - Raises: - Error: If the cursor is closed. - """ - if self.closed: - raise RuntimeError("Operation cannot be performed: the cursor is closed.") - - def _create_parameter_types_list(self, parameter, param_info, parameters_list, i): - """ - Maps parameter types for the given parameter. - - Args: - parameter: parameter to bind. - - Returns: - paraminfo. - """ - paraminfo = param_info() - sql_type, c_type, column_size, decimal_digits = self._map_sql_type( - parameter, parameters_list, i - ) - paraminfo.paramCType = c_type - paraminfo.paramSQLType = sql_type - paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value - paraminfo.columnSize = column_size - paraminfo.decimalDigits = decimal_digits - return paraminfo - - def _initialize_description(self): - """ - Initialize the description attribute using SQLDescribeCol. - """ - col_metadata = [] - ret = ddbc_bindings.DDBCSQLDescribeCol(self.hstmt.value, col_metadata) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - - self.description = [ - ( - col["ColumnName"], - self._map_data_type(col["DataType"]), - None, - col["ColumnSize"], - col["ColumnSize"], - col["DecimalDigits"], - col["Nullable"] == ddbc_sql_const.SQL_NULLABLE.value, - ) - for col in col_metadata - ] - - def _map_data_type(self, sql_type): - """ - Map SQL data type to Python data type. - - Args: - sql_type: SQL data type. - - Returns: - Corresponding Python data type. - """ - sql_to_python_type = { - ddbc_sql_const.SQL_INTEGER.value: int, - ddbc_sql_const.SQL_VARCHAR.value: str, - ddbc_sql_const.SQL_WVARCHAR.value: str, - ddbc_sql_const.SQL_CHAR.value: str, - ddbc_sql_const.SQL_WCHAR.value: str, - ddbc_sql_const.SQL_FLOAT.value: float, - ddbc_sql_const.SQL_DOUBLE.value: float, - ddbc_sql_const.SQL_DECIMAL.value: decimal.Decimal, - ddbc_sql_const.SQL_NUMERIC.value: decimal.Decimal, - ddbc_sql_const.SQL_DATE.value: datetime.date, - ddbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime, - ddbc_sql_const.SQL_TIME.value: datetime.time, - ddbc_sql_const.SQL_BIT.value: bool, - ddbc_sql_const.SQL_TINYINT.value: int, - ddbc_sql_const.SQL_SMALLINT.value: int, - ddbc_sql_const.SQL_BIGINT.value: int, - ddbc_sql_const.SQL_BINARY.value: bytes, - ddbc_sql_const.SQL_VARBINARY.value: bytes, - ddbc_sql_const.SQL_LONGVARBINARY.value: bytes, - ddbc_sql_const.SQL_GUID.value: uuid.UUID, - # Add more mappings as needed - } - return sql_to_python_type.get(sql_type, str) - - def execute( - self, - operation: str, - *parameters, - use_prepare: bool = True, - reset_cursor: bool = True - ) -> None: - """ - Prepare and execute a database operation (query or command). - - Args: - operation: SQL query or command. - parameters: Sequence of parameters to bind. - use_prepare: Whether to use SQLPrepareW (default) or SQLExecDirectW. - reset_cursor: Whether to reset the cursor before execution. - """ - self._check_closed() # Check if the cursor is closed - - if reset_cursor: - self._reset_cursor() - - param_info = ddbc_bindings.ParamInfo - parameters_type = [] - - # Flatten parameters if a single tuple or list is passed - if len(parameters) == 1 and isinstance(parameters[0], (tuple, list)): - parameters = parameters[0] - - parameters = list(parameters) - - if parameters: - for i, param in enumerate(parameters): - paraminfo = self._create_parameter_types_list( - param, param_info, parameters, i - ) - parameters_type.append(paraminfo) - - # TODO: Use a more sophisticated string compare that handles redundant spaces etc. - # Also consider storing last query's hash instead of full query string. This will help - # in low-memory conditions - # (Ex: huge number of parallel queries with huge query string sizes) - if operation != self.last_executed_stmt: -# Executing a new statement. Reset is_stmt_prepared to false - self.is_stmt_prepared = [False] - - if ENABLE_LOGGING: - logger.debug("Executing query: %s", operation) - for i, param in enumerate(parameters): - logger.debug( - """Parameter number: %s, Parameter: %s, - Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", - i + 1, - param, - str(type(param)), - parameters_type[i].paramSQLType, - parameters_type[i].paramCType, - parameters_type[i].columnSize, - parameters_type[i].decimalDigits, - parameters_type[i].inputOutputType, - ) - - ret = ddbc_bindings.DDBCSQLExecute( - self.hstmt.value, - operation, - parameters, - parameters_type, - self.is_stmt_prepared, - use_prepare, - ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - self.last_executed_stmt = operation - - # Update rowcount after execution - # TODO: rowcount return code from SQL needs to be handled - self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt.value) - - # Initialize description after execution - self._initialize_description() - - def executemany(self, operation: str, seq_of_parameters: list) -> None: - """ - Prepare a database operation and execute it against all parameter sequences. - - Args: - operation: SQL query or command. - seq_of_parameters: Sequence of sequences or mappings of parameters. - - Raises: - Error: If the operation fails. - """ - self._check_closed() # Check if the cursor is closed - - self._reset_cursor() - - first_execution = True - total_rowcount = 0 - for parameters in seq_of_parameters: - parameters = list(parameters) - if ENABLE_LOGGING: - logger.info("Executing query with parameters: %s", parameters) - prepare_stmt = first_execution - first_execution = False - self.execute( - operation, parameters, use_prepare=prepare_stmt, reset_cursor=False - ) - if self.rowcount != -1: - total_rowcount += self.rowcount - else: - total_rowcount = -1 - self.rowcount = total_rowcount - - def fetchone(self) -> Union[None, Row]: - """ - Fetch the next row of a query result set. - - Returns: - Single Row object or None if no more data is available. - """ - self._check_closed() # Check if the cursor is closed - - # Fetch raw data - row_data = [] - ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt.value, row_data) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - - if ret == ddbc_sql_const.SQL_NO_DATA.value: - return None - - # Create and return a Row object - return Row(row_data, self.description) - - def fetchmany(self, size: int = None) -> List[Row]: - """ - Fetch the next set of rows of a query result. - - Args: - size: Number of rows to fetch at a time. - - Returns: - List of Row objects. - """ - self._check_closed() # Check if the cursor is closed - - if size is None: - size = self.arraysize - - if size <= 0: - return [] - - # Fetch raw data - rows_data = [] - ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt.value, rows_data, size) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - - if ret == ddbc_sql_const.SQL_NO_DATA.value: - return [] - - # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result. - - Returns: - List of Row objects. - """ - self._check_closed() # Check if the cursor is closed - - # Fetch raw data - rows_data = [] - ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt.value, rows_data) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - - if ret != ddbc_sql_const.SQL_NO_DATA.value: - return [] - - # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] - - def nextset(self) -> Union[bool, None]: - """ - Skip to the next available result set. - - Returns: - True if there is another result set, None otherwise. - - Raises: - Error: If the previous call to execute did not produce any result set. - """ - self._check_closed() # Check if the cursor is closed - - # Skip to the next result set - ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt.value) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - return False - return True diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index 6668f2b91..9c688ac61 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -3,11 +3,7 @@ Licensed under the MIT license. This module provides a way to create a new connection object to interact with the database. """ -import platform -if platform.system() == 'Windows': - from mssql_python.connection import Connection -else: - from mssql_python.connection_mac import Connection +from mssql_python.connection import Connection def connect(connection_str: str = "", autocommit: bool = True, attrs_before: dict = None, **kwargs) -> Connection: """ diff --git a/mssql_python/pybind/CMakeLists.txt b/mssql_python/pybind/CMakeLists.txt index 4b4c1f990..068d83b6f 100644 --- a/mssql_python/pybind/CMakeLists.txt +++ b/mssql_python/pybind/CMakeLists.txt @@ -178,20 +178,11 @@ endif() message(STATUS "Final Python library directory: ${PYTHON_LIB_DIR}") -# Determine which source file to use based on platform -if(APPLE) - set(DDBC_SOURCE "ddbc_bindings_mac.cpp") - message(STATUS "Using macOS-specific source file: ${DDBC_SOURCE}") - # Only ddbc_bindings_mac.cpp is used on macOS - # TODO: Implement smart pointer in macOS - add_library(ddbc_bindings MODULE ${DDBC_SOURCE}) -else() - # This is just windows block - set(DDBC_SOURCE "ddbc_bindings.cpp") - message(STATUS "Using standard source file: ${DDBC_SOURCE}") - # Include connection module for Windows - add_library(ddbc_bindings MODULE ${DDBC_SOURCE} connection/connection.cpp connection/connection_pool.cpp) -endif() +set(DDBC_SOURCE "ddbc_bindings.cpp") +message(STATUS "Using standard source file: ${DDBC_SOURCE}") +# Include connection module for Windows +add_library(ddbc_bindings MODULE ${DDBC_SOURCE} connection/connection.cpp connection/connection_pool.cpp) +# endif() # Set the output name to include Python version and architecture # Use appropriate file extension based on platform diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index b5184f4e6..06e79f02a 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -67,9 +67,20 @@ void Connection::connect(const py::dict& attrs_before) { setAutocommit(_autocommit); } } + SQLWCHAR* connStrPtr; +#if defined(__APPLE__) // macOS specific code + LOG("Creating connection string buffer for macOS"); + std::vector connStrBuffer = WStringToSQLWCHAR(_connStr); + // Ensure the buffer is null-terminated + LOG("Connection string buffer size - {}", connStrBuffer.size()); + connStrPtr = connStrBuffer.data(); + LOG("Connection string buffer created"); +#else + connStrPtr = const_cast(_connStr.c_str()); +#endif SQLRETURN ret = SQLDriverConnect_ptr( _dbcHandle->get(), nullptr, - (SQLWCHAR*)_connStr.c_str(), SQL_NTS, + connStrPtr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); checkError(ret); updateLastUsed(); @@ -236,12 +247,13 @@ std::chrono::steady_clock::time_point Connection::lastUsed() const { return _lastUsed; } -ConnectionHandle::ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore) - : _connStr(connStr), _usePool(usePool) { +ConnectionHandle::ConnectionHandle(const std::string& connStr, bool usePool, const py::dict& attrsBefore) + : _usePool(usePool) { + _connStr = Utf8ToWString(connStr); if (_usePool) { - _conn = ConnectionPoolManager::getInstance().acquireConnection(connStr, attrsBefore); + _conn = ConnectionPoolManager::getInstance().acquireConnection(_connStr, attrsBefore); } else { - _conn = std::make_shared(connStr, false); + _conn = std::make_shared(_connStr, false); _conn->connect(attrsBefore); } } diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index b9cc50b68..6129125e1 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -57,7 +57,7 @@ class Connection { class ConnectionHandle { public: - ConnectionHandle(const std::wstring& connStr, bool usePool, const py::dict& attrsBefore = py::dict()); + ConnectionHandle(const std::string& connStr, bool usePool, const py::dict& attrsBefore = py::dict()); ~ConnectionHandle(); void close(); diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp index 45e4ff3dc..249b0800e 100644 --- a/mssql_python/pybind/connection/connection_pool.cpp +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -5,7 +5,6 @@ // taken up in future. #include "connection_pool.h" -#include #include ConnectionPool::ConnectionPool(size_t max_size, int idle_timeout_secs) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c9525e0a1..07da2572e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -11,10 +11,7 @@ #include // std::setw, std::setfill #include #include // std::forward - -// Replace std::filesystem usage with Windows-specific headers -#include -#pragma comment(lib, "shlwapi.lib") +#include //------------------------------------------------------------------------------------------------- // Macro definitions @@ -193,9 +190,11 @@ ParamType* AllocateParamBuffer(std::vector>& paramBuffers, SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, const std::vector& paramInfos, std::vector>& paramBuffers) { + LOG("Starting parameter binding. Number of parameters: {}", params.size()); for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) { const auto& param = params[paramIndex]; const ParamInfo& paramInfo = paramInfos[paramIndex]; + LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType); void* dataPtr = nullptr; SQLLEN bufferLength = 0; SQLLEN* strLenOrIndPtr = nullptr; @@ -233,8 +232,46 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, "Streaming parameters is not yet supported. Parameter size" " must be less than 8192 bytes"); } + + // Log detailed parameter information + LOG("SQL_C_WCHAR Parameter[{}]: Length={}, Content='{}'", + paramIndex, + strParam->size(), + (strParam->size() <= 100 + ? WideToUTF8(std::wstring(strParam->begin(), strParam->end())) + : WideToUTF8(std::wstring(strParam->begin(), strParam->begin() + 100)) + "...")); + + // Log each character's code point for debugging + if (strParam->size() <= 20) { + for (size_t i = 0; i < strParam->size(); i++) { + LOG(" char[{}] = {} ({})", i, static_cast((*strParam)[i]), + ((*strParam)[i] >= 32 && (*strParam)[i] <= 126) ? + static_cast((*strParam)[i]) : '?'); + } + } +#if defined(__APPLE__) + // On macOS, we need special handling for wide characters + // Create a properly encoded SQLWCHAR buffer for the parameter + std::vector* sqlwcharBuffer = + AllocateParamBuffer>(paramBuffers); + + // Reserve space and convert from wstring to SQLWCHAR array + sqlwcharBuffer->resize(strParam->size() + 1, 0); // +1 for null terminator + + // Convert each wchar_t (4 bytes on macOS) to SQLWCHAR (2 bytes) + for (size_t i = 0; i < strParam->size(); i++) { + (*sqlwcharBuffer)[i] = static_cast((*strParam)[i]); + } + + // Use the SQLWCHAR buffer instead of the wstring directly + dataPtr = sqlwcharBuffer->data(); + bufferLength = (strParam->size() + 1) * sizeof(SQLWCHAR); + LOG("macOS: Created SQLWCHAR buffer for parameter with size: {} bytes", bufferLength); +#else + // On Windows, wchar_t and SQLWCHAR are the same size, so direct cast works dataPtr = const_cast(static_cast(strParam->c_str())); bufferLength = (strParam->size() + 1 /* null terminator */) * sizeof(wchar_t); +#endif strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_NTS; break; @@ -464,6 +501,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, } } } + LOG("Finished parameter binding. Number of parameters: {}", params.size()); return SQL_SUCCESS; } @@ -495,14 +533,6 @@ void LOG(const std::string& formatString, Args&&... args) { logging.attr("debug")(message); } -std::string WideToUTF8(const std::wstring& wstr) { - if (wstr.empty()) return {}; - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), nullptr, 0, nullptr, nullptr); - std::string result(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), result.data(), size_needed, nullptr, nullptr); - return result; -} - // TODO: Add more nuanced exception classes void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } @@ -511,130 +541,174 @@ std::string GetModuleDirectory() { py::object module_path = module.attr("__file__"); std::string module_file = module_path.cast(); +#ifdef _WIN32 + // Windows-specific path handling char path[MAX_PATH]; strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); PathRemoveFileSpecA(path); return std::string(path); -} - -// Helper to load the driver -// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit -// linking to load this DLL. It will simplify the code a lot. -std::wstring LoadDriverOrThrowException() { - const std::wstring& modulePath = L""; - std::wstring ddbcModulePath = modulePath; - if (ddbcModulePath.empty()) { - // Get the module path if not provided - std::string path = GetModuleDirectory(); - ddbcModulePath = std::wstring(path.begin(), path.end()); +#else + // macOS/Unix path handling without using std::filesystem + std::string::size_type pos = module_file.find_last_of('/'); + if (pos != std::string::npos) { + std::string dir = module_file.substr(0, pos); + return dir; } + std::cerr << "DEBUG: Could not extract directory from path: " << module_file << std::endl; + return module_file; +#endif +} - std::wstring dllDir = ddbcModulePath; - dllDir += L"\\libs\\"; - - // Convert ARCHITECTURE macro to wstring - std::wstring archStr(ARCHITECTURE, ARCHITECTURE + strlen(ARCHITECTURE)); +// Platform-agnostic function to load the driver dynamic library +DriverHandle LoadDriverLibrary(const std::string& driverPath) { + LOG("Loading driver from path: {}", driverPath); - // Map architecture identifiers to correct subdirectory names - std::wstring archDir; - if (archStr == L"win64" || archStr == L"amd64" || archStr == L"x64") { - archDir = L"x64"; - } else if (archStr == L"arm64") { - archDir = L"arm64"; +#ifdef _WIN32 + // Windows: Convert string to wide string for LoadLibraryW + std::wstring widePath(driverPath.begin(), driverPath.end()); + return LoadLibraryW(widePath.c_str()); +#else + // macOS/Unix: Use dlopen + return dlopen(driverPath.c_str(), RTLD_LAZY); +#endif +} + +// Platform-agnostic function to get last error message +std::string GetLastErrorMessage() { +#ifdef _WIN32 + // Windows: Use FormatMessageA + DWORD error = GetLastError(); + char* messageBuffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + error, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&messageBuffer, + 0, + NULL + ); + std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; + LocalFree(messageBuffer); + return "Error code: " + std::to_string(error) + " - " + errorMessage; +#else + // macOS/Unix: Use dlerror + const char* error = dlerror(); + return error ? std::string(error) : "Unknown error"; +#endif +} + +DriverHandle LoadDriverOrThrowException() { + namespace fs = std::filesystem; + + std::string moduleDir = GetModuleDirectory(); + LOG("Module directory: {}", moduleDir); + + std::string archStr = ARCHITECTURE; + std::string archDir = + (archStr == "win64" || archStr == "amd64" || archStr == "x64") ? "x64" : + (archStr == "arm64") ? "arm64" : + "x86"; + + fs::path driverPath; + +#ifdef _WIN32 + fs::path dllDir = fs::path(moduleDir) / "libs" / archDir; + + // Optionally load mssql-auth.dll if it exists + fs::path authDllPath = dllDir / "mssql-auth.dll"; + if (fs::exists(authDllPath)) { + HMODULE hAuth = LoadLibraryW(std::wstring(authDllPath.native().begin(), authDllPath.native().end()).c_str()); + if (hAuth) { + LOG("Authentication DLL loaded: {}", authDllPath.string()); + } else { + LOG("Failed to load mssql-auth.dll: {}", GetLastErrorMessage()); + } } else { - archDir = L"x86"; + LOG("Note: mssql-auth.dll not found. This is OK if Entra ID is not in use."); } - dllDir += archDir; - std::wstring mssqlauthDllPath = dllDir + L"\\mssql-auth.dll"; - dllDir += L"\\msodbcsql18.dll"; - - // Preload mssql-auth.dll from the same path if available - HMODULE hAuthModule = LoadLibraryW(mssqlauthDllPath.c_str()); - if (hAuthModule) { - LOG("Authentication library loaded successfully from - {}", mssqlauthDllPath.c_str()); + + driverPath = dllDir / "msodbcsql18.dll"; + +#else // macOS + std::string runtimeArch = + #if defined(__arm64__) || defined(__aarch64__) + "arm64"; + #else + "x86_64"; + #endif + + fs::path primaryPath = fs::path(moduleDir) / "libs" / "macos" / runtimeArch / "lib" / "libmsodbcsql.18.dylib"; + if (fs::exists(primaryPath)) { + driverPath = primaryPath; + LOG("macOS driver found at: {}", driverPath.string()); } else { - LOG("Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication.", mssqlauthDllPath.c_str()); + driverPath = fs::path(moduleDir) / "libs" / archDir / "macos" / "lib" / "libmsodbcsql.18.dylib"; + LOG("Using fallback macOS driver path: {}", driverPath.string()); } +#endif - // Convert wstring to string for logging - LOG("Attempting to load driver from - {}", WideToUTF8(dllDir)); - - HMODULE hModule = LoadLibraryW(dllDir.c_str()); - if (!hModule) { - // Failed to load the DLL, get the error message - DWORD error = GetLastError(); - char* messageBuffer = nullptr; - size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, - 0, - NULL - ); - std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; - LocalFree(messageBuffer); - - // Log the error message - LOG("Failed to load the driver with error code: {} - {}", error, errorMessage); - ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly."); + if (!fs::exists(driverPath)) { + ThrowStdException("ODBC driver not found at: " + driverPath.string()); } - // If we got here, we've successfully loaded the DLL. Now get the function pointers. - // Environment and handle function loading - SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress(hModule, "SQLAllocHandle"); - SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress(hModule, "SQLSetEnvAttr"); - SQLSetConnectAttr_ptr = (SQLSetConnectAttrFunc)GetProcAddress(hModule, "SQLSetConnectAttrW"); - SQLSetStmtAttr_ptr = (SQLSetStmtAttrFunc)GetProcAddress(hModule, "SQLSetStmtAttrW"); - SQLGetConnectAttr_ptr = (SQLGetConnectAttrFunc)GetProcAddress(hModule, "SQLGetConnectAttrW"); - - // Connection and statement function loading - SQLDriverConnect_ptr = (SQLDriverConnectFunc)GetProcAddress(hModule, "SQLDriverConnectW"); - SQLExecDirect_ptr = (SQLExecDirectFunc)GetProcAddress(hModule, "SQLExecDirectW"); - SQLPrepare_ptr = (SQLPrepareFunc)GetProcAddress(hModule, "SQLPrepareW"); - SQLBindParameter_ptr = (SQLBindParameterFunc)GetProcAddress(hModule, "SQLBindParameter"); - SQLExecute_ptr = (SQLExecuteFunc)GetProcAddress(hModule, "SQLExecute"); - SQLRowCount_ptr = (SQLRowCountFunc)GetProcAddress(hModule, "SQLRowCount"); - SQLGetStmtAttr_ptr = (SQLGetStmtAttrFunc)GetProcAddress(hModule, "SQLGetStmtAttrW"); - SQLSetDescField_ptr = (SQLSetDescFieldFunc)GetProcAddress(hModule, "SQLSetDescFieldW"); - - // Fetch and data retrieval function loading - SQLFetch_ptr = (SQLFetchFunc)GetProcAddress(hModule, "SQLFetch"); - SQLFetchScroll_ptr = (SQLFetchScrollFunc)GetProcAddress(hModule, "SQLFetchScroll"); - SQLGetData_ptr = (SQLGetDataFunc)GetProcAddress(hModule, "SQLGetData"); - SQLNumResultCols_ptr = (SQLNumResultColsFunc)GetProcAddress(hModule, "SQLNumResultCols"); - SQLBindCol_ptr = (SQLBindColFunc)GetProcAddress(hModule, "SQLBindCol"); - SQLDescribeCol_ptr = (SQLDescribeColFunc)GetProcAddress(hModule, "SQLDescribeColW"); - SQLMoreResults_ptr = (SQLMoreResultsFunc)GetProcAddress(hModule, "SQLMoreResults"); - SQLColAttribute_ptr = (SQLColAttributeFunc)GetProcAddress(hModule, "SQLColAttributeW"); - - // Transaction functions loading - SQLEndTran_ptr = (SQLEndTranFunc)GetProcAddress(hModule, "SQLEndTran"); - - // Disconnect and free functions loading - SQLFreeHandle_ptr = (SQLFreeHandleFunc)GetProcAddress(hModule, "SQLFreeHandle"); - SQLDisconnect_ptr = (SQLDisconnectFunc)GetProcAddress(hModule, "SQLDisconnect"); - SQLFreeStmt_ptr = (SQLFreeStmtFunc)GetProcAddress(hModule, "SQLFreeStmt"); - - // Diagnostic record function Loading - SQLGetDiagRec_ptr = (SQLGetDiagRecFunc)GetProcAddress(hModule, "SQLGetDiagRecW"); - - bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && - SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && - SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && - SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr && - SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr && - SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && - SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr && - SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + DriverHandle handle = LoadDriverLibrary(driverPath.string()); + if (!handle) { + LOG("Failed to load driver: {}", GetLastErrorMessage()); + ThrowStdException("Failed to load ODBC driver. Please check installation."); + } + + LOG("Driver library successfully loaded."); + + // Load function pointers using helper + SQLAllocHandle_ptr = GetFunctionPointer(handle, "SQLAllocHandle"); + SQLSetEnvAttr_ptr = GetFunctionPointer(handle, "SQLSetEnvAttr"); + SQLSetConnectAttr_ptr = GetFunctionPointer(handle, "SQLSetConnectAttrW"); + SQLSetStmtAttr_ptr = GetFunctionPointer(handle, "SQLSetStmtAttrW"); + SQLGetConnectAttr_ptr = GetFunctionPointer(handle, "SQLGetConnectAttrW"); + + SQLDriverConnect_ptr = GetFunctionPointer(handle, "SQLDriverConnectW"); + SQLExecDirect_ptr = GetFunctionPointer(handle, "SQLExecDirectW"); + SQLPrepare_ptr = GetFunctionPointer(handle, "SQLPrepareW"); + SQLBindParameter_ptr = GetFunctionPointer(handle, "SQLBindParameter"); + SQLExecute_ptr = GetFunctionPointer(handle, "SQLExecute"); + SQLRowCount_ptr = GetFunctionPointer(handle, "SQLRowCount"); + SQLGetStmtAttr_ptr = GetFunctionPointer(handle, "SQLGetStmtAttrW"); + SQLSetDescField_ptr = GetFunctionPointer(handle, "SQLSetDescFieldW"); + + SQLFetch_ptr = GetFunctionPointer(handle, "SQLFetch"); + SQLFetchScroll_ptr = GetFunctionPointer(handle, "SQLFetchScroll"); + SQLGetData_ptr = GetFunctionPointer(handle, "SQLGetData"); + SQLNumResultCols_ptr = GetFunctionPointer(handle, "SQLNumResultCols"); + SQLBindCol_ptr = GetFunctionPointer(handle, "SQLBindCol"); + SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); + SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); + SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); + + SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); + SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); + SQLFreeHandle_ptr = GetFunctionPointer(handle, "SQLFreeHandle"); + SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); + + SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + + bool success = + SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && + SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && + SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && + SQLExecute_ptr && SQLRowCount_ptr && SQLGetStmtAttr_ptr && + SQLSetDescField_ptr && SQLFetch_ptr && SQLFetchScroll_ptr && + SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && + SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && + SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && + SQLFreeStmt_ptr && SQLGetDiagRec_ptr; if (!success) { - ThrowStdException("Failed to load required function pointers from driver"); + ThrowStdException("Missing required ODBC function pointers."); } - LOG("Successfully loaded function pointers from driver"); - - return dllDir; + + LOG("All driver function pointers successfully loaded."); + return handle; } // DriverLoader definition @@ -714,8 +788,15 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET &nativeError, message, SQL_MAX_MESSAGE_LENGTH, &messageLen); if (SQL_SUCCEEDED(diagReturn)) { +#if defined(_WIN32) + // On Windows, SQLWCHAR and wchar_t are compatible errorInfo.sqlState = std::wstring(sqlState); errorInfo.ddbcErrorMsg = std::wstring(message); +#else + // On macOS/Linux, need to convert SQLWCHAR (usually unsigned short) to wchar_t + errorInfo.sqlState = SQLWCHARToWString(sqlState); + errorInfo.ddbcErrorMsg = SQLWCHARToWString(message, messageLen); +#endif } } return errorInfo; @@ -729,7 +810,14 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } - SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), const_cast(Query.c_str()), SQL_NTS); + SQLWCHAR* queryPtr; +#if defined(__APPLE__) + std::vector queryBuffer = WStringToSQLWCHAR(Query); + queryPtr = queryBuffer.data(); +#else + queryPtr = const_cast(Query.c_str()); +#endif + SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to execute query directly"); } @@ -761,7 +849,13 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, if (!statementHandle || !statementHandle->get()) { LOG("Statement handle is null or empty"); } - SQLWCHAR* queryPtr = const_cast(query.c_str()); + SQLWCHAR* queryPtr; +#if defined(__APPLE__) + std::vector queryBuffer = WStringToSQLWCHAR(query); + queryPtr = queryBuffer.data(); +#else + queryPtr = const_cast(query.c_str()); +#endif if (params.size() == 0) { // Execute statement directly if the statement is not parametrized. This is the // fastest way to submit a SQL statement for one-time execution according to @@ -861,7 +955,11 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta if (SQL_SUCCEEDED(retcode)) { // Append a named py::dict to ColumnMetadata // TODO: Should we define a struct for this task instead of dict? +#if defined(__APPLE__) + ColumnMetadata.append(py::dict("ColumnName"_a = SQLWCHARToWString(ColumnName, SQL_NTS), +#else ColumnMetadata.append(py::dict("ColumnName"_a = std::wstring(ColumnName), +#endif "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, "DecimalDigits"_a = DecimalDigits, "Nullable"_a = Nullable)); @@ -932,7 +1030,13 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // NOTE: dataBuffer.size() includes null-terminator, dataLen doesn't. Hence use '<'. if (numCharsInData < dataBuffer.size()) { // SQLGetData will null-terminate the data +#if defined(__APPLE__) + std::string fullStr(reinterpret_cast(dataBuffer.data())); + row.append(fullStr); + LOG("macOS: Appended CHAR string of length {} to result row", fullStr.length()); +#else row.append(std::string(reinterpret_cast(dataBuffer.data()))); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit @@ -975,7 +1079,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); if (numCharsInData < dataBuffer.size()) { // SQLGetData will null-terminate the data +#if defined(__APPLE__) + row.append(SQLWCHARToWString(dataBuffer.data(), SQL_NTS)); +#else row.append(std::wstring(dataBuffer.data())); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit @@ -1484,9 +1592,17 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (numCharsInData < fetchBufferSize) { // SQLFetch will nullterminate the data +#if defined(__APPLE__) + // Use macOS-specific conversion to handle the wchar_t/SQLWCHAR size difference + SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; + std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); + row.append(wstr); +#else + // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works row.append(std::wstring( reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), numCharsInData)); +#endif } else { // In this case, buffer size is smaller, and data to be retrieved is longer // TODO: Revisit @@ -1961,7 +2077,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("free", &SqlHandle::free, "Free the handle"); py::class_(m, "Connection") - .def(py::init(), py::arg("conn_str"), py::arg("use_pool"), py::arg("attrs_before") = py::dict()) + .def(py::init(), py::arg("conn_str"), py::arg("use_pool"), py::arg("attrs_before") = py::dict()) .def("close", &ConnectionHandle::close, "Close the connection") .def("commit", &ConnectionHandle::commit, "Commit the current transaction") .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index bb050eab8..555945747 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -7,14 +7,53 @@ #pragma once #include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions - -#include #include -#include -#include #include #include +#ifdef _WIN32 + // Windows-specific headers + #include // windows.h needs to be included before sql.h + #include + #pragma comment(lib, "shlwapi.lib") + #define IS_WINDOWS 1 +#else + #define IS_WINDOWS 0 +#endif + +#include +#include + +#if defined(__APPLE__) + // macOS-specific headers + #include + + inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { + if (!sqlwStr) return std::wstring(); + + if (length == SQL_NTS) { + size_t i = 0; + while (sqlwStr[i] != 0) ++i; + length = i; + } + + std::wstring result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + result.push_back(static_cast(sqlwStr[i])); + } + return result; + } + + inline std::vector WStringToSQLWCHAR(const std::wstring& str) { + std::vector result(str.size() + 1, 0); // +1 for null terminator + for (size_t i = 0; i < str.size(); ++i) { + result[i] = static_cast(str[i]); + } + return result; + } +#endif + #include #include #include @@ -23,6 +62,11 @@ namespace py = pybind11; using namespace pybind11::literals; +#if defined(__APPLE__) +#include "mac_utils.h" // For macOS-specific Unicode encoding fixes +#include "mac_buffers.h" // For macOS-specific buffer handling +#endif + //------------------------------------------------------------------------------------------------- // Function pointer typedefs //------------------------------------------------------------------------------------------------- @@ -125,11 +169,30 @@ void LOG(const std::string& formatString, Args&&... args); // Throws a std::runtime_error with the given message void ThrowStdException(const std::string& message); +// Define a platform-agnostic type for the driver handle +#ifdef _WIN32 +typedef HMODULE DriverHandle; +#else +typedef void* DriverHandle; +#endif + +// Platform-agnostic function to get a function pointer from the loaded library +template +T GetFunctionPointer(DriverHandle handle, const char* functionName) { +#ifdef _WIN32 + // Windows: Use GetProcAddress + return reinterpret_cast(GetProcAddress(handle, functionName)); +#else + // macOS/Unix: Use dlsym + return reinterpret_cast(dlsym(handle, functionName)); +#endif +} + //------------------------------------------------------------------------------------------------- // Loads the ODBC driver and resolves function pointers. // Throws if loading or resolution fails. //------------------------------------------------------------------------------------------------- -std::wstring LoadDriverOrThrowException(); +DriverHandle LoadDriverOrThrowException(); //------------------------------------------------------------------------------------------------- // DriverLoader (Singleton) @@ -178,4 +241,41 @@ struct ErrorInfo { }; ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode); -std::string WideToUTF8(const std::wstring& wstr); \ No newline at end of file +inline std::string WideToUTF8(const std::wstring& wstr) { + if (wstr.empty()) return {}; +#if defined(_WIN32) + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), nullptr, 0, nullptr, nullptr); + std::string result(size_needed, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.data(), static_cast(wstr.size()), result.data(), size_needed, nullptr, nullptr); + return result; +#else + std::string result; + result.reserve(wstr.size()); + for (wchar_t wc : wstr) { + if (wc < 0x80) { + result.push_back(static_cast(wc)); + } else { + result.push_back('?'); + } + } + return result; +#endif +} + +inline std::wstring Utf8ToWString(const std::string& str) { + if (str.empty()) return {}; +#if defined(_WIN32) + int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), nullptr, 0); + std::wstring result(size_needed, 0); + MultiByteToWideChar(CP_UTF8, 0, str.data(), static_cast(str.size()), result.data(), size_needed); + return result; +#else + std::wstring result; + result.reserve(str.size()); + for (char c : str) { + result.push_back(static_cast(c)); + } + return result; +#endif +} + diff --git a/mssql_python/pybind/ddbc_bindings_mac.cpp b/mssql_python/pybind/ddbc_bindings_mac.cpp deleted file mode 100644 index 6505efeea..000000000 --- a/mssql_python/pybind/ddbc_bindings_mac.cpp +++ /dev/null @@ -1,2457 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -// INFO|TODO - Note that is file is MacOS specific right now. Making it arch agnostic will be -// taken up in upcoming releases - -#include // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions - -#include -#include // std::setw, std::setfill -#include -#include -#include // std::forward - -// Platform-specific headers -#include -#include - -#ifdef _WIN32 - // Windows-specific headers - #include // windows.h needs to be included before sql.h - #include - #pragma comment(lib, "shlwapi.lib") - #define IS_WINDOWS 1 -#elif defined(__APPLE__) - // macOS-specific headers - #include - #include - #define IS_WINDOWS 0 - - // String conversion helpers for macOS - wchar_t is 4 bytes on macOS, but SQLWCHAR is 2 bytes - inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) { - if (!sqlwStr) return std::wstring(); - - if (length == SQL_NTS) { - // Determine length if not provided - size_t i = 0; - while (sqlwStr[i] != 0) ++i; - length = i; - } - - std::wstring result; - result.reserve(length); - for (size_t i = 0; i < length; ++i) { - result.push_back(static_cast(sqlwStr[i])); - } - return result; - } - - inline std::vector WStringToSQLWCHAR(const std::wstring& str) { - std::vector result(str.size() + 1, 0); // +1 for null terminator - for (size_t i = 0; i < str.size(); ++i) { - result[i] = static_cast(str[i]); - } - return result; - } -#else - // Other platforms - #include - #include - #define IS_WINDOWS 0 -#endif - -#include -#include -#include -#include // Add this line for datetime support -#include -#include -#include - -#if defined(__APPLE__) -#include "mac_fix.h" // For macOS-specific Unicode encoding fixes -#include "mac_buffers.h" // For macOS-specific buffer handling -#endif - - -namespace py = pybind11; -using namespace pybind11::literals; - -//------------------------------------------------------------------------------------------------- -// Macro definitions -//------------------------------------------------------------------------------------------------- - -// This constant is not exposed via sql.h, hence define it here -#define SQL_SS_TIME2 (-154) - -#define MAX_DIGITS_IN_NUMERIC 64 - -#define STRINGIFY_FOR_CASE(x) \ - case x: \ - return #x - -// Architecture-specific defines -#ifndef ARCHITECTURE -#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation -#endif - -//------------------------------------------------------------------------------------------------- -// Class definitions -//------------------------------------------------------------------------------------------------- - -// Struct to hold parameter information for binding. Used by SQLBindParameter. -// This struct is shared between C++ & Python code. -struct ParamInfo { - SQLSMALLINT inputOutputType; - SQLSMALLINT paramCType; - SQLSMALLINT paramSQLType; - SQLULEN columnSize; - SQLSMALLINT decimalDigits; - // TODO: Reuse python buffer for large data using Python buffer protocol - // Stores pointer to the python object that holds parameter value - // py::object* dataPtr; -}; - -// Mirrors the SQL_NUMERIC_STRUCT. But redefined to replace val char array -// with std::string, because pybind doesn't allow binding char array. -// This struct is shared between C++ & Python code. -struct NumericData { - SQLCHAR precision; - SQLSCHAR scale; - SQLCHAR sign; // 1=pos, 0=neg - std::uint64_t val; // 123.45 -> 12345 - - NumericData() : precision(0), scale(0), sign(0), val(0) {} - - NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value) - : precision(precision), scale(scale), sign(sign), val(value) {} -}; - -// Struct to hold data buffers and indicators for each column -struct ColumnBuffers { - std::vector> charBuffers; - std::vector> wcharBuffers; - std::vector> intBuffers; - std::vector> smallIntBuffers; - std::vector> realBuffers; - std::vector> doubleBuffers; - std::vector> timestampBuffers; - std::vector> bigIntBuffers; - std::vector> dateBuffers; - std::vector> timeBuffers; - std::vector> guidBuffers; - std::vector> indicators; - - ColumnBuffers(SQLSMALLINT numCols, int fetchSize) - : charBuffers(numCols), - wcharBuffers(numCols), - intBuffers(numCols), - smallIntBuffers(numCols), - realBuffers(numCols), - doubleBuffers(numCols), - timestampBuffers(numCols), - bigIntBuffers(numCols), - dateBuffers(numCols), - timeBuffers(numCols), - guidBuffers(numCols), - indicators(numCols, std::vector(fetchSize)) {} -}; - -// This struct is used to relay error info obtained from SQLDiagRec API to the Python module -struct ErrorInfo { - std::wstring sqlState; - std::wstring ddbcErrorMsg; -}; - - -//------------------------------------------------------------------------------------------------- -// Function pointer typedefs -//------------------------------------------------------------------------------------------------- - -// Handle APIs -typedef SQLRETURN (*SQLAllocHandleFunc)(SQLSMALLINT, SQLHANDLE, SQLHANDLE*); -typedef SQLRETURN (*SQLSetEnvAttrFunc)(SQLHANDLE, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLSetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLSetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLGetConnectAttrFunc)(SQLHDBC, SQLINTEGER, SQLPOINTER, SQLINTEGER, - SQLINTEGER*); - -// Connection and Execution APIs -typedef SQLRETURN (*SQLDriverConnectFunc)(SQLHANDLE, SQLHWND, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, - SQLSMALLINT, SQLSMALLINT*, SQLUSMALLINT); -typedef SQLRETURN (*SQLExecDirectFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (*SQLPrepareFunc)(SQLHANDLE, SQLWCHAR*, SQLINTEGER); -typedef SQLRETURN (*SQLBindParameterFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLSMALLINT, - SQLSMALLINT, SQLULEN, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLExecuteFunc)(SQLHANDLE); -typedef SQLRETURN (*SQLRowCountFunc)(SQLHSTMT, SQLLEN*); -typedef SQLRETURN (*SQLSetDescFieldFunc)(SQLHDESC, SQLSMALLINT, SQLSMALLINT, SQLPOINTER, SQLINTEGER); -typedef SQLRETURN (*SQLGetStmtAttrFunc)(SQLHSTMT, SQLINTEGER, SQLPOINTER, SQLINTEGER, SQLINTEGER*); - -// Data retrieval APIs -typedef SQLRETURN (*SQLFetchFunc)(SQLHANDLE); -typedef SQLRETURN (*SQLFetchScrollFunc)(SQLHANDLE, SQLSMALLINT, SQLLEN); -typedef SQLRETURN (*SQLGetDataFunc)(SQLHANDLE, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLNumResultColsFunc)(SQLHSTMT, SQLSMALLINT*); -typedef SQLRETURN (*SQLBindColFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT, SQLPOINTER, SQLLEN, - SQLLEN*); -typedef SQLRETURN (*SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, - SQLSMALLINT*, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, - SQLSMALLINT*); -typedef SQLRETURN (*SQLMoreResultsFunc)(SQLHSTMT); -typedef SQLRETURN (*SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, - SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); - -// Transaction APIs -typedef SQLRETURN (*SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); - -// Disconnect/free APIs -typedef SQLRETURN (*SQLFreeHandleFunc)(SQLSMALLINT, SQLHANDLE); -typedef SQLRETURN (*SQLDisconnectFunc)(SQLHDBC); -typedef SQLRETURN (*SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT); -typedef SQLRETURN (*SQLCloseCursorFunc)(SQLHSTMT); -typedef SQLRETURN (*SQLCancelFunc)(SQLHSTMT); - -// Catalog functions -typedef SQLRETURN (*SQLTablesFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (*SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (*SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (*SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (*SQLTablePrivilegesFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); -typedef SQLRETURN (*SQLColumnPrivilegesFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); - -// Diagnostic APIs -typedef SQLRETURN (*SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*, - SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*); - -//------------------------------------------------------------------------------------------------- -// Function pointer initialization -//------------------------------------------------------------------------------------------------- - -// Handle APIs -SQLAllocHandleFunc SQLAllocHandle_ptr = nullptr; -SQLSetEnvAttrFunc SQLSetEnvAttr_ptr = nullptr; -SQLSetConnectAttrFunc SQLSetConnectAttr_ptr = nullptr; -SQLSetStmtAttrFunc SQLSetStmtAttr_ptr = nullptr; -SQLGetConnectAttrFunc SQLGetConnectAttr_ptr = nullptr; - -// Connection and Execution APIs -SQLDriverConnectFunc SQLDriverConnect_ptr = nullptr; -SQLExecDirectFunc SQLExecDirect_ptr = nullptr; -SQLPrepareFunc SQLPrepare_ptr = nullptr; -SQLBindParameterFunc SQLBindParameter_ptr = nullptr; -SQLExecuteFunc SQLExecute_ptr = nullptr; -SQLRowCountFunc SQLRowCount_ptr = nullptr; -SQLGetStmtAttrFunc SQLGetStmtAttr_ptr = nullptr; -SQLSetDescFieldFunc SQLSetDescField_ptr = nullptr; - -// Data retrieval APIs -SQLFetchFunc SQLFetch_ptr = nullptr; -SQLFetchScrollFunc SQLFetchScroll_ptr = nullptr; -SQLGetDataFunc SQLGetData_ptr = nullptr; -SQLNumResultColsFunc SQLNumResultCols_ptr = nullptr; -SQLBindColFunc SQLBindCol_ptr = nullptr; -SQLDescribeColFunc SQLDescribeCol_ptr = nullptr; -SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; -SQLColAttributeFunc SQLColAttribute_ptr = nullptr; - -// Transaction APIs -SQLEndTranFunc SQLEndTran_ptr = nullptr; - -// Disconnect/free APIs -SQLFreeHandleFunc SQLFreeHandle_ptr = nullptr; -SQLDisconnectFunc SQLDisconnect_ptr = nullptr; -SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; -SQLCloseCursorFunc SQLCloseCursor_ptr = nullptr; -SQLCancelFunc SQLCancel_ptr = nullptr; - -// Catalog functions -SQLTablesFunc SQLTables_ptr = nullptr; -SQLColumnsFunc SQLColumns_ptr = nullptr; -SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; -SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; -SQLTablePrivilegesFunc SQLTablePrivileges_ptr = nullptr; -SQLColumnPrivilegesFunc SQLColumnPrivileges_ptr = nullptr; - -// Diagnostic APIs -SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; - -// Move GetModuleDirectory outside namespace to resolve ambiguity -std::string GetModuleDirectory() { - py::object module = py::module::import("mssql_python"); - py::object module_path = module.attr("__file__"); - std::string module_file = module_path.cast(); - -#ifdef _WIN32 - // Windows-specific path handling - char path[MAX_PATH]; - strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); - PathRemoveFileSpecA(path); - return std::string(path); -#else - // macOS/Unix path handling without using std::filesystem - std::string::size_type pos = module_file.find_last_of('/'); - if (pos != std::string::npos) { - std::string dir = module_file.substr(0, pos); - return dir; - } - std::cerr << "DEBUG: Could not extract directory from path: " << module_file << std::endl; - return module_file; -#endif -} - -namespace { - -// TODO: Revisit GIL considerations if we're using python's logger -template -void LOG(const std::string& formatString, Args&&... args) { - // Get the logger each time instead of caching it to ensure we get the latest state - py::object logging_module = py::module_::import("mssql_python.logging_config"); - py::object logger = logging_module.attr("get_logger")(); - - // If logger is None, don't try to log - if (py::isinstance(logger)) { - return; - } - - // Format the message and log it - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); - logger.attr("debug")(message); -} - -// Define a platform-agnostic type for the driver handle -#ifdef _WIN32 -typedef HMODULE DriverHandle; -#else -typedef void* DriverHandle; -#endif - -// Platform-agnostic function to load the driver dynamic library -DriverHandle LoadDriverLibrary(const std::string& driverPath) { - LOG("Loading driver from path: {}", driverPath); - -#ifdef _WIN32 - // Windows: Convert string to wide string for LoadLibraryW - std::wstring widePath(driverPath.begin(), driverPath.end()); - return LoadLibraryW(widePath.c_str()); -#else - // macOS/Unix: Use dlopen - return dlopen(driverPath.c_str(), RTLD_LAZY); -#endif -} - -// Platform-agnostic function to get last error message -std::string GetLastErrorMessage() { -#ifdef _WIN32 - // Windows: Use FormatMessageA - DWORD error = GetLastError(); - char* messageBuffer = nullptr; - size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, - 0, - NULL - ); - std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; - LocalFree(messageBuffer); - return "Error code: " + std::to_string(error) + " - " + errorMessage; -#else - // macOS/Unix: Use dlerror - const char* error = dlerror(); - return error ? std::string(error) : "Unknown error"; -#endif -} - -// Platform-agnostic function to get a function pointer from the loaded library -template -T GetFunctionPointer(DriverHandle handle, const char* functionName) { -#ifdef _WIN32 - // Windows: Use GetProcAddress - return reinterpret_cast(GetProcAddress(handle, functionName)); -#else - // macOS/Unix: Use dlsym - return reinterpret_cast(dlsym(handle, functionName)); -#endif -} - -// TODO: Add more nuanced exception classes -void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } - -// Cross-platform helper to load the ODBC driver and get function pointers -DriverHandle LoadDriverOrThrowException() { - // Get the module directory path - std::string moduleDir = GetModuleDirectory(); - LOG("Module directory: {}", moduleDir); - - // Build path to the architecture-specific driver library - std::string archStr(ARCHITECTURE); - std::string archDir; - - // Map architecture identifiers to correct subdirectory names - if (archStr == "win64" || archStr == "amd64" || archStr == "x64") { - archDir = "x64"; - } else if (archStr == "arm64") { - archDir = "arm64"; - } else { - archDir = "x86"; - } - - std::string driverPath; -#ifdef _WIN32 - // Windows: Build path to msodbcsql18.dll - driverPath = moduleDir + "\\libs\\" + archDir + "\\msodbcsql18.dll"; -#elif defined(__APPLE__) - // macOS: Build path to libmsodbcsql.18.dylib - first check in lib subdirectory - // We are supporting both Intel and Apple Silicon architectures, so we need to check the architecture - std::string runtimeArch = - #ifdef __arm64__ - "arm64"; - #else - "x86_64"; - #endif - std::string macosDriverPath = moduleDir + "/libs/macos/" + runtimeArch + "/lib/libmsodbcsql.18.dylib"; - - // Check if file exists using traditional C file functions instead of std::filesystem - FILE* file = fopen(macosDriverPath.c_str(), "r"); - if (file) { - fclose(file); - driverPath = macosDriverPath; - LOG("Found macOS driver in lib subdirectory: {}", driverPath); - } else { - // Fallback to the older path structure - driverPath = moduleDir + "/libs/" + archDir + "/macos/libmsodbcsql.18.dylib"; - LOG("Using fallback path for macOS driver: {}", driverPath); - } -#else - // Linux: Build path to libmsodbcsql.so.18.0 - driverPath = moduleDir + "/libs/" + archDir + "/libmsodbcsql.so.18.0"; -#endif - - LOG("Attempting to load driver from - {}", driverPath); - - // Load the driver library - DriverHandle driverHandle = LoadDriverLibrary(driverPath); - if (!driverHandle) { - std::string errorMessage = GetLastErrorMessage(); - LOG("Failed to load the driver: {}", errorMessage); - ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly."); - } - - LOG("Successfully loaded the ODBC driver"); - - // Load all the required function pointers - // Environment and handle function loading - SQLAllocHandle_ptr = GetFunctionPointer(driverHandle, "SQLAllocHandle"); - SQLSetEnvAttr_ptr = GetFunctionPointer(driverHandle, "SQLSetEnvAttr"); - SQLSetConnectAttr_ptr = GetFunctionPointer(driverHandle, "SQLSetConnectAttrW"); - SQLSetStmtAttr_ptr = GetFunctionPointer(driverHandle, "SQLSetStmtAttrW"); - SQLGetConnectAttr_ptr = GetFunctionPointer(driverHandle, "SQLGetConnectAttrW"); - - // Connection and statement function loading - SQLDriverConnect_ptr = GetFunctionPointer(driverHandle, "SQLDriverConnectW"); - SQLExecDirect_ptr = GetFunctionPointer(driverHandle, "SQLExecDirectW"); - SQLPrepare_ptr = GetFunctionPointer(driverHandle, "SQLPrepareW"); - SQLBindParameter_ptr = GetFunctionPointer(driverHandle, "SQLBindParameter"); - SQLExecute_ptr = GetFunctionPointer(driverHandle, "SQLExecute"); - SQLRowCount_ptr = GetFunctionPointer(driverHandle, "SQLRowCount"); - SQLGetStmtAttr_ptr = GetFunctionPointer(driverHandle, "SQLGetStmtAttrW"); - SQLSetDescField_ptr = GetFunctionPointer(driverHandle, "SQLSetDescFieldW"); // Fetch and data retrieval function loading - SQLFetch_ptr = GetFunctionPointer(driverHandle, "SQLFetch"); - SQLFetchScroll_ptr = GetFunctionPointer(driverHandle, "SQLFetchScroll"); - SQLGetData_ptr = GetFunctionPointer(driverHandle, "SQLGetData"); - SQLNumResultCols_ptr = GetFunctionPointer(driverHandle, "SQLNumResultCols"); - SQLBindCol_ptr = GetFunctionPointer(driverHandle, "SQLBindCol"); - SQLDescribeCol_ptr = GetFunctionPointer(driverHandle, "SQLDescribeColW"); - SQLMoreResults_ptr = GetFunctionPointer(driverHandle, "SQLMoreResults"); - SQLColAttribute_ptr = GetFunctionPointer(driverHandle, "SQLColAttributeW"); - - // Transaction functions loading - SQLEndTran_ptr = GetFunctionPointer(driverHandle, "SQLEndTran"); - - // Disconnect and free functions loading - SQLDisconnect_ptr = GetFunctionPointer(driverHandle, "SQLDisconnect"); - SQLFreeHandle_ptr = GetFunctionPointer(driverHandle, "SQLFreeHandle"); - SQLFreeStmt_ptr = GetFunctionPointer(driverHandle, "SQLFreeStmt"); - SQLCloseCursor_ptr = GetFunctionPointer(driverHandle, "SQLCloseCursor"); - SQLCancel_ptr = GetFunctionPointer(driverHandle, "SQLCancel"); - - // Catalog functions loading - SQLTables_ptr = GetFunctionPointer(driverHandle, "SQLTablesW"); - SQLColumns_ptr = GetFunctionPointer(driverHandle, "SQLColumnsW"); - SQLPrimaryKeys_ptr = GetFunctionPointer(driverHandle, "SQLPrimaryKeysW"); - SQLForeignKeys_ptr = GetFunctionPointer(driverHandle, "SQLForeignKeysW"); - SQLTablePrivileges_ptr = GetFunctionPointer(driverHandle, "SQLTablePrivilegesW"); - SQLColumnPrivileges_ptr = GetFunctionPointer(driverHandle, "SQLColumnPrivilegesW"); - - // Diagnostic record function Loading - SQLGetDiagRec_ptr = GetFunctionPointer(driverHandle, "SQLGetDiagRecW"); - - bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && - SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr && - SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && - SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr && - SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr && - SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && - SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr && - SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; - - if (!success) { - LOG("Failed to load required function pointers from driver"); - ThrowStdException("Failed to load required function pointers from driver"); - } - LOG("Successfully loaded function pointers from driver"); - - return driverHandle; -} - -// This section was removed because these functions are now properly declared -// in the main function pointer typedefs section above. - -const char* GetSqlCTypeAsString(const SQLSMALLINT cType) { - switch (cType) { - STRINGIFY_FOR_CASE(SQL_C_CHAR); - STRINGIFY_FOR_CASE(SQL_C_WCHAR); - STRINGIFY_FOR_CASE(SQL_C_SSHORT); - STRINGIFY_FOR_CASE(SQL_C_USHORT); - STRINGIFY_FOR_CASE(SQL_C_SHORT); - STRINGIFY_FOR_CASE(SQL_C_SLONG); - STRINGIFY_FOR_CASE(SQL_C_ULONG); - STRINGIFY_FOR_CASE(SQL_C_LONG); - STRINGIFY_FOR_CASE(SQL_C_STINYINT); - STRINGIFY_FOR_CASE(SQL_C_UTINYINT); - STRINGIFY_FOR_CASE(SQL_C_TINYINT); - STRINGIFY_FOR_CASE(SQL_C_SBIGINT); - STRINGIFY_FOR_CASE(SQL_C_UBIGINT); - STRINGIFY_FOR_CASE(SQL_C_FLOAT); - STRINGIFY_FOR_CASE(SQL_C_DOUBLE); - STRINGIFY_FOR_CASE(SQL_C_BIT); - STRINGIFY_FOR_CASE(SQL_C_BINARY); - STRINGIFY_FOR_CASE(SQL_C_TYPE_DATE); - STRINGIFY_FOR_CASE(SQL_C_TYPE_TIME); - STRINGIFY_FOR_CASE(SQL_C_TYPE_TIMESTAMP); - STRINGIFY_FOR_CASE(SQL_C_NUMERIC); - STRINGIFY_FOR_CASE(SQL_C_GUID); - STRINGIFY_FOR_CASE(SQL_C_DEFAULT); - default: - return "Unknown"; - } -} - -std::string MakeParamMismatchErrorStr(const SQLSMALLINT cType, const int paramIndex) { - std::string errorString = - "Parameter's object type does not match parameter's C type. paramIndex - " + - std::to_string(paramIndex) + ", C type - " + GetSqlCTypeAsString(cType); - return errorString; -} - -// This function allocates a buffer of ParamType, stores it as a void* in paramBuffers for -// book-keeping and then returns a ParamType* to the allocated memory. -// ctorArgs are the arguments to ParamType's constructor used while creating/allocating ParamType -template -ParamType* AllocateParamBuffer(std::vector>& paramBuffers, - CtorArgs&&... ctorArgs) { - paramBuffers.emplace_back(new ParamType(std::forward(ctorArgs)...), - std::default_delete()); - return static_cast(paramBuffers.back().get()); -} - -// Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with -// appropriate arguments -SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, - const std::vector& paramInfos, - std::vector>& paramBuffers) { - LOG("Starting parameter binding. Number of parameters: {}", params.size()); - for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) { - const auto& param = params[paramIndex]; - const ParamInfo& paramInfo = paramInfos[paramIndex]; - LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType); - void* dataPtr = nullptr; - SQLLEN bufferLength = 0; - SQLLEN* strLenOrIndPtr = nullptr; - // TODO: Add more data types like money, guid, interval, TVPs etc. - switch (paramInfo.paramCType) { - case SQL_C_CHAR: - case SQL_C_BINARY: { - if (!py::isinstance(param) && !py::isinstance(param) && - !py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - std::string* strParam = - AllocateParamBuffer(paramBuffers, param.cast()); - if (strParam->size() > 8192 /* TODO: Fix max length */) { - ThrowStdException( - "Streaming parameters is not yet supported. Parameter size" - " must be less than 8192 bytes"); - } - dataPtr = const_cast(static_cast(strParam->c_str())); - bufferLength = strParam->size() + 1 /* null terminator */; - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = SQL_NTS; - break; - } - case SQL_C_WCHAR: { - if (!py::isinstance(param) && !py::isinstance(param) && - !py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - std::wstring* strParam = - AllocateParamBuffer(paramBuffers, param.cast()); - if (strParam->size() > 4096 /* TODO: Fix max length */) { - ThrowStdException( - "Streaming parameters is not yet supported. Parameter size" - " must be less than 8192 bytes"); - } - - // Log detailed parameter information - LOG("SQL_C_WCHAR Parameter[{}]: Length={}, Content='{}'", - paramIndex, strParam->size(), - strParam->size() <= 100 ? std::string(strParam->begin(), strParam->end()) : - std::string(strParam->begin(), strParam->begin() + 100) + "..."); - - // Log each character's code point for debugging - if (strParam->size() <= 20) { - for (size_t i = 0; i < strParam->size(); i++) { - LOG(" char[{}] = {} ({})", i, static_cast((*strParam)[i]), - ((*strParam)[i] >= 32 && (*strParam)[i] <= 126) ? - static_cast((*strParam)[i]) : '?'); - } - } - -#if defined(__APPLE__) - // On macOS, we need special handling for wide characters - // Create a properly encoded SQLWCHAR buffer for the parameter - std::vector* sqlwcharBuffer = - AllocateParamBuffer>(paramBuffers); - - // Reserve space and convert from wstring to SQLWCHAR array - sqlwcharBuffer->resize(strParam->size() + 1, 0); // +1 for null terminator - - // Convert each wchar_t (4 bytes on macOS) to SQLWCHAR (2 bytes) - for (size_t i = 0; i < strParam->size(); i++) { - (*sqlwcharBuffer)[i] = static_cast((*strParam)[i]); - } - - // Use the SQLWCHAR buffer instead of the wstring directly - dataPtr = sqlwcharBuffer->data(); - bufferLength = (strParam->size() + 1) * sizeof(SQLWCHAR); - LOG("macOS: Created SQLWCHAR buffer for parameter with size: {} bytes", bufferLength); -#else - // On Windows, wchar_t and SQLWCHAR are the same size, so direct cast works - dataPtr = const_cast(static_cast(strParam->c_str())); - bufferLength = (strParam->size() + 1 /* null terminator */) * sizeof(wchar_t); -#endif - - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = SQL_NTS; - break; - } - case SQL_C_BIT: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = - static_cast(AllocateParamBuffer(paramBuffers, param.cast())); - break; - } - case SQL_C_DEFAULT: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - // TODO: This wont work for None values added to BINARY/VARBINARY columns. None values - // of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY - dataPtr = nullptr; - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = SQL_NULL_DATA; - break; - } case SQL_C_STINYINT: - case SQL_C_TINYINT: - case SQL_C_SSHORT: - case SQL_C_SHORT: { - LOG("BINDING INTEGER PARAMETER"); - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = - static_cast(AllocateParamBuffer(paramBuffers, param.cast())); - LOG("INITIALIZED INTEGER PARAMETER"); - // Set the strLenOrIndPtr for integer parameters - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = sizeof(int); - LOG("INITIALIZED INTEGER PARAMETER LENGTH"); - break; - } - case SQL_C_UTINYINT: - case SQL_C_USHORT: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); - // Set the strLenOrIndPtr for unsigned integer parameters - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = sizeof(unsigned int); - break; - } case SQL_C_SBIGINT: - case SQL_C_SLONG: - case SQL_C_LONG: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); - // Set the strLenOrIndPtr for long integer parameters - strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - *strLenOrIndPtr = sizeof(int64_t); - break; - } - case SQL_C_UBIGINT: - case SQL_C_ULONG: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); - break; - } - case SQL_C_FLOAT: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); - break; - } - case SQL_C_DOUBLE: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); - break; - } - case SQL_C_TYPE_DATE: { - py::object dateType = py::module_::import("datetime").attr("date"); - if (!py::isinstance(param, dateType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - // TODO: can be moved to python by registering SQL_DATE_STRUCT in pybind - SQL_DATE_STRUCT* sqlDatePtr = AllocateParamBuffer(paramBuffers); - sqlDatePtr->year = param.attr("year").cast(); - sqlDatePtr->month = param.attr("month").cast(); - sqlDatePtr->day = param.attr("day").cast(); - dataPtr = static_cast(sqlDatePtr); - break; - } - case SQL_C_TYPE_TIME: { - py::object timeType = py::module_::import("datetime").attr("time"); - if (!py::isinstance(param, timeType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - // TODO: can be moved to python by registering SQL_TIME_STRUCT in pybind - SQL_TIME_STRUCT* sqlTimePtr = AllocateParamBuffer(paramBuffers); - sqlTimePtr->hour = param.attr("hour").cast(); - sqlTimePtr->minute = param.attr("minute").cast(); - sqlTimePtr->second = param.attr("second").cast(); - dataPtr = static_cast(sqlTimePtr); - break; - } - case SQL_C_TYPE_TIMESTAMP: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); - if (!py::isinstance(param, datetimeType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - SQL_TIMESTAMP_STRUCT* sqlTimestampPtr = - AllocateParamBuffer(paramBuffers); - sqlTimestampPtr->year = param.attr("year").cast(); - sqlTimestampPtr->month = param.attr("month").cast(); - sqlTimestampPtr->day = param.attr("day").cast(); - sqlTimestampPtr->hour = param.attr("hour").cast(); - sqlTimestampPtr->minute = param.attr("minute").cast(); - sqlTimestampPtr->second = param.attr("second").cast(); - // SQL server supports in ns, but python datetime supports in µs - sqlTimestampPtr->fraction = static_cast( - param.attr("microsecond").cast() * 1000); // Convert µs to ns - dataPtr = static_cast(sqlTimestampPtr); - break; - } - case SQL_C_NUMERIC: { - if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); - } - NumericData decimalParam = param.cast(); - LOG("Received numeric parameter: precision - {}, scale- {}, sign - {}, value - {}", - decimalParam.precision, decimalParam.scale, decimalParam.sign, - decimalParam.val); - SQL_NUMERIC_STRUCT* decimalPtr = - AllocateParamBuffer(paramBuffers); - decimalPtr->precision = decimalParam.precision; - decimalPtr->scale = decimalParam.scale; - decimalPtr->sign = decimalParam.sign; - // Convert the integer decimalParam.val to char array - std:memset(static_cast(decimalPtr->val), 0, sizeof(decimalPtr->val)); - std::memcpy(static_cast(decimalPtr->val), - reinterpret_cast(&decimalParam.val), - sizeof(decimalParam.val)); - dataPtr = static_cast(decimalPtr); - // TODO: Remove these lines - //strLenOrIndPtr = AllocateParamBuffer(paramBuffers); - //*strLenOrIndPtr = sizeof(SQL_NUMERIC_STRUCT); - break; - } - case SQL_C_GUID: { - // TODO - } - default: { - std::ostringstream errorString; - errorString << "Unsupported parameter type - " << paramInfo.paramCType - << " for parameter - " << paramIndex; - ThrowStdException(errorString.str()); - } - } - assert(SQLBindParameter_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr); - - LOG("BINDING PARAM!!"); - RETCODE rc = SQLBindParameter_ptr( - hStmt, paramIndex + 1 /* 1-based indexing */, paramInfo.inputOutputType, - paramInfo.paramCType, paramInfo.paramSQLType, paramInfo.columnSize, - paramInfo.decimalDigits, dataPtr, bufferLength, strLenOrIndPtr); - LOG("RETCODE - {}", rc); - if (!SQL_SUCCEEDED(rc)) { - LOG("Error when binding parameter - {}", paramIndex); - return rc; - } - // Special handling for Numeric type - - // https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/retrieve-numeric-data-sql-numeric-struct-kb222831?view=sql-server-ver16#sql_c_numeric-overview - if (paramInfo.paramCType == SQL_C_NUMERIC) { - SQLHDESC hDesc = nullptr; - RETCODE rc = SQLGetStmtAttr_ptr(hStmt, SQL_ATTR_APP_PARAM_DESC, &hDesc, 0, NULL); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when getting statement attribute - {}", paramIndex); - return rc; - } - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_TYPE, (SQLPOINTER) SQL_C_NUMERIC, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_TYPE - {}", paramIndex); - return rc; - } - SQL_NUMERIC_STRUCT* numericPtr = reinterpret_cast(dataPtr); - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_PRECISION, - (SQLPOINTER) numericPtr->precision, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_PRECISION - {}", paramIndex); - return rc; - } - - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_SCALE, - (SQLPOINTER) numericPtr->scale, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_SCALE - {}", paramIndex); - return rc; - } - - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_DATA_PTR, (SQLPOINTER) numericPtr, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_DATA_PTR - {}", paramIndex); - return rc; - } - } - } - LOG("Finished parameter binding. Number of parameters: {}", params.size()); - return SQL_SUCCESS; -} - - -// This is temporary hack to avoid crash when SQLDescribeCol returns 0 as columnSize -// for NVARCHAR(MAX) & similar types. Variable length data needs more nuanced handling. -// TODO: Fix this in beta -// This function sets the buffer allocated to fetch NVARCHAR(MAX) & similar types to -// 4096 chars. So we'll retrieve data upto 4096. Anything greater then that will throw -// error -void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { - if (columnSize == 0) { - columnSize = 4096; - } -} - -} // namespace - -// Wrap SQLAllocHandle -SQLRETURN SQLAllocHandle_wrap(SQLSMALLINT HandleType, intptr_t InputHandle, intptr_t OutputHandle) { - LOG("Allocate SQL Handle"); - if (!SQLAllocHandle_ptr) { - LoadDriverOrThrowException(); - } - - SQLHANDLE* pOutputHandle = reinterpret_cast(OutputHandle); - return SQLAllocHandle_ptr(HandleType, reinterpret_cast(InputHandle), pOutputHandle); -} - -// Wrap SQLSetEnvAttr -SQLRETURN SQLSetEnvAttr_wrap(intptr_t EnvHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL environment Attribute"); - if (!SQLSetEnvAttr_ptr) { - LoadDriverOrThrowException(); - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - return SQLSetEnvAttr_ptr(reinterpret_cast(EnvHandle), Attribute, - reinterpret_cast(ValuePtr), StringLength); -} - -// Wrap SQLSetConnectAttr -SQLRETURN SQLSetConnectAttr_wrap(intptr_t ConnectionHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL Connection Attribute"); - if (!SQLSetConnectAttr_ptr) { - LoadDriverOrThrowException(); - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - return SQLSetConnectAttr_ptr(reinterpret_cast(ConnectionHandle), Attribute, - reinterpret_cast(ValuePtr), StringLength); -} - -// Wrap SQLSetStmtAttr -SQLRETURN SQLSetStmtAttr_wrap(intptr_t ConnectionHandle, SQLINTEGER Attribute, intptr_t ValuePtr, - SQLINTEGER StringLength) { - LOG("Set SQL Statement Attribute"); - if (!SQLSetConnectAttr_ptr) { - LoadDriverOrThrowException(); - } - - // TODO: Does ValuePtr need to be converted from Python to C++ object? - return SQLSetStmtAttr_ptr(reinterpret_cast(ConnectionHandle), Attribute, - reinterpret_cast(ValuePtr), StringLength); -} - -// Wrap SQLGetConnectionAttrA -// Currently only supports retrieval of int-valued attributes -// TODO: add support to retrieve all types of attributes -SQLINTEGER SQLGetConnectionAttr_wrap(intptr_t ConnectionHandle, SQLINTEGER attribute) { - LOG("Get SQL COnnection Attribute"); - if (!SQLGetConnectAttr_ptr) { - LoadDriverOrThrowException(); - } - - SQLINTEGER stringLength; - SQLINTEGER intValue; - - // Try to get the attribute as an integer - SQLGetConnectAttr_ptr(reinterpret_cast(ConnectionHandle), attribute, &intValue, - sizeof(SQLINTEGER), &stringLength); - return intValue; -} - -// Helper function to check for driver errors -ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, intptr_t handle, SQLRETURN retcode) { - LOG("Checking errors for retcode - {}" , retcode); - ErrorInfo errorInfo; - - // Initialize default error message in case we can't get specific diagnostics - errorInfo.sqlState = L"HY000"; // General error - - // Handle null handle case - if (handle == 0) { - LOG("Null handle received"); - errorInfo.ddbcErrorMsg = L"Null handle!"; - return errorInfo; - } - // Only try to get diagnostic info if we got an error and have a valid handle - if (!SQL_SUCCEEDED(retcode)) { - LOG("Error occurred!!!!!, retcode - {}", retcode); - if (!SQLGetDiagRec_ptr) { - try { - LoadDriverOrThrowException(); - } catch (const std::exception& e) { - LOG("Failed to load driver for error check: {}", e.what()); - errorInfo.ddbcErrorMsg = L"Driver error: Could not load ODBC driver"; - return errorInfo; - } - } - - try { - LOG("Getting diagnostic information for handle type - {}", handleType); -#if defined(__APPLE__) - // macOS: Use enhanced error collection with mac_buffers - if (!SQLGetDiagRec_ptr) { - try { - LoadDriverOrThrowException(); - } catch (const std::exception& e) { - LOG("Failed to load driver for error check: {}", e.what()); - errorInfo.ddbcErrorMsg = L"Driver error: Could not load ODBC driver"; - return errorInfo; - } - } - - // Use our enhanced DiagnosticRecords to collect all errors - mac_buffers::DiagnosticRecords diagnostics; - int recordNum = 1; - - while (true) { - // Create buffers for error information - mac_buffers::SQLWCHARBuffer sqlState(6); - mac_buffers::SQLWCHARBuffer message(SQL_MAX_MESSAGE_LENGTH); - SQLINTEGER nativeError = 0; - SQLSMALLINT messageLen = 0; - - // Get the diagnostic record (similar to Python PoC _check_ret) - SQLRETURN diagReturn = SQLGetDiagRec_ptr( - handleType, - reinterpret_cast(handle), - recordNum, - sqlState.data(), - &nativeError, - message.data(), - message.size(), - &messageLen - ); - - LOG("Diagnostic record {} retrieval result: {}", recordNum, diagReturn); - - if (diagReturn == SQL_NO_DATA) { - // No more diagnostic records - similar to Python PoC - if (!diagnostics.empty()) { - // We have collected errors, return them - errorInfo.sqlState = diagnostics.getSQLState(); - errorInfo.ddbcErrorMsg = diagnostics.getFullErrorMessage(); - return errorInfo; - } - break; - } else if (diagReturn == SQL_INVALID_HANDLE) { - errorInfo.ddbcErrorMsg = L"SQL_INVALID_HANDLE"; - return errorInfo; - } else if (diagReturn == SQL_ERROR) { - errorInfo.ddbcErrorMsg = L"SQL_ERROR"; - return errorInfo; - } else if (SQL_SUCCEEDED(diagReturn)) { - // Successfully retrieved a diagnostic record - std::wstring stateStr = sqlState.toString(); - std::wstring msgStr = message.toString(messageLen); - - // Add to our collection of errors (similar to err_list.append in Python PoC) - diagnostics.addRecord(stateStr, msgStr, nativeError); - - // Continue to the next record - recordNum++; - } else { - break; - } - } - - // Process collected errors - if (!diagnostics.empty()) { - // Use the DiagnosticRecords utility to format error messages - errorInfo.sqlState = diagnostics.getSQLState(); - errorInfo.ddbcErrorMsg = diagnostics.getFullErrorMessage(); - } else { - // No diagnostic records found - errorInfo.ddbcErrorMsg = L"No diagnostic information available"; - } -#else - // Windows/other platforms: Use the first diagnostic record only - SQLRETURN diagReturn = - SQLGetDiagRec_ptr(handleType, reinterpret_cast(handle), 1, sqlState, - &nativeError, message, SQL_MAX_MESSAGE_LENGTH, &messageLen); - LOG("Diagnostic information retrieved, diagReturn - {}", diagReturn); - if (SQL_SUCCEEDED(diagReturn)) { - // Process diagnostic information - errorInfo.sqlState = std::wstring(sqlState); - errorInfo.ddbcErrorMsg = std::wstring(message); - } else { - // Failed to get diagnostic info, provide a generic error - LOG("Failed to get diagnostic information, diagReturn = {}", diagReturn); - errorInfo.ddbcErrorMsg = L"Failed to retrieve error information"; - } -#endif - } catch (const std::exception& e) { - // Handle exceptions during diagnostic retrieval - LOG("Exception during error diagnostics: {}", e.what()); - errorInfo.ddbcErrorMsg = L"Exception during error diagnostics"; - } - } - - // If we have no error message but there was an error, provide a generic one - if (errorInfo.ddbcErrorMsg.empty() && !SQL_SUCCEEDED(retcode)) { - errorInfo.ddbcErrorMsg = L"Unknown error occurred"; - } - - return errorInfo; -} - - -// Sanitize connection string to remove sensitive information -std::wstring SanitizeConnectionString(const std::wstring& connectionString) { - // This function will remove the UID and Pwd parameters for security reasons - std::wstring sanitizedString = connectionString; - std::wstring lowerCaseString = sanitizedString; - // Convert the string to lowercase for case-insensitive search - // Using lowerCaseString to avoid modifying the original string - // This is necessary because towlower works on wide characters - std::transform(lowerCaseString.begin(), lowerCaseString.end(), lowerCaseString.begin(), - ::towlower); - // Can be UID or uid or UID, test only on lowercase uid - size_t uidPos = lowerCaseString.find(L"uid="); - if (uidPos != std::wstring::npos) { - size_t endPos = sanitizedString.find(L';', uidPos); - if (endPos != std::wstring::npos) { - sanitizedString.erase(uidPos, endPos - uidPos + 1); - lowerCaseString.erase(uidPos, endPos - uidPos + 1); - } else { - sanitizedString.erase(uidPos); - lowerCaseString.erase(uidPos); - } - } - // Can be Pwd or pwd or PWD, test only on lowercase pwd - size_t pwdPos = lowerCaseString.find(L"pwd="); - if (pwdPos != std::wstring::npos) { - size_t endPos = sanitizedString.find(L';', pwdPos); - if (endPos != std::wstring::npos) { - sanitizedString.erase(pwdPos, endPos - pwdPos + 1); - lowerCaseString.erase(pwdPos, endPos - pwdPos + 1); - } else { - sanitizedString.erase(pwdPos); - lowerCaseString.erase(pwdPos); - } - } - return sanitizedString; -} - -// Wrap SQLDriverConnect -SQLRETURN SQLDriverConnect_wrap(intptr_t ConnectionHandle, intptr_t WindowHandle, - const std::wstring& ConnectionString) { - LOG("Driver Connect to MSSQL"); - if (!SQLDriverConnect_ptr) { - LoadDriverOrThrowException(); - } - // LOG("DECLARE SQLDriverConnect_ptr - {}"); - SQLWCHAR* connStrPtr; -#if defined(__APPLE__) // macOS specific code - LOG("Creating connection string buffer for macOS"); - std::vector connStrBuffer = WStringToSQLWCHAR(ConnectionString); - // Ensure the buffer is null-terminated - LOG("Connection string buffer size - {}", connStrBuffer.size()); - connStrPtr = connStrBuffer.data(); - LOG("Connection string buffer created"); -#else - connStrPtr = const_cast(ConnectionString.c_str()); -#endif - // Log the sanitized connection string - LOG("Connection string - {}", SanitizeConnectionString(ConnectionString).c_str()); - return SQLDriverConnect_ptr(reinterpret_cast(ConnectionHandle), - reinterpret_cast(WindowHandle), - connStrPtr, SQL_NTS, nullptr, - 0, nullptr, SQL_DRIVER_NOPROMPT); -} - -// Wrap SQLExecDirect -SQLRETURN SQLExecDirect_wrap(intptr_t StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); - if (!SQLExecDirect_ptr) { - LoadDriverOrThrowException(); - } - - SQLWCHAR* queryPtr; -#if defined(__APPLE__) - std::vector queryBuffer = WStringToSQLWCHAR(Query); - queryPtr = queryBuffer.data(); -#else - queryPtr = const_cast(Query.c_str()); -#endif - - return SQLExecDirect_ptr(reinterpret_cast(StatementHandle), queryPtr, SQL_NTS); -} - -// Executes the provided query. If the query is parametrized, it prepares the statement and -// binds the parameters. Otherwise, it executes the query directly. -// 'usePrepare' parameter can be used to disable the prepare step for queries that might already -// be prepared in a previous call. -SQLRETURN SQLExecute_wrap(const intptr_t statementHandle, - const std::wstring& query /* TODO: Use SQLTCHAR? */, - const py::list& params, const std::vector& paramInfos, - py::list& isStmtPrepared, const bool usePrepare = true) { - LOG("Execute SQL Query - {}", query.c_str()); - if (!SQLPrepare_ptr) { - LoadDriverOrThrowException(); - } - assert(SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && SQLExecDirect_ptr); - - if (params.size() != paramInfos.size()) { - // TODO: This should be a special internal exception, that python wont relay to users as is - ThrowStdException("Number of parameters and paramInfos do not match"); - } RETCODE rc; - SQLHANDLE hStmt = reinterpret_cast(statementHandle); - SQLWCHAR* queryPtr; -#if defined(__APPLE__) - std::vector queryBuffer = WStringToSQLWCHAR(query); - queryPtr = queryBuffer.data(); -#else - queryPtr = const_cast(query.c_str()); -#endif - - if (params.size() == 0) { - // Execute statement directly if the statement is not parametrized. This is the - // fastest way to submit a SQL statement for one-time execution according to - // DDBC documentation - - // https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlexecdirect-function?view=sql-server-ver16 - rc = SQLExecDirect_ptr(hStmt, queryPtr, SQL_NTS); - if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) { - LOG("Error during direct execution of the statement"); - } - return rc; - } else { - // isStmtPrepared is a list instead of a bool coz bools in Python are immutable. - // Hence, we can't pass around bools by reference & modify them. Therefore, isStmtPrepared - // must be a list with exactly one bool element - assert(isStmtPrepared.size() == 1); - if (usePrepare) { - rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); - if (!SQL_SUCCEEDED(rc)) { - LOG("Error while preparing the statement"); - return rc; - } - isStmtPrepared[0] = py::cast(true); - } else { - // Make sure the statement has been prepared earlier if we're not preparing now - bool isStmtPreparedAsBool = isStmtPrepared[0].cast(); - if (!isStmtPreparedAsBool) { - // TODO: Print the query - ThrowStdException("Cannot execute unprepared statement"); - } - } - - // This vector manages the heap memory allocated for parameter buffers. - // It must be in scope until SQLExecute is done. - std::vector> paramBuffers; - rc = BindParameters(hStmt, params, paramInfos, paramBuffers); - if (!SQL_SUCCEEDED(rc)) { - return rc; - } - - LOG("Executing the statement with bound parameters"); - rc = SQLExecute_ptr(hStmt); - LOG("SQLExecute return code - {}", rc); - if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) { - LOG("DDBCSQLExecute: Error during execution of the statement"); - return rc; - } - // TODO: Handle huge input parameters by checking rc == SQL_NEED_DATA - - // Unbind the bound buffers for all parameters coz the buffers' memory will - // be freed when this function exits (parambuffers goes out of scope) - LOG("Unbinding the parameters after execution"); - rc = SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); - LOG("SQLFreeStmt return code - {}", rc); - - return rc; - } -} - -// Wrap SQLNumResultCols -SQLSMALLINT SQLNumResultCols_wrap(intptr_t statementHandle) { - LOG("Get number of columns in result set"); - if (!SQLNumResultCols_ptr) { - LoadDriverOrThrowException(); - } - - SQLSMALLINT columnCount; - // TODO: Handle the return code - SQLNumResultCols_ptr(reinterpret_cast(statementHandle), &columnCount); - return columnCount; -} - -// Wrap SQLDescribeCol -SQLRETURN SQLDescribeCol_wrap(intptr_t StatementHandle, py::list& ColumnMetadata) { - LOG("Get column description"); - if (!SQLDescribeCol_ptr) { - LoadDriverOrThrowException(); - } - - SQLSMALLINT ColumnCount; - SQLRETURN retcode = - SQLNumResultCols_ptr(reinterpret_cast(StatementHandle), &ColumnCount); - if (!SQL_SUCCEEDED(retcode)) { - LOG("Failed to get number of columns"); - return retcode; - } - - for (SQLUSMALLINT i = 1; i <= ColumnCount; ++i) { - SQLWCHAR ColumnName[256]; - SQLSMALLINT NameLength; - SQLSMALLINT DataType; - SQLULEN ColumnSize; - SQLSMALLINT DecimalDigits; - SQLSMALLINT Nullable; - - retcode = SQLDescribeCol_ptr(reinterpret_cast(StatementHandle), i, ColumnName, - sizeof(ColumnName) / sizeof(SQLWCHAR), &NameLength, &DataType, - &ColumnSize, &DecimalDigits, &Nullable); - - if (SQL_SUCCEEDED(retcode)) { - // Append a named py::dict to ColumnMetadata - // TODO: Should we define a struct for this task instead of dict? -#if defined(__APPLE__) - ColumnMetadata.append(py::dict("ColumnName"_a = SQLWCHARToWString(ColumnName), -#else - ColumnMetadata.append(py::dict("ColumnName"_a = std::wstring(ColumnName), -#endif - "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, - "DecimalDigits"_a = DecimalDigits, - "Nullable"_a = Nullable)); - } else { - return retcode; - } - } - return SQL_SUCCESS; -} - -// Wrap SQLFetch to retrieve rows -SQLRETURN SQLFetch_wrap(intptr_t StatementHandle) { - LOG("Fetch next row"); - if (!SQLFetch_ptr) { - LoadDriverOrThrowException(); - } - - return SQLFetch_ptr(reinterpret_cast(StatementHandle)); -} - -// Helper function to retrieve column data -// TODO: Handle variable length data correctly -SQLRETURN SQLGetData_wrap(intptr_t StatementHandle, SQLUSMALLINT colCount, py::list& row) { - LOG("Get data from columns"); - if (!SQLGetData_ptr) { - LoadDriverOrThrowException(); - } - - SQLRETURN ret; - SQLHSTMT hStmt = reinterpret_cast(StatementHandle); - for (SQLSMALLINT i = 1; i <= colCount; ++i) { - SQLWCHAR columnName[256]; - SQLSMALLINT columnNameLen; - SQLSMALLINT dataType; - SQLULEN columnSize; - SQLSMALLINT decimalDigits; - SQLSMALLINT nullable; - - ret = SQLDescribeCol_ptr(hStmt, i, columnName, sizeof(columnName) / sizeof(SQLWCHAR), - &columnNameLen, &dataType, &columnSize, &decimalDigits, &nullable); - if (!SQL_SUCCEEDED(ret)) { - LOG("Error retrieving data for column - {}, SQLDescribeCol return code - {}", i, ret); - row.append(py::none()); - // TODO: Do we want to continue in this case or return? - continue; - } - - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - std::vector dataBuffer(fetchBufferSize); - SQLLEN dataLen; - // TODO: Handle the return code better - ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(), - &dataLen); - - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - // columnSize is in chars, dataLen is in bytes - if (dataLen > 0) { - uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - // NOTE: dataBuffer.size() includes null-terminator, dataLen doesn't. Hence use '<'. - if (numCharsInData < dataBuffer.size()) { - // SQLGetData will null-terminate the data -#if defined(__APPLE__) - std::string fullStr(reinterpret_cast(dataBuffer.data())); - row.append(fullStr); - LOG("macOS: Appended CHAR string of length {} to result row", fullStr.length()); -#else - row.append(std::string(reinterpret_cast(dataBuffer.data()))); -#endif - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << dataBuffer.size()-1 << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << i << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else { - assert(dataLen == SQL_NO_TOTAL); - LOG("SQLGetData couldn't determine the length of the data. " - "Returning NULL value instead. Column ID - {}", i); - row.append(py::none()); - } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - std::vector dataBuffer(fetchBufferSize); - SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), - dataBuffer.size() * sizeof(SQLWCHAR), &dataLen); - - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - if (dataLen > 0) { - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); if (numCharsInData < dataBuffer.size()) { - // SQLGetData will null-terminate the data -#if defined(__APPLE__) - row.append(SQLWCHARToWString(dataBuffer.data())); -#else - row.append(std::wstring(dataBuffer.data())); -#endif - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << dataBuffer.size()-1 << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << i << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else { - assert(dataLen == SQL_NO_TOTAL); - LOG("SQLGetData couldn't determine the length of the data. " - "Returning NULL value instead. Column ID - {}", i); - row.append(py::none()); - } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_INTEGER: { - SQLINTEGER intValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_LONG, &intValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(intValue)); - } else { - row.append(py::none()); - } - break; - } - case SQL_SMALLINT: { - SQLSMALLINT smallIntValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_SHORT, &smallIntValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(smallIntValue)); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_REAL: { - SQLREAL realValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_FLOAT, &realValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(realValue); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_DECIMAL: - case SQL_NUMERIC: { - SQLCHAR numericStr[MAX_DIGITS_IN_NUMERIC] = {0}; - SQLLEN indicator; - ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator); - - if (SQL_SUCCEEDED(ret)) { - try{ - // Convert numericStr to py::decimal.Decimal and append to row - row.append(py::module_::import("decimal").attr("Decimal")( - std::string(reinterpret_cast(numericStr), indicator))); - } catch (const py::error_already_set& e) { - // If the conversion fails, append None - LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); - } - } - else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_DOUBLE: - case SQL_FLOAT: { - SQLDOUBLE doubleValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_DOUBLE, &doubleValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(doubleValue); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_BIGINT: { - SQLBIGINT bigintValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_SBIGINT, &bigintValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(bigintValue)); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_TYPE_DATE: { - SQL_DATE_STRUCT dateValue; - ret = - SQLGetData_ptr(hStmt, i, SQL_C_TYPE_DATE, &dateValue, sizeof(dateValue), NULL); - if (SQL_SUCCEEDED(ret)) { - row.append( - py::module_::import("datetime").attr("date")( - dateValue.year, - dateValue.month, - dateValue.day - ) - ); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: { - SQL_TIME_STRUCT timeValue; - ret = - SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, sizeof(timeValue), NULL); - if (SQL_SUCCEEDED(ret)) { - row.append( - py::module_::import("datetime").attr("time")( - timeValue.hour, - timeValue.minute, - timeValue.second - ) - ); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT timestampValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIMESTAMP, ×tampValue, - sizeof(timestampValue), NULL); - if (SQL_SUCCEEDED(ret)) { - row.append( - py::module_::import("datetime").attr("datetime")( - timestampValue.year, - timestampValue.month, - timestampValue.day, - timestampValue.hour, - timestampValue.minute, - timestampValue.second, - timestampValue.fraction / 1000 // Convert back ns to µs - ) - ); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - // TODO: revisit - HandleZeroColumnSizeAtFetch(columnSize); - std::unique_ptr dataBuffer(new SQLCHAR[columnSize]); - SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.get(), columnSize, &dataLen); - - if (SQL_SUCCEEDED(ret)) { - // TODO: Refactor these if's across other switches to avoid code duplication - if (dataLen > 0) { - if (dataLen <= columnSize) { - row.append(py::bytes(reinterpret_cast( - dataBuffer.get()), dataLen)); - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << dataLen << "). ColumnID - " - << i << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - } else if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - } else { - assert(dataLen == SQL_NO_TOTAL); - LOG("SQLGetData couldn't determine the length of the data. " - "Returning NULL value instead. Column ID - {}", i); - row.append(py::none()); - } - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_TINYINT: { - SQLCHAR tinyIntValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_TINYINT, &tinyIntValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(tinyIntValue)); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } - case SQL_BIT: { - SQLCHAR bitValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_BIT, &bitValue, 0, NULL); - if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(bitValue)); - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } -#if (ODBCVER >= 0x0350) - case SQL_GUID: { - SQLGUID guidValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), NULL); - if (SQL_SUCCEEDED(ret)) { - std::ostringstream oss; - oss << std::hex << std::setfill('0') << std::setw(8) << guidValue.Data1 << '-' - << std::setw(4) << guidValue.Data2 << '-' << std::setw(4) << guidValue.Data3 - << '-' << std::setw(2) << static_cast(guidValue.Data4[0]) - << std::setw(2) << static_cast(guidValue.Data4[1]) << '-' << std::hex - << std::setw(2) << static_cast(guidValue.Data4[2]) << std::setw(2) - << static_cast(guidValue.Data4[3]) << std::setw(2) - << static_cast(guidValue.Data4[4]) << std::setw(2) - << static_cast(guidValue.Data4[5]) << std::setw(2) - << static_cast(guidValue.Data4[6]) << std::setw(2) - << static_cast(guidValue.Data4[7]); - row.append(oss.str()); // Append GUID as a string - } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", - i, dataType, ret); - row.append(py::none()); - } - break; - } -#endif - default: - std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName << ", Type - " - << dataType << ", column ID - " << i; - LOG(errorString.str()); - ThrowStdException(errorString.str()); - break; - } - } - return ret; -} - -// For column in the result set, binds a buffer to retrieve column data -// TODO: Move to anonymous namespace, since it is not used outside this file -SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - SQLUSMALLINT numCols, int fetchSize) { - SQLRETURN ret = SQL_SUCCESS; - // Bind columns based on their data types - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: { - // TODO: handle variable length data correctly. This logic wont suffice - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - // TODO: For LONGVARCHAR/BINARY types, columnSize is returned as 2GB-1 by - // SQLDescribeCol. So fetchBufferSize = 2GB. fetchSize=1 if columnSize>1GB. - // So we'll allocate a vector of size 2GB. If a query fetches multiple (say N) - // LONG... columns, we will have allocated multiple (N) 2GB sized vectors. This - // will make driver very slow. And if the N is high enough, we could hit the OS - // limit for heap memory that we can allocate, & hence get a std::bad_alloc. The - // process could also be killed by OS for consuming too much memory. - // Hence this will be revisited in beta to not allocate 2GB+ memory, - // & use streaming instead - buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), - fetchBufferSize * sizeof(SQLCHAR), - buffers.indicators[col - 1].data()); - break; - } - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - // TODO: handle variable length data correctly. This logic wont suffice - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - buffers.wcharBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(), - fetchBufferSize * sizeof(SQLWCHAR), - buffers.indicators[col - 1].data()); - break; - } - case SQL_INTEGER: - buffers.intBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_SLONG, buffers.intBuffers[col - 1].data(), - sizeof(SQLINTEGER), buffers.indicators[col - 1].data()); - break; - case SQL_SMALLINT: - buffers.smallIntBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_SSHORT, - buffers.smallIntBuffers[col - 1].data(), sizeof(SQLSMALLINT), - buffers.indicators[col - 1].data()); - break; - case SQL_TINYINT: - buffers.charBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_TINYINT, buffers.charBuffers[col - 1].data(), - sizeof(SQLCHAR), buffers.indicators[col - 1].data()); - break; - case SQL_BIT: - buffers.charBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_BIT, buffers.charBuffers[col - 1].data(), - sizeof(SQLCHAR), buffers.indicators[col - 1].data()); - break; - case SQL_REAL: - buffers.realBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_FLOAT, buffers.realBuffers[col - 1].data(), - sizeof(SQLREAL), buffers.indicators[col - 1].data()); - break; - case SQL_DECIMAL: - case SQL_NUMERIC: - buffers.charBuffers[col - 1].resize(fetchSize * MAX_DIGITS_IN_NUMERIC); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), - MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), - buffers.indicators[col - 1].data()); - break; - case SQL_DOUBLE: - case SQL_FLOAT: - buffers.doubleBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_DOUBLE, buffers.doubleBuffers[col - 1].data(), - sizeof(SQLDOUBLE), buffers.indicators[col - 1].data()); - break; - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: - buffers.timestampBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr( - hStmt, col, SQL_C_TYPE_TIMESTAMP, buffers.timestampBuffers[col - 1].data(), - sizeof(SQL_TIMESTAMP_STRUCT), buffers.indicators[col - 1].data()); - break; - case SQL_BIGINT: - buffers.bigIntBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_SBIGINT, buffers.bigIntBuffers[col - 1].data(), - sizeof(SQLBIGINT), buffers.indicators[col - 1].data()); - break; - case SQL_TYPE_DATE: - buffers.dateBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_DATE, buffers.dateBuffers[col - 1].data(), - sizeof(SQL_DATE_STRUCT), buffers.indicators[col - 1].data()); - break; - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: - buffers.timeBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_TIME, buffers.timeBuffers[col - 1].data(), - sizeof(SQL_TIME_STRUCT), buffers.indicators[col - 1].data()); - break; - case SQL_GUID: - buffers.guidBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_GUID, buffers.guidBuffers[col - 1].data(), - sizeof(SQLGUID), buffers.indicators[col - 1].data()); - break; - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: - // TODO: handle variable length data correctly. This logic wont suffice - HandleZeroColumnSizeAtFetch(columnSize); - buffers.charBuffers[col - 1].resize(fetchSize * columnSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_BINARY, buffers.charBuffers[col - 1].data(), - columnSize, buffers.indicators[col - 1].data()); - break; - default: - std::wstring columnName = columnMeta["ColumnName"].cast(); - std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; - LOG(errorString.str()); - ThrowStdException(errorString.str()); - break; - } - if (!SQL_SUCCEEDED(ret)) { - std::wstring columnName = columnMeta["ColumnName"].cast(); - std::ostringstream errorString; - errorString << "Failed to bind column - " << columnName.c_str() << ", Type - " - << dataType << ", column ID - " << col; - LOG(errorString.str()); - ThrowStdException(errorString.str()); - return ret; - } - } - return ret; -} - -// Fetch rows in batches -// TODO: Move to anonymous namespace, since it is not used outside this file -SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) { - LOG("Fetching data in batches"); - SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); - if (ret == SQL_NO_DATA) { - LOG("No data to fetch"); - return ret; - } - if (!SQL_SUCCEEDED(ret)) { - LOG("Error while fetching rows in batches"); - return ret; - } - // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by - // SQLFetchScroll - for (SQLULEN i = 0; i < numRowsFetched; i++) { - py::list row; - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); - SQLLEN dataLen = buffers.indicators[col - 1][i]; - - if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - continue; - } - // TODO: variable length data needs special handling, this logic wont suffice - // This value indicates that the driver cannot determine the length of the data - if (dataLen == SQL_NO_TOTAL) { - LOG("Cannot determine the length of the data. Returning NULL value instead." - "Column ID - {}", col); - row.append(py::none()); - continue; - } - assert(dataLen > 0 && "Must be > 0 since SQL_NULL_DATA & SQL_NO_DATA is already handled"); - - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: { - // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data - row.append(std::string( - reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - break; - } case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data -#if defined(__APPLE__) - // Use macOS-specific conversion to handle the wchar_t/SQLWCHAR size difference - SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; - std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); - row.append(wstr); -#else - // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works - row.append(std::wstring( - reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); -#endif - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << numCharsInData << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - break; - } - case SQL_INTEGER: { - row.append(buffers.intBuffers[col - 1][i]); - break; - } - case SQL_SMALLINT: { - row.append(buffers.smallIntBuffers[col - 1][i]); - break; - } - case SQL_TINYINT: { - row.append(buffers.charBuffers[col - 1][i]); - break; - } - case SQL_BIT: { - row.append(static_cast(buffers.charBuffers[col - 1][i])); - break; - } - case SQL_REAL: { - row.append(buffers.realBuffers[col - 1][i]); - break; - } - case SQL_DECIMAL: - case SQL_NUMERIC: { - try { - // Convert numericStr to py::decimal.Decimal and append to row - row.append(py::module_::import("decimal").attr("Decimal")(std::string( - reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), - buffers.indicators[col - 1][i]))); - } catch (const py::error_already_set& e) { - // Handle the exception, e.g., log the error and append py::none() - LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); - } - break; - } - case SQL_DOUBLE: - case SQL_FLOAT: { - row.append(buffers.doubleBuffers[col - 1][i]); - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: { - row.append(py::module_::import("datetime") - .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, - buffers.timestampBuffers[col - 1][i].month, - buffers.timestampBuffers[col - 1][i].day, - buffers.timestampBuffers[col - 1][i].hour, - buffers.timestampBuffers[col - 1][i].minute, - buffers.timestampBuffers[col - 1][i].second, - buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); - break; - } - case SQL_BIGINT: { - row.append(buffers.bigIntBuffers[col - 1][i]); - break; - } - case SQL_TYPE_DATE: { - row.append(py::module_::import("datetime") - .attr("date")(buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day)); - break; - } - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: { - row.append(py::module_::import("datetime") - .attr("time")(buffers.timeBuffers[col - 1][i].hour, - buffers.timeBuffers[col - 1][i].minute, - buffers.timeBuffers[col - 1][i].second)); - break; - } - case SQL_GUID: { - row.append( - py::bytes(reinterpret_cast(&buffers.guidBuffers[col - 1][i]), - sizeof(SQLGUID))); - break; - } - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - if (dataLen <= columnSize) { - row.append(py::bytes(reinterpret_cast( - &buffers.charBuffers[col - 1][i * columnSize]), - dataLen)); - } else { - // In this case, buffer size is smaller, and data to be retrieved is longer - // TODO: Revisit - std::ostringstream oss; - oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data " - << "to be retrieved is longer (" << dataLen << "). ColumnID - " - << col << ", datatype - " << dataType; - ThrowStdException(oss.str()); - } - break; - } - default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); - std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; - LOG(errorString.str()); - ThrowStdException(errorString.str()); - break; - } - } - } - rows.append(row); - } - return ret; -} - -// Given a list of columns that are a part of single row in the result set, calculates -// the max size of the row -// TODO: Move to anonymous namespace, since it is not used outside this file -size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { - size_t rowSize = 0; - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: - rowSize += columnSize; - break; - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: - rowSize += columnSize * sizeof(SQLWCHAR); - break; - case SQL_INTEGER: - rowSize += sizeof(SQLINTEGER); - break; - case SQL_SMALLINT: - rowSize += sizeof(SQLSMALLINT); - break; - case SQL_REAL: - rowSize += sizeof(SQLREAL); - break; - case SQL_FLOAT: - rowSize += sizeof(SQLFLOAT); - break; - case SQL_DOUBLE: - rowSize += sizeof(SQLDOUBLE); - break; - case SQL_DECIMAL: - case SQL_NUMERIC: - rowSize += MAX_DIGITS_IN_NUMERIC; - break; - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: - rowSize += sizeof(SQL_TIMESTAMP_STRUCT); - break; - case SQL_BIGINT: - rowSize += sizeof(SQLBIGINT); - break; - case SQL_TYPE_DATE: - rowSize += sizeof(SQL_DATE_STRUCT); - break; - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: - rowSize += sizeof(SQL_TIME_STRUCT); - break; - case SQL_GUID: - rowSize += sizeof(SQLGUID); - break; - case SQL_TINYINT: - case SQL_BIT: - rowSize += sizeof(SQLCHAR); - break; - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: - rowSize += columnSize; - break; - default: - std::wstring columnName = columnMeta["ColumnName"].cast(); - std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; - LOG(errorString.str()); - ThrowStdException(errorString.str()); - break; - } - } - return rowSize; -} - -// FetchMany_wrap - Fetches multiple rows of data from the result set. -// -// @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param rows: A Python list that will be populated with the fetched rows of data. -// @param fetchSize: The number of rows to fetch. Default value is 1. -// -// @return SQLRETURN: SQL_SUCCESS if data is fetched successfully, -// SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. -// -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches the specified number of rows from the result set and populates the provided -// Python list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an -// error occurs during fetching, it throws a runtime error. -SQLRETURN FetchMany_wrap(intptr_t StatementHandle, py::list& rows, int fetchSize = 1) { - SQLRETURN ret; - SQLHSTMT hStmt = reinterpret_cast(StatementHandle); - // Retrieve column count - SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); - - // Retrieve column metadata - py::list columnNames; - ret = SQLDescribeCol_wrap(StatementHandle, columnNames); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to get column descriptions"); - return ret; - } - - // Initialize column buffers - ColumnBuffers buffers(numCols, fetchSize); - - // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); - if (!SQL_SUCCEEDED(ret)) { - LOG("Error when binding columns"); - return ret; - } - - SQLULEN numRowsFetched; - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)fetchSize, 0); - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched); - if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { - LOG("Error when fetching data"); - return ret; - } - - return ret; -} - -// FetchAll_wrap - Fetches all rows of data from the result set. -// -// @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param rows: A Python list that will be populated with the fetched rows of data. -// -// @return SQLRETURN: SQL_SUCCESS if data is fetched successfully, -// SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. -// -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches all rows from the result set and populates the provided Python list with the -// row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error occurs during -// fetching, it throws a runtime error. -SQLRETURN FetchAll_wrap(intptr_t StatementHandle, py::list& rows) { - SQLRETURN ret; - SQLHSTMT hStmt = reinterpret_cast(StatementHandle); - // Retrieve column count - SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); - - // Retrieve column metadata - py::list columnNames; - ret = SQLDescribeCol_wrap(StatementHandle, columnNames); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to get column descriptions"); - return ret; - } - - // Define a memory limit (1 GB) - const size_t memoryLimit = 1ULL * 1024 * 1024 * 1024; // 1 GB - size_t totalRowSize = calculateRowSize(columnNames, numCols); - - // Calculate fetch size based on the total row size and memory limit - size_t numRowsInMemLimit; - if (totalRowSize > 0) { - numRowsInMemLimit = static_cast(memoryLimit / totalRowSize); - } else { - // Handle case where totalRowSize is 0 to avoid division by zero. - // This can happen for NVARCHAR(MAX) cols. SQLDescribeCol returns 0 - // for column size of such columns. - // TODO: Find why NVARCHAR(MAX) returns columnsize 0 - // TODO: What if a row has 2 cols, an int & NVARCHAR(MAX)? - // totalRowSize will be 4+0 = 4. It wont take NVARCHAR(MAX) - // into account. So, we will end up fetching 1000 rows at a time. - numRowsInMemLimit = 1; // fetchsize will be 10 - } - // TODO: Revisit this logic. Eventhough we're fetching fetchSize rows at a time, - // fetchall will keep all rows in memory anyway. So what are we gaining by fetching - // fetchSize rows at a time? - // Also, say the table has only 10 rows, each row size if 100 bytes. Here, we'll have - // fetchSize = 1000, so we'll allocate memory for 1000 rows inside SQLBindCol_wrap, while - // actually only need to retrieve 10 rows - int fetchSize; - if (numRowsInMemLimit == 0) { - // If the row size is larger than the memory limit, fetch one row at a time - fetchSize = 1; - } else if (numRowsInMemLimit > 0 && numRowsInMemLimit <= 100) { - // If between 1-100 rows fit in memoryLimit, fetch 10 rows at a time - fetchSize = 10; - } else if (numRowsInMemLimit > 100 && numRowsInMemLimit <= 1000) { - // If between 100-1000 rows fit in memoryLimit, fetch 100 rows at a time - fetchSize = 100; - } else { - fetchSize = 1000; - } - LOG("Fetching data in batch sizes of {}", fetchSize); - - ColumnBuffers buffers(numCols, fetchSize); - - // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); - if (!SQL_SUCCEEDED(ret)) { - LOG("Error when binding columns"); - return ret; - } - - SQLULEN numRowsFetched; - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)fetchSize, 0); - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - - while (ret != SQL_NO_DATA) { - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched); - if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { - LOG("Error when fetching data"); - return ret; - } - } - - return ret; -} - -// FetchOne_wrap - Fetches a single row of data from the result set. -// -// @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param row: A Python list that will be populated with the fetched row data. -// -// @return SQLRETURN: SQL_SUCCESS or SQL_SUCCESS_WITH_INFO if data is fetched successfully, -// SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. -// -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches the next row of data from the result set and populates the provided Python -// list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error -// occurs during fetching, it throws a runtime error. -SQLRETURN FetchOne_wrap(intptr_t StatementHandle, py::list& row) { - SQLRETURN ret; - SQLHSTMT hStmt = reinterpret_cast(StatementHandle); - - // Assume hStmt is already allocated and a query has been executed - ret = SQLFetch_ptr(hStmt); - if (SQL_SUCCEEDED(ret)) { - // Retrieve column count - SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); - ret = SQLGetData_wrap(StatementHandle, colCount, row); - } else if (ret != SQL_NO_DATA) { - LOG("Error when fetching data"); - } - return ret; -} - -// Wrap SQLMoreResults -SQLRETURN SQLMoreResults_wrap(intptr_t StatementHandle) { - LOG("Check for more results"); - if (!SQLMoreResults_ptr) { - LoadDriverOrThrowException(); - } - - return SQLMoreResults_ptr(reinterpret_cast(StatementHandle)); -} - -// Wrap SQLEndTran -SQLRETURN SQLEndTran_wrap(SQLSMALLINT HandleType, intptr_t Handle, SQLSMALLINT CompletionType) { - LOG("End SQL Transaction"); - if (!SQLEndTran_ptr) { - LoadDriverOrThrowException(); - } - - return SQLEndTran_ptr(HandleType, reinterpret_cast(Handle), CompletionType); -} - -// Wrap SQLFreeHandle -SQLRETURN SQLFreeHandle_wrap(SQLSMALLINT HandleType, intptr_t Handle) { - LOG("Free SQL handle"); - if (!SQLAllocHandle_ptr) { - LoadDriverOrThrowException(); - } - - return SQLFreeHandle_ptr(HandleType, reinterpret_cast(Handle)); -} - -// Wrap SQLDisconnect -SQLRETURN SQLDisconnect_wrap(intptr_t ConnectionHandle) { - LOG("Disconnect from MSSQL"); - if (!SQLDisconnect_ptr) { - LoadDriverOrThrowException(); - } - - return SQLDisconnect_ptr(reinterpret_cast(ConnectionHandle)); -} - -// Wrap SQLRowCount -SQLLEN SQLRowCount_wrap(intptr_t StatementHandle) { - LOG("Get number of row affected by last execute"); - if (!SQLRowCount_ptr) { - LoadDriverOrThrowException(); - } - - SQLLEN rowCount; - SQLRETURN ret = SQLRowCount_ptr(reinterpret_cast(StatementHandle), &rowCount); - if (!SQL_SUCCEEDED(ret)) { - LOG("SQLRowCount failed with error code - {}", ret); - return ret; - } - LOG("SQLRowCount returned {}", rowCount); - return rowCount; -} - -// Functions/data to be exposed to Python as a part of ddbc_bindings module -PYBIND11_MODULE(ddbc_bindings, m) { - m.doc() = "msodbcsql driver api bindings for Python"; - - // Add architecture information as module attribute - m.attr("__architecture__") = ARCHITECTURE; - - // Expose architecture-specific constants - m.attr("ARCHITECTURE") = ARCHITECTURE; - - // Expose the C++ functions to Python - m.def("ThrowStdException", &ThrowStdException); - - // Define parameter info class - py::class_(m, "ParamInfo") - .def(py::init<>()) - .def_readwrite("inputOutputType", &ParamInfo::inputOutputType) - .def_readwrite("paramCType", &ParamInfo::paramCType) - .def_readwrite("paramSQLType", &ParamInfo::paramSQLType) - .def_readwrite("columnSize", &ParamInfo::columnSize) - .def_readwrite("decimalDigits", &ParamInfo::decimalDigits); - - // Define numeric data class - py::class_(m, "NumericData") - .def(py::init<>()) - .def(py::init()) - .def_readwrite("precision", &NumericData::precision) - .def_readwrite("scale", &NumericData::scale) - .def_readwrite("sign", &NumericData::sign) - .def_readwrite("val", &NumericData::val); - - // Define error info class - py::class_(m, "ErrorInfo") - .def_readwrite("sqlState", &ErrorInfo::sqlState) - .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); - - // Expose all the SQL functions with proper error handling - m.def("DDBCSQLAllocHandle", &SQLAllocHandle_wrap, - "Allocate an environment, connection, statement, or descriptor handle"); - m.def("DDBCSQLSetEnvAttr", &SQLSetEnvAttr_wrap, - "Set an attribute that governs aspects of environments"); - m.def("DDBCSQLSetConnectAttr", &SQLSetConnectAttr_wrap, - "Set an attribute that governs aspects of connections"); - m.def("DDBCSQLSetStmtAttr", &SQLSetStmtAttr_wrap, - "Set an attribute that governs aspects of statements"); - m.def("DDBCSQLGetConnectionAttr", &SQLGetConnectionAttr_wrap, - "Get an attribute that governs aspects of connections"); - m.def("DDBCSQLDriverConnect", &SQLDriverConnect_wrap, - "Connect to a data source with a connection string"); - m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); - m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); - m.def("DDBCSQLRowCount", &SQLRowCount_wrap, - "Get the number of rows affected by the last statement"); - m.def("DDBCSQLFetch", &SQLFetch_wrap, "Fetch the next row from the result set"); - m.def("DDBCSQLNumResultCols", &SQLNumResultCols_wrap, - "Get the number of columns in the result set"); - m.def("DDBCSQLDescribeCol", &SQLDescribeCol_wrap, - "Get information about a column in the result set"); - m.def("DDBCSQLGetData", &SQLGetData_wrap, "Retrieve data from the result set"); - m.def("DDBCSQLMoreResults", &SQLMoreResults_wrap, "Check for more results in the result set"); - m.def("DDBCSQLFetchOne", &FetchOne_wrap, "Fetch one row from the result set"); - m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), - py::arg("fetchSize") = 1, "Fetch many rows from the result set"); - m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); - m.def("DDBCSQLEndTran", &SQLEndTran_wrap, "End a transaction"); - m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); - m.def("DDBCSQLDisconnect", &SQLDisconnect_wrap, "Disconnect from a data source"); - m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); - - // Add a version attribute - m.attr("__version__") = "1.0.0"; - - try { - // Try loading the ODBC driver when the module is imported - LoadDriverOrThrowException(); - } catch (const std::exception& e) { - // Log the error but don't throw - let the error happen when functions are called - LOG("Failed to load ODBC driver during module initialization: {}", e.what()); - } -} diff --git a/mssql_python/pybind/mac_fix.cpp b/mssql_python/pybind/mac_utils.cpp similarity index 94% rename from mssql_python/pybind/mac_fix.cpp rename to mssql_python/pybind/mac_utils.cpp index 06ebb2539..bd792d8ab 100644 --- a/mssql_python/pybind/mac_fix.cpp +++ b/mssql_python/pybind/mac_utils.cpp @@ -1,5 +1,10 @@ -// Mac OS specific fixes for the C++ code -// This file contains patches to fix issues specific to macOS +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// This header defines utility functions for safely handling SQLWCHAR-based +// wide-character data in ODBC operations on macOS. It includes conversions +// between SQLWCHAR, std::wstring, and UTF-8 strings to bridge encoding +// differences specific to macOS. #if defined(__APPLE__) // Constants for character encoding diff --git a/mssql_python/pybind/mac_fix.h b/mssql_python/pybind/mac_utils.h similarity index 75% rename from mssql_python/pybind/mac_fix.h rename to mssql_python/pybind/mac_utils.h index 04592b048..776ab6447 100644 --- a/mssql_python/pybind/mac_fix.h +++ b/mssql_python/pybind/mac_utils.h @@ -1,3 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// This header defines utility functions for safely handling SQLWCHAR-based +// wide-character data in ODBC operations on macOS. It includes conversions +// between SQLWCHAR, std::wstring, and UTF-8 strings to bridge encoding +// differences specific to macOS. + #pragma once #include diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 635979747..a71255ac3 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -11,7 +11,8 @@ """ import pytest -from mssql_python import Connection, connect +import time +from mssql_python import Connection, connect, pooling def drop_table_if_exists(cursor, table_name): """Drop the table if it exists""" @@ -178,3 +179,47 @@ def test_connection_close(conn_str): temp_conn = connect(conn_str) # Check if the database connection can be closed temp_conn.close() + +def test_connection_pooling_speed(conn_str): + # No pooling + start_no_pool = time.perf_counter() + conn1 = connect(conn_str) + conn1.close() + end_no_pool = time.perf_counter() + no_pool_duration = end_no_pool - start_no_pool + + # Second connection + start2 = time.perf_counter() + conn2 = connect(conn_str) + conn2.close() + end2 = time.perf_counter() + duration2 = end2 - start2 + + # Pooling enabled + pooling(max_size=2, idle_timeout=10) + connect(conn_str).close() + + # Pooled connection (should be reused, hence faster) + start_pool = time.perf_counter() + conn2 = connect(conn_str) + conn2.close() + end_pool = time.perf_counter() + pool_duration = end_pool - start_pool + assert pool_duration < no_pool_duration, "Expected faster connection with pooling" + +def test_connection_pooling_basic(conn_str): + # Enable pooling with small pool size + pooling(max_size=2, idle_timeout=5) + conn1 = connect(conn_str) + conn2 = connect(conn_str) + assert conn1 is not None + assert conn2 is not None + try: + conn3 = connect(conn_str) + assert conn3 is not None, "Third connection failed — pooling is not working or limit is too strict" + conn3.close() + except Exception as e: + print(f"Expected: Could not open third connection due to max_size=2: {e}") + + conn1.close() + conn2.close()