diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 03f8dfffef..4221f03215 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,12 +1,12 @@ import dataclasses -from typing import ClassVar, Final, Optional, Any, Dict, List +from typing import Callable, ClassVar, Final, Optional, Any, Dict, List, cast from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError -from databricks.sdk.core import Config, oauth_service_principal, CredentialsProvider +from databricks.sdk.core import Config, oauth_service_principal DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -42,13 +42,13 @@ def on_resolved(self) -> None: "If you are using an access token for authentication, omit the 'auth_type' field." ) - def _get_oauth_credentials(self) -> Optional[CredentialsProvider]: + def _get_oauth_credentials(self) -> Optional[Callable[[], Dict[str, str]]]: config = Config( host=f"https://{self.server_hostname}", client_id=self.client_id, client_secret=self.client_secret, ) - return oauth_service_principal(config) + return cast(Callable[[], Dict[str, str]], oauth_service_principal(config)) def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 8bff4e0d73..010652dc48 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -13,8 +13,7 @@ Dict, ) - -from databricks import sql as databricks_lib +import databricks.sql from databricks.sql.client import ( Connection as DatabricksSqlConnection, Cursor as DatabricksSqlCursor, @@ -60,7 +59,7 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): - dbapi: ClassVar[DBApi] = databricks_lib + dbapi: ClassVar[DBApi] = databricks.sql def __init__( self, @@ -75,7 +74,7 @@ def __init__( def open_connection(self) -> DatabricksSqlConnection: conn_params = self.credentials.to_connector_params() - self._conn = databricks_lib.connect( + self._conn = databricks.sql.connect( **conn_params, schema=self.dataset_name, use_inline_params="silent" ) return self._conn @@ -156,7 +155,7 @@ def catalog_name(self, escape: bool = True) -> Optional[str]: @staticmethod def _make_database_exception(ex: Exception) -> Exception: - if isinstance(ex, databricks_lib.ServerOperationError): + if isinstance(ex, databricks.sql.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) elif "SCHEMA_NOT_FOUND" in str(ex): @@ -164,15 +163,15 @@ def _make_database_exception(ex: Exception) -> Exception: elif "PARSE_SYNTAX_ERROR" in str(ex): return DatabaseTransientException(ex) return DatabaseTerminalException(ex) - elif isinstance(ex, databricks_lib.OperationalError): + elif isinstance(ex, databricks.sql.OperationalError): return DatabaseTransientException(ex) - elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): + elif isinstance(ex, (databricks.sql.ProgrammingError, databricks.sql.IntegrityError)): return DatabaseTerminalException(ex) - elif isinstance(ex, databricks_lib.DatabaseError): + elif isinstance(ex, databricks.sql.DatabaseError): return DatabaseTransientException(ex) else: return DatabaseTransientException(ex) @staticmethod def is_dbapi_exception(ex: Exception) -> bool: - return isinstance(ex, databricks_lib.DatabaseError) + return isinstance(ex, databricks.sql.DatabaseError)