Skip to content

Commit

Permalink
resolve comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vadiklyutiy committed Jan 20, 2025
1 parent 010cca5 commit dddf37c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
17 changes: 15 additions & 2 deletions apps/compile_server/resources/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

from .compile_worker import CompilationWorkers

'''
The compilation server will launch as many Flask applications as there are vCPUs (using gunicorn).
Each Flask application (i.e., our compilation server process) will handle the requests,
and there will be at most vCPU number of requests being processed at the same time.
Each process will maintain a pool of compilation workers with max_workers=5 (i.e., independent processes),
with a specific version of hidet that has been imported in every process.
The job will (try to) be dispatched to a worker with the same hidet version first.
If no such worker exists, then a new one will be created to replace an existing one.
Increasing the `max_workers` in `CompilationWorkers` init will (potentially) consume more memory
(thanks to fork, this problem will not get severe) and create more processes (max_workers * vCPU in total).
Reducing the max_workers will reduce the opportunity to avoid importing hidet with the same version in nearby jobs.
'''

lock = threading.Lock()
logger = logging.Logger(__name__)

Expand Down Expand Up @@ -161,8 +175,7 @@ def post(self):
with lock: # Only one thread can access the following code at the same time
print('[{}] Start compiling: {}'.format(pid, job_id[:16]), flush=True)
start_time = time.time()
compilation_workers.submit_job(job_id, version_path)
compilation_workers.wait_all_jobs_finished()
compilation_workers.run_and_wait_job(job_id, version_path)
end_time = time.time()

# respond to the client
Expand Down
11 changes: 4 additions & 7 deletions apps/compile_server/resources/compile_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def compile_job(job_id: str):
# Worker process function to handle compilation jobs using a specific version of the 'hidet' module.
def worker_process(version, job_queue, result_queue, parent_pid):
sys.path.insert(0, os.path.join(version, 'python')) # Ensure the version path is first in sys.path
hidet = importlib.import_module("hidet") # Load the specific version
importlib.reload(hidet) # Reload to ensure correct version
print(f"[{parent_pid}] Worker loaded hidet version from {version}", flush=True)

while True:
Expand Down Expand Up @@ -137,13 +135,12 @@ def _get_or_create_worker(self, version_path):
self.workers[version_path] = (worker, job_queue)
return self.workers[version_path]

def submit_job(self, job_id, version_path):

def run_and_wait_job(self, job_id, version_path):
# Run the job and wait until it is finished
_, job_queue = self._get_or_create_worker(version_path)
job_queue.put(job_id)

def wait_all_jobs_finished(self):
# multiprocessing.Queue.get() waits until a new item is available
self.result_queue.get()
self.result_queue.get() # multiprocessing.Queue.get() waits until a new item is available

def shutdown(self):
for _, (worker, job_queue) in self.workers.items():
Expand Down

0 comments on commit dddf37c

Please sign in to comment.