Skip to content

Commit

Permalink
speed up comp server
Browse files Browse the repository at this point in the history
  • Loading branch information
vadiklyutiy committed Jan 20, 2025
1 parent d3f976e commit 010cca5
Showing 3 changed files with 112 additions and 53 deletions.
4 changes: 3 additions & 1 deletion apps/compile_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04

COPY ./run.py /app/run.py
COPY ./requirements.txt /app/requirements.txt
WORKDIR /app

ENV TZ=America/Toronto
@@ -16,7 +17,8 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& python -m pip install --upgrade pip \
&& python -m pip install filelock requests gunicorn flask cmake
&& python -m pip install filelock requests gunicorn flask cmake \
&& python -m pip install -r ./requirements.txt

EXPOSE 3281

50 changes: 27 additions & 23 deletions apps/compile_server/resources/compilation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Dict, Any, List, Tuple
from typing import Dict, Any, Tuple
import time
import re
import sys
import os
import traceback
import threading
import requests
import subprocess
import zipfile
import logging
@@ -17,16 +15,18 @@
from hashlib import sha256
from filelock import FileLock

from .compile_worker import CompilationWorkers

lock = threading.Lock()
logger = logging.Logger(__name__)

pid = os.getpid()
jobs_dir = os.path.join(os.getcwd(), 'jobs')
repos_dir = os.path.join(os.getcwd(), 'repos')
commits_dir = os.path.join(os.getcwd(), 'commits')
results_dir = os.path.join(os.getcwd(), 'results')
JOBS_DIR = os.path.join(os.getcwd(), 'jobs')
REPOS_DIR = os.path.join(os.getcwd(), 'repos')
COMMITS_DIR = os.path.join(os.getcwd(), 'commits')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')

compile_script = os.path.join(os.path.dirname(__file__), 'compile_worker.py')
compilation_workers = CompilationWorkers(max_workers=5)


def should_update(repo_timestamp) -> bool:
@@ -39,10 +39,10 @@ def should_update(repo_timestamp) -> bool:


def clone_github_repo(owner: str, repo: str, version: str) -> str:
repo_dir = os.path.join(repos_dir, "{}_{}".format(owner, repo))
repo_timestamp = os.path.join(repos_dir, "{}_{}_timestamp".format(owner, repo))
repo_dir = os.path.join(REPOS_DIR, "{}_{}".format(owner, repo))
repo_timestamp = os.path.join(REPOS_DIR, "{}_{}_timestamp".format(owner, repo))
os.makedirs(repo_dir, exist_ok=True)
with FileLock(os.path.join(repos_dir, '{}_{}.lock'.format(owner, repo))):
with FileLock(os.path.join(REPOS_DIR, '{}_{}.lock'.format(owner, repo))):
if not os.path.exists(os.path.join(repo_dir, '.git')):
repo = git.Repo.clone_from(
url="https://github.com/{}/{}.git".format(owner, repo),
@@ -76,12 +76,12 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str:
repo.git.checkout(version)
commit_id = repo.head.commit.hexsha

commit_dir = os.path.join(commits_dir, commit_id)
commit_dir = os.path.join(COMMITS_DIR, commit_id)
if os.path.exists(commit_dir):
return commit_id
with FileLock(os.path.join(commits_dir, commit_id + '.lock')):
repo.git.archive(commit_id, format='zip', output=os.path.join(commits_dir, f'{commit_id}.zip'))
with zipfile.ZipFile(os.path.join(commits_dir, f'{commit_id}.zip'), 'r') as zip_ref:
with FileLock(os.path.join(COMMITS_DIR, commit_id + '.lock')):
repo.git.archive(commit_id, format='zip', output=os.path.join(COMMITS_DIR, f'{commit_id}.zip'))
with zipfile.ZipFile(os.path.join(COMMITS_DIR, f'{commit_id}.zip'), 'r') as zip_ref:
os.makedirs(commit_dir, exist_ok=True)
zip_ref.extractall(commit_dir)
# build the hidet
@@ -139,8 +139,8 @@ def post(self):
}

