Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from jupyter_scheduler.managers import SQLAlchemyDatabaseManager
from jupyter_scheduler.orm import Base
from jupyter_scheduler.scheduler import Scheduler
from jupyter_scheduler.tests.mocks import MockEnvironmentManager
Expand Down Expand Up @@ -59,6 +60,8 @@ def jp_scheduler(jp_scheduler_db_url, jp_scheduler_root_dir, jp_scheduler_db):
db_url=jp_scheduler_db_url,
root_dir=str(jp_scheduler_root_dir),
environments_manager=MockEnvironmentManager(),
database_manager=SQLAlchemyDatabaseManager(),
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
)


Expand Down
27 changes: 24 additions & 3 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib
import io
import os
import shutil
import tarfile
import traceback
from abc import ABC, abstractmethod
from typing import Dict
from typing import Dict, Optional

import fsspec
import nbconvert
Expand All @@ -29,11 +30,31 @@ class ExecutionManager(ABC):
_model = None
_db_session = None

def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
def __init__(
self,
job_id: str,
root_dir: str,
db_url: str,
staging_paths: Dict[str, str],
database_manager_class,
job_data: Optional[Dict] = None, # NEW: Optional job data for passing metadata
):
self.job_id = job_id
self.staging_paths = staging_paths
self.root_dir = root_dir
self.db_url = db_url
self.job_data = job_data # Store for use by subclasses

self.database_manager = self._create_database_manager(database_manager_class)

def _create_database_manager(self, database_manager_class):
try:
module_name, class_name = database_manager_class.rsplit(".", 1)
module = importlib.import_module(module_name)
DatabaseManagerClass = getattr(module, class_name)
return DatabaseManagerClass()
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(f"Invalid database_manager_class '{database_manager_class}': {e}")

@property
def model(self):
Expand All @@ -46,7 +67,7 @@ def model(self):
@property
def db_session(self):
if self._db_session is None:
self._db_session = create_session(self.db_url)
self._db_session = create_session(self.db_url, self.database_manager)

return self._db_session

Expand Down
12 changes: 11 additions & 1 deletion jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class SchedulerApp(ExtensionApp):
def _db_url_default(self):
return f"sqlite:///{jupyter_data_dir()}/scheduler.sqlite"

database_manager_class = Type(
default_value="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
klass="jupyter_scheduler.managers.DatabaseManager",
config=True,
help=_i18n("Database manager class for custom database backends."),
)

environment_manager_class = Type(
default_value="jupyter_scheduler.environments.CondaEnvironmentManager",
klass="jupyter_scheduler.environments.EnvironmentManager",
Expand All @@ -69,7 +76,8 @@ def _db_url_default(self):
def initialize_settings(self):
super().initialize_settings()

create_tables(self.db_url, self.drop_tables)
database_manager = self.database_manager_class()
create_tables(self.db_url, self.drop_tables, database_manager=database_manager)

environments_manager = self.environment_manager_class()

Expand All @@ -78,6 +86,8 @@ def initialize_settings(self):
environments_manager=environments_manager,
db_url=self.db_url,
config=self.config,
database_manager=database_manager,
database_manager_class=self.database_manager_class,
)

job_files_manager = self.job_files_manager_class(scheduler=scheduler)
Expand Down
35 changes: 33 additions & 2 deletions jupyter_scheduler/job_files_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,44 @@ def generate_filepaths(self):
"""A generator that produces filepaths"""
output_formats = self.output_formats + ["input"]
for output_format in output_formats:
# Skip if this format is not in staging_paths (e.g., input file for CronJob jobs)
if output_format not in self.staging_paths:
continue
input_filepath = self.staging_paths[output_format]
output_filepath = os.path.join(self.output_dir, self.output_filenames[output_format])
if not os.path.exists(output_filepath) or self.redownload:
yield input_filepath, output_filepath

if self.staging_paths:
staging_dir = os.path.dirname(next(iter(self.staging_paths.values())))
if os.path.exists(staging_dir):
explicit_files = set()
for output_format in output_formats:
if output_format in self.staging_paths:
explicit_files.add(os.path.basename(self.staging_paths[output_format]))

