Skip to content

Commit

Permalink
refactor: use safe context
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrd committed Nov 17, 2022
1 parent 0b90200 commit f0608bc
Showing 1 changed file with 32 additions and 41 deletions.
73 changes: 32 additions & 41 deletions dictdatabase/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple, TypeVar, Generic, Any, Callable
from . import utils, io_unsafe, locking

from contextlib import contextmanager


T = TypeVar("T")
Expand Down Expand Up @@ -44,6 +45,28 @@ def write(self):



@contextmanager
def safe_context(super, self, *, db_names_to_lock=None):
"""
If an exception happens in the context, the __exit__ method of the passed super
class will be called.
"""
super.__enter__()
try:
if isinstance(db_names_to_lock, str):
self.write_lock = locking.WriteLock(self.db_name)
self.write_lock._lock()
elif isinstance(db_names_to_lock, list):
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
for lock in self.write_lock:
lock._lock()
yield
except BaseException as e:
super.__exit__(type(e), e, e.__traceback__)
raise e



class SessionFileFull(SessionBase, Generic[T]):
"""
Context manager for read-write access to a full file.
Expand All @@ -52,19 +75,10 @@ class SessionFileFull(SessionBase, Generic[T]):
Reads and writes the entire file.
"""

def __init__(self, db_name: str, as_type: T = None):
super().__init__(db_name, as_type)

def __enter__(self) -> Tuple[SessionFileFull, JSONSerializable | T]:
super().__enter__()
try:
self.write_lock = locking.WriteLock(self.db_name)
self.write_lock._lock()
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.data_handle = io_unsafe.read(self.db_name)
return self, type_cast(self.data_handle, self.as_type)
except BaseException as e:
super().__exit__(type(e), e, e.__traceback__)
raise e

def write(self):
super().write()
Expand All @@ -82,21 +96,15 @@ class SessionFileKey(SessionBase, Generic[T]):
the key-value are written.
"""

def __init__(self, db_name: str, key: str, as_type: T = None):
def __init__(self, db_name: str, key: str, as_type: T):
super().__init__(db_name, as_type)
self.key = key

def __enter__(self) -> Tuple[SessionFileKey, JSONSerializable | T]:
super().__enter__()
try:
self.write_lock = locking.WriteLock(self.db_name)
self.write_lock._lock()
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.partial_handle = io_unsafe.get_partial_file_handle(self.db_name, self.key)
self.data_handle = self.partial_handle.partial_dict.value
return self, type_cast(self.data_handle, self.as_type)
except BaseException as e:
super().__exit__(type(e), e, e.__traceback__)
raise e

def write(self):
super().write()
Expand All @@ -113,23 +121,17 @@ class SessionFileWhere(SessionBase, Generic[T]):
Reads and writes the entire file, so it is not more efficient than
SessionFileFull.
"""
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T = None):
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
super().__init__(db_name, as_type)
self.where = where

def __enter__(self) -> Tuple[SessionFileWhere, JSONSerializable | T]:
super().__enter__()
try:
self.write_lock = locking.WriteLock(self.db_name)
self.write_lock._lock()
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.original_data = io_unsafe.read(self.db_name)
for k, v in self.original_data.items():
if self.where(k, v):
self.data_handle[k] = v
return self, type_cast(self.data_handle, self.as_type)
except BaseException as e:
super().__exit__(type(e), e, e.__traceback__)
raise e

def write(self):
super().write()
Expand All @@ -147,20 +149,13 @@ class SessionDirFull(SessionBase, Generic[T]):
Efficiency:
Fully reads and writes all files.
"""
def __init__(self, db_name: str, as_type: T = None):
def __init__(self, db_name: str, as_type: T):
super().__init__(utils.find_all(db_name), as_type)

def __enter__(self) -> Tuple[SessionDirFull, JSONSerializable | T]:
super().__enter__()
try:
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
for lock in self.write_lock:
lock._lock()
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.data_handle = {n.split("/")[-1]: io_unsafe.read(n) for n in self.db_name}
return self, type_cast(self.data_handle, self.as_type)
except BaseException as e:
super().__exit__(type(e), e, e.__traceback__)
raise e

def write(self):
super().write()
Expand All @@ -177,13 +172,12 @@ class SessionDirWhere(SessionBase, Generic[T]):
Efficiency:
Fully reads all files, but only writes the selected files.
"""
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T = None):
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
super().__init__(utils.find_all(db_name), as_type)
self.where = where

def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
super().__enter__()
try:
with safe_context(super(), self):
selected_db_names, write_lock = [], []
for db_name in self.db_name:
lock = locking.WriteLock(db_name)
Expand All @@ -198,9 +192,6 @@ def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
self.write_lock = write_lock
self.db_name = selected_db_names
return self, type_cast(self.data_handle, self.as_type)
except BaseException as e:
super().__exit__(type(e), e, e.__traceback__)
raise e

def write(self):
super().write()
Expand Down

0 comments on commit f0608bc

Please sign in to comment.