job_id: str = sha256(commit_id.encode() + workload).hexdigest()
job_path = os.path.join(jobs_dir, job_id + '.pickle')
job_response_path = os.path.join(jobs_dir, job_id + '.response')
job_path = os.path.join(JOBS_DIR, job_id + '.pickle')
job_response_path = os.path.join(JOBS_DIR, job_id + '.response')

print('[{}] Received a job: {}'.format(pid, job_id[:16]))

@@ -151,22 +151,26 @@ def post(self):
return pickle.load(f)

# write the job to the disk
job_lock = os.path.join(jobs_dir, job_id + '.lock')
job_lock = os.path.join(JOBS_DIR, job_id + '.lock')
with FileLock(job_lock):
if not os.path.exists(job_path):
with open(job_path, 'wb') as f:
pickle.dump(job, f)

version_path = os.path.join(COMMITS_DIR, commit_id)
with lock: # Only one thread can access the following code at the same time
print('[{}] Start compiling: {}'.format(pid, job_id[:16]))
ret = subprocess.run([sys.executable, compile_script, '--job_id', job_id])
print('[{}] Start compiling: {}'.format(pid, job_id[:16]), flush=True)
start_time = time.time()
compilation_workers.submit_job(job_id, version_path)
compilation_workers.wait_all_jobs_finished()
end_time = time.time()

# respond to the client
response_path = os.path.join(jobs_dir, job_id + '.response')
response_path = os.path.join(JOBS_DIR, job_id + '.response')
if not os.path.exists(response_path):
raise RuntimeError('Can not find the response file:\n{}{}'.format(ret.stderr, ret.stdout))
raise RuntimeError('Can not find the response file')
else:
print('[{}] Finish compiling: {}'.format(pid, job_id[:16]))
print(f'[{pid}] Finish compiling: {job_id[:16]} in {end_time - start_time:.2f}s', flush=True)
with open(response_path, 'rb') as f:
response: Tuple[Dict, int] = pickle.load(f)
return response
111 changes: 82 additions & 29 deletions apps/compile_server/resources/compile_worker.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
from typing import Dict, Any, List, Tuple, Sequence, Union
from typing import Dict, Any, Sequence, Union
import os
import traceback
import argparse
import sys
import re
import subprocess
import zipfile
import logging
import pickle
import git
from hashlib import sha256
from filelock import FileLock
import multiprocessing
import importlib

logger = logging.Logger(__name__)

jobs_dir = os.path.join(os.getcwd(), 'jobs')
repos_dir = os.path.join(os.getcwd(), 'repos')
commits_dir = os.path.join(os.getcwd(), 'commits')
results_dir = os.path.join(os.getcwd(), 'results')
JOBS_DIR = os.path.join(os.getcwd(), 'jobs')
REPOS_DIR = os.path.join(os.getcwd(), 'repos')
COMMITS_DIR = os.path.join(os.getcwd(), 'commits')
RUSULTS_DIR = os.path.join(os.getcwd(), 'results')


def save_response(response, response_file: str):
@@ -27,14 +25,14 @@ def save_response(response, response_file: str):

def compile_job(job_id: str):
try:
job_file = os.path.join(jobs_dir, job_id + '.pickle')
job_file = os.path.join(JOBS_DIR, job_id + '.pickle')
if not os.path.exists(job_file):
# job not found
return 1

job_lock = os.path.join(jobs_dir, job_id + '.lock')
job_lock = os.path.join(JOBS_DIR, job_id + '.lock')
with FileLock(job_lock):
response_file = os.path.join(jobs_dir, job_id + '.response')
response_file = os.path.join(JOBS_DIR, job_id + '.response')
if os.path.exists(response_file):
# job already compiled
return 0
@@ -45,11 +43,8 @@ def compile_job(job_id: str):

# import the hidet from the commit
commit_id: str = job['commit_id']
commit_dir = os.path.join(commits_dir, commit_id)
sys.path.insert(0, os.path.join(commit_dir, 'python'))
import hidet # import the hidet from the commit