for file_name in os.listdir(staging_dir):
file_path = os.path.join(staging_dir, file_name)
if os.path.isfile(file_path) and file_name not in explicit_files:
input_filepath = file_path
output_filepath = os.path.join(self.output_dir, file_name)
if not os.path.exists(output_filepath) or self.redownload:
yield input_filepath, output_filepath

if self.include_staging_files:
staging_dir = os.path.dirname(self.staging_paths["input"])
for file_relative_path in self.output_filenames["files"]:
# Handle missing "input" key gracefully - it may not exist for CronJob jobs
if "input" in self.staging_paths:
staging_dir = os.path.dirname(self.staging_paths["input"])
elif self.staging_paths:
# Fall back to any available staging path directory
staging_dir = os.path.dirname(next(iter(self.staging_paths.values())))
else:
# No staging paths available, skip
return

# Handle missing "files" key gracefully - it may not exist if packaged_files was empty
files_list = self.output_filenames.get("files", [])
for file_relative_path in files_list:
input_filepath = os.path.join(staging_dir, file_relative_path)
output_filepath = os.path.join(self.output_dir, file_relative_path)
if not os.path.exists(output_filepath) or self.redownload:
Expand Down
66 changes: 66 additions & 0 deletions jupyter_scheduler/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from sqlite3 import OperationalError

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from jupyter_scheduler.orm import Base as DefaultBase
from jupyter_scheduler.orm import update_db_schema


class DatabaseManager(ABC):
"""Base class for database managers.

Database managers handle database operations for jupyter-scheduler.
Subclasses can implement custom storage backends (K8s, Redis, etc.)
while maintaining compatibility with the scheduler's session interface.
"""

@abstractmethod
def create_session(self, db_url: str):
"""Create a database session.

Args:
db_url: Database URL (e.g., "k8s://namespace", "redis://localhost")

Returns:
Session object compatible with SQLAlchemy session interface
"""
pass

@abstractmethod
def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
"""Create database tables/schema.

Args:
db_url: Database URL
drop_tables: Whether to drop existing tables first
Base: SQLAlchemy Base for custom schemas (tests)
"""
pass


class SQLAlchemyDatabaseManager(DatabaseManager):
"""Default database manager using SQLAlchemy."""

def create_session(self, db_url: str):
"""Create SQLAlchemy session factory."""
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)
return Session

def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
"""Create database tables using SQLAlchemy."""
if Base is None:
Base = DefaultBase

engine = create_engine(db_url)
update_db_schema(engine, Base)

try:
if drop_tables:
Base.metadata.drop_all(engine)
except OperationalError:
pass
finally:
Base.metadata.create_all(engine)
6 changes: 6 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CreateJob(BaseModel):
input_filename: str = None
runtime_environment_name: str
runtime_environment_parameters: Optional[Dict[str, EnvironmentParameterValues]]
environment_variables: Optional[Dict[str, str]] = None
output_formats: Optional[List[str]] = None
idempotency_token: Optional[str] = None
job_definition_id: Optional[str] = None
Expand Down Expand Up @@ -128,6 +129,7 @@ class DescribeJob(BaseModel):
input_filename: str = None
runtime_environment_name: str
runtime_environment_parameters: Optional[Dict[str, EnvironmentParameterValues]]
environment_variables: Optional[Dict[str, str]] = None
output_formats: Optional[List[str]] = None
idempotency_token: Optional[str] = None
job_definition_id: Optional[str] = None
Expand Down Expand Up @@ -193,6 +195,7 @@ class UpdateJob(BaseModel):
status: Optional[Status] = None
name: Optional[str] = None
compute_type: Optional[str] = None
environment_variables: Optional[Dict[str, str]] = None


