Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for retrying certain types of exceptions we see when running models with DuckDB #298

Merged
merged 6 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dbt/adapters/duckdb/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
from dataclasses import dataclass
from dataclasses import field
from functools import lru_cache
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -70,6 +71,20 @@ class Remote(dbtClassMixin):
password: Optional[str] = None


@dataclass
class Retries(dbtClassMixin):
# The number of times to attempt the initial duckdb.connect call
# (to wait for another process to free the lock on the DB file)
connect_attempts: int = 1

# The number of times to attempt to execute a DuckDB query that throws
# one of the retryable exceptions
query_attempts: Optional[int] = None

# The list of exceptions that we are willing to retry on
retryable_exceptions: List[str] = field(default_factory=lambda: ["IOException"])


@dataclass
class DuckDBCredentials(Credentials):
database: str = "main"
Expand Down Expand Up @@ -135,6 +150,11 @@ class DuckDBCredentials(Credentials):
# provide helper functions for dbt Python models.
module_paths: Optional[List[str]] = None

# An optional strategy for allowing retries when certain types of
# exceptions occur on a model run (e.g., IOExceptions that were caused
# by networking issues)
retries: Optional[Retries] = None

@classmethod
def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
data = super().__pre_deserialize__(data)
Expand Down
72 changes: 71 additions & 1 deletion dbt/adapters/duckdb/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
import sys
import tempfile
import time
from typing import Dict
from typing import List
from typing import Optional

import duckdb
Expand Down Expand Up @@ -31,6 +33,44 @@ def _ensure_event_loop():
asyncio.set_event_loop(loop)


class RetryableCursor:
def __init__(self, cursor, retry_attempts: int, retryable_exceptions: List[str]):
self._cursor = cursor
self._retry_attempts = retry_attempts
self._retryable_exceptions = retryable_exceptions

def execute(self, sql: str, bindings=None):
attempt, success, exc = 0, False, None
while not success and attempt < self._retry_attempts:
try:
if bindings is None:
self._cursor.execute(sql)
else:
self._cursor.execute(sql, bindings)
success = True
except Exception as e:
exception_name = type(e).__name__
if exception_name in self._retryable_exceptions:
time.sleep(2**attempt)
exc = e
attempt += 1
else:
print(f"Did not retry exception named '{exception_name}'")
raise e
if not success:
if exc:
raise exc
else:
raise RuntimeError(
"execute call failed, but no exceptions raised- this should be impossible"
)
return self

# forward along all non-execute() methods/attribute look-ups
def __getattr__(self, name):
return getattr(self._cursor, name)


class Environment(abc.ABC):
"""An Environment is an abstraction to describe *where* the code you execute in your dbt-duckdb project
actually runs. This could be the local Python process that runs dbt (which is the default),
Expand Down Expand Up @@ -74,7 +114,32 @@ def initialize_db(
cls, creds: DuckDBCredentials, plugins: Optional[Dict[str, BasePlugin]] = None
):
config = creds.config_options or {}
conn = duckdb.connect(creds.path, read_only=False, config=config)

if creds.retries:
success, attempt, exc = False, 0, None
while not success and attempt < creds.retries.connect_attempts:
try:
conn = duckdb.connect(creds.path, read_only=False, config=config)
success = True
except Exception as e:
exception_name = type(e).__name__
if exception_name in creds.retries.retryable_exceptions:
time.sleep(2**attempt)
exc = e
attempt += 1
else:
print(f"Did not retry exception named '{exception_name}'")
raise e
if not success:
if exc:
raise exc
else:
raise RuntimeError(
"connect call failed, but no exceptions raised- this should be impossible"
)

else:
conn = duckdb.connect(creds.path, read_only=False, config=config)

# install any extensions on the connection
if creds.extensions is not None:
Expand Down Expand Up @@ -127,6 +192,11 @@ def initialize_cursor(
for df_name, df in registered_df.items():
cursor.register(df_name, df)

if creds.retries and creds.retries.query_attempts:
cursor = RetryableCursor(
cursor, creds.retries.query_attempts, creds.retries.retryable_exceptions
)

return cursor

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/functional/plugins/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def profiles_config_update(self, dbt_profile_target, sqlite_test_db):
"type": "duckdb",
"path": dbt_profile_target.get("path", ":memory:"),
"plugins": plugins,
"retries": {"query_attempts": 2},
}
},
"target": "dev",
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_retries_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from unittest.mock import patch

from duckdb.duckdb import IOException

from dbt.adapters.duckdb.credentials import DuckDBCredentials
from dbt.adapters.duckdb.credentials import Retries
from dbt.adapters.duckdb.environments import Environment

class TestConnectRetries:

@pytest.fixture
def creds(self):
# Create a mock credentials object
return DuckDBCredentials(
path="foo.db",
retries=Retries(connect_attempts=2, retryable_exceptions=["IOException", "ArithmeticError"])
)

@pytest.mark.parametrize("exception", [None, IOException, ArithmeticError, ValueError])
def test_initialize_db(self, creds, exception):
# Mocking the duckdb.connect method
with patch('duckdb.connect') as mock_connect:
if exception:
mock_connect.side_effect = [exception, None]

if exception == ValueError:
with pytest.raises(ValueError) as excinfo:
Environment.initialize_db(creds)
else:
# Call the initialize_db method
Environment.initialize_db(creds)
if exception in {IOException, ArithmeticError}:
assert mock_connect.call_count == creds.retries.connect_attempts
else:
mock_connect.assert_called_once_with(creds.path, read_only=False, config={})
57 changes: 57 additions & 0 deletions tests/unit/test_retries_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from unittest.mock import MagicMock
from unittest.mock import patch

import duckdb

from dbt.adapters.duckdb.credentials import Retries
from dbt.adapters.duckdb.environments import RetryableCursor

class TestRetryableCursor:

@pytest.fixture
def mock_cursor(self):
return MagicMock()

@pytest.fixture
def mock_retries(self):
return Retries(query_attempts=3)

@pytest.fixture
def retry_cursor(self, mock_cursor, mock_retries):
return RetryableCursor(
mock_cursor,
mock_retries.query_attempts,
mock_retries.retryable_exceptions)

def test_successful_execute(self, mock_cursor, retry_cursor):
""" Test that execute successfully runs the SQL query. """
sql_query = "SELECT * FROM table"
retry_cursor.execute(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query)

def test_retry_on_failure(self, mock_cursor, retry_cursor):
""" Test that execute retries the SQL query on failure. """
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, None]
sql_query = "SELECT * FROM table"
retry_cursor.execute(sql_query)
assert mock_cursor.execute.call_count == 2

def test_no_retry_on_non_retryable_exception(self, mock_cursor, retry_cursor):
""" Test that a non-retryable exception is not retried. """
mock_cursor.execute.side_effect = ValueError
sql_query = "SELECT * FROM table"
with pytest.raises(ValueError):
retry_cursor.execute(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query)

def test_exponential_backoff(self, mock_cursor, retry_cursor):
""" Test that exponential backoff is applied between retries. """
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, duckdb.duckdb.IOException, None]
sql_query = "SELECT * FROM table"

with patch("time.sleep") as mock_sleep:
retry_cursor.execute(sql_query)
assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(1)
mock_sleep.assert_any_call(2)