Skip to content

Commit

Permalink
[ENTERPRISE-1418] Add support for plain JWT authentication (#1078)
Browse files Browse the repository at this point in the history
* [ENTERPRISE-1418] Add support for plain JWT authentication

* Run changie new

* wip: functional test for JWT

* clean up, and add some comments
  • Loading branch information
llam15 authored Jun 11, 2024
1 parent 4480734 commit 5ede6fe
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240610-171026.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support JWT Authentication
time: 2024-06-10T17:10:26.421463-04:00
custom:
Author: llam15
Issue: 1079 726
37 changes: 31 additions & 6 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag


Expand All @@ -70,7 +70,7 @@ class SnowflakeAdapterResponse(AdapterResponse):
@dataclass
class SnowflakeCredentials(Credentials):
account: str
user: str
user: Optional[str] = None
warehouse: Optional[str] = None
role: Optional[str] = None
password: Optional[str] = None
Expand All @@ -96,15 +96,31 @@ class SnowflakeCredentials(Credentials):
reuse_connections: Optional[bool] = None

def __post_init__(self):
if self.authenticator != "oauth" and (
self.oauth_client_secret or self.oauth_client_id or self.token
):
if self.authenticator != "oauth" and (self.oauth_client_secret or self.oauth_client_id):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
AdapterEventWarning(
base_msg="Authenticator is not set to oauth, but an oauth-only parameter is set! Did you mean to set authenticator: oauth?"
)
)

if self.authenticator not in ["oauth", "jwt"]:
if self.token:
warn_or_error(
AdapterEventWarning(
base_msg=(
"The token parameter was set, but the authenticator was "
"not set to 'oauth' or 'jwt'."
)
)
)

if not self.user:
# The user attribute is only optional if 'authenticator' is 'jwt' or 'oauth'
warn_or_error(
AdapterEventError(base_msg="Invalid profile: 'user' is a required property.")
)

self.account = self.account.replace("_", "-")

@property
Expand Down Expand Up @@ -146,6 +162,8 @@ def auth_args(self):
# Pull all of the optional authentication args for the connector,
# let connector handle the actual arg validation
result = {}
if self.user:
result["user"] = self.user
if self.password:
result["password"] = self.password
if self.host:
Expand Down Expand Up @@ -180,6 +198,14 @@ def auth_args(self):
)

result["token"] = token

elif self.authenticator == "jwt":
# If authenticator is 'jwt', then the 'token' value should be used
# unmodified. We expose this as 'jwt' in the profile, but the value
# passed into the snowflake.connect method should still be 'oauth'
result["token"] = self.token
result["authenticator"] = "oauth"

# enable id token cache for linux
result["client_store_temporary_credential"] = True
# enable mfa token cache for linux
Expand Down Expand Up @@ -346,7 +372,6 @@ def connect():

handle = snowflake.connector.connect(
account=creds.account,
user=creds.user,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
Expand Down
91 changes: 91 additions & 0 deletions tests/functional/oauth/test_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Please follow the instructions in test_oauth.py for instructions on how to set up
the security integration required to retrieve a JWT from Snowflake.
"""

import pytest
import os
from dbt.tests.util import run_dbt, check_relations_equal

from dbt.adapters.snowflake import SnowflakeCredentials

_MODELS__MODEL_1_SQL = """
select 1 as id
"""


_MODELS__MODEL_2_SQL = """
select 2 as id
"""


_MODELS__MODEL_3_SQL = """
select * from {{ ref('model_1') }}
union all
select * from {{ ref('model_2') }}
"""


_MODELS__MODEL_4_SQL = """
select 1 as id
union all
select 2 as id
"""


class TestSnowflakeJWT:
"""Tests that setting authenticator: jwt allows setting token to a plain JWT
that will be passed into the Snowflake connection without modification."""

@pytest.fixture(scope="class", autouse=True)
def access_token(self):
"""Because JWTs are short-lived, we need to get a fresh JWT via the refresh
token flow before running the test.
This fixture leverages the existing SnowflakeCredentials._get_access_token
method to retrieve a valid JWT from Snowflake.
"""
client_id = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_ID")
client_secret = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET")
refresh_token = os.getenv("SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN")

credentials = SnowflakeCredentials(
account=os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
database="",
schema="",
authenticator="oauth",
oauth_client_id=client_id,
oauth_client_secret=client_secret,
token=refresh_token,
)

yield credentials._get_access_token()

@pytest.fixture(scope="class", autouse=True)
def dbt_profile_target(self, access_token):
"""A dbt_profile that has authenticator set to JWT, and token set to
a JWT accepted by Snowflake. Also omits the user, as the user attribute
is optional when the authenticator is set to JWT.
"""
return {
"type": "snowflake",
"threads": 4,
"account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
"database": os.getenv("SNOWFLAKE_TEST_DATABASE"),
"warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"),
"authenticator": "jwt",
"token": access_token,
}

@pytest.fixture(scope="class")
def models(self):
return {
"model_1.sql": _MODELS__MODEL_1_SQL,
"model_2.sql": _MODELS__MODEL_2_SQL,
"model_3.sql": _MODELS__MODEL_3_SQL,
"model_4.sql": _MODELS__MODEL_4_SQL,
}

def test_snowflake_basic(self, project):
run_dbt()
check_relations_equal(project.adapter, ["MODEL_3", "MODEL_4"])
32 changes: 32 additions & 0 deletions tests/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,38 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p
]
)

def test_authenticator_jwt_authentication(self):
self.config.credentials = self.config.credentials.replace(
authenticator="jwt", token="my-jwt-token", user=None
)
self.adapter = SnowflakeAdapter(self.config, get_context("spawn"))
conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config")

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls(
[
mock.call(
account="test-account",
autocommit=True,
client_session_keep_alive=False,
database="test_database",
role=None,
schema="public",
warehouse="test_warehouse",
authenticator="oauth",
token="my-jwt-token",
private_key=None,
application="dbt",
client_request_mfa_token=True,
client_store_temporary_credential=True,
insecure_mode=False,
session_parameters={},
reuse_connections=None,
)
]
)

def test_query_tag(self):
self.config.credentials = self.config.credentials.replace(
password="test_password", query_tag="test_query_tag"
Expand Down

0 comments on commit 5ede6fe

Please sign in to comment.