class DeleteJob(BaseModel):
Expand All @@ -204,6 +207,7 @@ class CreateJobDefinition(BaseModel):
input_filename: str = None
runtime_environment_name: str
runtime_environment_parameters: Optional[Dict[str, EnvironmentParameterValues]]
environment_variables: Optional[Dict[str, str]] = None
output_formats: Optional[List[str]] = None
parameters: Optional[Dict[str, str]] = None
tags: Optional[Tags] = None
Expand All @@ -226,6 +230,7 @@ class DescribeJobDefinition(BaseModel):
input_filename: str = None
runtime_environment_name: str
runtime_environment_parameters: Optional[Dict[str, EnvironmentParameterValues]]
environment_variables: Optional[Dict[str, str]] = None
output_formats: Optional[List[str]] = None
parameters: Optional[Dict[str, str]] = None
tags: Optional[Tags] = None
Expand All @@ -248,6 +253,7 @@ class Config:
class UpdateJobDefinition(BaseModel):
runtime_environment_name: Optional[str]
runtime_environment_parameters: Optional[Dict[str, EnvironmentParameterValues]]
environment_variables: Optional[Dict[str, str]] = None
output_formats: Optional[List[str]] = None
parameters: Optional[Dict[str, str]] = None
tags: Optional[Tags] = None
Expand Down
21 changes: 5 additions & 16 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CommonColumns:
# Any default values specified for new columns will be ignored during the migration process.
package_input_folder = Column(Boolean)
packaged_files = Column(JsonType, default=[])
environment_variables = Column(JsonType(4096))


class Job(CommonColumns, Base):
Expand Down Expand Up @@ -146,21 +147,9 @@ def update_db_schema(engine, Base):
connection.execute(alter_statement)


def create_tables(db_url, drop_tables=False, Base=Base):
engine = create_engine(db_url)
update_db_schema(engine, Base)
def create_tables(db_url, drop_tables=False, Base=Base, *, database_manager):
database_manager.create_tables(db_url, drop_tables, Base)

try:
if drop_tables:
Base.metadata.drop_all(engine)
except OperationalError:
pass
finally:
Base.metadata.create_all(engine)


def create_session(db_url):
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)

return Session
def create_session(db_url, database_manager):
return database_manager.create_session(db_url)
32 changes: 31 additions & 1 deletion jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,19 +405,23 @@ def __init__(
environments_manager: Type[EnvironmentManager],
db_url: str,
config=None,
database_manager=None,
database_manager_class=None,
**kwargs,
):
super().__init__(
root_dir=root_dir, environments_manager=environments_manager, config=config, **kwargs
)
self.db_url = db_url
self.database_manager = database_manager
self.database_manager_class = database_manager_class
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

@property
def db_session(self):
if not self._db_session:
self._db_session = create_session(self.db_url)
self._db_session = create_session(self.db_url, self.database_manager)

return self._db_session

Expand Down Expand Up @@ -485,13 +489,39 @@ def create_job(self, model: CreateJob) -> str:
#
# See: https://github.com/python/cpython/issues/66285
# See also: https://github.com/jupyter/jupyter_core/pull/362
# Serialize job data for cross-process passing
job_data = {
"job_id": job.job_id,
"name": job.name if hasattr(job, "name") else None,
"input_filename": job.input_filename if hasattr(job, "input_filename") else None,
"runtime_environment_name": (
job.runtime_environment_name
if hasattr(job, "runtime_environment_name")
else None
),
"runtime_environment_parameters": (
job.runtime_environment_parameters
if hasattr(job, "runtime_environment_parameters")
else None
),
"output_formats": job.output_formats if hasattr(job, "output_formats") else [],
"parameters": job.parameters if hasattr(job, "parameters") else None,
"tags": job.tags if hasattr(job, "tags") else [],
"package_input_folder": (
job.package_input_folder if hasattr(job, "package_input_folder") else False
),
"packaged_files": job.packaged_files if hasattr(job, "packaged_files") else [],
}

mp_ctx = mp.get_context("spawn")
p = mp_ctx.Process(
target=self.execution_manager_class(
job_id=job.job_id,
staging_paths=staging_paths,
root_dir=self.root_dir,
db_url=self.db_url,
database_manager_class=self.database_manager_class,
job_data=job_data,
).process
)
p.start()
Expand Down
1 change: 1 addition & 0 deletions jupyter_scheduler/tests/test_execution_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_add_side_effects_files(
root_dir=jp_scheduler_root_dir,
db_url=jp_scheduler_db_url,
staging_paths={"input": staged_notebook_file_path},
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
)
manager.add_side_effects_files(staged_notebook_dir)

Expand Down
Loading
Loading