@@ -207,60 +207,94 @@ jobs:
207207 import pathlib
208208 import shutil
209209 import sys
210+ from typing import Iterable, Optional
210211
211212 try:
212213 import mlx # type: ignore
213214 except ImportError:
214215 print("::error::The 'mlx' Python package is not installed; cannot locate mlx.metallib.")
215216 sys.exit(1)
216217
217- kernels_dir = None
218+ search_dirs: list[pathlib.Path] = []
219+ package_dir: Optional[pathlib.Path] = None
220+
221+ try:
222+ spec = importlib.util.find_spec("mlx.backend.metal.kernels")
223+ except ModuleNotFoundError:
224+ spec = None
218225
219- spec = importlib.util.find_spec("mlx.backend.metal.kernels")
220226 if spec and spec.origin:
221- candidate_origin = spec.origin
222- if candidate_origin:
223- candidate = pathlib.Path(candidate_origin).resolve().parent
224- if candidate.exists():
225- kernels_dir = candidate
227+ candidate = pathlib.Path(spec.origin).resolve().parent
228+ if candidate.exists():
229+ search_dirs.append(candidate)
226230
227- if kernels_dir is None:
231+ try:
232+ resource = resources.files("mlx.backend.metal") / "kernels"
233+ except (ModuleNotFoundError, AttributeError):
234+ resource = None
235+ else:
228236 try:
229- resource = resources.files("mlx.backend.metal") / "kernels"
230- except (ModuleNotFoundError, AttributeError):
231- resource = None
232- if resource is not None:
233- try:
234- with resources.as_file(resource) as extracted:
235- if extracted is not None:
236- extracted_path = pathlib.Path(extracted)
237- if extracted_path.exists():
238- kernels_dir = extracted_path
239- except FileNotFoundError:
240- pass
241-
242- if kernels_dir is None:
243- package_file = getattr(mlx, "__file__", None)
244- if package_file:
245- package_dir = pathlib.Path(package_file).resolve().parent
246- candidate = package_dir / "backend" / "metal" / "kernels"
247- if candidate.exists():
248- kernels_dir = candidate
249-
250- if kernels_dir is None or not kernels_dir.exists():
251- print("::error::Could not locate the MLX metal kernels directory; checked module spec and importlib resources.")
237+ with resources.as_file(resource) as extracted:
238+ if extracted:
239+ extracted_path = pathlib.Path(extracted).resolve()
240+ if extracted_path.exists():
241+ search_dirs.append(extracted_path)
242+ except (FileNotFoundError, RuntimeError):
243+ pass
244+
245+ package_file = getattr(mlx, "__file__", None)
246+ if package_file:
247+ package_dir = pathlib.Path(package_file).resolve().parent
248+ search_dirs.extend(
249+ [
250+ package_dir / "backend" / "metal" / "kernels",
251+ package_dir / "backend" / "metal",
252+ package_dir,
253+ ]
254+ )
255+
256+ ordered_dirs: list[pathlib.Path] = []
257+ seen: set[pathlib.Path] = set()
258+ for candidate in search_dirs:
259+ if not candidate:
260+ continue
261+ candidate = candidate.resolve()
262+ if candidate in seen:
263+ continue
264+ seen.add(candidate)
265+ ordered_dirs.append(candidate)
266+
267+ def iter_metallibs(dirs: Iterable[pathlib.Path]):
268+ for directory in dirs:
269+ if not directory.exists():
270+ continue
271+ preferred = directory / "mlx.metallib"
272+ if preferred.exists():
273+ yield preferred
274+ continue
275+ for alternative in sorted(directory.glob("*.metallib")):
276+ yield alternative
277+
278+ src = next(iter_metallibs(ordered_dirs), None)
279+
280+ if src is None and package_dir and package_dir.exists():
281+ for candidate in package_dir.rglob("mlx.metallib"):
282+ src = candidate
283+ print(f"::warning::Resolved metallib via recursive search under {package_dir}")
284+ break
285+
286+ if src is None and package_dir and package_dir.exists():
287+ for candidate in sorted(package_dir.rglob("*.metallib")):
288+ src = candidate
289+ print(f"::warning::Using metallib {candidate.name} discovered via package-wide search")
290+ break
291+
292+ if src is None:
293+ print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.")
252294 sys.exit(1)
253295
254- preferred = kernels_dir / "mlx.metallib"
255- if preferred.exists():
256- src = preferred
257- else:
258- candidates = sorted(kernels_dir.glob("*.metallib"))
259- if not candidates:
260- print(f"::error::No metallib files were found under {kernels_dir}")
261- sys.exit(1)
262- src = candidates[0]
263- print(f"::warning::Defaulting to metallib file {src.name}")
296+ if src.name != "mlx.metallib":
297+ print(f"::warning::Using metallib {src.name} from {src.parent}")
264298
265299 dest = pathlib.Path("artifacts/native/osx-arm64/mlx.metallib").resolve()
266300 dest.parent.mkdir(parents=True, exist_ok=True)
0 commit comments