Skip to content

Commit

Permalink
one connection per thread
Browse files Browse the repository at this point in the history
  • Loading branch information
abyesilyurt committed May 20, 2024
1 parent 3d153ac commit f8fba21
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions packages/syft/src/syft/store/sqlite_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

# stdlib
from collections.abc import Generator
import contextlib
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import sqlite3
import tempfile
import threading
from typing import Any

# third party
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(
path = Path(self.file_path)
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
self._local = threading.local()
self.create_table()

@property
Expand All @@ -84,34 +85,37 @@ def create_table(self) -> None:
) as _:
pass

def _initialize_connection(self, con: sqlite3.Connection) -> None:
"""Set PRAGMA settings for the connection."""
con.execute("PRAGMA journal_mode = WAL")
con.execute("PRAGMA busy_timeout = 5000")
con.execute("PRAGMA temp_store = 2")
con.execute("PRAGMA synchronous = 1")

def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._local, "connection"):
timeout = (
self.store_config.client_config.timeout
if self.store_config.client_config
else 5
)
self._local.connection = sqlite3.connect(
self.file_path, timeout=timeout, check_same_thread=False
)
self._initialize_connection(self._local.connection)
return self._local.connection

@contextmanager
def _cursor(
self, sql: str, *args: list[Any] | None
) -> Generator[sqlite3.Cursor, None, None]:
timeout = (
self.store_config.client_config.timeout
if self.store_config.client_config
else 5
)
with contextlib.closing(
sqlite3.connect(
self.file_path,
timeout=timeout,
check_same_thread=False, # do we need this if we use the lock?
# check_same_thread=self.store_config.client_config.check_same_thread,
)
) as con:
# Set journal mode to WAL.
con.execute("PRAGMA journal_mode = WAL")
con.execute("PRAGMA busy_timeout = 5000")
con.execute("PRAGMA temp_store = 2")
con.execute("PRAGMA synchronous = 1")
cur = con.cursor()
yield cur.execute(sql, *args)
try:
con.commit()
finally:
cur.close()
con = con = self._get_connection()
cur = con.cursor()
yield cur.execute(sql, *args)
try:
con.commit()
finally:
cur.close()

def _set(self, key: UID, value: Any) -> None:
if self._exists(key):
Expand Down Expand Up @@ -236,6 +240,10 @@ def __contains__(self, key: Any) -> bool:
def __iter__(self) -> Any:
return iter(self.keys())

def __del__(self) -> None:
if hasattr(self._local, "connection"):
self._local.connection.close()


@serializable()
class SQLiteStorePartition(KeyValueStorePartition):
Expand Down

0 comments on commit f8fba21

Please sign in to comment.