Skip to content
47 changes: 29 additions & 18 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,25 @@ def get_token_struct(token: str) -> bytes:

@staticmethod
def get_token(auth_type: str) -> bytes:
"""Get token using the specified authentication type"""
"""Get DDBC token struct for the specified authentication type."""
token_struct, _ = AADAuth._acquire_token(auth_type)
return token_struct

@staticmethod
def get_raw_token(auth_type: str) -> str:
"""Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).

This deliberately does NOT cache the token — each call goes through
Azure Identity, which has its own internal cache on the credential
object. A fresh acquisition avoids expired-token errors when
bulkcopy() is called long after the original DDBC connect().
"""
_, raw_token = AADAuth._acquire_token(auth_type)
return raw_token

@staticmethod
def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
"""Internal: acquire token and return (ddbc_struct, raw_jwt)."""
# Import Azure libraries inside method to support test mocking
# pylint: disable=import-outside-toplevel
try:
Expand Down Expand Up @@ -61,22 +79,15 @@ def get_token(auth_type: str) -> bytes:
)

try:
logger.debug(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed unnecessary log statements, there is an info log above which captures these

"get_token: Creating credential instance - credential_class=%s",
credential_class.__name__,
)
credential = credential_class()
logger.debug(
"get_token: Requesting token from Azure AD - scope=https://database.windows.net/.default"
)
token = credential.get_token("https://database.windows.net/.default").token
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"get_token: Azure AD token acquired successfully - token_length=%d chars",
len(token),
len(raw_token),
)
return AADAuth.get_token_struct(token)
token_struct = AADAuth.get_token_struct(raw_token)
return token_struct, raw_token
except ClientAuthenticationError as e:
# Re-raise with more specific context about Azure AD authentication failure
logger.error(
"get_token: Azure AD authentication failed - credential_class=%s, error=%s",
credential_class.__name__,
Expand All @@ -88,7 +99,6 @@ def get_token(auth_type: str) -> bytes:
f"user cancellation, network issues, or unsupported configuration."
) from e
except Exception as e:
# Catch any other unexpected exceptions
logger.error(
"get_token: Unexpected error during credential creation - credential_class=%s, error=%s",
credential_class.__name__,
Expand Down Expand Up @@ -180,7 +190,7 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:


def get_auth_token(auth_type: str) -> Optional[bytes]:
"""Get authentication token based on auth type"""
"""Get DDBC authentication token struct based on auth type."""
logger.debug("get_auth_token: Starting - auth_type=%s", auth_type)
if not auth_type:
logger.debug("get_auth_token: No auth_type specified, returning None")
Expand All @@ -204,15 +214,16 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:

def process_connection_string(
connection_string: str,
) -> Tuple[str, Optional[Dict[int, bytes]]]:
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]:
"""
Process connection string and handle authentication.

Args:
connection_string: The connection string to process

Returns:
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
attrs_before dict if needed, and auth_type string for bulk copy token acquisition

Raises:
ValueError: If the connection string is invalid or empty
Expand Down Expand Up @@ -259,7 +270,7 @@ def process_connection_string(
"process_connection_string: Token authentication configured successfully - auth_type=%s",
auth_type,
)
return ";".join(modified_parameters) + ";", {1256: token_struct}
return ";".join(modified_parameters) + ";", {1256: token_struct}, auth_type
else:
logger.warning(
"process_connection_string: Token acquisition failed, proceeding without token"
Expand All @@ -269,4 +280,4 @@ def process_connection_string(
"process_connection_string: Connection string processing complete - has_auth=%s",
bool(auth_type),
)
return ";".join(modified_parameters) + ";", None
return ";".join(modified_parameters) + ";", None, None
7 changes: 7 additions & 0 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def __init__(
},
}

# Auth type for acquiring fresh tokens at bulk copy time.
# We intentionally do NOT cache the token — a fresh one is acquired
# each time bulkcopy() is called to avoid expired-token errors.
self._auth_type = None

# Check if the connection string contains authentication parameters
# This is important for processing the connection string correctly.
# If authentication is specified, it will be processed to handle
Expand All @@ -272,6 +277,8 @@ def __init__(
self.connection_str = connection_result[0]
if connection_result[1]:
self._attrs_before.update(connection_result[1])
# Store auth type so bulkcopy() can acquire a fresh token later
self._auth_type = connection_result[2]

self._closed = False
self._timeout = timeout
Expand Down
29 changes: 22 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,15 +2607,30 @@ def _bulkcopy(
context = {
"server": params.get("server"),
"database": params.get("database"),
"user_name": params.get("uid", ""),
"trust_server_certificate": trust_cert,
"encryption": encryption,
}

# Extract password separately to avoid storing it in generic context that may be logged
password = params.get("pwd", "")
# Build pycore_context with appropriate authentication.
# For Azure AD: acquire a FRESH token right now instead of reusing
# the one from connect() time — avoids expired-token errors when
# bulkcopy() is called long after the original connection.
pycore_context = dict(context)
pycore_context["password"] = password

if self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection
from mssql_python.auth import AADAuth

raw_token = AADAuth.get_raw_token(self.connection._auth_type)
pycore_context["access_token"] = raw_token
logger.debug(
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
self.connection._auth_type,
)
else:
# SQL Server authentication — use uid/password from connection string
pycore_context["user_name"] = params.get("uid", "")
pycore_context["password"] = params.get("pwd", "")

pycore_connection = None
pycore_cursor = None
Expand Down Expand Up @@ -2653,10 +2668,10 @@ def _bulkcopy(

finally:
# Clear sensitive data to minimize memory exposure
password = ""
if pycore_context:
pycore_context["password"] = ""
pycore_context["user_name"] = ""
pycore_context.pop("password", None)
pycore_context.pop("user_name", None)
pycore_context.pop("access_token", None)
# Clean up bulk copy resources
for resource in (pycore_cursor, pycore_connection):
if resource and hasattr(resource, "close"):
Expand Down
26 changes: 23 additions & 3 deletions tests/test_008_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import platform
import sys
from unittest.mock import patch, MagicMock
from mssql_python.auth import (
AADAuth,
process_auth_parameters,
Expand Down Expand Up @@ -82,6 +83,11 @@ def test_get_token_struct(self):
assert isinstance(token_struct, bytes)
assert len(token_struct) > 4

def test_get_raw_token_default(self):
raw_token = AADAuth.get_raw_token("default")
assert isinstance(raw_token, str)
assert raw_token == SAMPLE_TOKEN

def test_get_token_default(self):
token_struct = AADAuth.get_token("default")
assert isinstance(token_struct, bytes)
Expand Down Expand Up @@ -326,34 +332,37 @@ def test_remove_sensitive_parameters(self):
class TestProcessConnectionString:
def test_process_connection_string_with_default_auth(self):
conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert attrs is not None
assert 1256 in attrs
assert isinstance(attrs[1256], bytes)
assert auth_type == "default"

def test_process_connection_string_no_auth(self):
conn_str = "Server=test;Database=testdb;UID=user;PWD=password"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert "UID=user" in result_str
assert "PWD=password" in result_str
assert attrs is None
assert auth_type is None

def test_process_connection_string_interactive_non_windows(self, monkeypatch):
monkeypatch.setattr(platform, "system", lambda: "Darwin")
conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert attrs is not None
assert 1256 in attrs
assert isinstance(attrs[1256], bytes)
assert auth_type == "interactive"


def test_error_handling():
Expand All @@ -368,3 +377,14 @@ def test_error_handling():
# Test non-string input
with pytest.raises(ValueError, match="Connection string must be a string"):
process_connection_string(None)


class TestConnectionAuthType:
@patch("mssql_python.connection.ddbc_bindings.Connection")
def test_auth_type_stored_on_connection(self, mock_ddbc_conn):
mock_ddbc_conn.return_value = MagicMock()
from mssql_python import connect

conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault")
assert conn._auth_type == "default"
conn.close()
Loading