diff --git a/.changes/unreleased/Features-20240610-171026.yaml b/.changes/unreleased/Features-20240610-171026.yaml new file mode 100644 index 000000000..5cc055160 --- /dev/null +++ b/.changes/unreleased/Features-20240610-171026.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support JWT Authentication +time: 2024-06-10T17:10:26.421463-04:00 +custom: + Author: llam15 + Issue: 1079 726 diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index aca115b4b..ba79e03d1 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -43,7 +43,7 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.events.logging import AdapterLogger from dbt_common.events.functions import warn_or_error -from dbt.adapters.events.types import AdapterEventWarning +from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError from dbt_common.ui import line_wrap_message, warning_tag @@ -70,7 +70,7 @@ class SnowflakeAdapterResponse(AdapterResponse): @dataclass class SnowflakeCredentials(Credentials): account: str - user: str + user: Optional[str] = None warehouse: Optional[str] = None role: Optional[str] = None password: Optional[str] = None @@ -96,15 +96,31 @@ class SnowflakeCredentials(Credentials): reuse_connections: Optional[bool] = None def __post_init__(self): - if self.authenticator != "oauth" and ( - self.oauth_client_secret or self.oauth_client_id or self.token - ): + if self.authenticator != "oauth" and (self.oauth_client_secret or self.oauth_client_id): # the user probably forgot to set 'authenticator' like I keep doing warn_or_error( AdapterEventWarning( base_msg="Authenticator is not set to oauth, but an oauth-only parameter is set! Did you mean to set authenticator: oauth?" ) ) + + if self.authenticator not in ["oauth", "jwt"]: + if self.token: + warn_or_error( + AdapterEventWarning( + base_msg=( + "The token parameter was set, but the authenticator was " + "not set to 'oauth' or 'jwt'." + ) + ) + ) + + if not self.user: + # The user attribute is only optional if 'authenticator' is 'jwt' or 'oauth' + warn_or_error( + AdapterEventError(base_msg="Invalid profile: 'user' is a required property.") + ) + self.account = self.account.replace("_", "-") @property @@ -146,6 +162,8 @@ def auth_args(self): # Pull all of the optional authentication args for the connector, # let connector handle the actual arg validation result = {} + if self.user: + result["user"] = self.user if self.password: result["password"] = self.password if self.host: @@ -180,6 +198,14 @@ def auth_args(self): ) result["token"] = token + + elif self.authenticator == "jwt": + # If authenticator is 'jwt', then the 'token' value should be used + # unmodified. We expose this as 'jwt' in the profile, but the value + # passed into the snowflake.connect method should still be 'oauth' + result["token"] = self.token + result["authenticator"] = "oauth" + # enable id token cache for linux result["client_store_temporary_credential"] = True # enable mfa token cache for linux @@ -346,7 +372,6 @@ def connect(): handle = snowflake.connector.connect( account=creds.account, - user=creds.user, database=creds.database, schema=creds.schema, warehouse=creds.warehouse, diff --git a/tests/functional/oauth/test_jwt.py b/tests/functional/oauth/test_jwt.py new file mode 100644 index 000000000..fbe8e20e6 --- /dev/null +++ b/tests/functional/oauth/test_jwt.py @@ -0,0 +1,91 @@ +""" +Please follow the instructions in test_oauth.py for instructions on how to set up +the security integration required to retrieve a JWT from Snowflake. +""" + +import pytest +import os +from dbt.tests.util import run_dbt, check_relations_equal + +from dbt.adapters.snowflake import SnowflakeCredentials + +_MODELS__MODEL_1_SQL = """ +select 1 as id +""" + + +_MODELS__MODEL_2_SQL = """ +select 2 as id +""" + + +_MODELS__MODEL_3_SQL = """ +select * from {{ ref('model_1') }} +union all +select * from {{ ref('model_2') }} +""" + + +_MODELS__MODEL_4_SQL = """ +select 1 as id +union all +select 2 as id +""" + + +class TestSnowflakeJWT: + """Tests that setting authenticator: jwt allows setting token to a plain JWT + that will be passed into the Snowflake connection without modification.""" + + @pytest.fixture(scope="class", autouse=True) + def access_token(self): + """Because JWTs are short-lived, we need to get a fresh JWT via the refresh + token flow before running the test. + + This fixture leverages the existing SnowflakeCredentials._get_access_token + method to retrieve a valid JWT from Snowflake. + """ + client_id = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_ID") + client_secret = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET") + refresh_token = os.getenv("SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN") + + credentials = SnowflakeCredentials( + account=os.getenv("SNOWFLAKE_TEST_ACCOUNT"), + database="", + schema="", + authenticator="oauth", + oauth_client_id=client_id, + oauth_client_secret=client_secret, + token=refresh_token, + ) + + yield credentials._get_access_token() + + @pytest.fixture(scope="class", autouse=True) + def dbt_profile_target(self, access_token): + """A dbt_profile that has authenticator set to JWT, and token set to + a JWT accepted by Snowflake. Also omits the user, as the user attribute + is optional when the authenticator is set to JWT. + """ + return { + "type": "snowflake", + "threads": 4, + "account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"), + "database": os.getenv("SNOWFLAKE_TEST_DATABASE"), + "warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), + "authenticator": "jwt", + "token": access_token, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "model_1.sql": _MODELS__MODEL_1_SQL, + "model_2.sql": _MODELS__MODEL_2_SQL, + "model_3.sql": _MODELS__MODEL_3_SQL, + "model_4.sql": _MODELS__MODEL_4_SQL, + } + + def test_snowflake_basic(self, project): + run_dbt() + check_relations_equal(project.adapter, ["MODEL_3", "MODEL_4"]) diff --git a/tests/unit/test_snowflake_adapter.py b/tests/unit/test_snowflake_adapter.py index ff92b9b65..f6a768da8 100644 --- a/tests/unit/test_snowflake_adapter.py +++ b/tests/unit/test_snowflake_adapter.py @@ -550,6 +550,38 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p ] ) + def test_authenticator_jwt_authentication(self): + self.config.credentials = self.config.credentials.replace( + authenticator="jwt", token="my-jwt-token", user=None + ) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) + conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") + + self.snowflake.assert_not_called() + conn.handle + self.snowflake.assert_has_calls( + [ + mock.call( + account="test-account", + autocommit=True, + client_session_keep_alive=False, + database="test_database", + role=None, + schema="public", + warehouse="test_warehouse", + authenticator="oauth", + token="my-jwt-token", + private_key=None, + application="dbt", + client_request_mfa_token=True, + client_store_temporary_credential=True, + insecure_mode=False, + session_parameters={}, + reuse_connections=None, + ) + ] + ) + def test_query_tag(self): self.config.credentials = self.config.credentials.replace( password="test_password", query_tag="test_query_tag"