Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/parallel jobs #195

Open
wants to merge 12 commits into
base: project/improve-queue
Choose a base branch
from
Open
43 changes: 17 additions & 26 deletions DashAI/back/api/api_v1/endpoints/jobs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
from typing import Callable, ContextManager

from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
from fastapi.exceptions import HTTPException
from sqlalchemy.orm import Session

from DashAI.back.api.api_v1.schemas.job_params import JobParams
from DashAI.back.containers import Container
Expand Down Expand Up @@ -104,9 +102,6 @@ async def get_job(
@inject
async def enqueue_job(
params: JobParams,
session_factory: Callable[..., ContextManager[Session]] = Depends(
Provide[Container.db.provided.session]
),
component_registry: ComponentRegistry = Depends(
Provide[Container.component_registry]
),
Expand All @@ -131,27 +126,23 @@ async def enqueue_job(
dict
dict with the new job on the database
"""
with session_factory() as db:
params.db = db
job: BaseJob = component_registry[params.job_type]["class"](
**params.model_dump()
)
try:
job.set_status_as_delivered()
except JobError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Job not delivered",
) from e
try:
job_queue.put(job)
except JobQueueError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Job not enqueued",
) from e
job: BaseJob = component_registry[params.job_type]["class"](**params.model_dump())
try:
job.deliver_job()
except JobError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Job not delivered",
) from e
try:
job_queue.put(job)
except JobQueueError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Job not enqueued",
) from e
return job


Expand Down
61 changes: 51 additions & 10 deletions DashAI/back/dependencies/job_queues/job_queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import logging
from multiprocessing import Pipe, Process

from dependency_injector.wiring import Provide, inject
from sqlalchemy import exc

from DashAI.back.containers import Container
from DashAI.back.dependencies.job_queues import BaseJobQueue
Expand All @@ -22,17 +23,57 @@ async def job_queue_loop(

Parameters
----------
job_queue : BaseJobQueue
The current app job queue.
stop_when_queue_empties: bool
boolean to set the while loop condition.
job_queue : BaseJobQueue
The current app job queue.

"""
while not job_queue.is_empty() if stop_when_queue_empties else True:
try:
try:
while not (stop_when_queue_empties and job_queue.is_empty()):
# Get the job from the queue
job: BaseJob = await job_queue.async_get()
job.run()
except exc.SQLAlchemyError as e:
logger.exception(e)
except JobError as e:
logger.exception(e)
parent_conn, child_conn = Pipe()

# Get Job Arguments
job_args = job.get_args()
job_args["pipe"] = child_conn

# Create the Proccess to run the job
job_process = Process(target=job.run, kwargs=job_args, daemon=True)

# Launch the job
job.start_job()

# Proccess managment
job_process.start()

# Wait until the job fails or the child send the resutls
while job_process.is_alive() and not parent_conn.poll():
logger.debug(f"Awaiting {job.id} process for 1 second.")
await asyncio.sleep(1)

# Check if the job send the results:
if parent_conn.poll():
job_results = parent_conn.recv()
parent_conn.close()
job_process.join()
# Check if the job fails
elif not job_process.is_alive() and job_process.exitcode != 0:
job.terminate_job()
continue

# Finish the job
job.finish_job()

# Store the results of the job
job.store_results(**job_results)

except JobError as e:
logger.exception(e)
raise RuntimeError(
f"Error in the training execution loop: {e}.\nShutting down the app."
) from e
except KeyboardInterrupt:
logger.info("Shutting down the app")
return
88 changes: 86 additions & 2 deletions DashAI/back/job/base_job.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""Base Job abstract class."""

import logging
from abc import ABCMeta, abstractmethod
from typing import Final

from dependency_injector.wiring import Provide, inject
from sqlalchemy import exc

from DashAI.back.dependencies.database.models import Run

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


class BaseJob(metaclass=ABCMeta):
"""Abstract class for all Jobs."""
Expand All @@ -20,9 +29,84 @@ def __init__(self, **kwargs):
job_kwargs = kwargs.pop("kwargs", {})
self.kwargs = {**kwargs, **job_kwargs}

@abstractmethod
def set_status_as_delivered(self) -> None:
@inject
def deliver_job(self, session_factory=Provide["db"]) -> None:
"""Set the status of the job as delivered."""
with session_factory.session() as db:
run_id: int = self.kwargs["run_id"]
run: Run = db.get(Run, run_id)
if not run:
raise JobError(
f"Cannot deliver job: Run {run_id} does not exist in DB."
)
try:
run.set_status_as_delivered()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Internal database error",
) from e

@inject
def start_job(self, session_factory=Provide["db"]) -> None:
"""Set the status of the job as started."""
with session_factory.session() as db:
run_id: int = self.kwargs["run_id"]
run: Run = db.get(Run, run_id)
if not run:
raise JobError(f"Cannot start job: Run {run_id} does not exist in DB.")
try:
run.set_status_as_started()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Internal database error",
) from e

@inject
def finish_job(self, session_factory=Provide["db"]) -> None:
"""Set the status of the job as finished."""
with session_factory.session() as db:
run_id: int = self.kwargs["run_id"]
run: Run = db.get(Run, run_id)
if not run:
raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.")
try:
run.set_status_as_finished()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Internal database error",
) from e

@inject
def terminate_job(self, session_factory=Provide["db"]) -> None:
"""Set the status of the job as error."""
with session_factory.session() as db:
run_id: int = self.kwargs["run_id"]
run: Run = db.get(Run, run_id)
if not run:
raise JobError(f"Cannot finish job: Run {run_id} does not exist in DB.")
try:
run.set_status_as_error()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Internal database error",
) from e

@abstractmethod
def get_args(self) -> dict:
"""Get the arguements to pass to run method."""
raise NotImplementedError

@abstractmethod
def store_results(self, results: dict) -> None:
"""Store the results of the job."""
raise NotImplementedError

@abstractmethod
Expand Down
Loading
Loading