Skip to content

Commit

Permalink
chore: improve sqlite transaction handling
Browse files Browse the repository at this point in the history
  • Loading branch information
makkus committed Jan 31, 2024
1 parent 6a5ea9e commit fcc8d50
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/kiara/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def load_from_file(cls, path: Union[Path, str, None] = None) -> "KiaraConfig":
)
default_store_type: Literal["sqlite", "filesystem"] = Field(
description="The default store type to ues if not specified.",
default="filesystem",
default="sqlite",
)
auto_generate_contexts: bool = Field(
description="Whether to auto-generate requested contexts if they don't exist yet.",
Expand Down
2 changes: 1 addition & 1 deletion src/kiara/context/runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class KiaraRuntimeConfig(BaseSettings):

job_cache: JobCacheStrategy = Field(
description="Name of the strategy that determines when to re-run jobs or use cached results.",
default=JobCacheStrategy.data_hash,
default=JobCacheStrategy.no_cache,
)
allow_external: bool = Field(
description="Whether to allow external external pipelines.", default=True
Expand Down
27 changes: 20 additions & 7 deletions src/kiara/interfaces/python_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,9 @@ def store_value(
"""
Store the specified value in the (default) value store.
This method does not raise an error if the storing of the value fails, so you have to investigate the
'StoreValueResult' instance that is returned to see if the storing was successful.
Arguments:
value: the value (or a reference to it)
alias: (Optional) aliases for the value
Expand Down Expand Up @@ -1682,13 +1685,18 @@ def store_value(
def store_values(
self,
values: Mapping[str, Union[str, uuid.UUID, Value]],
alias_map: Union[Mapping[str, Iterable[str]], None] = None,
alias_map: Union[Mapping[str, Iterable[str]], bool] = False,
allow_overwrite: bool = True,
) -> StoreValuesResult:
"""
Store multiple values into the (default) kiara value store.
Values are identified by unique keys in both input arguments, the alias map references the key that is used in
the 'values' argument.
If alias_map is 'False', no aliases will be registered. If 'True', the key in the 'values' argument will be used.
Alternatively, if a map is provided, the key in the 'values' argument will be used to look up the alias(es) in the
'alias_map' argument.
This method does not raise an error if the storing of the value fails, so you have to investigate the
'StoreValuesResult' instance that is returned to see if the storing was successful.
Arguments:
values: a map of value keys/values
Expand All @@ -1699,12 +1707,17 @@ def store_values(
"""
result = {}
for field_name, value in values.items():
if alias_map:
aliases = alias_map.get(field_name)
else:
if alias_map is False:
aliases = None
elif alias_map is True:
aliases = [field_name]
else:
aliases = alias_map.get(field_name)

value_obj = self.get_value(value)
store_result = self.store_value(value=value_obj, alias=aliases)
store_result = self.store_value(
value=value_obj, alias=aliases, allow_overwrite=allow_overwrite
)
result[field_name] = store_result

return StoreValuesResult(root=result)
Expand Down
4 changes: 4 additions & 0 deletions src/kiara/models/archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ def create_from_context(
def combined_size(self) -> int:

combined = 0
archive_ids = set()
for archive_info in self.item_infos.values():
if archive_info.archive_id in archive_ids:
continue
archive_ids.add(archive_info.archive_id)
size = archive_info.details.size
if size and size > 0:
combined = combined + size
Expand Down
8 changes: 4 additions & 4 deletions src/kiara/registries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ArchiveDetails(BaseModel):
)


class ArchiveMetadta(BaseModel):
class ArchiveMetadata(BaseModel):

archive_id: Union[uuid.UUID, None] = Field(
description="The id of the stored archive.", default=None
Expand Down Expand Up @@ -165,11 +165,11 @@ def __init__(
self._archive_metadata: Union[Mapping[str, Any], None] = None

@property
def archive_metadata(self) -> ArchiveMetadta:
def archive_metadata(self) -> ArchiveMetadata:

if self._archive_metadata is None:
archive_metadata = self._retrieve_archive_metadata()
self._archive_metadata = ArchiveMetadta(**archive_metadata)
self._archive_metadata = ArchiveMetadata(**archive_metadata)

return self._archive_metadata

Expand Down Expand Up @@ -388,7 +388,7 @@ class CHUNK_COMPRESSION_TYPE(Enum):

class SqliteDataStoreConfig(SqliteArchiveConfig):

default_compression_type: Literal["none", "lz4", "zstd"] = Field(
default_chunk_compression: Literal["none", "lz4", "zstd"] = Field(
description="The default compression type to use for data in this store.",
default="zstd",
)
51 changes: 34 additions & 17 deletions src/kiara/registries/data/data_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import typing
import uuid
from io import BytesIO
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Set, Union

import structlog
from rich.console import RenderableType
Expand All @@ -28,7 +28,7 @@
from kiara.registries import ARCHIVE_CONFIG_CLS, BaseArchive

if TYPE_CHECKING:
pass
from multiformats import CID

logger = structlog.getLogger()

Expand Down Expand Up @@ -390,7 +390,7 @@ def store_value(self, value: Value) -> PersistedData:
return persisted_value

@abc.abstractmethod
def _persist_chunks(self, chunks: typing.Iterator[Tuple[str, BytesIO]]):
def _persist_chunks(self, chunks: Mapping["CID", BytesIO]):
"""Persist the specified chunk, and return the chunk id.
If the chunk is a string, it represents a local file path, otherwise it is a BytesIO instance representing the actual data of the chunk.
Expand All @@ -403,13 +403,18 @@ def _persist_value_data(self, value: Value) -> PersistedData:

# dbg(serialized_value.model_dump())

SIZE_LIMIT = 100000000

chunk_id_map = {}
chunks_to_persist = {}
chunks_persisted = set()
current_size = 0
for key in serialized_value.get_keys():

data_model = serialized_value.get_serialized_data(key)

if data_model.type == "chunk": # type: ignore
chunks: Iterable[Union[str, BytesIO]] = [BytesIO(data_model.chunk)] # type: ignore
chunks: Iterable[BytesIO] = [BytesIO(data_model.chunk)] # type: ignore
elif data_model.type == "chunks": # type: ignore
chunks = (BytesIO(c) for c in data_model.chunks) # type: ignore
elif data_model.type == "file": # type: ignore
Expand All @@ -432,17 +437,7 @@ def _persist_value_data(self, value: Value) -> PersistedData:

cids = serialized_value.get_cids_for_key(key)
chunk_iterable = zip(cids, chunks)
# chunks_to_persist.update(chunk_iterable)
# print(key)
# print(type(chunks))
# self._persist_chunks(chunk_iterable)

chunk_ids = []
for item in chunk_iterable:
cid = item[0]
_chunk = item[1]
self._persist_chunk(str(cid), _chunk)
chunk_ids.append(str(cid))
chunks_to_persist.update(chunk_iterable)

chunk_ids = [str(cid) for cid in cids]
scids = SerializedChunkIDs(
Expand All @@ -453,8 +448,30 @@ def _persist_value_data(self, value: Value) -> PersistedData:
scids._data_registry = self.kiara_context.data_registry
chunk_id_map[key] = scids

# print("chunks_to_persist")
# print(chunks_to_persist)
key_size = data_model.get_size()
current_size += key_size
# this is not super-exact, because the actual size of all chunks to be persisted is not known
# since some of them might be filtered out, should be good enough to not let the memory blow up too much
if current_size > SIZE_LIMIT:
self._persist_chunks(
chunks={
k: v
for k, v in chunks_to_persist.items()
if k not in chunks_persisted
}
)
chunks_persisted.update(chunks_to_persist.keys())
chunks_to_persist = {}
current_size = 0

if chunks_to_persist:
self._persist_chunks(
chunks={
k: v
for k, v in chunks_to_persist.items()
if k not in chunks_persisted
}
)

pers_value = PersistedData(
archive_id=self.archive_id,
Expand Down
81 changes: 58 additions & 23 deletions src/kiara/registries/data/data_store/sqlite_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any, Generic, Iterable, Iterator, Mapping, Set, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Mapping,
Set,
Union,
)

from orjson import orjson
from sqlalchemy import Engine, create_engine, text
from sqlalchemy import Connection, Engine, create_engine, text

from kiara.defaults import kiara_app_dirs
from kiara.models.values.value import PersistedData, Value
from kiara.registries import (
ARCHIVE_CONFIG_CLS,
CHUNK_COMPRESSION_TYPE,
ArchiveDetails,
SqliteArchiveConfig,
SqliteDataStoreConfig,
)
Expand All @@ -22,6 +31,9 @@
from kiara.utils.json import orjson_dumps
from kiara.utils.windows import fix_windows_longpath

if TYPE_CHECKING:
from multiformats import CID


class SqliteDataArchive(DataArchive[SqliteArchiveConfig], Generic[ARCHIVE_CONFIG_CLS]):

Expand Down Expand Up @@ -245,6 +257,14 @@ def _retrieve_all_value_ids(
self._value_id_cache = result_set
return result_set

def retrieve_all_chunk_ids(self) -> Iterable[str]:

sql = text("SELECT chunk_id FROM values_data")
with self.sqlite_engine.connect() as conn:
cursor = conn.execute(sql)
result = cursor.fetchall()
return {x[0] for x in result}

def _find_values_with_hash(
self,
value_hash: str,
Expand Down Expand Up @@ -331,6 +351,11 @@ def retrieve_chunk(
def _delete_archive(self):
os.unlink(self.sqlite_path)

def get_archive_details(self) -> ArchiveDetails:

size = self._db_path.stat().st_size
return ArchiveDetails(size=size)


class SqliteDataStore(SqliteDataArchive[SqliteDataStoreConfig], BaseDataStore):

Expand Down Expand Up @@ -400,12 +425,23 @@ def _persist_environment_details(
#
# raise NotImplementedError()

def _persist_chunks(self, chunks: Iterator[Tuple[str, Union[str, BytesIO]]]):
def _persist_chunks(self, chunks: Mapping["CID", Union[str, BytesIO]]):

all_chunk_ids = self.retrieve_all_chunk_ids()

with self.sqlite_engine.connect() as conn:

for chunk_id, chunk in chunks.items():
cid_str = str(chunk_id)
if cid_str in all_chunk_ids:
continue
self._persist_chunk(conn, cid_str, chunk)

for chunk_id, chunk in chunks:
self._persist_chunk(str(chunk_id), chunk)
conn.commit()

def _persist_chunk(self, chunk_id: str, chunk: Union[str, BytesIO]):
def _persist_chunk(
self, conn: Connection, chunk_id: str, chunk: Union[str, BytesIO]
):

import lzma

Expand All @@ -414,13 +450,13 @@ def _persist_chunk(self, chunk_id: str, chunk: Union[str, BytesIO]):

cctx = ZstdCompressor()

sql = text(
"SELECT EXISTS(SELECT 1 FROM values_data WHERE chunk_id = :chunk_id)"
)
with self.sqlite_engine.connect() as conn:
result = conn.execute(sql, {"chunk_id": chunk_id}).scalar()
if result:
return
# sql = text(
# "SELECT EXISTS(SELECT 1 FROM values_data WHERE chunk_id = :chunk_id)"
# )
# with self.sqlite_engine.connect() as conn:
# result = conn.execute(sql, {"chunk_id": chunk_id}).scalar()
# if result:
# return

if isinstance(chunk, str):
with open(chunk, "rb") as file:
Expand All @@ -430,7 +466,7 @@ def _persist_chunk(self, chunk_id: str, chunk: Union[str, BytesIO]):
bytes_io = chunk

compression_type = CHUNK_COMPRESSION_TYPE[
self.config.default_compression_type.upper()
self.config.default_chunk_compression.upper()
]

if compression_type == CHUNK_COMPRESSION_TYPE.NONE:
Expand All @@ -447,7 +483,7 @@ def _persist_chunk(self, chunk_id: str, chunk: Union[str, BytesIO]):
final_bytes = lz4.frame.compress(data)
else:
raise ValueError(
f"Unsupported compression type: {self.config.default_compression_type}"
f"Unsupported compression type: {self.config.default_chunk_compression}"
)

compression_type_value = (
Expand All @@ -458,15 +494,14 @@ def _persist_chunk(self, chunk_id: str, chunk: Union[str, BytesIO]):
sql = text(
"INSERT INTO values_data (chunk_id, chunk_data, compression_type) VALUES (:chunk_id, :chunk_data, :compression_type)"
)
with self.sqlite_engine.connect() as conn:
params = {
"chunk_id": chunk_id,
"chunk_data": final_bytes,
"compression_type": compression_type_value,
}
params = {
"chunk_id": chunk_id,
"chunk_data": final_bytes,
"compression_type": compression_type_value,
}

conn.execute(sql, params)
conn.commit()
conn.execute(sql, params)
# conn.commit()

def _persist_stored_value_info(self, value: Value, persisted_value: PersistedData):

Expand Down

0 comments on commit fcc8d50

Please sign in to comment.