Skip to content

Commit

Permalink
Feature/SK-1289 | Added test for store list functionality (#808)
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-andersson authored Feb 10, 2025
1 parent 9e5ce66 commit 2058d71
Show file tree
Hide file tree
Showing 28 changed files with 1,131 additions and 35 deletions.
23 changes: 23 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "PyTest All",
"type": "python",
"request": "launch",
"module": "pytest",
"justMyCode": true
},
{
"args": [
"--nf",
"--lf"
],
"name": "PyTest New and Failing",
"type": "python",
"request": "launch",
"module": "pytest",
"justMyCode": true
},
]
}
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
"esbonio.sphinx.confDir": "",
"python.testing.pytestArgs": [
"fedn/tests",
"--color=yes",
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
}
14 changes: 7 additions & 7 deletions fedn/network/storage/dbconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __new__(cls, *, force_create_new: bool = False) -> "DatabaseConnection":
"""Create a new instance of DatabaseConnection or return the existing singleton instance.
Args:
force_create_new (bool): If True, a new instance will be created regardless of the singleton pattern.
force_create_new (bool): If True, a new instance will be created regardless of the singleton pattern. Only used for testing purpose.
Returns:
DatabaseConnection: A new instance if force_create_new is True, otherwise the existing singleton instance.
Expand All @@ -72,11 +72,9 @@ def __new__(cls, *, force_create_new: bool = False) -> "DatabaseConnection":

return cls._instance

def _init_connection(self, statestore_config: dict = None, network_id: dict = None) -> None:
if statestore_config is None:
statestore_config = get_statestore_config()
if network_id is None:
network_id = get_network_config()
def _init_connection(self) -> None:
statestore_config = get_statestore_config()
network_id = get_network_config()

if statestore_config["type"] == "MongoDB":
mdb: Database = self._setup_mongo(statestore_config, network_id)
Expand Down Expand Up @@ -123,7 +121,9 @@ def _setup_mongo(self, statestore_config: dict, network_id: str) -> "DatabaseCon

def _setup_sql(self, statestore_config: dict) -> "DatabaseConnection":
if statestore_config["type"] == "SQLite":
engine = create_engine("sqlite:///my_database.db", echo=False)
sqlite_config = statestore_config["sqlite_config"]
dbname = sqlite_config["dbname"]
engine = create_engine(f"sqlite:///{dbname}", echo=False)
elif statestore_config["type"] == "PostgreSQL":
postgres_config = statestore_config["postgres_config"]
username = postgres_config["username"]
Expand Down
4 changes: 3 additions & 1 deletion fedn/network/storage/statestore/stores/client_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.scalars(stmt).all()

Expand Down
4 changes: 3 additions & 1 deletion fedn/network/storage/statestore/stores/combiner_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.scalars(stmt).all()

Expand Down
4 changes: 3 additions & 1 deletion fedn/network/storage/statestore/stores/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.scalars(stmt).all()

Expand Down
17 changes: 10 additions & 7 deletions fedn/network/storage/statestore/stores/package_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ def get(self, id: str) -> Package:
return from_document(document, response_active)

def _complement(self, item: Package):
if "id" not in item or item.id is None:
if "id" not in item or item["id"] is None:
item["id"] = str(uuid.uuid4())

if "key" not in item or item.key is None:
if "key" not in item or item["key"] is None:
item["key"] = "package_trail"

if "committed_at" not in item or item.committed_at is None:
if "committed_at" not in item or item["committed_at"] is None:
item["committed_at"] = datetime.now()

extension = item["file_name"].rsplit(".", 1)[1].lower()

if "storage_file_name" not in item or item.storage_file_name is None:
if "storage_file_name" not in item or item["storage_file_name"] is None:
storage_file_name = secure_filename(f"{str(uuid.uuid4())}.{extension}")
item["storage_file_name"] = storage_file_name

Expand Down Expand Up @@ -310,12 +310,13 @@ def __init__(self, Session):
super().__init__(Session)

def _complement(self, item: Package):
if "committed_at" not in item or item.committed_at is None:
# TODO: Not complemented the same way as in MongoDBStore
if "committed_at" not in item or item["committed_at"] is None:
item["committed_at"] = datetime.now()

extension = item["file_name"].rsplit(".", 1)[1].lower()

if "storage_file_name" not in item or item.storage_file_name is None:
if "storage_file_name" not in item or item["storage_file_name"] is None:
storage_file_name = secure_filename(f"{str(uuid.uuid4())}.{extension}")
item["storage_file_name"] = storage_file_name

Expand Down Expand Up @@ -374,8 +375,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.scalars(stmt).all()

Expand Down
12 changes: 7 additions & 5 deletions fedn/network/storage/statestore/stores/prediction_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def add(self, item: Prediction) -> Tuple[bool, Any]:
correlation_id=item.get("correlationId") or item.get("correlation_id"),
data=item.get("data"),
model_id=item.get("modelId") or item.get("model_id"),
receiver_name=receiver.get("name"),
receiver_role=receiver.get("role"),
sender_name=sender.get("name"),
sender_role=sender.get("role"),
receiver_name=receiver.get("name") if receiver else None,
receiver_role=receiver.get("role") if receiver else None,
sender_name=sender.get("name") if sender else None,
sender_role=sender.get("role") if sender else None,
prediction_id=item.get("predictionId") or item.get("prediction_id"),
timestamp=item.get("timestamp"),
)
Expand Down Expand Up @@ -187,8 +187,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.execute(stmt)

Expand Down
4 changes: 3 additions & 1 deletion fedn/network/storage/statestore/stores/round_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
if skip:
stmt = stmt.offset(skip)

items = session.execute(stmt)

Expand Down
8 changes: 5 additions & 3 deletions fedn/network/storage/statestore/stores/session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def add(self, item: Session) -> Tuple[bool, Any]:
parent_item = SessionModel(
id=item["session_id"], status=item["status"], name=item["name"] if "name" in item else None, committed_at=item["committed_at"] or None
)
session.add(parent_item)

session_config = item["session_config"]

Expand All @@ -311,10 +310,11 @@ def add(self, item: Session) -> Tuple[bool, Any]:
clients_required=session_config["clients_required"],
validate=session_config["validate"],
helper_type=session_config["helper_type"],
session_id=parent_item.id,
)
child_item.session = parent_item

session.add(child_item)
session.add(parent_item)
session.commit()

combined_dict = {
Expand Down Expand Up @@ -357,8 +357,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.execute(stmt)

Expand Down
2 changes: 1 addition & 1 deletion fedn/network/storage/statestore/stores/sql/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class SessionConfigModel(MyAbstractBase):
validate: Mapped[bool]
helper_type: Mapped[str] = mapped_column(String(255))
model_id: Mapped[str] = mapped_column(ForeignKey("models.id"))
session_id: Mapped[str] = mapped_column(ForeignKey("sessions.id"))
session: Mapped["SessionModel"] = relationship(back_populates="session_config")


Expand All @@ -26,6 +25,7 @@ class SessionModel(MyAbstractBase):

name: Mapped[Optional[str]] = mapped_column(String(255))
status: Mapped[str] = mapped_column(String(255))
session_config_id: Mapped[str] = mapped_column(ForeignKey("session_configs.id"))
session_config: Mapped["SessionConfigModel"] = relationship(back_populates="session")
models: Mapped[List["ModelModel"]] = relationship(back_populates="session")

Expand Down
8 changes: 5 additions & 3 deletions fedn/network/storage/statestore/stores/status_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def add(self, item: Status) -> Tuple[bool, Any]:

status = StatusModel(
log_level=item.get("log_level") or item.get("logLevel"),
sender_name=sender.get("name"),
sender_role=sender.get("role"),
sender_name=sender.get("name") if sender else None,
sender_role=sender.get("role") if sender else None,
status=item.get("status"),
timestamp=item.get("timestamp"),
type=item.get("type"),
Expand Down Expand Up @@ -177,8 +177,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.execute(stmt)

Expand Down
12 changes: 7 additions & 5 deletions fedn/network/storage/statestore/stores/validation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def add(self, item: Validation) -> Tuple[bool, Any]:
correlation_id=item.get("correlationId") or item.get("correlation_id"),
data=item.get("data"),
model_id=item.get("modelId") or item.get("model_id"),
receiver_name=receiver.get("name"),
receiver_role=receiver.get("role"),
sender_name=sender.get("name"),
sender_role=sender.get("role"),
receiver_name=receiver.get("name") if receiver else None,
receiver_role=receiver.get("role") if receiver else None,
sender_name=sender.get("name") if sender else None,
sender_role=sender.get("role") if sender else None,
session_id=item.get("sessionId") or item.get("session_id"),
timestamp=item.get("timestamp"),
)
Expand Down Expand Up @@ -190,8 +190,10 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI

stmt = stmt.order_by(sort_obj)

if limit != 0:
if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)

items = session.execute(stmt)

Expand Down
Empty file added fedn/tests/stores/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions fedn/tests/stores/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

import sys
import pytest

from fedn.tests.stores.helpers.database_helper import mongo_connection, sql_connection, postgres_connection


# These lines ensure that pytests trigger breakpoints when assertions fail during debugging
def is_debugging():
return 'debugpy' in sys.modules

# enable_stop_on_exceptions if the debugger is running during a test
if is_debugging():
@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact(call):
raise call.excinfo.value

@pytest.hookimpl(tryfirst=True)
def pytest_internalerror(excinfo):
raise excinfo.value
69 changes: 69 additions & 0 deletions fedn/tests/stores/helpers/database_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from unittest.mock import patch

from fedn.network.storage.dbconnection import DatabaseConnection
from fedn.tests.stores.helpers.mongo_docker import start_mongodb_container, stop_mongodb_container
from fedn.tests.stores.helpers.postgres_docker import start_postgres_container, stop_postgres_container


def network_id():
return "test_network"

@pytest.fixture(scope="package")
def mongo_connection():
already_running, _, port = start_mongodb_container()

def mongo_config():
return {
"type": "MongoDB",
"mongo_config": {
"host": "localhost",
"port": port,
"username": "fedn_admin",
"password": "password"
}
}

with patch('fedn.network.storage.dbconnection.get_statestore_config', return_value=mongo_config()), \
patch('fedn.network.storage.dbconnection.get_network_config', return_value=network_id()):
yield DatabaseConnection(force_create_new=True)
if not already_running:
stop_mongodb_container()

@pytest.fixture(scope="package")
def sql_connection():
def sql_config():
return {
"type": "SQLite",
"sqlite_config": {
"dbname": ":memory:",
}
}

with patch('fedn.network.storage.dbconnection.get_statestore_config', return_value=sql_config()), \
patch('fedn.network.storage.dbconnection.get_network_config', return_value=network_id()):
return DatabaseConnection(force_create_new=True)

@pytest.fixture(scope="package")
def postgres_connection():
already_running, _, port = start_postgres_container()



def postgres_config():
return {
"type": "PostgreSQL",
"postgres_config": {
"username": "fedn_admin",
"password": "password",
"database": "fedn_db",
"host": "localhost",
"port": port
}
}

with patch('fedn.network.storage.dbconnection.get_statestore_config', return_value=postgres_config()), \
patch('fedn.network.storage.dbconnection.get_network_config', return_value=network_id()):
yield DatabaseConnection(force_create_new=True)
if not already_running:
stop_postgres_container()
Loading

0 comments on commit 2058d71

Please sign in to comment.