Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/devskim.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: devskim-results.sarif
# Exclude TODO (DS176209) and localhost (DS162092) alerts
# Since they are used only for future improvements & tests respectively
exclude-rules: DS176209,DS162092
24 changes: 12 additions & 12 deletions mssql_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

# Connection Objects
from .db_connection import connect, Connection
from .pooling import PoolingManager

# Cursor Objects
from .cursor import Cursor
Expand All @@ -58,20 +59,19 @@
paramstyle = "qmark"
threadsafety = 1

from .pooling import PoolingManager
def pooling(max_size=100, idle_timeout=600, enabled=True):
# """
# Enable connection pooling with the specified parameters.
# By default:
# - If not explicitly called, pooling will be auto-enabled with default values.
"""
Enable connection pooling with the specified parameters.
By default:
- If not explicitly called, pooling will be auto-enabled with default values.

Args:
max_size (int): Maximum number of connections in the pool.
idle_timeout (int): Time in seconds before idle connections are closed.

# Args:
# max_size (int): Maximum number of connections in the pool.
# idle_timeout (int): Time in seconds before idle connections are closed.

# Returns:
# None
# """
Returns:
None
"""
if not enabled:
PoolingManager.disable()
else:
Expand Down
58 changes: 35 additions & 23 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import platform
import struct
from typing import Tuple, Dict, Optional, Union
from typing import Tuple, Dict, Optional
from mssql_python.constants import AuthType


class AADAuth:
"""Handles Azure Active Directory authentication"""

@staticmethod
def get_token_struct(token: str) -> bytes:
"""Convert token to SQL Server compatible format"""
Expand All @@ -22,21 +23,21 @@ def get_token_struct(token: str) -> bytes:
def get_token(auth_type: str) -> bytes:
"""Get token using the specified authentication type"""
from azure.identity import (
DefaultAzureCredential,
DeviceCodeCredential,
InteractiveBrowserCredential
DefaultAzureCredential,
DeviceCodeCredential,
InteractiveBrowserCredential,
)
from azure.core.exceptions import ClientAuthenticationError

# Mapping of auth types to credential classes
credential_map = {
"default": DefaultAzureCredential,
"devicecode": DeviceCodeCredential,
"interactive": InteractiveBrowserCredential,
}

credential_class = credential_map[auth_type]

try:
credential = credential_class()
token = credential.get_token("https://database.windows.net/.default").token
Expand All @@ -50,18 +51,21 @@ def get_token(auth_type: str) -> bytes:
) from e
except Exception as e:
# Catch any other unexpected exceptions
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e
raise RuntimeError(
f"Failed to create {credential_class.__name__}: {e}"
) from e


def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
"""
Process connection parameters and extract authentication type.

Args:
parameters: List of connection string parameters

Returns:
Tuple[list, Optional[str]]: Modified parameters and authentication type

Raises:
ValueError: If an invalid authentication type is provided
"""
Expand All @@ -88,7 +92,7 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
# Interactive authentication (browser-based); only append parameter for non-Windows
if platform.system().lower() == "windows":
auth_type = None # Let Windows handle AADInteractive natively

elif value_lower == AuthType.DEVICE_CODE.value:
# Device code authentication (for devices without browser)
auth_type = "devicecode"
Expand All @@ -99,40 +103,48 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:

return modified_parameters, auth_type


def remove_sensitive_params(parameters: list) -> list:
"""Remove sensitive parameters from connection string"""
exclude_keys = [
"uid=", "pwd=", "encrypt=", "trustservercertificate=", "authentication="
"uid=",
"pwd=",
"encrypt=",
"trustservercertificate=",
"authentication=",
]
return [
param for param in parameters
param
for param in parameters
if not any(param.lower().startswith(exclude) for exclude in exclude_keys)
]


def get_auth_token(auth_type: str) -> Optional[bytes]:
"""Get authentication token based on auth type"""
if not auth_type:
return None

# Handle platform-specific logic for interactive auth
if auth_type == "interactive" and platform.system().lower() == "windows":
return None # Let Windows handle AADInteractive natively

try:
return AADAuth.get_token(auth_type)
except (ValueError, RuntimeError):
return None


def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]:
"""
Process connection string and handle authentication.

Args:
connection_string: The connection string to process

Returns:
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed

Raises:
ValueError: If the connection string is invalid or empty
"""
Expand All @@ -145,9 +157,9 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
raise ValueError("Connection string cannot be empty")

parameters = connection_string.split(";")

# Validate that there's at least one valid parameter
if not any('=' in param for param in parameters):
if not any("=" in param for param in parameters):
raise ValueError("Invalid connection string format")

modified_parameters, auth_type = process_auth_parameters(parameters)
Expand All @@ -158,4 +170,4 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
if token_struct:
return ";".join(modified_parameters) + ";", {1256: token_struct}

return ";".join(modified_parameters) + ";", None
return ";".join(modified_parameters) + ";", None
24 changes: 21 additions & 3 deletions mssql_python/bcp_options.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
This module provides options for bulk copy operations.
"""

from dataclasses import dataclass, field
from typing import List, Optional, Literal

Expand Down Expand Up @@ -31,6 +37,9 @@ class ColumnFormat:
file_col: int = 1

def __post_init__(self):
"""
Validate column format options.
"""
if self.prefix_len < 0:
raise ValueError("prefix_len must be a non-negative integer.")
if self.data_len < 0:
Expand Down Expand Up @@ -88,12 +97,21 @@ class BCPOptions:
columns: List[ColumnFormat] = field(default_factory=list)

def __post_init__(self):
"""
Validate BCP options.
"""
if self.direction not in ["in", "out"]:
raise ValueError("direction must be 'in' or 'out'.")
if not self.data_file:
raise ValueError("data_file must be provided and non-empty for 'in' or 'out' directions.")
if self.error_file is None or not self.error_file: # Making error_file mandatory for in/out
raise ValueError("error_file must be provided and non-empty for 'in' or 'out' directions.")
raise ValueError(
"data_file must be provided and non-empty for 'in' or 'out' directions."
)
if (
self.error_file is None or not self.error_file
): # Making error_file mandatory for in/out
raise ValueError(
"error_file must be provided and non-empty for 'in' or 'out' directions."
)

if self.format_file is not None and not self.format_file:
raise ValueError("format_file, if provided, must not be an empty string.")
Expand Down
Loading
Loading