From 5b1287e41ccd72c83282a30e60049de38aae2eee Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Fri, 15 Mar 2024 10:07:49 -0300 Subject: [PATCH 01/10] new procces for jobs --- .../back/dependencies/job_queues/job_queue.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/DashAI/back/dependencies/job_queues/job_queue.py b/DashAI/back/dependencies/job_queues/job_queue.py index 1e63528aa..cfd68acd6 100644 --- a/DashAI/back/dependencies/job_queues/job_queue.py +++ b/DashAI/back/dependencies/job_queues/job_queue.py @@ -1,7 +1,8 @@ +import asyncio import logging +from multiprocessing import Process from dependency_injector.wiring import Provide, inject -from sqlalchemy import exc from DashAI.back.containers import Container from DashAI.back.dependencies.job_queues import BaseJobQueue @@ -13,7 +14,9 @@ @inject async def job_queue_loop( - stop_when_queue_empties: bool, + component_registry=Provide[Container.component_registry], + session_factory=Provide[Container.db.provided.session], + config=Provide[Container.config], job_queue: BaseJobQueue = Provide[Container.job_queue], ): """Loop function to execute all the pending jobs in the job queue. @@ -28,11 +31,26 @@ async def job_queue_loop( boolean to set the while loop condition. """ - while not job_queue.is_empty() if stop_when_queue_empties else True: - try: + try: + while True: job: BaseJob = await job_queue.async_get() - job.run() - except exc.SQLAlchemyError as e: - logger.exception(e) - except JobError as e: - logger.exception(e) + job_process = Process( + target=job.run, + daemon=True, + ) + job_process.start() + + while job_process.is_alive(): + logger.debug(f"Awaiting {job.id} process for 1 second.") + await asyncio.sleep(1) + + job_process.join() + + except JobError as e: + logger.exception(e) + raise RuntimeError( + f"Error in the training execution loop: {e}.\nShutting down the app." + ) from e + except KeyboardInterrupt: + logger.info("Shutting down the app") + return From 619ad46605d65fc72ed56985bc4e7136232dd447 Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Tue, 26 Mar 2024 17:03:59 -0300 Subject: [PATCH 02/10] refactor job_queue, now run is done without relying on DashAI core. --- .../back/dependencies/job_queues/job_queue.py | 39 ++- DashAI/back/job/base_job.py | 71 +++++- DashAI/back/job/model_job.py | 224 ++++++++---------- 3 files changed, 198 insertions(+), 136 deletions(-) diff --git a/DashAI/back/dependencies/job_queues/job_queue.py b/DashAI/back/dependencies/job_queues/job_queue.py index cfd68acd6..8a7cb94b5 100644 --- a/DashAI/back/dependencies/job_queues/job_queue.py +++ b/DashAI/back/dependencies/job_queues/job_queue.py @@ -1,6 +1,6 @@ import asyncio import logging -from multiprocessing import Process +from multiprocessing import Pipe, Process from dependency_injector.wiring import Provide, inject @@ -14,9 +14,7 @@ @inject async def job_queue_loop( - component_registry=Provide[Container.component_registry], - session_factory=Provide[Container.db.provided.session], - config=Provide[Container.config], + stop_when_queue_empties: bool, job_queue: BaseJobQueue = Provide[Container.job_queue], ): """Loop function to execute all the pending jobs in the job queue. @@ -25,27 +23,48 @@ async def job_queue_loop( Parameters ---------- - job_queue : BaseJobQueue - The current app job queue. stop_when_queue_empties: bool boolean to set the while loop condition. + job_queue : BaseJobQueue + The current app job queue. """ try: while True: job: BaseJob = await job_queue.async_get() - job_process = Process( - target=job.run, - daemon=True, - ) + parent_conn, child_conn = Pipe() + + # Get Job Arguments + job_args = job.get_args() + job_args["pipe"] = child_conn + + # Create the Proccess to run the job + job_process = Process(target=job.run, kwargs=job_args, daemon=True) + + # Launch the job + job.start_job() + + # Proccess managment job_process.start() while job_process.is_alive(): logger.debug(f"Awaiting {job.id} process for 1 second.") await asyncio.sleep(1) + job_results = parent_conn.recv() + parent_conn.close() job_process.join() + # Finish the job + job.finish_job() + + # Store the results of the job + job.store_results(**job_results) + + # TODO: Fix tests, test loops when trying to run with multiple jobs + if stop_when_queue_empties: + return + except JobError as e: logger.exception(e) raise RuntimeError( diff --git a/DashAI/back/job/base_job.py b/DashAI/back/job/base_job.py index 06b41d070..a75988479 100644 --- a/DashAI/back/job/base_job.py +++ b/DashAI/back/job/base_job.py @@ -1,7 +1,16 @@ """Base Job abstract class.""" +import logging from abc import ABCMeta, abstractmethod from typing import Final +from dependency_injector.wiring import Provide, inject +from sqlalchemy import exc + +from DashAI.back.dependencies.database.models import Run + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + class BaseJob(metaclass=ABCMeta): """Abstract class for all Jobs.""" @@ -19,9 +28,67 @@ def __init__(self, **kwargs): job_kwargs = kwargs.pop("kwargs", {}) self.kwargs = {**kwargs, **job_kwargs} - @abstractmethod - def set_status_as_delivered(self) -> None: + @inject + def deliver_job(self, session_factory=Provide["db"]) -> None: """Set the status of the job as delivered.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError( + f"Cannot deliver job: Run {run_id} does not exist in DB." + ) + try: + run.set_status_as_delivered() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def start_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as started.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot start job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_started() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def finish_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as finished.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_finished() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @abstractmethod + def get_args(self) -> dict: + """Get the arguements to pass to run method.""" + raise NotImplementedError + + @abstractmethod + def store_results(self, results: dict) -> None: + """Store the results of the job.""" raise NotImplementedError @abstractmethod diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py index 1ef5825c5..a139615c2 100644 --- a/DashAI/back/job/model_job.py +++ b/DashAI/back/job/model_job.py @@ -1,10 +1,10 @@ import logging import os -from typing import List +from multiprocessing.connection import PipeConnection +from typing import List, Type from dependency_injector.wiring import Provide, inject from sqlalchemy import exc -from sqlalchemy.orm import Session from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset, load_dataset from DashAI.back.dependencies.database.models import Dataset, Experiment, Run @@ -20,80 +20,29 @@ class ModelJob(BaseJob): """ModelJob class to run the model training.""" - def set_status_as_delivered(self) -> None: - """Set the status of the job as delivered.""" - run_id: int = self.kwargs["run_id"] - db: Session = self.kwargs["db"] - - run: Run = db.get(Run, run_id) - if not run: - raise JobError(f"Run {run_id} does not exist in DB.") - try: - run.set_status_as_delivered() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Internal database error", - ) from e - @inject - def run( + def get_args( self, component_registry=Provide["component_registry"], - config=Provide["config"], - ) -> None: + session_factory=Provide["db"], + ) -> dict: from DashAI.back.api.api_v1.endpoints.components import ( _intersect_component_lists, ) - run_id: int = self.kwargs["run_id"] - db: Session = self.kwargs["db"] - - run: Run = db.get(Run, run_id) - try: + with session_factory.session() as db: + run: Run = db.get(Run, self.kwargs["run_id"]) + if not run: + raise JobError(f"Run {self.kwargs['run_id']} does not exist in DB.") experiment: Experiment = db.get(Experiment, run.experiment_id) if not experiment: raise JobError(f"Experiment {run.experiment_id} does not exist in DB.") dataset: Dataset = db.get(Dataset, experiment.dataset_id) - if not dataset: + if not experiment: raise JobError(f"Dataset {experiment.dataset_id} does not exist in DB.") - - try: - loaded_dataset: DashAIDataset = load_dataset( - f"{dataset.file_path}/dataset" - ) - except Exception as e: - log.exception(e) - raise JobError( - f"Can not load dataset from path {dataset.file_path}", - ) from e - - try: - run_model_class = component_registry[run.model_name]["class"] - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to find Model with name {run.model_name} in registry.", - ) from e - - try: - model: BaseModel = run_model_class(**run.parameters) - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to instantiate model using run {run_id}", - ) from e - - try: - task: BaseTask = component_registry[experiment.task_name]["class"]() - except Exception as e: - log.exception(e) - raise JobError( - f"Unable to find Task with name {experiment.task_name} in registry", - ) from e - try: + model_class = component_registry[run.model_name]["class"] + task_class = component_registry[experiment.task_name]["class"] selected_metrics = { component_dict["name"]: component_dict for component_dict in component_registry.get_components_by_types( @@ -104,72 +53,100 @@ def run( selected_metrics, component_registry.get_related_components(experiment.task_name), ) - metrics: List[BaseMetric] = [ - metric["class"] for metric in selected_metrics.values() - ] - except Exception as e: + except (KeyError, ValueError) as e: log.exception(e) raise JobError( - "Unable to find metrics associated with" - f"Task {experiment.task_name} in registry", + f"Unable to find component classes for run {run.id}" ) from e + metrics_classes = [metric["class"] for metric in selected_metrics.values()] + return { + "dataset_id": dataset.id, + "dataset_file_path": dataset.file_path, + "model_class": model_class, + "model_kwargs": run.parameters, + "task_class": task_class, + "metrics_classes": metrics_classes, + } - try: - prepared_dataset = task.prepare_for_task(loaded_dataset) - except Exception as e: - log.exception(e) - raise JobError( - f"""Can not prepare Dataset {dataset.id} - for Task {experiment.task_name}""", - ) from e + @inject + def run( + self, + dataset_id: int, + dataset_file_path: str, + model_class: Type[BaseModel], + model_kwargs: dict, + task_class: Type[BaseTask], + metrics_classes: List[Type[BaseMetric]], + pipe: PipeConnection, + ) -> None: + try: + loaded_dataset: DashAIDataset = load_dataset(f"{dataset_file_path}/dataset") + except Exception as e: + log.exception(e) + raise JobError( + f"Can not load dataset from path {dataset_file_path}", + ) from e - try: - run.set_status_as_started() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Connection with the database failed", - ) from e + try: + model: BaseModel = model_class(**model_kwargs) + except Exception as e: + log.exception(e) + raise JobError( + f"Unable to instantiate model using run {self.kwargs['run_id']}", + ) from e - try: - # Training - model.fit(prepared_dataset["train"]) - except Exception as e: - log.exception(e) - raise JobError( - "Model training failed", - ) from e + try: + prepared_dataset = task_class().prepare_for_task(loaded_dataset) + except Exception as e: + log.exception(e) + raise JobError( + f"""Can not prepare Dataset {dataset_id} + for Task {task_class.__name__}""", + ) from e - try: - run.set_status_as_finished() - db.commit() - except exc.SQLAlchemyError as e: - log.exception(e) - raise JobError( - "Connection with the database failed", - ) from e + try: + # Training + model.fit(prepared_dataset["train"]) + except Exception as e: + log.exception(e) + raise JobError( + "Model training failed", + ) from e - try: - model_metrics = { - split: { - metric.__name__: metric.score( - prepared_dataset[split], - model.predict(prepared_dataset[split]), - ) - for metric in metrics - } - for split in ["train", "validation", "test"] + try: + model_metrics = { + split: { + metric.__name__: metric.score( + prepared_dataset[split], + model.predict(prepared_dataset[split]), + ) + for metric in metrics_classes } - except Exception as e: - log.exception(e) - raise JobError( - "Metrics calculation failed", - ) from e + for split in ["train", "validation", "test"] + } + except Exception as e: + log.exception(e) + raise JobError( + "Metrics calculation failed", + ) from e - run.train_metrics = model_metrics["train"] - run.validation_metrics = model_metrics["validation"] - run.test_metrics = model_metrics["test"] + pipe.send( + { + "model": model, + "metrics": model_metrics, + } + ) + + @inject + def store_results( + self, + model: BaseModel, + metrics: dict, + session_factory=Provide["db"], + config=Provide["config"], + ) -> None: + with session_factory.session() as db: + run: Run = db.get(Run, self.kwargs["run_id"]) try: run_path = os.path.join(config["RUNS_PATH"], str(run.id)) @@ -180,8 +157,11 @@ def run( "Model saving failed", ) from e + run.train_metrics = metrics["train"] + run.validation_metrics = metrics["validation"] + run.test_metrics = metrics["test"] + run.run_path = run_path try: - run.run_path = run_path db.commit() except exc.SQLAlchemyError as e: log.exception(e) @@ -190,7 +170,3 @@ def run( raise JobError( "Connection with the database failed", ) from e - except Exception as e: - run.set_status_as_error() - db.commit() - raise e From 44836fc378d098e92a6d8181fd53d31db9e7b414 Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Tue, 26 Mar 2024 17:05:10 -0300 Subject: [PATCH 03/10] refactor job api to use new Job interface --- DashAI/back/api/api_v1/endpoints/jobs.py | 43 ++++++++++-------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/DashAI/back/api/api_v1/endpoints/jobs.py b/DashAI/back/api/api_v1/endpoints/jobs.py index 47277a564..72af05fcb 100644 --- a/DashAI/back/api/api_v1/endpoints/jobs.py +++ b/DashAI/back/api/api_v1/endpoints/jobs.py @@ -1,10 +1,8 @@ import logging -from typing import Callable, ContextManager from dependency_injector.wiring import Provide, inject from fastapi import APIRouter, BackgroundTasks, Depends, Response, status from fastapi.exceptions import HTTPException -from sqlalchemy.orm import Session from DashAI.back.api.api_v1.schemas.job_params import JobParams from DashAI.back.containers import Container @@ -104,9 +102,6 @@ async def get_job( @inject async def enqueue_job( params: JobParams, - session_factory: Callable[..., ContextManager[Session]] = Depends( - Provide[Container.db.provided.session] - ), component_registry: ComponentRegistry = Depends( Provide[Container.component_registry] ), @@ -131,27 +126,23 @@ async def enqueue_job( dict dict with the new job on the database """ - with session_factory() as db: - params.db = db - job: BaseJob = component_registry[params.job_type]["class"]( - **params.model_dump() - ) - try: - job.set_status_as_delivered() - except JobError as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Job not delivered", - ) from e - try: - job_queue.put(job) - except JobQueueError as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Job not enqueued", - ) from e + job: BaseJob = component_registry[params.job_type]["class"](**params.model_dump()) + try: + job.deliver_job() + except JobError as e: + logger.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Job not delivered", + ) from e + try: + job_queue.put(job) + except JobQueueError as e: + logger.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Job not enqueued", + ) from e return job From 2618e8932f867cdf6eba879c1ccef505a3677c77 Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Tue, 26 Mar 2024 17:07:33 -0300 Subject: [PATCH 04/10] remove tricky part of test, must be re-added later when the bug was solved --- tests/back/api/test_jobs.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/back/api/test_jobs.py b/tests/back/api/test_jobs.py index 6c6818919..faa41817e 100644 --- a/tests/back/api/test_jobs.py +++ b/tests/back/api/test_jobs.py @@ -305,12 +305,7 @@ def test_execute_jobs(client: TestClient, run_id: int, failed_run_id: int): assert data["start_time"] is not None assert data["end_time"] is not None - response = client.get(f"/api/v1/run/{failed_run_id}") - data = response.json() - assert data["status"] == 4 - assert data["delivery_time"] is not None - assert data["start_time"] is not None - assert data["end_time"] is None + # TODO: Add test that check the failed run job fails def test_job_with_wrong_run(client: TestClient): From 128414a595e134cc991ee97b67e7b06d42efa62c Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Tue, 7 May 2024 17:17:11 -0400 Subject: [PATCH 05/10] fix loop in queue --- .../back/dependencies/job_queues/job_queue.py | 20 +++++++++++++------ DashAI/back/job/base_job.py | 17 ++++++++++++++++ tests/back/api/test_jobs.py | 7 ++++++- tests/back/job_queue/__init__.py | 0 tests/back/job_queue/test_simple_job_queue.py | 9 ++++++--- 5 files changed, 43 insertions(+), 10 deletions(-) create mode 100644 tests/back/job_queue/__init__.py diff --git a/DashAI/back/dependencies/job_queues/job_queue.py b/DashAI/back/dependencies/job_queues/job_queue.py index 8a7cb94b5..bf2ab2aa6 100644 --- a/DashAI/back/dependencies/job_queues/job_queue.py +++ b/DashAI/back/dependencies/job_queues/job_queue.py @@ -30,7 +30,8 @@ async def job_queue_loop( """ try: - while True: + while not (stop_when_queue_empties and job_queue.is_empty()): + # Get the job from the queue job: BaseJob = await job_queue.async_get() parent_conn, child_conn = Pipe() @@ -47,12 +48,23 @@ async def job_queue_loop( # Proccess managment job_process.start() - while job_process.is_alive(): + # Wait until the job fails or the child send the resutls + while job_process.is_alive() and not parent_conn.poll(): logger.debug(f"Awaiting {job.id} process for 1 second.") await asyncio.sleep(1) + # Check if the job fails + if not job_process.is_alive() and job_process.exitcode != 0: + job.terminate_job() + continue + job_results = parent_conn.recv() parent_conn.close() + + # Wait until the job returns + while job_process.is_alive(): + logger.debug(f"Awaiting {job.id} process for 1 second.") + await asyncio.sleep(1) job_process.join() # Finish the job @@ -61,10 +73,6 @@ async def job_queue_loop( # Store the results of the job job.store_results(**job_results) - # TODO: Fix tests, test loops when trying to run with multiple jobs - if stop_when_queue_empties: - return - except JobError as e: logger.exception(e) raise RuntimeError( diff --git a/DashAI/back/job/base_job.py b/DashAI/back/job/base_job.py index a80e36f4a..190aa6378 100644 --- a/DashAI/back/job/base_job.py +++ b/DashAI/back/job/base_job.py @@ -82,6 +82,23 @@ def finish_job(self, session_factory=Provide["db"]) -> None: "Internal database error", ) from e + @inject + def terminate_job(self, session_factory=Provide["db"]) -> None: + """Set the status of the job as error.""" + with session_factory.session() as db: + run_id: int = self.kwargs["run_id"] + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.") + try: + run.set_status_as_error() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + @abstractmethod def get_args(self) -> dict: """Get the arguements to pass to run method.""" diff --git a/tests/back/api/test_jobs.py b/tests/back/api/test_jobs.py index f9a833bb0..2ecc5c86f 100644 --- a/tests/back/api/test_jobs.py +++ b/tests/back/api/test_jobs.py @@ -314,7 +314,12 @@ def test_execute_jobs(client: TestClient, run_id: int, failed_run_id: int): assert data["start_time"] is not None assert data["end_time"] is not None - # TODO: Add test that check the failed run job fails + response = client.get(f"/api/v1/run/{failed_run_id}") + data = response.json() + assert data["status"] == 4 + assert data["delivery_time"] is not None + assert data["start_time"] is not None + assert data["end_time"] is None def test_job_with_wrong_run(client: TestClient): diff --git a/tests/back/job_queue/__init__.py b/tests/back/job_queue/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/back/job_queue/test_simple_job_queue.py b/tests/back/job_queue/test_simple_job_queue.py index fe2b35ac2..6703d0696 100644 --- a/tests/back/job_queue/test_simple_job_queue.py +++ b/tests/back/job_queue/test_simple_job_queue.py @@ -6,15 +6,18 @@ class DummyJob(BaseJob): - def run(self) -> None: + def get_args(self) -> dict: + return {} + + def store_results(self, results: dict) -> None: return None - def set_status_as_delivered(self) -> None: + def run(self) -> None: return None @pytest.fixture(name="job_queue") -def fixture_job_queue() -> BaseJobQueue: +def fixture_job_queue(): queue = SimpleJobQueue() yield queue while not queue.is_empty(): From 14705234545f0dfa0864de6aae519da2b92c4220 Mon Sep 17 00:00:00 2001 From: Lurrea <rodrigo.urrea@ug.uchile.cl> Date: Tue, 7 May 2024 17:27:51 -0400 Subject: [PATCH 06/10] change pipe class to pass Linux test --- DashAI/back/job/model_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py index bf4bca066..7dfd4ba79 100644 --- a/DashAI/back/job/model_job.py +++ b/DashAI/back/job/model_job.py @@ -1,7 +1,7 @@ import json import logging import os -from multiprocessing.connection import PipeConnection +from multiprocessing.connection import Connection from typing import List, Type from dependency_injector.wiring import Provide, inject @@ -90,7 +90,7 @@ def run( model_kwargs: dict, task_class: Type[BaseTask], metrics_classes: List[Type[BaseMetric]], - pipe: PipeConnection, + pipe: Connection, ) -> None: try: model: BaseModel = model_class(**model_kwargs) From 81bbef0e71530a239278fa5271cae5584ca515a6 Mon Sep 17 00:00:00 2001 From: Riul1999 <rodrigo.urrea@ug.uchile.cl> Date: Fri, 17 May 2024 14:00:47 -0400 Subject: [PATCH 07/10] fix test jobs, update save and load methods --- tests/back/api/test_jobs.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/back/api/test_jobs.py b/tests/back/api/test_jobs.py index 2ecc5c86f..e47581df9 100644 --- a/tests/back/api/test_jobs.py +++ b/tests/back/api/test_jobs.py @@ -1,5 +1,6 @@ import json import os +import pickle import joblib import pytest @@ -28,11 +29,16 @@ class DummyModel(BaseModel): def get_schema(cls): return {} - def save(self, filename): - joblib.dump(self, filename) + def save(self, filename=None): + if filename: + joblib.dump(self, filename) + else: + return pickle.dumps(self) - def load(self, filename): - return + @staticmethod + def load(filename=None, byte_array=None): + if byte_array: + return pickle.loads(byte_array) def predict(self, x): return {} @@ -51,7 +57,8 @@ def get_schema(cls): def save(self, filename): return - def load(self, filename): + @staticmethod + def load(filename, byte_array): return def predict(self, x): From 30e11e8911a833ca025dd6c5d8d03aea12fb8b42 Mon Sep 17 00:00:00 2001 From: Riul1999 <rodrigo.urrea@ug.uchile.cl> Date: Fri, 17 May 2024 14:01:37 -0400 Subject: [PATCH 08/10] update save and load methods, add documentation --- DashAI/back/models/base_model.py | 40 ++++++++++++++----- .../models/scikit_learn/sklearn_like_model.py | 27 +++++++++---- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/DashAI/back/models/base_model.py b/DashAI/back/models/base_model.py index b5a8529c8..0b7ce46d0 100644 --- a/DashAI/back/models/base_model.py +++ b/DashAI/back/models/base_model.py @@ -3,7 +3,7 @@ import json import os from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Final +from typing import Any, Dict, Final, Optional from DashAI.back.config_object import ConfigObject @@ -17,25 +17,43 @@ class BaseModel(ConfigObject, metaclass=ABCMeta): TYPE: Final[str] = "Model" - # TODO implement a check_params method to check the params - # using the JSON schema file. - # TODO implement a method to check the initialization of TASK - # an task params variables. - @abstractmethod - def save(self, filename: str) -> None: + def save(self, filename: Optional[str] = None) -> Optional[bytes]: """Store an instance of a model. - filename (Str): Indicates where to store the model, - if filename is None, this method returns a bytes array with the model. + Parameters + ---------- + + filename : Optional[str] + Indicates where to store the model, if filename is None this method returns + a byte array with the model. + + Returns + ---------- + Optional[bytes] + If the filename is None returns the byte array associated with the model. """ raise NotImplementedError @abstractmethod - def load(self, filename: str) -> Any: + def load( + self, + filename: Optional[str] = None, + byte_array: Optional[bytes] = None, + ) -> "BaseModel": """Restores an instance of a model. - filename (Str): Indicates where the model was stored. + Parameters + ---------- + filename: Optional[str] + Indicates where the model was stored. + byte_array: Optional[bytes] + The bytes associated with the model. + + Returns + ---------- + BaseModel + The loaded model. """ raise NotImplementedError diff --git a/DashAI/back/models/scikit_learn/sklearn_like_model.py b/DashAI/back/models/scikit_learn/sklearn_like_model.py index f8c485610..da4c2d3a7 100644 --- a/DashAI/back/models/scikit_learn/sklearn_like_model.py +++ b/DashAI/back/models/scikit_learn/sklearn_like_model.py @@ -1,4 +1,5 @@ -from typing import Type +import pickle +from typing import Optional, Type import joblib @@ -9,14 +10,26 @@ class SklearnLikeModel(BaseModel): """Abstract class to define the way to save sklearn like models.""" - def save(self, filename: str) -> None: - """Save the model in the specified path.""" - joblib.dump(self, filename) + def save(self, filename: Optional[str] = None) -> Optional[bytes]: + """Save the model in the specified path or return the associated bytes.""" + if filename: + joblib.dump(self, filename) + else: + return pickle.dumps(self) @staticmethod - def load(filename: str) -> None: - """Load the model of the specified path.""" - model = joblib.load(filename) + def load( + filename: Optional[str] = None, byte_array: Optional[bytes] = None + ) -> "SklearnLikeModel": + """Load the model of the specified path or from the byte array.""" + if filename: + model = joblib.load(filename) + elif byte_array: + model = pickle.loads(byte_array) + else: + raise ValueError( + "Must pass filename or byte_array yo load method, none of both passed." + ) return model # --- Methods for process the data for sklearn models --- From de12b38308224f46e3dda590d4b45665c93ee30e Mon Sep 17 00:00:00 2001 From: Riul1999 <rodrigo.urrea@ug.uchile.cl> Date: Fri, 17 May 2024 14:02:56 -0400 Subject: [PATCH 09/10] update ModelJob, now pass bytes train results and stores it later into a file --- DashAI/back/job/model_job.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py index 7dfd4ba79..d4f58cd83 100644 --- a/DashAI/back/job/model_job.py +++ b/DashAI/back/job/model_job.py @@ -160,26 +160,40 @@ def run( "Metrics calculation failed", ) from e - pipe.send( - { - "model": model, - "metrics": model_metrics, - } - ) + try: + pipe.send( + { + "model_bytes": model.save(), + "metrics": model_metrics, + } + ) + except ValueError as e: + log.exception(e) + raise JobError( + "Sending results failed", + ) from e @inject def store_results( self, - model: BaseModel, + model_bytes: bytes, metrics: dict, session_factory=Provide["db"], + component_registry=Provide["component_registry"], config=Provide["config"], ) -> None: with session_factory.session() as db: run: Run = db.get(Run, self.kwargs["run_id"]) + try: + model_class: BaseModel = component_registry[run.model_name]["class"] + except KeyError as e: + log.exception(e) + raise JobError(f"Unable to find model class for run {run.id}") from e + try: run_path = os.path.join(config["RUNS_PATH"], str(run.id)) + model = model_class.load(byte_array=model_bytes) model.save(run_path) except Exception as e: log.exception(e) From 767427b527740d9c9a0d442fbe5d2efa8da9f698 Mon Sep 17 00:00:00 2001 From: Riul1999 <rodrigo.urrea@ug.uchile.cl> Date: Fri, 17 May 2024 14:03:43 -0400 Subject: [PATCH 10/10] fix jobQueue, now it ask first if the child send a message and later if an error ocurrs --- DashAI/back/dependencies/job_queues/job_queue.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/DashAI/back/dependencies/job_queues/job_queue.py b/DashAI/back/dependencies/job_queues/job_queue.py index bf2ab2aa6..ce7f74863 100644 --- a/DashAI/back/dependencies/job_queues/job_queue.py +++ b/DashAI/back/dependencies/job_queues/job_queue.py @@ -53,20 +53,16 @@ async def job_queue_loop( logger.debug(f"Awaiting {job.id} process for 1 second.") await asyncio.sleep(1) + # Check if the job send the results: + if parent_conn.poll(): + job_results = parent_conn.recv() + parent_conn.close() + job_process.join() # Check if the job fails - if not job_process.is_alive() and job_process.exitcode != 0: + elif not job_process.is_alive() and job_process.exitcode != 0: job.terminate_job() continue - job_results = parent_conn.recv() - parent_conn.close() - - # Wait until the job returns - while job_process.is_alive(): - logger.debug(f"Awaiting {job.id} process for 1 second.") - await asyncio.sleep(1) - job_process.join() - # Finish the job job.finish_job()