# load the workload
import hidet
# load the workload
workload: Dict[str, Any] = pickle.loads(job['workload'])
ir_module: Union[hidet.ir.IRModule, Sequence[hidet.ir.IRModule]] = workload['ir_module']
target: str = workload['target']
@@ -59,10 +54,10 @@ def compile_job(job_id: str):
module_string = str(ir_module)
key = module_string + target + output_kind + commit_id
hash_digest: str = sha256(key.encode()).hexdigest()
zip_file_path: str = os.path.join(results_dir, hash_digest + '.zip')
zip_file_path: str = os.path.join(RUSULTS_DIR, hash_digest + '.zip')
if not os.path.exists(zip_file_path):
output_dir: str = os.path.join(results_dir, hash_digest)
with FileLock(os.path.join(results_dir, f'{hash_digest}.lock')):
output_dir: str = os.path.join(RUSULTS_DIR, hash_digest)
with FileLock(os.path.join(RUSULTS_DIR, f'{hash_digest}.lock')):
if not os.path.exists(os.path.join(output_dir, 'lib.so')):
hidet.drivers.build_ir_module(
ir_module,
@@ -88,12 +83,70 @@ def compile_job(job_id: str):
return 0


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--job_id', type=str, required=True)
args = parser.parse_args()
exit(compile_job(args.job_id))


if __name__ == '__main__':
main()
# Worker process function to handle compilation jobs using a specific version of the 'hidet' module.
def worker_process(version, job_queue, result_queue, parent_pid):
sys.path.insert(0, os.path.join(version, 'python')) # Ensure the version path is first in sys.path
hidet = importlib.import_module("hidet") # Load the specific version
importlib.reload(hidet) # Reload to ensure correct version
print(f"[{parent_pid}] Worker loaded hidet version from {version}", flush=True)

while True:
job = job_queue.get()
if job == "STOP":
print(f"[{parent_pid}] Shutting down worker for version: {version}", flush=True)
break

# Compile
job_id = job
print(f"[{parent_pid}] Worker processing job {job_id[:16]} with hidet version {version}", flush=True)
compile_job(job_id)
result_queue.put((job_id, 'DONE'))


class CompilationWorkers:
"""
A class to manage a pool of compilation workers.
It is needed to avoid the overhead of loading the hidet module for every job.
Every worker processes a compilation with a fixed version of hidet (fixed commit hash).
One worker per version.
Only one worker is compiling at the same time. No concurrent compilation.
Concurrency compilation is processed on upper level.
"""
def __init__(self, max_workers: int = 5):
self.max_workers = max_workers
self.workers = {} # {version_path: (worker_process, job_queue)}
self.result_queue = multiprocessing.Queue()

def _get_or_create_worker(self, version_path):
# If a worker for the version exists, return it
if version_path in self.workers:
return self.workers[version_path]

# If the worker pool is full, remove the oldest worker
if len(self.workers) >= self.max_workers:
_, (worker, job_queue) = self.workers.popitem()
job_queue.put("STOP") # Send shutdown signal to the removing worker
worker.join() # Wait for it to exit

# Create a new worker for the version
job_queue = multiprocessing.Queue()
worker = multiprocessing.Process(target=worker_process,
args=(version_path, job_queue, self.result_queue, os.getpid())
)
worker.start()
self.workers[version_path] = (worker, job_queue)
return self.workers[version_path]

def submit_job(self, job_id, version_path):
_, job_queue = self._get_or_create_worker(version_path)
job_queue.put(job_id)

def wait_all_jobs_finished(self):
# multiprocessing.Queue.get() waits until a new item is available
self.result_queue.get()

def shutdown(self):
for _, (worker, job_queue) in self.workers.items():
job_queue.put("STOP")
worker.join()
print(f"[{os.getpgid}] All compilation workers are shuted down.", flush=True)

0 comments on commit 010cca5

Please sign in to comment.