diff --git a/src/codeflare_sdk/ray/cluster/generate_yaml.py b/src/codeflare_sdk/ray/cluster/generate_yaml.py index 0b174650a..95d3ba11d 100755 --- a/src/codeflare_sdk/ray/cluster/generate_yaml.py +++ b/src/codeflare_sdk/ray/cluster/generate_yaml.py @@ -18,6 +18,7 @@ """ import json +import sys import typing import yaml import os @@ -31,6 +32,11 @@ ) import codeflare_sdk +SUPPORTED_PYTHON_VERSIONS = { + "3.9": "quay.io/modh/ray@sha256:0d715f92570a2997381b7cafc0e224cfa25323f18b9545acfd23bc2b71576d06", + "3.11": "quay.io/modh/ray:2.35.0-py311-cu121", +} + def read_template(template): with open(template, "r") as stream: @@ -88,9 +94,15 @@ def update_names( def update_image(spec, image): containers = spec.get("containers") - if image != "": - for container in containers: - container["image"] = image + if not image: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + try: + if python_version in SUPPORTED_PYTHON_VERSIONS: + image = SUPPORTED_PYTHON_VERSIONS[python_version] + except Exception: # pragma: no cover + print(f"Python version '{python_version}' is not supported. Only {', '.join(SUPPORTED_PYTHON_VERSIONS.keys())} are supported.") + for container in containers: + container["image"] = image def update_image_pull_secrets(spec, image_pull_secrets):