Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions .github/ci/gpu-tests.sky.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,12 @@ setup: |
https://github.com/primatrix/pallas-kernel.git ~/pallas-kernel

cd ~/pallas-kernel
# Use uv's managed Python 3.12 (VM ships Python 3.10, project requires >=3.12)
uv venv --python 3.12 .venv
source .venv/bin/activate
uv pip install -e '.[gpu]'
uv pip install pytest
uv sync --extra gpu

run: |
set -ex
cd ~/pallas-kernel
source .venv/bin/activate

# Run only GPU tests: CPU ref vs FLA Triton (simple_gla only)
# -k triton skips CPU-only cross-validation tests
python -m pytest \
uv run python -m pytest \
-o "addopts=--strict-markers" \
tests/ref/simple_gla/ \
-k triton \
Expand Down
7 changes: 4 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ High-performance JAX/Pallas TPU and GPU kernels for modern neural network archit

```bash
# Install dependencies
uv sync # Base install
uv sync --extra gpu # With GPU support (CUDA 12.6)
uv sync --extra tpu # With TPU support
uv sync # Base install (with pytest)
uv sync --extra dev # With dev tools (ruff, pre-commit)
uv sync --extra gpu # GPU development
uv sync --extra tpu # TPU development

# Run tests
uv run pytest tests/ -v # All tests
Expand Down
13 changes: 3 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,30 @@ build-backend = "hatchling.build"
[project]
name = "tops"
dynamic = ["version"]
description = "JAX/Flax Pallas kernels for Gated Linear Attention (GLA)"
description = "High-performance JAX/Pallas kernels for modern neural network architectures"
readme = "README.md"
requires-python = ">=3.12"
license = "Apache-2.0"
dependencies = [
"einops>=0.8.2",
"flax>=0.11.2",
"numpy>=2.4.2",
"pytest>=9.0.2",
]

[project.optional-dependencies]
dev = [
"pre-commit>=4.5.1",
"pytest>=9.0.2",
"ruff>=0.15.0",
]
gpu = [
"torch",
"torchvision",
"jax[cuda12,gpu]>=0.6.2",
"flash-linear-attention[gpu]>=0.4.1",
"jax[cuda12]>=0.6.2",
"flash-linear-attention>=0.4.1",
]
tpu = [
"jax[tpu]>=0.8.1",
"torch",
]
profile = [
"xprof==2.22.0",
]

[[tool.uv.index]]
name = "pytorch-cu126"
Expand All @@ -59,7 +53,6 @@ torch = [
{ index = "pytorch-cu126", extra = "gpu" },
{ index = "pytorch-cpu", extra = "tpu" },
]
torchvision = [{ index = "pytorch-cu126", extra = "gpu" }]


[tool.pytest.ini_options]
Expand Down
13 changes: 8 additions & 5 deletions scripts/launch-tpuv7.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ spec:
- |
set -ex

# 1. 准备 SSH 环境 (Secret 挂载通常是只读的,所以需要 cp 到家目录)
# 1. SSH setup (Secret mount is read-only, copy to home)
mkdir -p ~/.ssh
cp /etc/ssh-key/id_ed25519 ~/.ssh/id_ed25519
chmod 600 ~/.ssh/id_ed25519

# 2. 扫描 GitHub 指纹避免交互式确认
# 2. Scan GitHub fingerprint to avoid interactive confirmation
ssh-keyscan github.com >> ~/.ssh/known_hosts

# pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# 3. Install uv
pip install uv

# git clone https://github.com/primatrix/ant-pretrain.git
# cd ant-pretrain && git checkout feat/kda-with-kernel
# 4. Clone and setup project
git clone git@github.com:primatrix/pallas-kernel.git ~/pallas-kernel
cd ~/pallas-kernel
uv sync --extra tpu || pip install -e '.[tpu]'
Comment on lines +33 to +39
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Consider adding --no-dev flag or ensure uv creates the venv in a predictable location.

The uv sync --extra tpu command will create a virtual environment (typically .venv/ in the project directory). However, the subsequent python -c 'import jax; ...' on line 41 runs without activating this venv or using uv run, so it will use the system Python which won't have the dependencies installed.

🐛 Proposed fix to use uv run for the verification command
          # 4. Clone and setup project
          git clone git@github.com:primatrix/pallas-kernel.git ~/pallas-kernel
          cd ~/pallas-kernel
          uv sync --extra tpu || pip install -e '.[tpu]'
 
-          python -c 'import jax; print("Total TPU devices (cores):", jax.device_count())'
+          uv run python -c 'import jax; print("Total TPU devices (cores):", jax.device_count())'
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/launch-tpuv7.yml` around lines 33 - 39, The uv sync step currently
creates a project virtualenv but later verification runs use system Python;
update the workflow so dependencies are executed inside the venv: run uv sync
--extra tpu (optionally add --no-dev to avoid dev deps) and replace the
standalone python verification command with uv run python -c '...' (or ensure uv
creates a predictable venv path and activate it before running verification).
Target the uv invocation lines (uv sync and the subsequent verification python
command) and ensure you either add --no-dev to uv sync or use uv run so the
verification uses the virtualenv created by uv.


python -c 'import jax; print("Total TPU devices (cores):", jax.device_count())'

sleep infinity
Expand Down
Loading
Loading