Skip to content

Commit

Permalink
Merge pull request #14 from taxe10/main
Browse files Browse the repository at this point in the history
Adding extra parameters for slurm jobs
  • Loading branch information
taxe10 authored Jun 5, 2024
2 parents fb254a6 + 5898c70 commit 46fc0ec
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ MAX_TIME_CPU="1:00:00"
PARTITIONS_GPU='["p_gpu1", "p_gpu2"]'
RESERVATIONS_GPU='["r_gpu1", "r_gpu2"]'
MAX_TIME_GPU="1:00:00"
SUBMISSION_SSH_KEY="~/.ssh/id_rsa"
FORWARD_PORTS='["8888:8888"]'
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ services:
PARTITIONS_GPU: "${PARTITIONS_GPU}"
RESERVATIONS_GPU: "${RESERVATIONS_GPU}"
MAX_TIME_GPU: "${MAX_TIME_GPU}"
SUBMISSION_SSH_KEY: "${SUBMISSION_SSH_KEY}"
FORWARD_PORTS: "${FORWARD_PORTS}"
volumes:
- $READ_DIR:/app/work/data
- $WRITE_DIR:/app/work/mlex_store
Expand Down
21 changes: 14 additions & 7 deletions src/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
FLOW_TYPE = os.getenv("FLOW_TYPE", "podman")

# Slurm
PARTITIONS_CPU = os.getenv("PARTITIONS_CPU", [])
RESERVATIONS_CPU = os.getenv("RESERVATIONS_CPU", [])
PARTITIONS_CPU = json.loads(os.getenv("PARTITIONS_CPU", []))
RESERVATIONS_CPU = json.loads(os.getenv("RESERVATIONS_CPU", []))
MAX_TIME_CPU = os.getenv("MAX_TIME_CPU", "1:00:00")
PARTITIONS_GPU = os.getenv("PARTITIONS_CPU", [])
RESERVATIONS_GPU = os.getenv("RESERVATIONS_CPU", [])
PARTITIONS_GPU = json.loads(os.getenv("PARTITIONS_CPU", []))
RESERVATIONS_GPU = json.loads(os.getenv("RESERVATIONS_CPU", []))
MAX_TIME_GPU = os.getenv("MAX_TIME_CPU", "1:00:00")
SUBMISSION_SSH_KEY = os.getenv("SUBMISSION_SSH_KEY", "")
FORWARD_PORTS = json.loads(os.getenv("FORWARD_PORTS", []))

# Mlex content api
CONTENT_API_URL = os.getenv("CONTENT_API_URL", "http://localhost:8000/api/v0/models")
Expand Down Expand Up @@ -99,6 +101,8 @@
"reservations": RESERVATIONS_CPU,
"max_time": MAX_TIME_CPU,
"conda_env_name": "mlex_dimension_reduction_pca",
"submission_ssh_key": SUBMISSION_SSH_KEY,
"forward_ports": FORWARD_PORTS,
"params": {
"io_parameters": {"uid_save": "uid0001", "uid_retrieve": ""}
},
Expand Down Expand Up @@ -366,18 +370,21 @@ def submit_dimension_reduction_job(
}
elif FLOW_TYPE == "conda":
autoencoder_params = {
"conda_env_name": "pytorch_autoencoders",
"conda_env_name": "mlex_pytorch_autoencoders",
"params": auto_params,
"python_file_name": "mlex_pytorch_autoencoders/src/predict_model.py",
}
else:
else: # slurm
autoencoder_params = {
"job_name": "latent_space_explorer",
"num_nodes": 1,
"partitions": PARTITIONS_GPU,
"reservations": RESERVATIONS_GPU,
"max_time": MAX_TIME_GPU,
"conda_env_name": "pytorch_autoencoders",
"conda_env_name": "mlex_pytorch_autoencoders",
"python_file_name": "mlex_pytorch_autoencoders/src/predict_model.py",
"submission_ssh_key": SUBMISSION_SSH_KEY,
"forward_ports": FORWARD_PORTS,
"params": auto_params,
}
job_params["params_list"].insert(0, autoencoder_params)
Expand Down

0 comments on commit 46fc0ec

Please sign in to comment.