diff --git a/DashAI/back/api/api_v1/endpoints/jobs.py b/DashAI/back/api/api_v1/endpoints/jobs.py index 47277a564..72af05fcb 100644 --- a/DashAI/back/api/api_v1/endpoints/jobs.py +++ b/DashAI/back/api/api_v1/endpoints/jobs.py @@ -1,10 +1,8 @@ import logging -from typing import Callable, ContextManager from dependency_injector.wiring import Provide, inject from fastapi import APIRouter, BackgroundTasks, Depends, Response, status from fastapi.exceptions import HTTPException -from sqlalchemy.orm import Session from DashAI.back.api.api_v1.schemas.job_params import JobParams from DashAI.back.containers import Container @@ -104,9 +102,6 @@ async def get_job( @inject async def enqueue_job( params: JobParams, - session_factory: Callable[..., ContextManager[Session]] = Depends( - Provide[Container.db.provided.session] - ), component_registry: ComponentRegistry = Depends( Provide[Container.component_registry] ), @@ -131,27 +126,23 @@ async def enqueue_job( dict dict with the new job on the database """ - with session_factory() as db: - params.db = db - job: BaseJob = component_registry[params.job_type]["class"]( - **params.model_dump() - ) - try: - job.set_status_as_delivered() - except JobError as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Job not delivered", - ) from e - try: - job_queue.put(job) - except JobQueueError as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Job not enqueued", - ) from e + job: BaseJob = component_registry[params.job_type]["class"](**params.model_dump()) + try: + job.deliver_job() + except JobError as e: + logger.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Job not delivered", + ) from e + try: + job_queue.put(job) + except JobQueueError as e: + logger.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Job not enqueued", + ) from e return job diff --git a/DashAI/back/dependencies/job_queues/job_queue.py b/DashAI/back/dependencies/job_queues/job_queue.py index 1e63528aa..ce7f74863 100644 --- a/DashAI/back/dependencies/job_queues/job_queue.py +++ b/DashAI/back/dependencies/job_queues/job_queue.py @@ -1,7 +1,8 @@ +import asyncio import logging +from multiprocessing import Pipe, Process from dependency_injector.wiring import Provide, inject -from sqlalchemy import exc from DashAI.back.containers import Container from DashAI.back.dependencies.job_queues import BaseJobQueue @@ -22,17 +23,57 @@ async def job_queue_loop( Parameters ---------- - job_queue : BaseJobQueue - The current app job queue. stop_when_queue_empties: bool boolean to set the while loop condition. + job_queue : BaseJobQueue + The current app job queue. """ - while not job_queue.is_empty() if stop_when_queue_empties else True: - try: + try: + while not (stop_when_queue_empties and job_queue.is_empty()): + # Get the job from the queue job: BaseJob = await job_queue.async_get() - job.run() - except exc.SQLAlchemyError as e: - logger.exception(e) - except JobError as e: - logger.exception(e) + parent_conn, child_conn = Pipe() + + # Get Job Arguments + job_args = job.get_args() + job_args["pipe"] = child_conn + + # Create the Proccess to run the job + job_process = Process(target=job.run, kwargs=job_args, daemon=True) + + # Launch the job + job.start_job() + + # Proccess managment + job_process.start() + + # Wait until the job fails or the child send the resutls + while job_process.is_alive() and not parent_conn.poll(): + logger.debug(f"Awaiting {job.id} process for 1 second.") + await asyncio.sleep(1) + + # Check if the job send the results: + if parent_conn.poll(): + job_results = parent_conn.recv() + parent_conn.close() + job_process.join() + # Check if the job fails + elif not job_process.is_alive() and job_process.exitcode != 0: + job.terminate_job() + continue + + # Finish the job + job.finish_job() + + # Store the results of the job + job.store_results(**job_results) + + except JobError as e: + logger.exception(e) + raise RuntimeError( + f"Error in the training execution loop: {e}.\nShutting down the app." + ) from e + except KeyboardInterrupt: + logger.info("Shutting down the app") + return diff --git a/DashAI/back/job/base_job.py b/DashAI/back/job/base_job.py index d9d506bb6..190aa6378 100644 --- a/DashAI/back/job/base_job.py +++ b/DashAI/back/job/base_job.py @@ -1,8 +1,17 @@ """Base Job abstract class.""" +import logging from abc import ABCMeta, abstractmethod from typing import Final +from dependency_injector.wiring import Provide, inject +from sqlalchemy import exc + +from DashAI.back.dependencies.database.models import Run + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + class BaseJob(metaclass=ABCMeta): """Abstract class for all Jobs.""" @@ -20,9 +29,84 @@ def __init__(self, **kwargs): job_kwargs = kwargs.pop("kwargs", {}) self.kwargs = {**kwargs, **job_kwargs} - @abstractmethod - def set_status_as_delivered(self) -> None: + @inject + def deliver_job(self, session_factory=Provide["db"]) -> None: """Set the status of the job as delivered.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError( + f"Cannot deliver job: Run {run_id} does not exist in DB." + ) + try: + run.set_status_as_delivered() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def start_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as started.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot start job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_started() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def finish_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as finished.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_finished() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def terminate_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as error.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_error() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @abstractmethod + def get_args(self) -> dict: + """Get the arguements to pass to run method.""" + raise NotImplementedError + + @abstractmethod + def store_results(self, results: dict) -> None: + """Store the results of the job.""" raise NotImplementedError @abstractmethod diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py index 679c66564..d4f58cd83 100644 --- a/DashAI/back/job/model_job.py +++ b/DashAI/back/job/model_job.py @@ -1,11 +1,11 @@ import json import logging import os -from typing import List +from multiprocessing.connection import Connection +from typing import List, Type from dependency_injector.wiring import Provide, inject from sqlalchemy import exc -from sqlalchemy.orm import Session from DashAI.back.dataloaders.classes.dashai_dataset import ( DashAIDataset, @@ -26,80 +26,29 @@ class ModelJob(BaseJob): """ModelJob class to run the model training.""" - def set_status_as_delivered(self) -> None: - """Set the status of the job as delivered.""" - run_id: int = self.kwargs["run_id"] - db: Session = self.kwargs["db"] - - run: Run = db.get(Run, run_id) - if not run: - raise JobError(f"Run {run_id} does not exist in DB.") - try: - run.set_status_as_delivered() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Internal database error", - ) from e - @inject - def run( + def get_args( self, component_registry=Provide["component_registry"], - config=Provide["config"], - ) -> None: + session_factory=Provide["db"], + ) -> dict: from DashAI.back.api.api_v1.endpoints.components import ( _intersect_component_lists, ) - run_id: int = self.kwargs["run_id"] - db: Session = self.kwargs["db"] - - run: Run = db.get(Run, run_id) - try: + with session_factory.session() as db: + run: Run = db.get(Run, self.kwargs["run_id"]) + if not run: + raise JobError(f"Run {self.kwargs['run_id']} does not exist in DB.") experiment: Experiment = db.get(Experiment, run.experiment_id) if not experiment: raise JobError(f"Experiment {run.experiment_id} does not exist in DB.") dataset: Dataset = db.get(Dataset, experiment.dataset_id) - if not dataset: + if not experiment: raise JobError(f"Dataset {experiment.dataset_id} does not exist in DB.") - - try: - loaded_dataset: DashAIDataset = load_dataset( - f"{dataset.file_path}/dataset" - ) - except Exception as e: - log.exception(e) - raise JobError( - f"Can not load dataset from path {dataset.file_path}", - ) from e - - try: - run_model_class = component_registry[run.model_name]["class"] - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to find Model with name {run.model_name} in registry.", - ) from e - - try: - model: BaseModel = run_model_class(**run.parameters) - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to instantiate model using run {run_id}", - ) from e - - try: - task: BaseTask = component_registry[experiment.task_name]["class"]() - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to find Task with name {experiment.task_name} in registry", - ) from e - try: + model_class = component_registry[run.model_name]["class"] + task_class = component_registry[experiment.task_name]["class"] selected_metrics = { component_dict["name"]: component_dict for component_dict in component_registry.get_components_by_types( @@ -110,92 +59,141 @@ def run( selected_metrics, component_registry.get_related_components(experiment.task_name), ) - metrics: List[BaseMetric] = [ - metric["class"] for metric in selected_metrics.values() - ] - except Exception as e: + except (KeyError, ValueError) as e: log.exception(e) raise JobError( - "Unable to find metrics associated with" - f"Task {experiment.task_name} in registry", + f"Unable to find component classes for run {run.id}" ) from e + metrics_classes = [metric["class"] for metric in selected_metrics.values()] + return { + "dataset_id": dataset.id, + "dataset_file_path": dataset.file_path, + "experiment_splits": experiment.splits, + "experiment_columns": { + "input": experiment.input_columns, + "output": experiment.output_columns, + }, + "model_class": model_class, + "model_kwargs": run.parameters, + "task_class": task_class, + "metrics_classes": metrics_classes, + } - try: - splits = json.loads(experiment.splits) - if splits["has_changed"]: - new_splits = { - "train": splits["train"], - "test": splits["test"], - "validation": splits["validation"], - } - loaded_dataset = update_dataset_splits( - loaded_dataset, new_splits, splits["is_random"] - ) - prepared_dataset = task.prepare_for_task( - loaded_dataset, experiment.output_columns - ) - x, y = select_columns( - prepared_dataset, - experiment.input_columns, - experiment.output_columns, + @inject + def run( + self, + dataset_id: int, + dataset_file_path: str, + experiment_splits: str, + experiment_columns: dict, + model_class: Type[BaseModel], + model_kwargs: dict, + task_class: Type[BaseTask], + metrics_classes: List[Type[BaseMetric]], + pipe: Connection, + ) -> None: + try: + model: BaseModel = model_class(**model_kwargs) + except Exception as e: + log.exception(e) + raise JobError( + f"Unable to instantiate model using run {self.kwargs['run_id']}", + ) from e + + try: + loaded_dataset: DashAIDataset = load_dataset(f"{dataset_file_path}/dataset") + except Exception as e: + log.exception(e) + raise JobError( + f"Can not load dataset from path {dataset_file_path}", + ) from e + + try: + splits = json.loads(experiment_splits) + if splits["has_changed"]: + new_splits = { + "train": splits["train"], + "test": splits["test"], + "validation": splits["validation"], + } + loaded_dataset = update_dataset_splits( + loaded_dataset, new_splits, splits["is_random"] ) - except Exception as e: - log.exception(e) - raise JobError( - f"""Can not prepare Dataset {dataset.id} - for Task {experiment.task_name}""", - ) from e + prepared_dataset = task_class().prepare_for_task( + loaded_dataset, experiment_columns["output"] + ) + x, y = select_columns( + prepared_dataset, + experiment_columns["input"], + experiment_columns["output"], + ) + except Exception as e: + log.exception(e) + raise JobError( + f"""Can not prepare Dataset {dataset_id} + for Task {task_class.__name__}""", + ) from e - try: - run.set_status_as_started() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Connection with the database failed", - ) from e + try: + # Training + model.fit(x["train"], y["train"]) + except Exception as e: + log.exception(e) + raise JobError( + "Model training failed", + ) from e - try: - # Training - model.fit(x["train"], y["train"]) - except Exception as e: - log.exception(e) - raise JobError( - "Model training failed", - ) from e + try: + model_metrics = { + split: { + metric.__name__: metric.score( + y[split], + model.predict(x[split]), + ) + for metric in metrics_classes + } + for split in ["train", "validation", "test"] + } + except Exception as e: + log.exception(e) + raise JobError( + "Metrics calculation failed", + ) from e - try: - run.set_status_as_finished() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Connection with the database failed", - ) from e + try: + pipe.send( + { + "model_bytes": model.save(), + "metrics": model_metrics, + } + ) + except ValueError as e: + log.exception(e) + raise JobError( + "Sending results failed", + ) from e + + @inject + def store_results( + self, + model_bytes: bytes, + metrics: dict, + session_factory=Provide["db"], + component_registry=Provide["component_registry"], + config=Provide["config"], + ) -> None: + with session_factory.session() as db: + run: Run = db.get(Run, self.kwargs["run_id"]) try: - model_metrics = { - split: { - metric.__name__: metric.score( - y[split], - model.predict(x[split]), - ) - for metric in metrics - } - for split in ["train", "validation", "test"] - } - except Exception as e: + model_class: BaseModel = component_registry[run.model_name]["class"] + except KeyError as e: log.exception(e) - raise JobError( - "Metrics calculation failed", - ) from e - - run.train_metrics = model_metrics["train"] - run.validation_metrics = model_metrics["validation"] - run.test_metrics = model_metrics["test"] + raise JobError(f"Unable to find model class for run {run.id}") from e try: run_path = os.path.join(config["RUNS_PATH"], str(run.id)) + model = model_class.load(byte_array=model_bytes) model.save(run_path) except Exception as e: log.exception(e) @@ -203,8 +201,11 @@ def run( "Model saving failed", ) from e + run.train_metrics = metrics["train"] + run.validation_metrics = metrics["validation"] + run.test_metrics = metrics["test"] + run.run_path = run_path try: - run.run_path = run_path db.commit() except exc.SQLAlchemyError as e: log.exception(e) @@ -213,7 +214,3 @@ def run( raise JobError( "Connection with the database failed", ) from e - except Exception as e: - run.set_status_as_error() - db.commit() - raise e diff --git a/DashAI/back/models/base_model.py b/DashAI/back/models/base_model.py index b5a8529c8..0b7ce46d0 100644 --- a/DashAI/back/models/base_model.py +++ b/DashAI/back/models/base_model.py @@ -3,7 +3,7 @@ import json import os from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Final +from typing import Any, Dict, Final, Optional from DashAI.back.config_object import ConfigObject @@ -17,25 +17,43 @@ class BaseModel(ConfigObject, metaclass=ABCMeta): TYPE: Final[str] = "Model" - # TODO implement a check_params method to check the params - # using the JSON schema file. - # TODO implement a method to check the initialization of TASK - # an task params variables. - @abstractmethod - def save(self, filename: str) -> None: + def save(self, filename: Optional[str] = None) -> Optional[bytes]: """Store an instance of a model. - filename (Str): Indicates where to store the model, - if filename is None, this method returns a bytes array with the model. + Parameters + ---------- + + filename : Optional[str] + Indicates where to store the model, if filename is None this method returns + a byte array with the model. + + Returns + ---------- + Optional[bytes] + If the filename is None returns the byte array associated with the model. """ raise NotImplementedError @abstractmethod - def load(self, filename: str) -> Any: + def load( + self, + filename: Optional[str] = None, + byte_array: Optional[bytes] = None, + ) -> "BaseModel": """Restores an instance of a model. - filename (Str): Indicates where the model was stored. + Parameters + ---------- + filename: Optional[str] + Indicates where the model was stored. + byte_array: Optional[bytes] + The bytes associated with the model. + + Returns + ---------- + BaseModel + The loaded model. """ raise NotImplementedError diff --git a/DashAI/back/models/scikit_learn/sklearn_like_model.py b/DashAI/back/models/scikit_learn/sklearn_like_model.py index f8c485610..da4c2d3a7 100644 --- a/DashAI/back/models/scikit_learn/sklearn_like_model.py +++ b/DashAI/back/models/scikit_learn/sklearn_like_model.py @@ -1,4 +1,5 @@ -from typing import Type +import pickle +from typing import Optional, Type import joblib @@ -9,14 +10,26 @@ class SklearnLikeModel(BaseModel): """Abstract class to define the way to save sklearn like models.""" - def save(self, filename: str) -> None: - """Save the model in the specified path.""" - joblib.dump(self, filename) + def save(self, filename: Optional[str] = None) -> Optional[bytes]: + """Save the model in the specified path or return the associated bytes.""" + if filename: + joblib.dump(self, filename) + else: + return pickle.dumps(self) @staticmethod - def load(filename: str) -> None: - """Load the model of the specified path.""" - model = joblib.load(filename) + def load( + filename: Optional[str] = None, byte_array: Optional[bytes] = None + ) -> "SklearnLikeModel": + """Load the model of the specified path or from the byte array.""" + if filename: + model = joblib.load(filename) + elif byte_array: + model = pickle.loads(byte_array) + else: + raise ValueError( + "Must pass filename or byte_array yo load method, none of both passed." + ) return model # --- Methods for process the data for sklearn models --- diff --git a/tests/back/api/test_jobs.py b/tests/back/api/test_jobs.py index 2ecc5c86f..e47581df9 100644 --- a/tests/back/api/test_jobs.py +++ b/tests/back/api/test_jobs.py @@ -1,5 +1,6 @@ import json import os +import pickle import joblib import pytest @@ -28,11 +29,16 @@ class DummyModel(BaseModel): def get_schema(cls): return {} - def save(self, filename): - joblib.dump(self, filename) + def save(self, filename=None): + if filename: + joblib.dump(self, filename) + else: + return pickle.dumps(self) - def load(self, filename): - return + @staticmethod + def load(filename=None, byte_array=None): + if byte_array: + return pickle.loads(byte_array) def predict(self, x): return {} @@ -51,7 +57,8 @@ def get_schema(cls): def save(self, filename): return - def load(self, filename): + @staticmethod + def load(filename, byte_array): return def predict(self, x): diff --git a/tests/back/job_queue/__init__.py b/tests/back/job_queue/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/back/job_queue/test_simple_job_queue.py b/tests/back/job_queue/test_simple_job_queue.py index fe2b35ac2..6703d0696 100644 --- a/tests/back/job_queue/test_simple_job_queue.py +++ b/tests/back/job_queue/test_simple_job_queue.py @@ -6,15 +6,18 @@ class DummyJob(BaseJob): - def run(self) -> None: + def get_args(self) -> dict: + return {} + + def store_results(self, results: dict) -> None: return None - def set_status_as_delivered(self) -> None: + def run(self) -> None: return None @pytest.fixture(name="job_queue") -def fixture_job_queue() -> BaseJobQueue: +def fixture_job_queue(): queue = SimpleJobQueue() yield queue while not queue.is_empty():