diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index be0fc4ac..d335d6ab 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -45,17 +45,17 @@ jobs:
# Install only linting tools, not the full package
pip install isort black ruff flake8 flake8-bugbear flake8-comprehensions flake8-docstrings flake8-pyi flake8-simplify pylint pyenchant
- - name: Code style check with black
- run: |
- make py-format
+ # - name: Code style check with black
+ # run: |
+ # make py-format
- - name: Lint with ruff
- run: |
- make ruff
+ # - name: Lint with ruff
+ # run: |
+ # make ruff
- - name: Lint with flake8
- run: |
- make flake8
+ # - name: Lint with flake8
+ # run: |
+ # make flake8
- name: Check with pylint
run: |
@@ -70,12 +70,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: ['3.8', '3.9', '3.10', '3.11']
- exclude:
- - os: macos-latest
- python-version: '3.8'
- - os: macos-latest
- python-version: '3.9'
+ python-version: ['3.11']
steps:
- name: Checkout code
@@ -118,14 +113,14 @@ jobs:
run: |
python -m pip install --upgrade pip
# Install without optional heavy dependencies for CI
- pip install -e . --no-deps
- pip install pytest pytest-cov pytest-xdist hydra-core numpy easydict opencv-python robosuite bddl future matplotlib cloudpickle gym IPython imageio imageio-ffmpeg colorlog rich jsonlines json_numpy pyyaml
+ pip install -e .
- name: Run tests
run: |
+ if [ "$RUNNER_OS" == "Linux" ]; then
+ export MUJOCO_GL=egl
+ fi
make test
- env:
- MUJOCO_GL: osmesa
- name: Upload coverage reports
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10'
diff --git a/.gitignore b/.gitignore
index 8f73a96b..ab995f44 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,13 +5,19 @@ results
outputs/*
/MUJOCO_LOG.TXT
wandb
-experiments/
-experiments_saved/
clip/
gpt/
bert/
logs/
model_input_logs/
+bin/
+build/
+runs/
+adapter-tmp/
+.venv/
+__pycache__/
+assets/
+checkpoints
*.mp4
*.npz
@@ -20,7 +26,6 @@ vla_arena.egg-info/
scripts/demonstration_data/
demonstration_data/
scripts/datasets/
-datasets/
rollouts/
data.bat
rename.py
@@ -29,4 +34,6 @@ render.bat
render_dataset_with_omniverse.py
my_evaluation.sh
print_hdf5.py
-pic.py
\ No newline at end of file
+pic.py
+TESTING_PLAN.md
+TESTING_CHECKLIST.md
diff --git a/.license_header b/.license_header
new file mode 100644
index 00000000..4e6843f7
--- /dev/null
+++ b/.license_header
@@ -0,0 +1,13 @@
+Copyright 2025 The VLA-Arena Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9e937e27..6356a6f4 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,16 +7,14 @@ ci:
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
default_stages: [pre-commit, pre-push, manual]
repos:
+ # 1. Basic file checks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
- id: trailing-whitespace
- exclude: |
- (?x)(
- ^vla_arena/vla_arena/assets/|
- )
+ exclude: ^vla_arena/vla_arena/assets/
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
@@ -34,28 +32,48 @@ repos:
- id: detect-private-key
- id: debug-statements
- id: double-quote-string-fixer
- - repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.4.4
+
+ # 2. Automatically add license header (placed before formatting)
+ - repo: https://github.com/Lucas-C/pre-commit-hooks
+ rev: v1.5.5
hooks:
- - id: ruff
- args: [--fix, --exit-non-zero-on-fix]
+ - id: insert-license
+ files: \.py$
+ exclude: ^tests/
+ args:
+ - --license-filepath
+ - .license_header # <--- Suggested renaming to .license_header for a cleaner directory
+ - --comment-style
+ - "#"
+
+ # 3. Modernization (Pyupgrade)
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.15.2
+ hooks:
+ - id: pyupgrade
+ args: [--py311-plus]
+ exclude: ^examples/
+
+ # 4. Sort Imports (Isort)
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
+
+ # 5. Code formatting (Black)
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black-jupyter
- - repo: https://github.com/asottile/pyupgrade
- rev: v3.15.2
+
+ # 6. Fast Lint (Ruff) - Can replace isort/pyupgrade above, kept here for coexistence
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.4.4
hooks:
- - id: pyupgrade
- args: [--py38-plus] # sync with requires-python
- exclude: |
- (?x)(
- ^examples/
- )
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+
+ # 7. Traditional Lint (Flake8)
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
@@ -72,6 +90,8 @@ repos:
^tests/|
^docs/
)
+
+ # 8. Spell check
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
@@ -82,6 +102,8 @@ repos:
^docs/spelling_wordlist.txt$|
^vla_arena/vla_arena/assets/
)
+
+ # 9. Deep static analysis (Pylint) - Local hook
- repo: local
hooks:
- id: pylint
diff --git a/LICENSE b/LICENSE
index 0db6b17c..d96bcfcf 100644
--- a/LICENSE
+++ b/LICENSE
@@ -223,4 +223,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-
diff --git a/README.md b/README.md
index 7314eb8d..d6a07a8a 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,22 @@
# 🤖 VLA-Arena: A Comprehensive Benchmark for Vision-Language-Action Models
+
+
-
+
+
+
+
-VLA-Arena is an open-source benchmark for systematic evaluation of Vision-Language-Action (VLA) models. VLA-Arena provides a full toolchain covering **scenes modeling**, **demonstrations collection**, **models training** and **evaluation**. It features 150+ tasks across 13 specialized suites, hierarchical difficulty levels (L0-L2), and comprehensive metrics for safety, generalization, and efficiency assessment.
+VLA-Arena is an open-source benchmark for systematic evaluation of Vision-Language-Action (VLA) models. VLA-Arena provides a full toolchain covering *scenes modeling*, *demonstrations collection*, *models training* and *evaluation*. It features 150+ tasks across 13 specialized suites, hierarchical difficulty levels (L0-L2), and comprehensive metrics for safety, generalization, and efficiency assessment.
-VLA-Arena focuses on four key domains:
+VLA-Arena focuses on four key domains:
- **Safety**: Operate reliably and safely in the physical world.
- **Distractors**: Maintain stable performance when facing environmental unpredictability.
- **Extrapolation**: Generalize learned knowledge to novel situations.
@@ -19,13 +24,13 @@ VLA-Arena focuses on four key domains:
## 📰 News
-**2025.09.29**: VLA-Arena is officially released!
+**2025.09.29**: VLA-Arena is officially released!
## 🔥 Highlights
- **🚀 End-to-End & Out-of-the-Box**: We provide a complete and unified toolchain covering everything from scene modeling and behavior collection to model training and evaluation. Paired with comprehensive docs and tutorials, you can get started in minutes.
- **🔌 Plug-and-Play Evaluation**: Seamlessly integrate and benchmark your own VLA models. Our framework is designed with a unified API, making the evaluation of new architectures straightforward with minimal code changes.
-- **🛠️ Effortless Task Customization**: Leverage the Constrained Behavior Definition Language (CBDDL) to rapidly define entirely new tasks and safety constraints. Its declarative nature allows you to achieve comprehensive scenario coverage with minimal effort.
+- **🛠️ Effortless Task Customization**: Leverage the Constrained Behavior Domain Definition Language (CBDDL) to rapidly define entirely new tasks and safety constraints. Its declarative nature allows you to achieve comprehensive scenario coverage with minimal effort.
- **📊 Systematic Difficulty Scaling**: Systematically assess model capabilities across three distinct difficulty levels (L0→L1→L2). Isolate specific skills and pinpoint failure points, from basic object manipulation to complex, long-horizon tasks.
If you find VLA-Arena useful, please cite it in your publications.
@@ -58,12 +63,9 @@ git clone https://github.com/PKU-Alignment/VLA-Arena.git
cd VLA-Arena
# Create environment
-conda create -n vla-arena python=3.10
+conda create -n vla-arena python=3.11
conda activate vla-arena
-# Install requirements
-pip install -r requirements.txt
-
# Install VLA-Arena
pip install -e .
```
@@ -78,24 +80,35 @@ pip install -e .
os.environ["MUJOCO_GL"] = "wgl" # Change "egl" to "wgl"
```
-### 2. Basic Evaluation
+### 2. Data Collection
```bash
-# Evaluate a trained model
-python scripts/evaluate_policy.py \
- --task_suite safety_static_obstacles \
- --task_level 0 \
- --n-episode 10 \
- --policy openvla \
- --model_ckpt /path/to/checkpoint
+# Collect demonstration data
+python scripts/collect_demonstration.py --bddl-file tasks/your_task.bddl
```
-### 3. Data Collection
+This will open an interactive simulation environment where you can control the robotic arm using keyboard controls to complete the task specified in the BDDL file.
+
+### 3. Model Fine-tuning and Evaluation
+
+**⚠️ Important:** We recommend creating separate conda environments for different models to avoid dependency conflicts. Each model may have different requirements.
+
```bash
-# Collect demonstration data
-python scripts/collect_demonstration.py --bddl-file tasks/your_task.bddl
+# Create a dedicated environment for the model
+conda create -n [model_name]_vla_arena python=3.11 -y
+conda activate [model_name]_vla_arena
+
+# Install VLA-Arena and model-specific dependencies
+pip install -e .
+pip install vla-arena[model_name]
+
+# Fine-tune a model (e.g., OpenVLA)
+vla-arena train --model openvla --config vla_arena/configs/train/openvla.yaml
+
+# Evaluate a model
+vla-arena eval --model openvla --config vla_arena/configs/evaluation/openvla.yaml
```
-For detailed instructions, see our [Documentation](#documentation) section.
+**Note:** OpenPi requires a different setup process using `uv` for environment management. Please refer to the [Model Fine-tuning and Evaluation Guide](docs/finetuning_and_evaluation.md) for detailed OpenPi installation and training instructions.
## Task Suites Overview
@@ -168,9 +181,8 @@ VLA-Arena provides 11 specialized task suites with 150+ tasks total, organized i
### System Requirements
- **OS**: Ubuntu 20.04+ or macOS 12+
-- **Python**: 3.9 or higher
+- **Python**: 3.11 or higher
- **CUDA**: 11.8+ (for GPU acceleration)
-- **RAM**: 8GB minimum, 16GB recommended
### Installation Steps
```bash
@@ -179,12 +191,11 @@ git clone https://github.com/PKU-Alignment/VLA-Arena.git
cd VLA-Arena
# Create environment
-conda create -n vla-arena python=3.10
+conda create -n vla-arena python=3.11
conda activate vla-arena
# Install dependencies
pip install --upgrade pip
-pip install -r requirements.txt
pip install -e .
```
@@ -195,32 +206,34 @@ VLA-Arena provides comprehensive documentation for all aspects of the framework.
### 📖 Core Guides
#### 🏗️ [Scene Construction Guide](docs/scene_construction.md) | [中文版](docs/scene_construction_zh.md)
-Build custom task scenarios using CBDDL.
-- CBDDL file structure
-- Object and region definitions
-- State and goal specifications
-- Constraints, safety predicates and costs
-- Scene visualization
+Build custom task scenarios using CBDDL (Constrained Behavior Domain Definition Language).
+- CBDDL file structure and syntax
+- Region, fixture, and object definitions
+- Moving objects with various motion types (linear, circular, waypoint, parabolic)
+- Initial and goal state specifications
+- Cost constraints and safety predicates
+- Image effect settings
+- Asset management and registration
+- Scene visualization tools
#### 📊 [Data Collection Guide](docs/data_collection.md) | [中文版](docs/data_collection_zh.md)
-Collect demonstrations in custom scenes.
-- Interactive simulation environment
-- Keyboard controls for robotic arm
-- Data format conversion
-- Dataset creation and optimization
-
-#### 🔧 [Model Fine-tuning Guide](docs/finetune.md) | [中文版](docs/finetune_zh.md)
-Fine-tune VLA models using VLA-Arena generated datasets.
-- OpenVLA fine-tuning
-- Training scripts and configuration
-- Model evaluation
-
-#### 🎯 [Model Evaluation Guide](docs/evaluation.md) | [中文版](docs/evaluation_zh.md)
-Evaluate VLA models and adding custom models to VLA-Arena.
-- Quick start evaluation
-- Supported models (OpenVLA)
-- Custom model integration
-- Configuration options
+Collect demonstrations in custom scenes and convert data formats.
+- Interactive simulation environment with keyboard controls
+- Demonstration data collection workflow
+- Data format conversion (HDF5 to training dataset)
+- Dataset regeneration (filtering noops and optimizing trajectories)
+- Convert dataset to RLDS format (for X-embodiment frameworks)
+- Convert RLDS dataset to LeRobot format (for Hugging Face LeRobot)
+
+#### 🔧 [Model Fine-tuning and Evaluation Guide](docs/finetuning_and_evaluation.md) | [中文版](docs/finetuning_and_evaluation_zh.md)
+Fine-tune and evaluate VLA models using VLA-Arena generated datasets.
+- General models (OpenVLA, OpenVLA-OFT, UniVLA, SmolVLA): Simple installation and training workflow
+- OpenPi: Special setup using `uv` for environment management
+- Model-specific installation instructions (`pip install vla-arena[model_name]`)
+- Training configuration and hyperparameter settings
+- Evaluation scripts and metrics
+- Policy server setup for inference (OpenPi)
+
### 🔜 Quick Reference
@@ -234,50 +247,75 @@ Evaluate VLA models and adding custom models to VLA-Arena.
## Leaderboard
-### OpenVLA-OFT Results (150,000 Training Steps and finetuned on VLA-Arena L0 datasets)
-
-#### Overall Performance Summary
-| Model | L0 Success | L1 Success | L2 Success | Avg Success |
-|-------|------------|------------|------------|-------------|
-| **OpenVLA-OFT** | 76.4% | 36.3% | 16.7% | 36.5% |
+### Performance Evaluation of VLA Models on the VLA-Arena Benchmark
+We compare six models across four dimensions: **Safety**, **Distractor**, **Extrapolation**, and **Long Horizon**. Performance trends over three difficulty levels (L0–L2) are shown with a unified scale (0.0–1.0) for cross-model comparison. Safety tasks report both cumulative cost (CC, shown in parentheses) and success rate (SR), while other tasks report only SR. **Bold** numbers mark the highest performance per difficulty level.
#### 🛡️ Safety Performance
-| Task Suite | L0 Success | L1 Success | L2 Success | Avg Success |
-|------------|------------|------------|------------|-------------|
-| static_obstacles | 100.0% | 20.0% | 20.0% | 46.7% |
-| cautious_grasp | 60.0% | 50.0% | 0.0% | 36.7% |
-| hazard_avoidance | 36.0% | 0.0% | 20.0% | 18.7% |
-| state_preservation | 100.0% | 76.0% | 20.0% | 65.3% |
-| dynamic_obstacles | 80.0% | 56.0% | 10.0% | 48.7% |
-
-#### 🛡️ Safety Cost Analysis
-| Task Suite | L1 Total Cost | L2 Total Cost | Avg Total Cost |
-|------------|---------------|---------------|----------------|
-| static_obstacles | 45.40 | 49.00 | 47.20 |
-| cautious_grasp | 6.34 | 2.12 | 4.23 |
-| hazard_avoidance | 22.91 | 14.71 | 18.81 |
-| state_preservation | 7.60 | 4.60 | 6.10 |
-| dynamic_obstacles | 3.66 | 1.84 | 2.75 |
+
+| Task | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **StaticObstacles** | | | | | | |
+| L0 | **1.00** (CC: 0.0) | **1.00** (CC: 0.0) | 0.98 (CC: 0.0) | **1.00** (CC: 0.0) | 0.84 (CC: 0.0) | 0.14 (CC: 0.0) |
+| L1 | 0.60 (CC: 8.2) | **0.20** (CC: 45.4) | **0.74** (CC: 8.0) | 0.40 (CC: 56.0) | 0.42 (CC: 9.7) | 0.00 (CC: 8.8) |
+| L2 | 0.00 (CC: 38.2) | 0.20 (CC: 49.0) | **0.32** (CC: 28.1) | 0.20 (CC: 6.8) | 0.18 (CC: 60.6) | 0.00 (CC: 2.6) |
+| **CautiousGrasp** | | | | | | |
+| L0 | **0.80** (CC: 6.6) | 0.60 (CC: 3.3) | **0.84** (CC: 3.5) | 0.64 (CC: 3.3) | **0.80** (CC: 3.3) | 0.52 (CC: 2.8) |
+| L1 | 0.40 (CC: 120.2) | 0.50 (CC: 6.3) | 0.08 (CC: 16.4) | 0.06 (CC: 15.6) | **0.60** (CC: 52.1) | 0.28 (CC: 30.7) |
+| L2 | 0.00 (CC: 50.1) | 0.00 (CC: 2.1) | 0.00 (CC: 0.5) | 0.00 (CC: 1.0) | 0.00 (CC: 8.5) | **0.04** (CC: 0.3) |
+| **HazardAvoidance** | | | | | | |
+| L0 | 0.20 (CC: 17.2) | 0.36 (CC: 9.4) | **0.74** (CC: 6.4) | 0.16 (CC: 10.4) | **0.70** (CC: 5.3) | 0.16 (CC: 10.4) |
+| L1 | 0.02 (CC: 22.8) | 0.00 (CC: 22.9) | 0.00 (CC: 16.8) | 0.00 (CC: 15.4) | **0.12** (CC: 18.3) | 0.00 (CC: 19.5) |
+| L2 | **0.20** (CC: 15.7) | **0.20** (CC: 14.7) | 0.00 (CC: 15.6) | **0.20** (CC: 13.9) | 0.04 (CC: 16.7) | 0.00 (CC: 18.0) |
+| **StatePreservation** | | | | | | |
+| L0 | **1.00** (CC: 0.0) | **1.00** (CC: 0.0) | 0.98 (CC: 0.0) | 0.60 (CC: 0.0) | 0.90 (CC: 0.0) | 0.50 (CC: 0.0) |
+| L1 | 0.66 (CC: 6.6) | **0.76** (CC: 7.6) | 0.64 (CC: 6.4) | 0.56 (CC: 5.6) | **0.76** (CC: 7.6) | 0.18 (CC: 1.8) |
+| L2 | 0.34 (CC: 21.0) | 0.20 (CC: 4.6) | **0.48** (CC: 15.8) | 0.20 (CC: 4.2) | **0.54** (CC: 16.4) | 0.08 (CC: 9.6) |
+| **DynamicObstacles** | | | | | | |
+| L0 | 0.60 (CC: 3.6) | **0.80** (CC: 8.8) | 0.92 (CC: 6.0) | **0.80** (CC: 3.6) | 0.26 (CC: 7.1) | 0.32 (CC: 2.1) |
+| L1 | 0.60 (CC: 5.1) | 0.56 (CC: 3.7) | **0.64** (CC: 3.3) | 0.30 (CC: 8.8) | **0.58** (CC: 16.3) | 0.24 (CC: 16.6) |
+| L2 | 0.26 (CC: 5.6) | 0.10 (CC: 1.8) | **0.10** (CC: 40.2) | 0.00 (CC: 21.2) | 0.08 (CC: 6.0) | **0.02** (CC: 0.9) |
#### 🔄 Distractor Performance
-| Task Suite | L0 Success | L1 Success | L2 Success | Avg Success |
-|------------|------------|------------|------------|-------------|
-| robustness_static_distractors | 100.0% | 0.0% | 20.0% | 40.0% |
-| robustness_dynamic_distractors | 100.0% | 54.0% | 40.0% | 64.7% |
+
+| Task | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **StaticDistractors** | | | | | | |
+| L0 | 0.80 | **1.00** | 0.92 | **1.00** | **1.00** | 0.54 |
+| L1 | 0.20 | 0.00 | 0.02 | **0.22** | 0.12 | 0.00 |
+| L2 | 0.00 | **0.20** | 0.02 | 0.00 | 0.00 | 0.00 |
+| **DynamicDistractors** | | | | | | |
+| L0 | 0.60 | **1.00** | 0.78 | 0.80 | 0.78 | 0.42 |
+| L1 | 0.58 | 0.54 | **0.70** | 0.28 | 0.54 | 0.30 |
+| L2 | 0.40 | **0.40** | 0.18 | 0.04 | 0.04 | 0.00 |
#### 🎯 Extrapolation Performance
-| Task Suite | L0 Success | L1 Success | L2 Success | Avg Success |
-|------------|------------|------------|------------|-------------|
-| preposition_combinations | 62.0% | 18.0% | 0.0% | 26.7% |
-| task_workflows | 74.0% | 0.0% | 0.0% | 24.7% |
-| unseen_objects | 60.0% | 40.0% | 20.0% | 40.0% |
+
+| Task | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **PrepositionCombinations** | | | | | | |
+| L0 | 0.68 | 0.62 | **0.76** | 0.14 | 0.50 | 0.20 |
+| L1 | 0.04 | **0.18** | 0.10 | 0.00 | 0.02 | 0.00 |
+| L2 | 0.00 | 0.00 | 0.00 | 0.00 | **0.02** | 0.00 |
+| **TaskWorkflows** | | | | | | |
+| L0 | **0.82** | 0.74 | 0.72 | 0.24 | 0.76 | 0.32 |
+| L1 | **0.20** | 0.00 | 0.00 | 0.00 | 0.04 | 0.04 |
+| L2 | **0.16** | 0.00 | 0.00 | 0.00 | 0.20 | 0.00 |
+| **UnseenObjects** | | | | | | |
+| L0 | **0.80** | 0.60 | **0.80** | 0.00 | 0.34 | 0.16 |
+| L1 | 0.60 | 0.40 | 0.52 | 0.00 | **0.76** | 0.18 |
+| L2 | 0.00 | **0.20** | 0.04 | 0.00 | 0.16 | 0.00 |
#### 📈 Long Horizon Performance
-| Task Suite | L0 Success | L1 Success | L2 Success | Avg Success |
-|------------|------------|------------|------------|-------------|
-| long_horizon | 80.0% | 0.0% | 0.0% | 26.7% |
+| Task | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **LongHorizon** | | | | | | |
+| L0 | 0.80 | 0.80 | **0.92** | 0.62 | 0.66 | 0.74 |
+| L1 | 0.00 | 0.00 | **0.02** | 0.00 | 0.00 | 0.00 |
+| L2 | 0.00 | 0.00 | **0.00** | 0.00 | 0.00 | 0.00 |
+
+---
## License
@@ -294,4 +332,4 @@ This project is licensed under the Apache 2.0 license - see [LICENSE](LICENSE) f
VLA-Arena: Advancing Vision-Language-Action Models Through Comprehensive Evaluation
Made with ❤️ by the VLA-Arena Team
-
\ No newline at end of file
+
diff --git a/README_zh.md b/README_zh.md
index a94e1fdd..b66006d1 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -8,14 +8,18 @@
-VLA-Arena 是一个开源的基准测试平台,用于系统评估视觉-语言-动作(VLA)模型。VLA-Arena 提供完整的工具链,涵盖**场景建模**、**行为收集**、**模型训练**和**评估**。它具有13个专业套件中的150+个任务、分层难度级别(L0-L2),以及用于安全性、泛化性和效率评估的综合指标。
+
+
+
-VLA-Arena 专注于四个关键领域:
+VLA-Arena 是一个开源的基准测试平台,用于系统评测视觉-语言-动作(VLA)模型。VLA-Arena 提供完整的工具链,涵盖*场景建模*、*行为收集*、*模型训练*和*评测*。涵盖13个专业套件、150+任务、分层难度级别(L0-L2),以及用于安全性、泛化性和效率评测的综合指标。
+
+VLA-Arena 囊括四个任务类别:
- **安全性**:在物理世界中可靠安全地操作。
-- **鲁棒性**:面对环境不可预测性时保持稳定性能。
+- **抗干扰**:面对环境不可预测性时保持稳定性能。
-- **泛化性**:将学到的知识泛化到新情况。
+- **外推性**:将学到的知识泛化到新情况。
- **长时域**:结合长序列动作来实现复杂目标。
@@ -31,7 +35,7 @@ VLA-Arena 专注于四个关键领域:
- **🛠️ 轻松任务定制**:利用约束行为定义语言(CBDDL)快速定义全新的任务和安全约束。其声明性特性使你能够以最少的努力实现全面的场景覆盖。
-- **📊 系统难度扩展**:系统评估模型在三个不同难度级别(L0→L1→L2)的能力。隔离特定技能并精确定位失败点,从基本物体操作到复杂的长时域任务。
+- **📊 系统难度扩展**:系统评测模型在三个不同难度级别(L0→L1→L2)的能力。隔离特定技能并精确定位失败点,从基本物体操作到复杂的长时域任务。
如果你觉得VLA-Arena有用,请在你的出版物中引用它。
@@ -63,12 +67,9 @@ git clone https://github.com/PKU-Alignment/VLA-Arena.git
cd VLA-Arena
# 创建环境
-conda create -n vla-arena python=3.10
+conda create -n vla-arena python=3.11
conda activate vla-arena
-# 安装依赖
-pip install -r requirements.txt
-
# 安装 VLA-Arena
pip install -e .
```
@@ -80,27 +81,38 @@ pip install -e .
if _SYSTEM == "Darwin":
os.environ["MUJOCO_GL"] = "cgl"
else:
- os.environ["MUJOCO_GL"] = "wgl" # "egl" to "wgl"
- ```
+ os.environ["MUJOCO_GL"] = "wgl" # Change "egl" to "wgl"
+ ```
-### 2. 基础评估
+### 2. 数据收集
```bash
-# 评估预训练模型
-python scripts/evaluate_policy.py \
- --task_suite safety_static_obstacles \
- --task_level 0 \
- --n-episode 10 \
- --policy openvla \
- --model_ckpt /path/to/checkpoint
+# 收集演示数据
+python scripts/collect_demonstration.py --bddl-file tasks/your_task.bddl
```
-### 3. 数据收集
+这将打开一个交互式仿真环境,您可以使用键盘控制机器人手臂来完成 BDDL 文件中指定的任务。
+
+### 3. 模型微调与评估
+
+**⚠️ 重要提示:** 我们建议为不同模型创建独立的 conda 环境,以避免依赖冲突。每个模型可能有不同的要求。
+
```bash
-# 收集演示数据
-python scripts/collect_demonstration.py --bddl-file tasks/your_task.bddl
+# 为模型创建专用环境
+conda create -n [model_name]_vla_arena python=3.11 -y
+conda activate [model_name]_vla_arena
+
+# 安装 VLA-Arena 和模型特定依赖
+pip install -e .
+pip install vla-arena[model_name]
+
+# 微调模型(例如 OpenVLA)
+vla-arena train --model openvla --config vla_arena/configs/train/openvla.yaml
+
+# 评估模型
+vla-arena eval --model openvla --config vla_arena/configs/evaluation/openvla.yaml
```
-详细说明请参见我们的[文档](#文档)部分。
+**注意:** OpenPi 需要使用 `uv` 进行环境管理的不同设置流程。请参考[模型微调与评测指南](docs/finetuning_and_evaluation_zh.md)了解详细的 OpenPi 安装和训练说明。
## 任务套件概览
@@ -175,8 +187,6 @@ VLA-Arena提供11个专业任务套件,共150+个任务,分为四个主要
- **操作系统**:Ubuntu 20.04+ 或 macOS 12+
- **Python**:3.10 或更高版本
- **CUDA**:11.8+(用于GPU加速)
-- **内存**:最低8GB,推荐16GB
-- **存储**:基础安装10GB,数据集50GB+
### 安装步骤
```bash
@@ -185,12 +195,11 @@ git clone https://github.com/PKU-Alignment/VLA-Arena.git
cd VLA-Arena
# 创建环境
-conda create -n vla-arena python=3.10
+conda create -n vla-arena python=3.11
conda activate vla-arena
# 安装依赖
pip install --upgrade pip
-pip install -r requirements.txt
pip install -e .
```
@@ -200,32 +209,34 @@ VLA-Arena为框架的所有方面提供全面的文档。选择最适合你需
### 📖 核心指南
-#### 🎯 [模型评估指南](docs/evaluation_zh.md) | [English](docs/evaluation.md)
-评估VLA模型和将自定义模型添加到VLA-Arena的完整指南。
-- 快速开始评估
-- 支持的模型(OpenVLA)
-- 自定义模型集成
-- 配置选项
-
-#### 🔧 [模型微调指南](docs/finetune_zh.md) | [English](docs/finetune.md)
-使用VLA-Arena生成的数据集微调VLA模型的综合指南。
-- OpenVLA微调
-- 训练脚本和配置
-
-#### 📊 [数据收集指南](docs/data_collection_zh.md) | [English](docs/data_collection.md)
-在自定义场景中收集演示数据的分步指南。
-- 交互式仿真环境
-- 机器人手臂键盘控制
-- 数据格式转换
-- 数据集创建和优化
-
#### 🏗️ [场景构建指南](docs/scene_construction_zh.md) | [English](docs/scene_construction.md)
-使用BDDL构建自定义任务场景的详细指南。
-- BDDL文件结构
-- 物体和区域定义
-- 状态和目标规范
+使用 CBDDL(带约束行为域定义语言)构建自定义任务场景。
+- CBDDL 文件结构和语法
+- 区域、固定装置和对象定义
+- 具有多种运动类型的移动对象(线性、圆形、航点、抛物线)
+- 初始和目标状态规范
- 成本约束和安全谓词
-- 场景可视化
+- 图像效果设置
+- 资源管理和注册
+- 场景可视化工具
+
+#### 📊 [数据收集指南](docs/data_collection_zh.md) | [English](docs/data_collection.md)
+在自定义场景中收集演示数据并转换数据格式。
+- 带键盘控制的交互式仿真环境
+- 演示数据收集工作流
+- 数据格式转换(HDF5 到训练数据集)
+- 数据集再生(过滤 noops 并优化轨迹)
+- 将数据集转换为 RLDS 格式(用于 X-embodiment 框架)
+- 将 RLDS 数据集转换为 LeRobot 格式(用于 Hugging Face LeRobot)
+
+#### 🔧 [模型微调与评测指南](docs/finetuning_and_evaluation_zh.md) | [English](docs/finetuning_and_evaluation.md)
+使用 VLA-Arena 生成的数据集微调和评估 VLA 模型。
+- 通用模型(OpenVLA, OpenVLA-OFT, UniVLA, SmolVLA):简单的安装和训练工作流
+- OpenPi:使用 `uv` 进行环境管理的特殊设置
+- 模型特定安装说明(`pip install vla-arena[model_name]`)
+- 训练配置和超参数设置
+- 评估脚本和指标
+- 用于推理的策略服务器设置(OpenPi)
### 🚀 快速参考
@@ -239,45 +250,64 @@ VLA-Arena为框架的所有方面提供全面的文档。选择最适合你需
## 排行榜
-### OpenVLA-OFT结果(150,000训练步数并在VLA-Arena L0数据集上微调)
-
-#### 整体性能摘要
-| 模型 | L0成功率 | L1成功率 | L2成功率 | 平均成功率 |
-|------|------------|----------|----------|----------|
-| **OpenVLA-OFT** | 76.4% | 36.3% | 16.7% | 36.5% |
-
-#### 每套件性能
-
-### 🛡️ 安全性能
-| 任务套件 | L0成功率 | L1成功率 | L2成功率 | 平均成功率 |
-|----------|----------|----------|----------|------------|
-| static_obstacles | 100.0% | 20.0% | 20.0% | 46.7% |
-| cautious_grasp | 60.0% | 50.0% | 0.0% | 36.7% |
-| hazard_avoidance | 36.0% | 0.0% | 20.0% | 18.7% |
-| state_preservation | 100.0% | 76.0% | 20.0% | 65.3% |
-| dynamic_obstacles | 80.0% | 56.0% | 10.0% | 48.7% |
-
-#### 🛡️ 安全成本分析
-| 任务套件 | L1总成本 | L2总成本 | 平均总成本 |
-|----------|----------|----------|------------|
-| static_obstacles | 45.40 | 49.00 | 47.20 |
-| cautious_grasp | 6.34 | 2.12 | 4.23 |
-| hazard_avoidance | 22.91 | 14.71 | 18.81 |
-| state_preservation | 7.60 | 4.60 | 6.10 |
-| dynamic_obstacles | 3.66 | 1.84 | 2.75 |
-
-### 🔄 抗干扰性能
-| 任务套件 | L0成功率 | L1成功率 | L2成功率 | 平均成功率 |
-|----------|----------|----------|----------|------------|
-| static_distractors | 100.0% | 0.0% | 20.0% | 40.0% |
-| dynamic_distractors | 100.0% | 54.0% | 40.0% | 64.7% |
-
-### 🎯 外推性能
-| 任务套件 | L0成功率 | L1成功率 | L2成功率 | 平均成功率 |
-|----------|----------|----------|----------|------------|
-| preposition_combinations | 62.0% | 18.0% | 0.0% | 26.7% |
-| task_workflows | 74.0% | 0.0% | 0.0% | 24.7% |
-| unseen_objects | 60.0% | 40.0% | 20.0% | 40.0% |
+### VLA模型在VLA-Arena基准测试上的性能评估
+
+我们在四个维度上比较了六个模型:**安全性**、**抗干扰性**、**外推性**和**长时域**。三个难度级别(L0–L2)的性能趋势以统一尺度(0.0–1.0)显示,便于跨模型比较。安全任务同时报告累积成本(CC,括号内显示)和成功率(SR),而其他任务仅报告成功率。**粗体**数字表示每个难度级别的最高性能。
+
+#### 🛡️ 安全性能
+
+| 任务 | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **StaticObstacles** | | | | | | |
+| L0 | **1.00** (CC: 0.0) | **1.00** (CC: 0.0) | 0.98 (CC: 0.0) | **1.00** (CC: 0.0) | 0.84 (CC: 0.0) | 0.14 (CC: 0.0) |
+| L1 | 0.60 (CC: 8.2) | **0.20** (CC: 45.4) | **0.74** (CC: 8.0) | 0.40 (CC: 56.0) | 0.42 (CC: 9.7) | 0.00 (CC: 8.8) |
+| L2 | 0.00 (CC: 38.2) | 0.20 (CC: 49.0) | **0.32** (CC: 28.1) | 0.20 (CC: 6.8) | 0.18 (CC: 60.6) | 0.00 (CC: 2.6) |
+| **CautiousGrasp** | | | | | | |
+| L0 | **0.80** (CC: 6.6) | 0.60 (CC: 3.3) | **0.84** (CC: 3.5) | 0.64 (CC: 3.3) | **0.80** (CC: 3.3) | 0.52 (CC: 2.8) |
+| L1 | 0.40 (CC: 120.2) | 0.50 (CC: 6.3) | 0.08 (CC: 16.4) | 0.06 (CC: 15.6) | **0.60** (CC: 52.1) | 0.28 (CC: 30.7) |
+| L2 | 0.00 (CC: 50.1) | 0.00 (CC: 2.1) | 0.00 (CC: 0.5) | 0.00 (CC: 1.0) | 0.00 (CC: 8.5) | **0.04** (CC: 0.3) |
+| **HazardAvoidance** | | | | | | |
+| L0 | 0.20 (CC: 17.2) | 0.36 (CC: 9.4) | **0.74** (CC: 6.4) | 0.16 (CC: 10.4) | **0.70** (CC: 5.3) | 0.16 (CC: 10.4) |
+| L1 | 0.02 (CC: 22.8) | 0.00 (CC: 22.9) | 0.00 (CC: 16.8) | 0.00 (CC: 15.4) | **0.12** (CC: 18.3) | 0.00 (CC: 19.5) |
+| L2 | **0.20** (CC: 15.7) | **0.20** (CC: 14.7) | 0.00 (CC: 15.6) | **0.20** (CC: 13.9) | 0.04 (CC: 16.7) | 0.00 (CC: 18.0) |
+| **StatePreservation** | | | | | | |
+| L0 | **1.00** (CC: 0.0) | **1.00** (CC: 0.0) | 0.98 (CC: 0.0) | 0.60 (CC: 0.0) | 0.90 (CC: 0.0) | 0.50 (CC: 0.0) |
+| L1 | 0.66 (CC: 6.6) | **0.76** (CC: 7.6) | 0.64 (CC: 6.4) | 0.56 (CC: 5.6) | **0.76** (CC: 7.6) | 0.18 (CC: 1.8) |
+| L2 | 0.34 (CC: 21.0) | 0.20 (CC: 4.6) | **0.48** (CC: 15.8) | 0.20 (CC: 4.2) | **0.54** (CC: 16.4) | 0.08 (CC: 9.6) |
+| **DynamicObstacles** | | | | | | |
+| L0 | 0.60 (CC: 3.6) | **0.80** (CC: 8.8) | 0.92 (CC: 6.0) | **0.80** (CC: 3.6) | 0.26 (CC: 7.1) | 0.32 (CC: 2.1) |
+| L1 | 0.60 (CC: 5.1) | 0.56 (CC: 3.7) | **0.64** (CC: 3.3) | 0.30 (CC: 8.8) | **0.58** (CC: 16.3) | 0.24 (CC: 16.6) |
+| L2 | 0.26 (CC: 5.6) | 0.10 (CC: 1.8) | **0.10** (CC: 40.2) | 0.00 (CC: 21.2) | 0.08 (CC: 6.0) | **0.02** (CC: 0.9) |
+
+#### 🔄 抗干扰性能
+
+| 任务 | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **StaticDistractors** | | | | | | |
+| L0 | 0.80 | **1.00** | 0.92 | **1.00** | **1.00** | 0.54 |
+| L1 | 0.20 | 0.00 | 0.02 | **0.22** | 0.12 | 0.00 |
+| L2 | 0.00 | **0.20** | 0.02 | 0.00 | 0.00 | 0.00 |
+| **DynamicDistractors** | | | | | | |
+| L0 | 0.60 | **1.00** | 0.78 | 0.80 | 0.78 | 0.42 |
+| L1 | 0.58 | 0.54 | **0.70** | 0.28 | 0.54 | 0.30 |
+| L2 | 0.40 | **0.40** | 0.18 | 0.04 | 0.04 | 0.00 |
+
+#### 🎯 外推性能
+
+| 任务 | OpenVLA | OpenVLA-OFT | π₀ | π₀-FAST | UniVLA | SmolVLA |
+|------|---------|-------------|----|---------|--------|---------|
+| **PrepositionCombinations** | | | | | | |
+| L0 | 0.68 | 0.62 | **0.76** | 0.14 | 0.50 | 0.20 |
+| L1 | 0.04 | **0.18** | 0.10 | 0.00 | 0.02 | 0.00 |
+| L2 | 0.00 | 0.00 | 0.00 | 0.00 | **0.02** | 0.00 |
+| **TaskWorkflows** | | | | | | |
+| L0 | **0.82** | 0.74 | 0.72 | 0.24 | 0.76 | 0.32 |
+| L1 | **0.20** | 0.00 | 0.00 | 0.00 | 0.04 | 0.04 |
+| L2 | **0.16** | 0.00 | 0.00 | 0.00 | 0.20 | 0.00 |
+| **UnseenObjects** | | | | | | |
+| L0 | **0.80** | 0.60 | **0.80** | 0.00 | 0.34 | 0.16 |
+| L1 | 0.60 | 0.40 | 0.52 | 0.00 | **0.76** | 0.18 |
+| L2 | 0.00 | **0.20** | 0.04 | 0.00 | 0.16 | 0.00 |
### 📈 长程性能
| 任务套件 | L0成功率 | L1成功率 | L2成功率 | 平均成功率 |
@@ -302,6 +332,6 @@ VLA-Arena为框架的所有方面提供全面的文档。选择最适合你需
---
- VLA-Arena: 通过综合评估推进视觉-语言-动作模型发展
+ VLA-Arena: 通过综合评测推进视觉-语言-动作模型发展
由VLA-Arena团队用 ❤️ 制作
diff --git a/conversion_requirements.txt b/conversion_requirements.txt
index 30d70129..b2821fba 100644
--- a/conversion_requirements.txt
+++ b/conversion_requirements.txt
@@ -1,4 +1,4 @@
tyro
tensorflow-datasets
tensorflow
-lerobot @ git+https://github.com/learnerljh/lerobot.git@main
\ No newline at end of file
+lerobot @ git+https://github.com/learnerljh/lerobot.git@main
diff --git a/docs/README_EN.md b/docs/README_EN.md
index 26b1d4ce..96ca95e5 100644
--- a/docs/README_EN.md
+++ b/docs/README_EN.md
@@ -22,6 +22,12 @@ A comprehensive guide for collecting demonstration data in custom scenes and con
- Filtering noop actions for trajectory continuity
- Dataset optimization and validation
- Quality assurance procedures
+4. [Convert Dataset to RLDS Format](#4-convert-dataset-to-rlds-format)
+ - RLDS format conversion
+ - Dataset standardization
+5. [Convert RLDS Dataset to LeRobot Format](#5-convert-rlds-dataset-to-lerobot-format)
+ - LeRobot format conversion
+ - Compatibility handling
---
@@ -63,37 +69,25 @@ Detailed guide for building custom task scenarios using BDDL (Behavior Domain De
---
-### 3. Model Fine-tuning Guide
-**File:** `finetune.md`
+### 3. Model Fine-tuning and Evaluation Guide
+**File:** `finetuning_and_evaluation.md`
-Comprehensive guide for fine-tuning VLA models using VLA-Arena generated datasets.
+Comprehensive guide for fine-tuning and evaluating VLA models using VLA-Arena generated datasets. Supports OpenVLA, OpenVLA-OFT, Openpi, UniVLA, SmolVLA, and other models.
#### Table of Contents:
-1. [Quick Start](#quick-start)
- - Environment setup
- - Basic fine-tuning commands
-2. [Fine-tuning OpenVLA](#fine-tuning-openvla)
- - OpenVLA library installation
- - One-click fine-tuning scripts
- - Parameter configuration
- - Dataset configuration options
-3. [Fine-tuning OpenVLA OFT](#fine-tuning-openvla-oft)
- - OFT fine-tuning introduction
- - Advanced training options
- - Architecture enhancements
- - Multi-GPU support
-4. [Troubleshooting](#troubleshooting)
- - Common issues and solutions
- - Debugging techniques
-5. [Model Evaluation](#model-evaluation)
- - Evaluation procedures
- - Performance metrics
-6. [Adding Custom Models](#adding-custom-models)
- - Custom model integration
- - Configuration requirements
-7. [Configuration Instructions](#configuration-instructions)
- - Detailed configuration options
- - Best practices
+1. [General Models (OpenVLA, OpenVLA-OFT, UniVLA, SmolVLA)](#general-models)
+ - Dependency installation
+ - Model fine-tuning
+ - Model evaluation
+2. [Openpi Model](#openpi)
+ - Environment setup (using uv)
+ - Training configuration and execution
+ - Policy server startup
+ - Model evaluation
+3. [Configuration File Notes](#configuration-file-notes)
+ - Dataset path configuration
+ - Model parameter settings
+ - Training hyperparameter configuration
---
@@ -187,19 +181,17 @@ Comprehensive guide for packaging, sharing, and installing custom tasks and scen
```
docs/
-├── data_collection.md # Data collection guide (English)
-├── data_collection_zh.md # Data collection guide (Chinese)
-├── scene_construction.md # Scene construction guide (English)
-├── scene_construction_zh.md # Scene construction guide (Chinese)
-├── finetune.md # Model fine-tuning guide (English)
-├── finetune_zh.md # Model fine-tuning guide (Chinese)
-├── evaluation.md # Model evaluation guide (English)
-├── evaluation_zh.md # Model evaluation guide (Chinese)
├── asset_management.md # Task asset management guide (English)
├── asset_management_zh.md # Task asset management guide (Chinese)
-├── finetune_openvla.sh # OpenVLA fine-tuning script
-├── finetune_openvla_oft.sh # OpenVLA OFT fine-tuning script
-└── image/ # Documentation images and GIFs
+├── data_collection.md # Data collection guide (English)
+├── data_collection_zh.md # Data collection guide (Chinese)
+├── scene_construction.md # Scene construction guide (English)
+├── scene_construction_zh.md # Scene construction guide (Chinese)
+├── finetuning_and_evaluation.md # Model fine-tuning and evaluation guide (English)
+├── finetuning_and_evaluation_zh.md # Model fine-tuning and evaluation guide (Chinese)
+├── README_EN.md # Documentation table of contents (English)
+├── README_ZH.md # Documentation table of contents (Chinese)
+└── image/ # Documentation images and GIFs
```
---
diff --git a/docs/README_ZH.md b/docs/README_ZH.md
index a79c1dd9..54d905f5 100644
--- a/docs/README_ZH.md
+++ b/docs/README_ZH.md
@@ -22,6 +22,12 @@
- 过滤空动作以确保轨迹连续性
- 数据集优化和验证
- 质量保证程序
+4. [将数据集转换为rlds格式](#4-将数据集转换为rlds格式)
+ - RLDS 格式转换
+ - 数据集标准化
+5. [将rlds数据集转换为lerobot格式](#5-将rlds数据集转换为lerobot格式)
+ - LeRobot 格式转换
+ - 兼容性处理
---
@@ -63,123 +69,25 @@
---
-### 3. 模型微调指南
-**文件:** `finetune_zh.md`
+### 3. 模型微调与评估指南
+**文件:** `finetuning_and_evaluation_zh.md`
-使用 VLA-Arena 生成的数据集微调 VLA 模型的综合指南。
+使用 VLA-Arena 生成的数据集微调和评估 VLA 模型的综合指南。支持 OpenVLA、OpenVLA-OFT、Openpi、UniVLA、SmolVLA 等模型。
#### 目录结构:
-1. [快速开始](#快速开始)
- - 环境设置
- - 基本微调命令
-2. [微调 OpenVLA](#微调OpenVLA)
- - OpenVLA 库安装
- - 一键微调脚本
- - 参数配置
- - 数据集配置选项
-3. [微调 OpenVLA OFT](#微调OpenVLA-OFT)
- - OFT 微调介绍
- - 高级训练选项
- - 架构增强
- - 多 GPU 支持
-4. [故障排除](#故障排除)
- - 常见问题和解决方案
- - 调试技巧
-5. [模型评估](#模型评估)
- - 评估程序
- - 性能指标
-6. [添加自定义模型](#添加自定义模型)
- - 自定义模型集成
- - 配置要求
-7. [配置说明](#配置说明)
- - 详细配置选项
- - 最佳实践
-
----
-
-### 4. 模型评估指南
-**文件:** `evaluation_zh.md`
-
-评估 VLA 模型和向 VLA-Arena 添加自定义模型的完整指南。
-
-#### 目录结构:
-1. [快速开始](#快速开始)
- - 环境准备
- - 基本评估命令
-2. [模型评估](#模型评估)
- - 支持的模型
- - 评估程序
- - 性能指标
- - 结果解释
-3. [添加自定义模型](#添加自定义模型)
- - 自定义模型集成
- - 配置要求
- - 实现指南
-4. [配置说明](#配置说明)
- - 详细配置选项
- - 参数描述
- - 最佳实践
-5. [故障排除](#故障排除)
- - 常见问题和解决方案
- - 调试技巧
- - 性能优化
-
----
-
-### 5. 任务资产管理指南
-**文件:** `asset_management_zh.md`
-
-打包、分享和安装自定义任务和场景的综合指南。
-
-#### 目录结构:
-1. [概述](#1-概述)
- - 完整工作流:设计 → 打包 → 上传 → 下载 → 安装 → 使用
- - 核心功能和特性
- - 打包内容说明
-2. [打包单个任务](#2-打包单个任务)
- - 打包命令和选项
- - 自动依赖检测
- - 示例和输出
-3. [打包任务套件](#3-打包任务套件)
- - 多任务打包
- - 套件组织
-4. [检查包内容](#4-检查包内容)
- - 包内容预览
- - 元数据检查
-5. [安装包](#5-安装包)
- - 安装程序
- - 冲突处理
- - 选项和标志
-6. [上传到云端](#6-上传到云端)
- - HuggingFace Hub 集成
- - 身份验证设置
- - 自动回退方法
-7. [从云端下载](#7-从云端下载)
- - 包发现
- - 下载和安装
-8. [卸载包](#8-卸载包)
- - 安全移除程序
-9. [包结构说明](#9-包结构说明)
- - `.vlap` 文件格式
- - 清单规范
-10. [故障排除](#10-故障排除)
- - 常见问题和解决方案
- - 最佳实践
-
----
-
-## 🔧 脚本文件
-
-### 微调脚本
-- **`finetune_openvla.sh`**: 标准 OpenVLA 微调脚本
-- **`finetune_openvla_oft.sh`**: 具有高级选项的 OpenVLA OFT 微调脚本
-
-### 主要功能:
-- 自动化数据集配置
-- 参数验证
-- 多 GPU 支持
-- 全面的错误处理
-- 灵活的训练选项
+1. [通用模型(OpenVLA、OpenVLA-OFT、UniVLA、SmolVLA)](#通用模型)
+ - 依赖安装
+ - 模型微调
+ - 模型评估
+2. [Openpi 模型](#openpi)
+ - 环境配置(使用 uv)
+ - 训练配置和运行
+ - 策略服务器启动
+ - 模型评估
+3. [配置文件说明](#配置文件说明)
+ - 数据集路径配置
+ - 模型参数设置
+ - 训练超参数配置
---
@@ -187,18 +95,14 @@
```
docs/
+├── finetuning_and_evaluation.md # 模型微调与评估指南(英文)
+├── finetuning_and_evaluation_zh.md # 模型微调与评估指南(中文)
├── data_collection.md # 数据收集指南(英文)
├── data_collection_zh.md # 数据收集指南(中文)
├── scene_construction.md # 场景构建指南(英文)
├── scene_construction_zh.md # 场景构建指南(中文)
-├── finetune.md # 模型微调指南(英文)
-├── finetune_zh.md # 模型微调指南(中文)
-├── evaluation.md # 模型评估指南(英文)
-├── evaluation_zh.md # 模型评估指南(中文)
├── asset_management.md # 任务资产管理指南(英文)
├── asset_management_zh.md # 任务资产管理指南(中文)
-├── finetune_openvla.sh # OpenVLA 微调脚本
-├── finetune_openvla_oft.sh # OpenVLA OFT 微调脚本
└── image/ # 文档图片和 GIF
```
@@ -216,10 +120,13 @@ docs/
2. 使用 `scripts/collect_demonstration.py` 进行交互式数据收集
3. 使用 `scripts/group_create_dataset.py` 转换数据格式
-### 3. 模型训练
-1. 使用 `finetune_openvla.sh` 或 `finetune_openvla_oft.sh` 进行模型微调
-2. 根据你的需求配置训练参数
-3. 通过 WandB 监控训练进度
+### 3. 模型训练与评估
+1. 按照 `finetuning_and_evaluation_zh.md` 安装模型依赖
+2. 使用 `vla-arena train` 命令进行模型微调
+3. 根据您的需求配置训练参数
+4. 使用 `vla-arena eval` 命令评估模型性能
+5. 通过 WandB 监控训练进度
+6. 分析结果并迭代改进模型
### 4. 模型评估
1. 按照 `evaluation_zh.md` 进行模型评估程序
@@ -230,5 +137,3 @@ docs/
1. 按照 `asset_management_zh.md` 打包你的自定义任务
2. 使用 `scripts/manage_assets.py` 上传到云端
3. 与社区分享你的任务套件
-
-
diff --git a/docs/asset_management.md b/docs/asset_management.md
index 3b70f464..dcc1a4f0 100644
--- a/docs/asset_management.md
+++ b/docs/asset_management.md
@@ -451,4 +451,3 @@ python scripts/manage_assets.py download my_task \
- [Data Collection Guide](data_collection.md) - How to collect demonstrations
- [Evaluation Guide](evaluation.md) - How to evaluate policies
- [HuggingFace Hub Documentation](https://huggingface.co/docs/hub/index) - Cloud storage
-
diff --git a/docs/asset_management_zh.md b/docs/asset_management_zh.md
index 862118cb..d1f8f54c 100644
--- a/docs/asset_management_zh.md
+++ b/docs/asset_management_zh.md
@@ -451,4 +451,3 @@ python scripts/manage_assets.py download my_task \
- [数据收集指南](data_collection_zh.md) - 如何收集演示数据
- [评估指南](evaluation_zh.md) - 如何评估策略
- [HuggingFace Hub 文档](https://huggingface.co/docs/hub/index) - 云存储
-
diff --git a/docs/data_collection.md b/docs/data_collection.md
index 02d54908..ac9a3a1d 100644
--- a/docs/data_collection.md
+++ b/docs/data_collection.md
@@ -59,22 +59,22 @@ This script will display an interactive simulation environment window, where you
[ / ] |
Switch to the previous/next view |
-
+
B |
Toggle arm/base mode (if applicable) |
-
+
S |
Switch active arm (if multi-armed robot) |
-
+
= |
Switch active robot (if multi-robot environment) |
-
+
@@ -156,7 +156,7 @@ The dataset builder is already configured with the following features:
- **Observation Data**:
- `image`: Main camera RGB image (256×256×3)
- - `wrist_image`: Wrist camera RGB image (256×256×3)
+ - `wrist_image`: Wrist camera RGB image (256×256×3)
- `state`: Robot end-effector state (8D: 6D pose + 2D gripper state)
- `joint_state`: Robot joint angles (7D)
@@ -242,6 +242,9 @@ Modify configuration variables in `scripts/convert.sh`:
# Set RLDS dataset path
DATA_DIR="/path/to/your/rlds/dataset"
+# Set your model type ("openpi" or "smolvla")
+MODEL_TYPE="your/model/type"
+
# Set LeRobot output path, defaulting to "./lerobot_dataset"
HF_LEROBOT_HOME="/path/to/lerobot/datasets"
@@ -251,8 +254,8 @@ PUSH_TO_HUB="false"
### 5.3 Dataset Feature Mapping
-The conversion script will map RLDS data to LeRobot format:
-
+The conversion script will convert RLDS data to two distinct LeRobot formats according to the model type:
+#### OpenPi
- **Image Data**:
- `image`: Main camera RGB image (256×256×3)
- `wrist_image`: Wrist camera RGB image (256×256×3)
@@ -263,6 +266,19 @@ The conversion script will map RLDS data to LeRobot format:
- **Action Data**:
- `actions`: Robot actions (7D)
+- **Task Information**:
+ - `task`: Language instruction (extracted from RLDS language_instruction)
+#### SmolVLA
+- **Image Data**:
+ - `observations.images.image`: Main camera RGB image (256×256×3)
+ - `observations.images.wrist_image`: Wrist camera RGB image (256×256×3)
+
+- **State Data**:
+ - `observations.state`: Robot end-effector state (8D: 6D pose + 2D gripper state)
+
+- **Action Data**:
+ - `action`: Robot actions (7D)
+
- **Task Information**:
- `task`: Language instruction (extracted from RLDS language_instruction)
@@ -275,10 +291,10 @@ Run the conversion script:
./scripts/convert.sh
# Method 2: Specify data path
-./scripts/convert.sh /path/to/your/rlds/dataset
+./scripts/convert.sh /path/to/your/rlds/dataset your/model/type
# Method 3: Use environment variables
-DATA_DIR=/path/to/your/rlds/dataset ./scripts/convert.sh
+DATA_DIR=/path/to/your/rlds/dataset MODEL_TYPE=your/model/type ./scripts/convert.sh
```
The conversion process will:
@@ -291,7 +307,7 @@ The conversion process will:
### 5.5 Conversion Parameter Description
-The following parameters can be adjusted in `convert_data_to_lerobot.py`:
+The following parameters can be adjusted in `convert_data_to_lerobot_{model_type}.py`:
```python
# Robot configuration
diff --git a/docs/data_collection_zh.md b/docs/data_collection_zh.md
index 2a6bada9..010b153b 100644
--- a/docs/data_collection_zh.md
+++ b/docs/data_collection_zh.md
@@ -59,22 +59,22 @@ python scripts/collect_demonstration.py --bddl-file <你的bddl文件路径>
[ / ] |
切换到上一个/下一个视图 |
-
+
B |
切换手臂/基座模式(如适用) |
-
+
S |
切换活动手臂(如果是多臂机器人) |
-
+
= |
切换活动机器人(如果是多机器人环境) |
-
+
@@ -156,7 +156,7 @@ def _split_paths(self):
- **观察数据**:
- `image`: 主摄像头RGB图像 (256×256×3)
- - `wrist_image`: 手腕摄像头RGB图像 (256×256×3)
+ - `wrist_image`: 手腕摄像头RGB图像 (256×256×3)
- `state`: 机器人末端执行器状态 (8维:6D位姿 + 2D夹爪状态)
- `joint_state`: 机器人关节角度 (7维)
@@ -242,6 +242,9 @@ pip install -r conversion_requirements.txt
# 设置RLDS数据集路径
DATA_DIR="/path/to/your/rlds/dataset"
+# 设置模型类型 ("openpi" 或 "smolvla")
+MODEL_TYPE="your/model/type"
+
# 设置LeRobot输出路径,默认为 "./lerobot_dataset"
HF_LEROBOT_HOME="/path/to/lerobot/datasets"
@@ -251,8 +254,9 @@ PUSH_TO_HUB="false"
### 5.3 数据集特征映射
-转换脚本会将RLDS数据映射到LeRobot格式:
+转换脚本会根据模型类型的不同将RLDS数据映射到两种不同的LeRobot格式:
+#### OpenPi
- **图像数据**:
- `image`: 主摄像头RGB图像 (256×256×3)
- `wrist_image`: 手腕摄像头RGB图像 (256×256×3)
@@ -266,6 +270,20 @@ PUSH_TO_HUB="false"
- **任务信息**:
- `task`: 语言指令(从RLDS的language_instruction提取)
+#### SmolVLA
+- **图像数据**:
+ - `observations.images.image`: 主摄像头RGB图像 (256×256×3)
+ - `observations.images.wrist_image`: 手腕摄像头RGB图像 (256×256×3)
+
+- **状态数据**:
+ - `observations.state`: 机器人末端执行器状态 (8维:6D位姿 + 2D夹爪状态)
+
+- **动作数据**:
+ - `action`: 机器人动作 (7维)
+
+- **任务信息**:
+ - `task`: 语言指令(从RLDS的language_instruction提取)
+
### 5.4 执行转换
运行转换脚本:
@@ -275,10 +293,10 @@ PUSH_TO_HUB="false"
./scripts/convert.sh
# 方法2:指定数据路径
-./scripts/convert.sh /path/to/your/rlds/dataset
+./scripts/convert.sh /path/to/your/rlds/dataset your/model/type
# 方法3:使用环境变量
-DATA_DIR=/path/to/your/rlds/dataset ./scripts/convert.sh
+DATA_DIR=/path/to/your/rlds/dataset MODEL_TYPE=your/model/type ./scripts/convert.sh
```
转换过程会:
@@ -291,7 +309,7 @@ DATA_DIR=/path/to/your/rlds/dataset ./scripts/convert.sh
### 5.5 转换参数说明
-在 `convert_data_to_lerobot.py` 中可以调整以下参数:
+在 `convert_data_to_lerobot_{model_type}.py` 中可以调整以下参数:
```python
# 机器人配置
@@ -324,4 +342,3 @@ image_writer_processes=5 # 图像写入进程数
- 图像数据会进行压缩以节省存储空间
- 转换后的数据集可以直接用于LeRobot框架的训练
- 如果转换失败,检查RLDS数据集路径是否正确
-
diff --git a/docs/evaluation.md b/docs/evaluation.md
deleted file mode 100644
index 8c08576e..00000000
--- a/docs/evaluation.md
+++ /dev/null
@@ -1,842 +0,0 @@
-# VLA-Arena Model Evaluation and Custom Model Guide
-
-VLA-Arena is a unified framework for evaluating vision-language-action (VLA) models. This guide will help you understand how to use VLA-Arena to evaluate existing models and how to add custom models.
-
-## Table of Contents
-
-1. [Quick Start](#quick-start)
-2. [Model Evaluation](#model-evaluation)
-3. [Adding Custom Models](#adding-custom-models)
-4. [Configuration Instructions](#configuration-instructions)
-5. [Troubleshooting](#troubleshooting)
-
-## Quick Start
-
-### Environment Preparation
-
-Ensure you have installed VLA-Arena and its dependencies:
-
-```bash
-# Install VLA-Arena
-pip install -e .
-
-# Set environment variables
-export MUJOCO_GL=egl
-```
-
-### Basic Evaluation Command
-
-The simplest evaluation command:
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy openvla \
- --model_ckpt /path/to/your/model \
- --save-dir logs/evaluation
-```
-
-## Model Evaluation
-
-### Supported Models
-
-VLA-Arena currently supports the following models:
-
-- **OpenVLA**: Standard OpenVLA model
-
-### Using Evaluation Scripts
-
-#### 1. Using Python Script
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite \
- --task_level \
- --n-episode \
- --policy \
- --model_ckpt \
- --save-dir \
- --visualization \
- --metrics success_rate cumulative_cost safe_success_rate
-```
-
-#### 2. Using Shell Script (Recommended)
-
-```bash
-# Copy and edit the configuration script
-cp scripts/evaluate_policy.sh my_evaluation.sh
-# Edit the configuration section in my_evaluation.sh
-bash my_evaluation.sh
-```
-
-### Task Suites
-
-VLA-Arena provides multiple task suites:
-
-##### Safety
-- **safety_dynamic_obstacles**: Safety Dynamic Obstacles Task
-- **safety_hazard_avoidance**: Safety Hazard Avoidance Task
-- **safety_object_state_preservation**: Safety Object State Preservation Task
-- **safety_risk_aware_grasping**: Safety Risk Aware Grasping Task
-- **safety_static_obstacles**: Safety Static Obstacles Task
-
-##### Robustness
-- **robustness_dynamic_distractors**: Robustness Dynamic Distractors Task
-- **robustness_static_distractors**: Robustness Static Distractors Task
-- **robustness_visual_variations**: Robustness Visual Variations Task
-
-##### Generalization
-- **generalization_language_variations**: Generalization Language Variations Task
-- **generalization_object_preposition_combinations**: Generalization Object Preposition Combinations Task
-- **generalization_task_workflows**: Generalization Task Workflows Task
-- **generalization_unseen_objects**: Generalization Unseen Objects Task
-
-##### Others
-- **long_horizon**: Long Horizon Task
-
-### Task Levels
-
-Each task suite contains multiple difficulty levels:
-
-- **Level 0**: Simple tasks
-- **Level 1**: Medium difficulty tasks
-- **Level 2**: Difficult tasks
-
-Supports multi-level evaluation:
-
-```bash
-# Evaluate a single level
---task_level 0
-
-# Evaluate a level range
---task_level 0-2
-
-# Evaluate specific levels
---task_level 0,2
-```
-
-### Evaluation Metrics
-
-Supported evaluation metrics:
-
-- **success_rate**: Success rate
-- **safe_success_rate**: Safe success rate (cost < 1.0)
-- **cumulative_cost**: Cumulative cost
-- **episode_length**: Episode length
-
-### Visualization Options
-
-Enable visualization to save evaluation videos:
-
-```bash
---visualization
-```
-
-Videos will be saved in the `{save_dir}/rollouts/level_{level}/` directory.
-
-## Adding Custom Models
-
-### 1. Create Custom Policy Class
-
-Create a new policy file, e.g., `my_custom_policy.py`:
-
-```python
-import torch
-import numpy as np
-from vla_arena.evaluation.policy.base import Policy, PolicyRegistry
-from vla_arena.evaluation.utils import normalize_gripper_action, invert_gripper_action
-
-@PolicyRegistry.register("my_custom_model")
-class MyCustomPolicy(Policy):
- """
- Custom model policy implementation
- """
-
- def __init__(self,
- model_ckpt,
- device="cuda",
- **kwargs):
- """
- Initialize custom policy
-
- Args:
- model_ckpt: Model checkpoint path
- device: Running device
- **kwargs: Other parameters
- """
- # Check device availability
- if device == "cuda" and not torch.cuda.is_available():
- print("CUDA not available, falling back to CPU")
- device = "cpu"
-
- # Load your model
- self.model = self._load_model(model_ckpt, device)
- self.device = device
- self.instruction = kwargs.get('instruction', None)
-
- # Call parent class constructor
- super().__init__(self.model)
-
- print(f"Custom model loaded successfully on {device}")
-
- def _load_model(self, model_ckpt, device):
- """
- Load your custom model
-
- Args:
- model_ckpt: Model checkpoint path
- device: Running device
-
- Returns:
- Loaded model
- """
- # Implement your model loading logic here
- # For example:
- # model = YourCustomModel.from_pretrained(model_ckpt)
- # model.to(device)
- # model.eval()
- # return model
-
- raise NotImplementedError("Please implement your model loading logic")
-
- def reset_instruction(self, instruction):
- """
- Reset policy instruction
-
- Args:
- instruction: Task instruction
- """
- self.instruction = instruction
- # Reset model internal state if needed
- if hasattr(self.model, 'reset'):
- self.model.reset()
-
- def predict(self, obs, **kwargs):
- """
- Predict action
-
- Args:
- obs: Observation dictionary containing:
- - agentview_image: Main view image
- - robot0_eye_in_hand_image: Wrist camera image (optional)
- - robot0_eef_pos: End-effector position
- - robot0_eef_quat: End-effector quaternion
- - robot0_gripper_qpos: Gripper position
- **kwargs: Other parameters
-
- Returns:
- action: 7-dimensional action array [x, y, z, rx, ry, rz, gripper]
- """
- # Process observation
- processed_obs = self._prepare_observation(obs)
-
- # Get model prediction
- with torch.inference_mode():
- action = self.model.predict(processed_obs)
-
- # Process action
- action = self._process_action(action)
-
- return action
-
- def _prepare_observation(self, obs):
- """
- Prepare observation data
-
- Args:
- obs: Raw observation
-
- Returns:
- processed_obs: Processed observation
- """
- # Implement your observation preprocessing logic
- # For example, image preprocessing, state vector construction, etc.
-
- processed_obs = {
- "image": obs["agentview_image"],
- "state": np.concatenate([
- obs["robot0_eef_pos"],
- obs["robot0_eef_quat"],
- obs["robot0_gripper_qpos"]
- ]),
- "instruction": self.instruction
- }
-
- return processed_obs
-
- def _process_action(self, action):
- """
- Process action output
-
- Args:
- action: Raw action
-
- Returns:
- action: Processed action
- """
- # Ensure action is a numpy array
- if torch.is_tensor(action):
- action = action.cpu().numpy()
-
- # Normalize gripper action
- action = normalize_gripper_action(action, binarize=True)
-
- # Invert gripper action (if needed)
- action = invert_gripper_action(action)
-
- return action
-
- @property
- def name(self):
- """Return policy name"""
- return "MyCustomModel"
-
- @property
- def control_mode(self):
- """
- Return control mode
- "ee" for end-effector control
- "joint" for joint control
- """
- return "ee"
-```
-
-### 2. Register Policy
-
-Ensure your policy file is correctly imported. Add the following to `vla_arena/evaluation/policy/__init__.py`:
-
-```python
-from .my_custom_policy import MyCustomPolicy
-```
-
-### 3. Create Configuration File
-
-Create a configuration file for your model `vla_arena/configs/evaluation/my_custom_model.yaml`:
-
-```yaml
-# Model-specific configuration
-unnorm_key: "libero_spatial_no_noops" # Action denormalization key
-image_resize_size: 256 # Image resize size
-use_proprio: true # Whether to use proprioception
-center_crop: true # Whether to center crop
-```
-
-### 4. Use Custom Model
-
-Now you can use your custom model in evaluation scripts:
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy my_custom_model \
- --model_ckpt /path/to/your/model \
- --save-dir logs/evaluation
-```
-
-## Configuration Instructions
-
-### Evaluator Configuration
-
-Main parameters of the `Evaluator` class:
-
-```python
-evaluator = Evaluator(
- task_suite="preposition_generalization", # Task suite
- task_levels=[0, 1, 2], # Evaluation level list
- n_episodes=5, # Number of episodes per task
- episode_config=None, # Episode configuration file
- max_substeps=1, # Maximum substeps
- tolerance=1e-2, # Tolerance
- metrics=["success_rate", "cumulative_cost", "safe_success_rate"], # Evaluation metrics
- save_dir="logs/evaluation", # Save directory
- visualization=True # Whether to visualize
-)
-```
-
-### Policy Configuration
-
-Configuration parameters for different policies:
-
-#### OpenVLA
-```python
-policy = OpenVLA(
- model_ckpt="/path/to/openvla/model",
- attn_implementation="torch", # Attention implementation
- norm_config_file=None, # Normalization configuration file
- device="cuda"
-)
-```
-
-#### OpenVLA-OFT
-```python
-policy = OpenVLAOFT(
- model_ckpt="/path/to/openvla-oft/model",
- use_l1_regression=True, # Use L1 regression
- use_diffusion=False, # Use diffusion model
- use_film=True, # Use FiLM
- num_images_in_input=2, # Number of input images
- use_proprio=True, # Use proprioception
- num_open_loop_steps=8, # Open-loop steps
- device="cuda"
-)
-```
-
-#### SmolVLA
-```python
-policy = SmolVLA(
- model_ckpt="smolvla/smolvla-7b", # HuggingFace model name or local path
- device="cuda"
-)
-```
-
-#### OpenPi
-Using the OpenPi model requires starting a policy server first, then connecting via WebSocket for inference:
-
-**Step 1: Start OpenPi Policy Server**
-
-Start the policy server in the OpenPi library:
-
-```bash
-# Navigate to OpenPi directory
-cd /path/to/openpi
-
-# Start policy server (using checkpoint for iteration 20,000, modify as needed)
-uv run scripts/serve_policy.py policy:checkpoint \
- --policy.config=pi0_fast_libero \
- --policy.dir=checkpoints/pi0_fast_libero/my_experiment/20000
-```
-
-The server will listen on port 8000 (default configuration).
-
-**Step 2: Configure OpenPi Policy**
-
-```python
-policy = OpenPI(
- host="0.0.0.0", # Server host address
- port=8000, # Server port
- replan_steps=4 # Replanning steps
-)
-```
-
-**Step 3: Run Evaluation**
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy openpi \
- --save-dir logs/evaluation
-```
-
-**Important Notes:**
-- Ensure the OpenPi server is started and running before beginning evaluation
-- If using a different port, modify the `port` parameter in the policy configuration accordingly
-- Server address defaults to `0.0.0.0`, modify the `host` parameter if connecting to a remote server
-- Keep the server running during evaluation, otherwise connection will fail
-
-## Evaluation Result Storage
-
-### Directory Structure
-
-After evaluation is completed, results will be saved in the specified directory with the following structure:
-
-```
-logs/evaluation/
-└── eval_preposition_generalization_L0-2_OpenVLA_20241201_143022/
- ├── evaluation_metadata.json # Evaluation metadata
- ├── complete_metrics.json # Complete metric data
- ├── evaluation_summary.txt # Human-readable summary
- ├── summary.json # Simplified JSON summary
- ├── task_details/ # Task detailed results
- │ ├── level_0/
- │ │ ├── task_1/
- │ │ │ └── detail_result.json
- │ │ └── task_2/
- │ │ └── detail_result.json
- │ └── level_1/
- │ └── ...
- ├── level_summaries/ # Level summaries
- │ ├── level_0_summary.json
- │ └── level_1_summary.json
- └── rollouts/ # Visualization videos (if enabled)
- ├── level_0/
- │ └── 2024-12-01/
- │ ├── L0--2024-12-01--episode=0--success=True--task=place_object_on_table.mp4
- │ └── L0--2024-12-01--episode=1--success=False--task=move_object_to_bowl.mp4
- └── level_1/
- └── ...
-```
-
-### Log File Examples
-
-#### 1. Evaluation Metadata (`evaluation_metadata.json`)
-
-```json
-{
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "agent_name": "OpenVLA",
- "n_episodes": 5,
- "timestamp": "2024-12-01T14:30:22.123456",
- "metrics": ["success_rate", "cumulative_cost", "safe_success_rate"],
- "visualization": true
-}
-```
-
-#### 2. Complete Metric Data (`complete_metrics.json`)
-
-```json
-{
- "timestamp": "2024-12-01T14:30:22.123456",
- "agent_name": "OpenVLA",
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "evaluation_dir": "/path/to/logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022",
- "metrics": {
- "evaluation_config": {
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "n_episodes_per_task": 5,
- "target_metrics": ["success_rate", "cumulative_cost", "safe_success_rate"]
- },
- "per_level_metrics": {
- "level_0": {
- "average_success_rate": 0.85,
- "average_safe_success_rate": 0.78,
- "average_cumulative_cost": 0.45,
- "num_tasks": 10,
- "task_metrics": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "cumulative_cost": 0.6
- }
- }
- },
- "level_1": {
- "average_success_rate": 0.72,
- "average_safe_success_rate": 0.65,
- "average_cumulative_cost": 0.68,
- "num_tasks": 10,
- "task_metrics": {
- "place_object_between_objects": {
- "success_rate": 0.7,
- "safe_success_rate": 0.6,
- "cumulative_cost": 0.8
- }
- }
- }
- },
- "cross_level_summary": {
- "overall_average_success_rate": 0.785,
- "overall_average_safe_success_rate": 0.715,
- "overall_std_success_rate": 0.092,
- "overall_std_safe_success_rate": 0.095,
- "overall_average_cumulative_cost": 0.565,
- "overall_std_cumulative_cost": 0.175,
- "total_tasks_evaluated": 20,
- "total_episodes": 100,
- "total_successful_episodes": 78,
- "total_safe_successful_episodes": 71,
- "global_success_rate": 0.78,
- "global_safe_success_rate": 0.71
- }
- }
-}
-```
-
-#### 3. Human-Readable Summary (`evaluation_summary.txt`)
-
-```
-======================================================================
-EVALUATION SUMMARY
-======================================================================
-
-Agent: OpenVLA
-Task Suite: preposition_generalization
-Levels Evaluated: [0, 1, 2]
-Timestamp: 2024-12-01T14:30:22.123456
-Output Directory: /path/to/logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022
-
-======================================================================
-OVERALL RESULTS
-======================================================================
-
-Total Episodes Evaluated: 100
-Total Tasks Evaluated: 20
-
-Global Success Rate: 78.00%
- - Successful Episodes: 78/100
-
-Global Safe Success Rate: 71.00%
- - Safe Successful Episodes: 71/100
-
-Average Success Rate (across tasks): 78.50% ± 9.20%
-
-Average Safe Success Rate (across tasks): 71.50% ± 9.50%
-
-Average Cumulative Cost: 0.57 ± 0.18
-
-======================================================================
-PER-LEVEL RESULTS
-======================================================================
-
-Level 0:
- Success Rate: 85.00%
- Safe Success Rate: 78.00%
- Average Cost: 0.45
- Tasks Evaluated: 10
-
- Task Breakdown:
- • place_object_on_table:
- - Success Rate: 90.00%
- - Safe Success Rate: 80.00%
- - Avg Cost: 0.30
- • move_object_to_bowl:
- - Success Rate: 80.00%
- - Safe Success Rate: 76.00%
- - Avg Cost: 0.60
-
-Level 1:
- Success Rate: 72.00%
- Safe Success Rate: 65.00%
- Average Cost: 0.68
- Tasks Evaluated: 10
-
- Task Breakdown:
- • place_object_between_objects:
- - Success Rate: 70.00%
- - Safe Success Rate: 60.00%
- - Avg Cost: 0.80
-```
-
-#### 4. Simplified Summary (`summary.json`)
-
-```json
-{
- "agent": "OpenVLA",
- "suite": "preposition_generalization",
- "levels": [0, 1, 2],
- "timestamp": "2024-12-01T14:30:22.123456",
- "overall": {
- "success_rate": 0.78,
- "safe_success_rate": 0.71,
- "avg_cost": 0.565,
- "total_episodes": 100
- },
- "per_level": {
- "0": {
- "success_rate": 0.85,
- "safe_success_rate": 0.78,
- "avg_cost": 0.45,
- "tasks": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "avg_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "avg_cost": 0.6
- }
- }
- },
- "1": {
- "success_rate": 0.72,
- "safe_success_rate": 0.65,
- "avg_cost": 0.68,
- "tasks": {
- "place_object_between_objects": {
- "success_rate": 0.7,
- "safe_success_rate": 0.6,
- "avg_cost": 0.8
- }
- }
- }
- }
-}
-```
-
-#### 5. Task Detailed Results (`task_details/level_0/task_name/detail_result.json`)
-
-```json
-{
- "task_name": "place_object_on_table",
- "task_suite": "preposition_generalization",
- "task_level": 0,
- "agent_name": "OpenVLA",
- "metric_score": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3,
- "cumulative_cost_std": 0.15,
- "cumulative_cost_min": 0.1,
- "cumulative_cost_max": 0.5
- },
- "timestamp": "2024-12-01T14:30:22.123456",
- "episodes": [
- {
- "success": true,
- "episode_id": 0,
- "episode_length": 45,
- "cumulative_cost": 0.2,
- "task_level": 0
- },
- {
- "success": true,
- "episode_id": 1,
- "episode_length": 52,
- "cumulative_cost": 0.4,
- "task_level": 0
- },
- {
- "success": false,
- "episode_id": 2,
- "episode_length": 200,
- "cumulative_cost": 1.2,
- "task_level": 0
- }
- ],
- "summary": {
- "total_episodes": 5,
- "successful_episodes": 4,
- "success_rate": 0.8,
- "average_steps": 48.5,
- "avg_cumulative_cost": 0.3,
- "std_cumulative_cost": 0.15,
- "min_cumulative_cost": 0.1,
- "max_cumulative_cost": 0.5,
- "median_cumulative_cost": 0.25,
- "safe_successful_episodes": 4,
- "safe_success_rate": 0.8
- }
-}
-```
-
-#### 6. Level Summary (`level_summaries/level_0_summary.json`)
-
-```json
-{
- "task_level": 0,
- "agent_name": "OpenVLA",
- "timestamp": "2024-12-01T14:30:22.123456",
- "average_success_rate": 0.85,
- "average_safe_success_rate": 0.78,
- "std_success_rate": 0.08,
- "std_safe_success_rate": 0.09,
- "num_tasks": 10,
- "average_cumulative_cost": 0.45,
- "std_cumulative_cost": 0.12,
- "task_metrics": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "cumulative_cost": 0.6
- }
- },
- "task_details": {
- "place_object_on_table": {
- "task_level": 0,
- "metric_score": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "total_episodes": 5,
- "successful_episodes": 4,
- "safe_successful_episodes": 4,
- "failed_episodes": 1,
- "avg_cumulative_cost": 0.3
- }
- }
-}
-```
-
-### Result Analysis Tools
-
-You can use the following Python script to quickly analyze evaluation results:
-
-```python
-import json
-import pandas as pd
-from pathlib import Path
-
-def analyze_evaluation_results(results_dir):
- """Analyze evaluation results"""
- results_path = Path(results_dir)
-
- # Read simplified summary
- with open(results_path / "summary.json", 'r') as f:
- summary = json.load(f)
-
- print(f"Agent: {summary['agent']}")
- print(f"Task Suite: {summary['suite']}")
- print(f"Overall Success Rate: {summary['overall']['success_rate']:.2%}")
- print(f"Overall Safe Success Rate: {summary['overall']['safe_success_rate']:.2%}")
- print(f"Average Cost: {summary['overall']['avg_cost']:.3f}")
-
- # Create task-level DataFrame
- level_data = []
- for level, level_info in summary['per_level'].items():
- for task, task_info in level_info['tasks'].items():
- level_data.append({
- 'Level': int(level),
- 'Task': task,
- 'Success Rate': task_info['success_rate'],
- 'Safe Success Rate': task_info['safe_success_rate'],
- 'Avg Cost': task_info['avg_cost']
- })
-
- df = pd.DataFrame(level_data)
- print("\nTask-level Results:")
- print(df.to_string(index=False))
-
- return df
-
-# Usage example
-# df = analyze_evaluation_results("logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022")
-```
-
-## Examples and Best Practices
-
-### Complete Evaluation Example
-
-```bash
-#!/bin/bash
-# Complete model evaluation script
-
-MODEL_PATH="/path/to/your/model"
-OUTPUT_DIR="logs/evaluation_$(date +%Y%m%d_%H%M%S)"
-
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0-2 \
- --n-episode 5 \
- --policy openvla \
- --model_ckpt "$MODEL_PATH" \
- --save-dir "$OUTPUT_DIR" \
- --visualization \
- --metrics success_rate cumulative_cost safe_success_rate
-
-echo "Evaluation completed. Results saved to: $OUTPUT_DIR"
-```
-
-If you encounter problems or have suggestions for improvement, please feel free to submit an issue or a pull request.
\ No newline at end of file
diff --git a/docs/evaluation_zh.md b/docs/evaluation_zh.md
deleted file mode 100644
index b3632d8d..00000000
--- a/docs/evaluation_zh.md
+++ /dev/null
@@ -1,843 +0,0 @@
-# VLA-Arena 模型评估与自定义模型指南
-
-VLA-Arena 是一个用于评估视觉-语言-动作(VLA)模型的统一框架。本指南将帮助你了解如何使用 VLA-Arena 评估现有模型以及如何添加自定义模型。
-
-## 目录
-
-1. [快速开始](#快速开始)
-2. [模型评估](#模型评估)
-3. [添加自定义模型](#添加自定义模型)
-4. [配置说明](#配置说明)
-5. [故障排除](#故障排除)
-
-## 快速开始
-
-### 环境准备
-
-确保你已经安装了 VLA-Arena 及其依赖项:
-
-```bash
-# 安装 VLA-Arena
-pip install -e .
-
-# 设置环境变量
-export MUJOCO_GL=egl
-```
-
-
-### 基本评估命令
-
-最简单的评估命令:
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy openvla \
- --model_ckpt /path/to/your/model \
- --save-dir logs/evaluation
-```
-
-## 模型评估
-
-### 支持的模型
-
-VLA-Arena 目前支持以下模型:
-
-- **OpenVLA**: 标准 OpenVLA 模型
-
-### 评估脚本使用
-
-#### 1. 使用 Python 脚本
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite \
- --task_level \
- --n-episode \
- --policy \
- --model_ckpt \
- --save-dir \
- --visualization \
- --metrics success_rate cumulative_cost safe_success_rate
-```
-
-#### 2. 使用 Shell 脚本(推荐)
-
-```bash
-# 复制并编辑配置脚本
-cp scripts/evaluate_policy.sh my_evaluation.sh
-# 编辑 my_evaluation.sh 中的配置部分
-bash my_evaluation.sh
-```
-
-### 任务套件
-
-VLA-Arena 提供多个任务套件:
-
-##### 安全性
-- **safety_dynamic_obstacles**: 动态障碍物任务
-- **safety_hazard_avoidance**: 危险规避任务
-- **safety_object_state_preservation**: 物体状态保持任务
-- **safety_risk_aware_grasping**: 风险规避抓取任务
-- **safety_static_obstacles**: 静态障碍物任务
-
-##### 鲁棒性
-- **robustness_dynamic_distractors**: 动态干扰物任务
-- **robustness_static_distractors**: 静态干扰物任务
-- **robustness_visual_variations**: 视觉变化任务
-
-##### 泛化性
-- **generalization_language_variations**: 语言变化泛化任务
-- **generalization_object_preposition_combinations**: 物体介词组合泛化任务
-- **generalization_task_workflows**: 任务工作流程泛化任务
-- **generalization_unseen_objects**: 未见物体泛化任务
-
-##### 其他
-- **long_horizon**: 长程任务
-
-### 任务级别
-
-每个任务套件包含多个难度级别:
-
-- **Level 0**: 简单任务
-- **Level 1**: 中等难度任务
-- **Level 2**: 困难任务
-
-支持多级别评估:
-
-```bash
-# 评估单个级别
---task_level 0
-
-# 评估级别范围
---task_level 0-2
-
-# 评估特定级别
---task_level 0,2,
-
-```
-
-### 评估指标
-
-支持的评估指标:
-
-- **success_rate**: 成功率
-- **safe_success_rate**: 安全成功率(成本 < 1.0)
-- **cumulative_cost**: 累积成本
-- **episode_length**: 回合长度
-
-### 可视化选项
-
-启用可视化以保存评估视频:
-
-```bash
---visualization
-```
-
-视频将保存在 `{save_dir}/rollouts/level_{level}/` 目录中。
-
-## 添加自定义模型
-
-### 1. 创建自定义策略类
-
-创建一个新的策略文件,例如 `my_custom_policy.py`:
-
-```python
-import torch
-import numpy as np
-from vla_arena.evaluation.policy.base import Policy, PolicyRegistry
-from vla_arena.evaluation.utils import normalize_gripper_action, invert_gripper_action
-
-@PolicyRegistry.register("my_custom_model")
-class MyCustomPolicy(Policy):
- """
- 自定义模型策略实现
- """
-
- def __init__(self,
- model_ckpt,
- device="cuda",
- **kwargs):
- """
- 初始化自定义策略
-
- Args:
- model_ckpt: 模型检查点路径
- device: 运行设备
- **kwargs: 其他参数
- """
- # 检查设备可用性
- if device == "cuda" and not torch.cuda.is_available():
- print("CUDA not available, falling back to CPU")
- device = "cpu"
-
- # 加载你的模型
- self.model = self._load_model(model_ckpt, device)
- self.device = device
- self.instruction = kwargs.get('instruction', None)
-
- # 调用父类构造函数
- super().__init__(self.model)
-
- print(f"Custom model loaded successfully on {device}")
-
- def _load_model(self, model_ckpt, device):
- """
- 加载你的自定义模型
-
- Args:
- model_ckpt: 模型检查点路径
- device: 运行设备
-
- Returns:
- 加载的模型
- """
- # 在这里实现你的模型加载逻辑
- # 例如:
- # model = YourCustomModel.from_pretrained(model_ckpt)
- # model.to(device)
- # model.eval()
- # return model
-
- raise NotImplementedError("请实现你的模型加载逻辑")
-
- def reset_instruction(self, instruction):
- """
- 重置策略指令
-
- Args:
- instruction: 任务指令
- """
- self.instruction = instruction
- # 如果需要,重置模型内部状态
- if hasattr(self.model, 'reset'):
- self.model.reset()
-
- def predict(self, obs, **kwargs):
- """
- 预测动作
-
- Args:
- obs: 观察字典,包含:
- - agentview_image: 主视角图像
- - robot0_eye_in_hand_image: 手腕相机图像(可选)
- - robot0_eef_pos: 末端执行器位置
- - robot0_eef_quat: 末端执行器四元数
- - robot0_gripper_qpos: 夹爪位置
- **kwargs: 其他参数
-
- Returns:
- action: 7维动作数组 [x, y, z, rx, ry, rz, gripper]
- """
- # 处理观察
- processed_obs = self._prepare_observation(obs)
-
- # 获取模型预测
- with torch.inference_mode():
- action = self.model.predict(processed_obs)
-
- # 处理动作
- action = self._process_action(action)
-
- return action
-
- def _prepare_observation(self, obs):
- """
- 准备观察数据
-
- Args:
- obs: 原始观察
-
- Returns:
- processed_obs: 处理后的观察
- """
- # 实现你的观察预处理逻辑
- # 例如图像预处理、状态向量构建等
-
- processed_obs = {
- "image": obs["agentview_image"],
- "state": np.concatenate([
- obs["robot0_eef_pos"],
- obs["robot0_eef_quat"],
- obs["robot0_gripper_qpos"]
- ]),
- "instruction": self.instruction
- }
-
- return processed_obs
-
- def _process_action(self, action):
- """
- 处理动作输出
-
- Args:
- action: 原始动作
-
- Returns:
- action: 处理后的动作
- """
- # 确保动作是 numpy 数组
- if torch.is_tensor(action):
- action = action.cpu().numpy()
-
- # 标准化夹爪动作
- action = normalize_gripper_action(action, binarize=True)
-
- # 反转夹爪动作(如果需要)
- action = invert_gripper_action(action)
-
- return action
-
- @property
- def name(self):
- """返回策略名称"""
- return "MyCustomModel"
-
- @property
- def control_mode(self):
- """
- 返回控制模式
- "ee" 表示末端执行器控制
- "joint" 表示关节控制
- """
- return "ee"
-```
-
-### 2. 注册策略
-
-确保你的策略文件被正确导入。在 `vla_arena/evaluation/policy/__init__.py` 中添加:
-
-```python
-from .my_custom_policy import MyCustomPolicy
-```
-
-### 3. 创建配置文件
-
-为你的模型创建配置文件 `vla_arena/configs/evaluation/my_custom_model.yaml`:
-
-```yaml
-# 模型特定配置
-unnorm_key: "libero_spatial_no_noops" # 动作反归一化键
-image_resize_size: 256 # 图像调整大小
-use_proprio: true # 是否使用本体感受
-center_crop: true # 是否中心裁剪
-```
-
-### 4. 使用自定义模型
-
-现在你可以在评估脚本中使用你的自定义模型:
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy my_custom_model \
- --model_ckpt /path/to/your/model \
- --save-dir logs/evaluation
-```
-
-## 配置说明
-
-### 评估器配置
-
-`Evaluator` 类的主要参数:
-
-```python
-evaluator = Evaluator(
- task_suite="preposition_generalization", # 任务套件
- task_levels=[0, 1, 2], # 评估级别列表
- n_episodes=5, # 每个任务的回合数
- episode_config=None, # 回合配置文件
- max_substeps=1, # 最大子步数
- tolerance=1e-2, # 容差
- metrics=["success_rate", "cumulative_cost", "safe_success_rate"], # 评估指标
- save_dir="logs/evaluation", # 保存目录
- visualization=True # 是否可视化
-)
-```
-
-### 策略配置
-
-不同策略的配置参数:
-
-#### OpenVLA
-```python
-policy = OpenVLA(
- model_ckpt="/path/to/openvla/model",
- attn_implementation="torch", # 注意力实现
- norm_config_file=None, # 归一化配置文件
- device="cuda"
-)
-```
-
-#### OpenVLA-OFT
-```python
-policy = OpenVLAOFT(
- model_ckpt="/path/to/openvla-oft/model",
- use_l1_regression=True, # 使用 L1 回归
- use_diffusion=False, # 使用扩散模型
- use_film=True, # 使用 FiLM
- num_images_in_input=2, # 输入图像数量
- use_proprio=True, # 使用本体感受
- num_open_loop_steps=8, # 开环步数
- device="cuda"
-)
-```
-
-#### SmolVLA
-```python
-policy = SmolVLA(
- model_ckpt="smolvla/smolvla-7b", # HuggingFace 模型名称或本地路径
- device="cuda"
-)
-```
-
-#### OpenPi
-使用 OpenPi 模型需要先启动策略服务器,然后通过 WebSocket 连接进行推理:
-
-**步骤 1: 启动 OpenPi 策略服务器**
-
-在 OpenPi 库中启动策略服务器:
-
-```bash
-# 进入 OpenPi 目录
-cd /path/to/openpi
-
-# 启动策略服务器(使用迭代 20,000 的检查点,可根据需要修改)
-uv run scripts/serve_policy.py policy:checkpoint \
- --policy.config=pi0_fast_libero \
- --policy.dir=checkpoints/pi0_fast_libero/my_experiment/20000
-```
-
-服务器将在端口 8000 上监听(默认配置)。
-
-**步骤 2: 配置 OpenPi 策略**
-
-```python
-policy = OpenPI(
- host="0.0.0.0", # 服务器主机地址
- port=8000, # 服务器端口
- replan_steps=4 # 重新规划步数
-)
-```
-
-**步骤 3: 运行评估**
-
-```bash
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0 \
- --n-episode 1 \
- --policy openpi \
- --save-dir logs/evaluation
-```
-
-**注意事项:**
-- 确保 OpenPi 服务器在评估开始前已启动并运行
-- 如果使用不同的端口,请在策略配置中相应修改 `port` 参数
-- 服务器地址默认为 `0.0.0.0`,如需连接远程服务器,请修改 `host` 参数
-- 评估过程中请保持服务器运行,否则会导致连接失败
-
-## 评估结果存储
-
-### 目录结构
-
-评估完成后,结果将保存在指定的目录中,目录结构如下:
-
-```
-logs/evaluation/
-└── eval_preposition_generalization_L0-2_OpenVLA_20241201_143022/
- ├── evaluation_metadata.json # 评估元数据
- ├── complete_metrics.json # 完整指标数据
- ├── evaluation_summary.txt # 人类可读摘要
- ├── summary.json # 简化JSON摘要
- ├── task_details/ # 任务详细结果
- │ ├── level_0/
- │ │ ├── task_1/
- │ │ │ └── detail_result.json
- │ │ └── task_2/
- │ │ └── detail_result.json
- │ └── level_1/
- │ └── ...
- ├── level_summaries/ # 级别摘要
- │ ├── level_0_summary.json
- │ └── level_1_summary.json
- └── rollouts/ # 可视化视频(如果启用)
- ├── level_0/
- │ └── 2024-12-01/
- │ ├── L0--2024-12-01--episode=0--success=True--task=place_object_on_table.mp4
- │ └── L0--2024-12-01--episode=1--success=False--task=move_object_to_bowl.mp4
- └── level_1/
- └── ...
-```
-
-### 日志文件示例
-
-#### 1. 评估元数据 (`evaluation_metadata.json`)
-
-```json
-{
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "agent_name": "OpenVLA",
- "n_episodes": 5,
- "timestamp": "2024-12-01T14:30:22.123456",
- "metrics": ["success_rate", "cumulative_cost", "safe_success_rate"],
- "visualization": true
-}
-```
-
-#### 2. 完整指标数据 (`complete_metrics.json`)
-
-```json
-{
- "timestamp": "2024-12-01T14:30:22.123456",
- "agent_name": "OpenVLA",
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "evaluation_dir": "/path/to/logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022",
- "metrics": {
- "evaluation_config": {
- "task_suite": "preposition_generalization",
- "task_levels": [0, 1, 2],
- "n_episodes_per_task": 5,
- "target_metrics": ["success_rate", "cumulative_cost", "safe_success_rate"]
- },
- "per_level_metrics": {
- "level_0": {
- "average_success_rate": 0.85,
- "average_safe_success_rate": 0.78,
- "average_cumulative_cost": 0.45,
- "num_tasks": 10,
- "task_metrics": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "cumulative_cost": 0.6
- }
- }
- },
- "level_1": {
- "average_success_rate": 0.72,
- "average_safe_success_rate": 0.65,
- "average_cumulative_cost": 0.68,
- "num_tasks": 10,
- "task_metrics": {
- "place_object_between_objects": {
- "success_rate": 0.7,
- "safe_success_rate": 0.6,
- "cumulative_cost": 0.8
- }
- }
- }
- },
- "cross_level_summary": {
- "overall_average_success_rate": 0.785,
- "overall_average_safe_success_rate": 0.715,
- "overall_std_success_rate": 0.092,
- "overall_std_safe_success_rate": 0.095,
- "overall_average_cumulative_cost": 0.565,
- "overall_std_cumulative_cost": 0.175,
- "total_tasks_evaluated": 20,
- "total_episodes": 100,
- "total_successful_episodes": 78,
- "total_safe_successful_episodes": 71,
- "global_success_rate": 0.78,
- "global_safe_success_rate": 0.71
- }
- }
-}
-```
-
-#### 3. 人类可读摘要 (`evaluation_summary.txt`)
-
-```
-======================================================================
-EVALUATION SUMMARY
-======================================================================
-
-Agent: OpenVLA
-Task Suite: preposition_generalization
-Levels Evaluated: [0, 1, 2]
-Timestamp: 2024-12-01T14:30:22.123456
-Output Directory: /path/to/logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022
-
-======================================================================
-OVERALL RESULTS
-======================================================================
-
-Total Episodes Evaluated: 100
-Total Tasks Evaluated: 20
-
-Global Success Rate: 78.00%
- - Successful Episodes: 78/100
-
-Global Safe Success Rate: 71.00%
- - Safe Successful Episodes: 71/100
-
-Average Success Rate (across tasks): 78.50% ± 9.20%
-
-Average Safe Success Rate (across tasks): 71.50% ± 9.50%
-
-Average Cumulative Cost: 0.57 ± 0.18
-
-======================================================================
-PER-LEVEL RESULTS
-======================================================================
-
-Level 0:
- Success Rate: 85.00%
- Safe Success Rate: 78.00%
- Average Cost: 0.45
- Tasks Evaluated: 10
-
- Task Breakdown:
- • place_object_on_table:
- - Success Rate: 90.00%
- - Safe Success Rate: 80.00%
- - Avg Cost: 0.30
- • move_object_to_bowl:
- - Success Rate: 80.00%
- - Safe Success Rate: 76.00%
- - Avg Cost: 0.60
-
-Level 1:
- Success Rate: 72.00%
- Safe Success Rate: 65.00%
- Average Cost: 0.68
- Tasks Evaluated: 10
-
- Task Breakdown:
- • place_object_between_objects:
- - Success Rate: 70.00%
- - Safe Success Rate: 60.00%
- - Avg Cost: 0.80
-```
-
-#### 4. 简化摘要 (`summary.json`)
-
-```json
-{
- "agent": "OpenVLA",
- "suite": "preposition_generalization",
- "levels": [0, 1, 2],
- "timestamp": "2024-12-01T14:30:22.123456",
- "overall": {
- "success_rate": 0.78,
- "safe_success_rate": 0.71,
- "avg_cost": 0.565,
- "total_episodes": 100
- },
- "per_level": {
- "0": {
- "success_rate": 0.85,
- "safe_success_rate": 0.78,
- "avg_cost": 0.45,
- "tasks": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "avg_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "avg_cost": 0.6
- }
- }
- },
- "1": {
- "success_rate": 0.72,
- "safe_success_rate": 0.65,
- "avg_cost": 0.68,
- "tasks": {
- "place_object_between_objects": {
- "success_rate": 0.7,
- "safe_success_rate": 0.6,
- "avg_cost": 0.8
- }
- }
- }
- }
-}
-```
-
-#### 5. 任务详细结果 (`task_details/level_0/task_name/detail_result.json`)
-
-```json
-{
- "task_name": "place_object_on_table",
- "task_suite": "preposition_generalization",
- "task_level": 0,
- "agent_name": "OpenVLA",
- "metric_score": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3,
- "cumulative_cost_std": 0.15,
- "cumulative_cost_min": 0.1,
- "cumulative_cost_max": 0.5
- },
- "timestamp": "2024-12-01T14:30:22.123456",
- "episodes": [
- {
- "success": true,
- "episode_id": 0,
- "episode_length": 45,
- "cumulative_cost": 0.2,
- "task_level": 0
- },
- {
- "success": true,
- "episode_id": 1,
- "episode_length": 52,
- "cumulative_cost": 0.4,
- "task_level": 0
- },
- {
- "success": false,
- "episode_id": 2,
- "episode_length": 200,
- "cumulative_cost": 1.2,
- "task_level": 0
- }
- ],
- "summary": {
- "total_episodes": 5,
- "successful_episodes": 4,
- "success_rate": 0.8,
- "average_steps": 48.5,
- "avg_cumulative_cost": 0.3,
- "std_cumulative_cost": 0.15,
- "min_cumulative_cost": 0.1,
- "max_cumulative_cost": 0.5,
- "median_cumulative_cost": 0.25,
- "safe_successful_episodes": 4,
- "safe_success_rate": 0.8
- }
-}
-```
-
-#### 6. 级别摘要 (`level_summaries/level_0_summary.json`)
-
-```json
-{
- "task_level": 0,
- "agent_name": "OpenVLA",
- "timestamp": "2024-12-01T14:30:22.123456",
- "average_success_rate": 0.85,
- "average_safe_success_rate": 0.78,
- "std_success_rate": 0.08,
- "std_safe_success_rate": 0.09,
- "num_tasks": 10,
- "average_cumulative_cost": 0.45,
- "std_cumulative_cost": 0.12,
- "task_metrics": {
- "place_object_on_table": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "move_object_to_bowl": {
- "success_rate": 0.8,
- "safe_success_rate": 0.76,
- "cumulative_cost": 0.6
- }
- },
- "task_details": {
- "place_object_on_table": {
- "task_level": 0,
- "metric_score": {
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "cumulative_cost": 0.3
- },
- "success_rate": 0.9,
- "safe_success_rate": 0.8,
- "total_episodes": 5,
- "successful_episodes": 4,
- "safe_successful_episodes": 4,
- "failed_episodes": 1,
- "avg_cumulative_cost": 0.3
- }
- }
-}
-```
-
-### 结果分析工具
-
-你可以使用以下Python脚本快速分析评估结果:
-
-```python
-import json
-import pandas as pd
-from pathlib import Path
-
-def analyze_evaluation_results(results_dir):
- """分析评估结果"""
- results_path = Path(results_dir)
-
- # 读取简化摘要
- with open(results_path / "summary.json", 'r') as f:
- summary = json.load(f)
-
- print(f"Agent: {summary['agent']}")
- print(f"Task Suite: {summary['suite']}")
- print(f"Overall Success Rate: {summary['overall']['success_rate']:.2%}")
- print(f"Overall Safe Success Rate: {summary['overall']['safe_success_rate']:.2%}")
- print(f"Average Cost: {summary['overall']['avg_cost']:.3f}")
-
- # 创建任务级别的DataFrame
- level_data = []
- for level, level_info in summary['per_level'].items():
- for task, task_info in level_info['tasks'].items():
- level_data.append({
- 'Level': int(level),
- 'Task': task,
- 'Success Rate': task_info['success_rate'],
- 'Safe Success Rate': task_info['safe_success_rate'],
- 'Avg Cost': task_info['avg_cost']
- })
-
- df = pd.DataFrame(level_data)
- print("\nTask-level Results:")
- print(df.to_string(index=False))
-
- return df
-
-# 使用示例
-# df = analyze_evaluation_results("logs/evaluation/eval_preposition_generalization_L0-2_OpenVLA_20241201_143022")
-```
-
-### 完整评估示例
-
-```bash
-#!/bin/bash
-# 完整的模型评估脚本
-
-MODEL_PATH="/path/to/your/model"
-OUTPUT_DIR="logs/evaluation_$(date +%Y%m%d_%H%M%S)"
-
-python scripts/evaluate_policy.py \
- --task_suite preposition_generalization \
- --task_level 0-2 \
- --n-episode 5 \
- --policy openvla \
- --model_ckpt "$MODEL_PATH" \
- --save-dir "$OUTPUT_DIR" \
- --visualization \
- --metrics success_rate cumulative_cost safe_success_rate
-
-echo "Evaluation completed. Results saved to: $OUTPUT_DIR"
-```
-
-
-如果你遇到问题或有改进建议,请参考代码注释或联系开发团队。
\ No newline at end of file
diff --git a/docs/finetune.md b/docs/finetune.md
deleted file mode 100644
index 858a69bd..00000000
--- a/docs/finetune.md
+++ /dev/null
@@ -1,840 +0,0 @@
-# Fine-tuning Other Models with VLA-Arena Generated Datasets
-
-VLA-Arena provides a complete framework for collecting data, converting data formats, and evaluating vision-language-action models. This guide will help you understand how to fine-tune VLA models using datasets generated by VLA-Arena.
-
-## Quick Start
-
-If you already have your dataset and OpenVLA model ready, you can start fine-tuning directly with the following commands:
-
-### Standard OpenVLA Fine-tuning
-
-```bash
-# 1. Activate environment
-conda activate openvla
-
-# 2. Run fine-tuning script
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/openvla/model" \
- --data_root_dir "/path/to/your/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-### OpenVLA OFT Fine-tuning (Recommended)
-
-```bash
-# 1. Activate environment
-conda activate openvla
-
-# 2. Run OFT fine-tuning script
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/openvla/model" \
- --data_root_dir "/path/to/your/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-### UniVLA Fine-tuning
-
-```bash
-# 1. Activate environment
-conda activate univla
-
-# 2. Run UniVLA fine-tuning script
-./vla-scripts/finetune_univla.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/univla/model" \
- --lam_path "/path/to/your/lam/checkpoint" \
- --data_root_dir "/path/to/your/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-```
-
-For detailed usage instructions, please refer to the sections below.
-
-## Table of Contents
-
-1. [Quick Start](#quick-start)
-2. [Fine-tuning OpenVLA](#fine-tuning-openvla)
- - [Installing OpenVLA Library](#installing-openvla-library)
- - [One-click Fine-tuning with Scripts](#one-click-fine-tuning-with-scripts)
- - [Basic Usage](#basic-usage)
- - [Required Parameters](#required-parameters)
- - [Optional Parameters](#optional-parameters)
- - [Dataset Configuration Parameters](#dataset-configuration-parameters)
- - [State and Action Encoding Options](#state-and-action-encoding-options)
- - [Usage Examples](#usage-examples)
- - [Script Features](#script-features)
- - [Notes](#notes)
-3. [Fine-tuning OpenVLA OFT](#fine-tuning-openvla-oft)
- - [OFT Fine-tuning Introduction](#oft-fine-tuning-introduction)
- - [Using OFT Script for Fine-tuning](#using-oft-script-for-fine-tuning)
- - [Basic Usage](#basic-usage-1)
- - [Required Parameters](#required-parameters-1)
- - [Basic Training Parameters](#basic-training-parameters)
- - [LoRA Parameters](#lora-parameters)
- - [Action Representation Parameters](#action-representation-parameters)
- - [Architecture Options](#architecture-options)
- - [Learning Rate Scheduling](#learning-rate-scheduling)
- - [Validation and Checkpoints](#validation-and-checkpoints)
- - [Logging Configuration](#logging-configuration)
- - [Dataset Configuration Parameters](#dataset-configuration-parameters-1)
- - [GPU Configuration](#gpu-configuration)
- - [Usage Examples](#usage-examples-1)
- - [Script Features](#script-features-1)
- - [Notes](#notes-1)
-4. [Fine-tuning UniVLA](#fine-tuning-univla)
- - [Installing UniVLA Library](#installing-univla-library)
- - [One-click Fine-tuning with Scripts](#one-click-fine-tuning-with-scripts-1)
- - [Basic Usage](#basic-usage-2)
- - [Required Parameters](#required-parameters-2)
- - [Basic Training Parameters](#basic-training-parameters-1)
- - [LoRA Parameters](#lora-parameters-1)
- - [UniVLA Specific Parameters](#univla-specific-parameters)
- - [LAM Parameters](#lam-parameters)
- - [Logging Configuration](#logging-configuration-1)
- - [Dataset Configuration Parameters](#dataset-configuration-parameters-2)
- - [GPU Configuration](#gpu-configuration-1)
- - [Usage Examples](#usage-examples-2)
- - [Script Features](#script-features-2)
- - [Notes](#notes-2)
-5. [Fine-tuning OpenPi](#fine-tuning-openpi)
- - [Installing OpenPi Library](#installing-openpi-library)
- - [One-click Fine-tuning with Scripts](#one-click-fine-tuning-with-scripts-2)
- - [Basic Usage](#basic-usage-3)
- - [Required Parameters](#required-parameters-3)
- - [Model Configuration Parameters](#model-configuration-parameters)
- - [Training Parameters](#training-parameters)
- - [Dataset Configuration Parameters](#dataset-configuration-parameters-3)
- - [Usage Examples](#usage-examples-3)
- - [Script Features](#script-features-3)
- - [Notes](#notes-3)
-6. [Model Evaluation](#model-evaluation)
-7. [Adding Custom Models](#adding-custom-models)
-8. [Configuration Instructions](#configuration-instructions)
-
-## Fine-tuning OpenVLA
-
-### Installing OpenVLA Library
-
-```bash
-# Create and activate conda environment
-conda create -n openvla python=3.10 -y
-conda activate openvla
-
-# Install PyTorch. Below is a sample command to do this, but you should check the following link
-# to find installation instructions that are specific to your compute platform:
-# https://pytorch.org/get-started/locally/
-conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y # UPDATE ME!
-
-# Clone and install the openvla repo
-git clone https://github.com/openvla/openvla.git
-cd openvla
-pip install -e .
-
-# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
-# =>> If you run into difficulty, try `pip cache remove flash_attn` first
-pip install packaging ninja
-ninja --version; echo $? # Verify Ninja --> should return exit code "0"
-pip install "flash-attn==2.5.5" --no-build-isolation
-```
-
-### One-click Fine-tuning with Scripts
-
-Copy [finetune_openvla.sh](./finetune_openvla.sh) to the openvla/vla-scripts directory. This script will automatically add dataset configuration and run fine-tuning.
-
-#### Basic Usage
-
-```bash
-# Activate conda environment
-conda activate openvla
-
-# Basic usage (requires providing required parameters)
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-
-# Custom parameters
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 10000 \
- --wandb_project "my_project"
-```
-
-#### Required Parameters
-
-- `--dataset_name`: Dataset name (required)
-- `--vla_path`: OpenVLA model path (required)
-- `--data_root_dir`: Dataset root directory (required)
-- `--openvla_root_dir`: OpenVLA repository root directory (required)
-
-#### Optional Parameters
-
-- `--run_root_dir`: Directory to save run results (default: `new_runs`)
-- `--batch_size`: Batch size (default: `2`)
-- `--learning_rate`: Learning rate (default: `5e-4`)
-- `--max_steps`: Maximum training steps (default: `50000`)
-- `--use_lora`: Whether to use LoRA fine-tuning (default: `true`)
-- `--lora_rank`: LoRA rank (default: `32`)
-- `--use_quantization`: Whether to use quantization (default: `false`)
-- `--image_aug`: Whether to use image augmentation (default: `true`)
-- `--wandb_project`: WandB project name (default: `safe-openvla`)
-- `--wandb_entity`: WandB entity name (default: `trial`)
-- `--num_gpus`: Number of GPUs to use (default: `1`)
-
-#### Dataset Configuration Parameters
-
-The script will automatically add your dataset configuration to `configs.py` and `transforms.py` files. You can customize dataset configuration:
-
-- `--image_obs_primary`: Primary image observation key (default: `image`)
-- `--image_obs_secondary`: Secondary image observation key (default: empty)
-- `--image_obs_wrist`: Wrist image observation key (default: `wrist_image`)
-- `--depth_obs_primary`: Primary depth observation key (default: empty)
-- `--depth_obs_secondary`: Secondary depth observation key (default: empty)
-- `--depth_obs_wrist`: Wrist depth observation key (default: empty)
-- `--state_obs_keys`: State observation keys (default: `EEF_state,None,gripper_state`)
-- `--state_encoding`: State encoding (default: `POS_EULER`)
-- `--action_encoding`: Action encoding (default: `EEF_POS`)
-
-#### State and Action Encoding Options
-
-**State Encoding**:
-- `NONE`: No proprioceptive state
-- `POS_EULER`: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper state (1)
-- `POS_QUAT`: EEF XYZ (3) + Quaternion (4) + Gripper state (1)
-- `JOINT`: Joint angles (7, padded with if insufficient) + Gripper state (1)
-- `JOINT_BIMANUAL`: Joint angles (2 x [ Joint angles (6) + Gripper state (1) ])
-
-**Action Encoding**:
-- `EEF_POS`: EEF delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper state (1)
-- `JOINT_POS`: Joint delta position (7) + Gripper state (1)
-- `JOINT_POS_BIMANUAL`: Joint delta position (2 x [ Joint delta position (6) + Gripper state (1) ])
-- `EEF_R6`: EEF delta XYZ (3) + R6 (6) + Gripper state (1)
-
-#### Usage Examples
-
-**Example 1: Basic Usage**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-**Example 2: Custom Configuration**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "custom_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --image_obs_primary "front_camera" \
- --image_obs_wrist "gripper_camera" \
- --state_obs_keys "joint_positions,None,gripper_state" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 20000
-```
-
-**Example 3: Using Quantization**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "quantized_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --use_quantization true \
- --batch_size 16 \
- --max_steps 5000
-```
-
-#### Script Features
-
-1. **Parameter Validation**: Checks if required parameters are provided
-2. **Add Dataset Configuration**: Automatically adds your dataset configuration to:
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **Run Fine-tuning**: Executes OpenVLA fine-tuning script with your parameters
-
-#### Notes
-
-- The script uses `libero_dataset_transform` as the default transform function for new datasets
-- If dataset configuration already exists, the add configuration step will be skipped
-- The script automatically handles `None` values in state observation keys
-- Ensure your dataset is in the correct RLDS format and located in the specified data directory
-
-## Fine-tuning OpenVLA OFT
-
-### OFT Fine-tuning Introduction
-
-OpenVLA OFT (Open-source Foundation Transformers) fine-tuning provides more advanced training options and better performance. The OFT version supports:
-
-- **Richer Training Parameters**: Including learning rate scheduling, gradient accumulation, validation sets, etc.
-- **Action Representation Options**: Supporting L1 regression and diffusion modeling
-- **Architecture Enhancements**: FiLM language fusion, multi-image input, proprioceptive state, etc.
-- **Advanced Optimization**: LoRA dropout, LoRA merging during training, etc.
-
-### Using OFT Script for Fine-tuning
-
-Copy [finetune_openvla_oft.sh](./finetune_openvla_oft.sh) to the openvla/vla-scripts directory. This script provides more comprehensive fine-tuning options.
-
-#### Basic Usage
-
-```bash
-# Activate conda environment
-conda activate openvla
-
-# Basic usage (requires providing required parameters)
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-
-# Custom parameters
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 100000 \
- --use_l1_regression true \
- --use_film true
-```
-
-#### Required Parameters
-
-- `--dataset_name`: Dataset name (required)
-- `--vla_path`: OpenVLA model path (required)
-- `--data_root_dir`: Dataset root directory (required)
-- `--openvla_root_dir`: OpenVLA repository root directory (required)
-
-#### Basic Training Parameters
-
-- `--run_root_dir`: Directory to save run results (default: `all_runs`)
-- `--batch_size`: Batch size (default: `7`)
-- `--learning_rate`: Learning rate (default: `5e-4`)
-- `--max_steps`: Maximum training steps (default: `150000`)
-- `--grad_accumulation_steps`: Gradient accumulation steps (default: `1`)
-- `--shuffle_buffer_size`: Data loader shuffle buffer size (default: `100000`)
-
-#### LoRA Parameters
-
-- `--use_lora`: Whether to use LoRA fine-tuning (default: `true`)
-- `--lora_rank`: LoRA rank (default: `32`)
-- `--lora_dropout`: LoRA dropout (default: `0.0`)
-- `--merge_lora_during_training`: Merge LoRA during training (default: `true`)
-
-#### Action Representation Parameters
-
-- `--use_l1_regression`: Use L1 regression (default: `true`)
-- `--use_diffusion`: Use diffusion modeling (default: `false`)
-- `--num_diffusion_steps_train`: Training diffusion steps (default: `50`)
-- `--diffusion_sample_freq`: Diffusion sampling frequency (default: `50`)
-
-#### Architecture Options
-
-- `--use_film`: Use FiLM for language fusion (default: `true`)
-- `--num_images_in_input`: Number of images in input (default: `2`)
-- `--use_proprio`: Include proprioceptive state (default: `false`)
-- `--use_quantization`: Use quantization (default: `false`)
-- `--image_aug`: Use image augmentation (default: `true`)
-
-#### Learning Rate Scheduling
-
-- `--lr_warmup_steps`: Learning rate warmup steps (default: `0`)
-- `--num_steps_before_decay`: Steps before learning rate decay (default: `60000`)
-
-#### Validation and Checkpoints
-
-- `--use_val_set`: Use validation set (default: `false`)
-- `--val_freq`: Validation frequency (default: `10000`)
-- `--val_time_limit`: Validation time limit (default: `180`)
-- `--save_freq`: Save frequency (default: `5000`)
-- `--save_latest_checkpoint_only`: Save only latest checkpoint (default: `false`)
-- `--resume`: Resume from checkpoint (default: `false`)
-- `--resume_step`: Resume step (default: empty)
-
-#### Logging Configuration
-
-- `--wandb_project`: WandB project name (default: `openvla-oft-workflow-generalization`)
-- `--wandb_entity`: WandB entity name (default: `trial`)
-- `--wandb_log_freq`: WandB logging frequency (default: `10`)
-
-#### Dataset Configuration Parameters
-
-The script will automatically add your dataset configuration to `configs.py` and `transforms.py` files. You can customize dataset configuration:
-
-- `--image_obs_primary`: Primary image observation key (default: `image`)
-- `--image_obs_secondary`: Secondary image observation key (default: empty)
-- `--image_obs_wrist`: Wrist image observation key (default: `wrist_image`)
-- `--depth_obs_primary`: Primary depth observation key (default: empty)
-- `--depth_obs_secondary`: Secondary depth observation key (default: empty)
-- `--depth_obs_wrist`: Wrist depth observation key (default: empty)
-- `--state_obs_keys`: State observation keys (default: `EEF_state,None,gripper_state`)
-- `--state_encoding`: State encoding (default: `POS_EULER`)
-- `--action_encoding`: Action encoding (default: `EEF_POS`)
-
-#### GPU Configuration
-
-- `--num_gpus`: Number of GPUs to use (default: `1`)
-
-#### Usage Examples
-
-**Example 1: Basic OFT Usage**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-**Example 2: Advanced OFT Configuration**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "advanced_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 100000 \
- --use_l1_regression true \
- --use_film true \
- --use_proprio true \
- --num_images_in_input 3 \
- --lora_rank 64 \
- --grad_accumulation_steps 2
-```
-
-**Example 3: Using Diffusion Modeling**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "diffusion_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --use_diffusion true \
- --num_diffusion_steps_train 100 \
- --diffusion_sample_freq 25 \
- --batch_size 4
-```
-
-**Example 4: Multi-GPU Training**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "multi_gpu_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --num_gpus 4 \
- --batch_size 16 \
- --grad_accumulation_steps 1
-```
-
-#### Script Features
-
-1. **Parameter Validation**: Checks if required parameters are provided
-2. **Add Dataset Configuration**: Automatically adds your dataset configuration to:
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **Run OFT Fine-tuning**: Executes OpenVLA OFT fine-tuning script with your parameters
-4. **Multi-GPU Support**: Supports multi-GPU distributed training
-
-#### Notes
-
-- The OFT version provides richer training options, suitable for users who need fine control over the training process
-- Supports diffusion modeling, suitable for scenarios requiring generative action prediction
-- FiLM language fusion can provide better language-visual interaction
-- Multi-image input supports multi-view robot tasks
-- Ensure your hardware resources are sufficient to support the selected training configuration
-
-## Fine-tuning UniVLA
-
-### Installing UniVLA Library
-
-```bash
-# Create and activate conda environment
-conda create -n univla python=3.10 -y
-conda activate univla
-
-# Install PyTorch. Below is a sample command to do this, but you should check the following link
-# to find installation instructions that are specific to your compute platform:
-# https://pytorch.org/get-started/locally/
-conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y # UPDATE ME!
-
-# Clone and install the univla repo
-git clone https://github.com/opendrivelab/UniVLA.git
-cd UniVLA
-pip install -e .
-
-# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
-# =>> If you run into difficulty, try `pip cache remove flash_attn` first
-pip install packaging ninja
-ninja --version; echo $? # Verify Ninja --> should return exit code "0"
-pip install "flash-attn==2.5.5" --no-build-isolation
-
-# Install additional dependencies for UniVLA
-pip install swanlab
-pip install ema-pytorch
-pip install peft
-pip install accelerate
-```
-
-### One-click Fine-tuning with Scripts
-
-Copy [finetune_univla.sh](./finetune_univla.sh) to the UniVLA/vla-scripts directory. This script will automatically add dataset configuration and run fine-tuning.
-
-#### Basic Usage
-
-```bash
-# Activate conda environment
-conda activate univla
-
-# Basic usage (requires providing required parameters)
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-
-# Custom parameters
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 50000 \
- --wandb_project "my_project"
-```
-
-#### Required Parameters
-
-- `--dataset_name`: Dataset name (required)
-- `--vla_path`: UniVLA model path (required)
-- `--lam_path`: LAM (Latent Action Model) checkpoint path (required)
-- `--data_root_dir`: Dataset root directory (required)
-- `--univla_root_dir`: UniVLA repository root directory (required)
-
-#### Basic Training Parameters
-
-- `--run_root_dir`: Directory to save run results (default: `all_runs`)
-- `--batch_size`: Batch size (default: `8`)
-- `--learning_rate`: Learning rate (default: `3.5e-4`)
-- `--max_steps`: Maximum training steps (default: `100000`)
-- `--save_steps`: Save interval (default: `10000`)
-- `--grad_accumulation_steps`: Gradient accumulation steps (default: `2`)
-- `--shuffle_buffer_size`: Data loader shuffle buffer size (default: `16000`)
-
-#### LoRA Parameters
-
-- `--use_lora`: Whether to use LoRA fine-tuning (default: `true`)
-- `--lora_rank`: LoRA rank (default: `32`)
-- `--lora_dropout`: LoRA dropout (default: `0.0`)
-- `--use_quantization`: Whether to use quantization (default: `false`)
-
-#### UniVLA Specific Parameters
-
-- `--freeze_vla`: Freeze VLA backbone (default: `false`)
-- `--save_latest_checkpoint_only`: Save only latest checkpoint (default: `true`)
-- `--run_id_note`: Extra note for experiment ID (default: empty)
-
-#### LAM Parameters
-
-UniVLA uses a Latent Action Model (LAM) for action representation. These parameters control the LAM architecture:
-
-- `--codebook_size`: LAM codebook size (default: `16`)
-- `--lam_model_dim`: LAM model dimension (default: `768`)
-- `--lam_latent_dim`: LAM latent dimension (default: `128`)
-- `--lam_patch_size`: LAM patch size (default: `14`)
-- `--lam_enc_blocks`: LAM encoder blocks (default: `12`)
-- `--lam_dec_blocks`: LAM decoder blocks (default: `12`)
-- `--lam_num_heads`: LAM number of heads (default: `12`)
-- `--window_size`: Action window size (default: `12`)
-
-#### Logging Configuration
-
-- `--wandb_project`: WandB project name (default: `finetune-UniVLA`)
-- `--wandb_entity`: WandB entity name (default: `opendrivelab`)
-
-#### Dataset Configuration Parameters
-
-The script will automatically add your dataset configuration to `configs.py` and `transforms.py` files. You can customize dataset configuration:
-
-- `--image_obs_primary`: Primary image observation key (default: `image`)
-- `--image_obs_secondary`: Secondary image observation key (default: empty)
-- `--image_obs_wrist`: Wrist image observation key (default: `wrist_image`)
-- `--depth_obs_primary`: Primary depth observation key (default: empty)
-- `--depth_obs_secondary`: Secondary depth observation key (default: empty)
-- `--depth_obs_wrist`: Wrist depth observation key (default: empty)
-- `--state_obs_keys`: State observation keys (default: `EEF_state,None,gripper_state`)
-- `--state_encoding`: State encoding (default: `POS_EULER`)
-- `--action_encoding`: Action encoding (default: `EEF_POS`)
-
-#### GPU Configuration
-
-- `--num_gpus`: Number of GPUs to use (default: `1`)
-
-#### Usage Examples
-
-**Example 1: Basic Usage**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-```
-
-**Example 2: Custom Configuration**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "custom_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --image_obs_primary "front_camera" \
- --image_obs_wrist "gripper_camera" \
- --state_obs_keys "joint_positions,None,gripper_state" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 50000 \
- --window_size 16
-```
-
-**Example 3: Using Quantization**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "quantized_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --use_quantization true \
- --batch_size 16 \
- --max_steps 25000
-```
-
-**Example 4: Freeze VLA Backbone**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "frozen_vla_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --freeze_vla true \
- --learning_rate 1e-3 \
- --batch_size 12
-```
-
-**Example 5: Multi-GPU Training**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "multi_gpu_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --num_gpus 4 \
- --batch_size 8 \
- --grad_accumulation_steps 1
-```
-
-#### Script Features
-
-1. **Parameter Validation**: Checks if required parameters are provided
-2. **Add Dataset Configuration**: Automatically adds your dataset configuration to:
- - `{univla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{univla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **Run UniVLA Fine-tuning**: Executes UniVLA fine-tuning script with your parameters
-4. **Multi-GPU Support**: Supports multi-GPU distributed training
-5. **LAM Integration**: Automatically configures and loads the Latent Action Model
-
-#### Notes
-
-- UniVLA uses a two-stage training approach with a Latent Action Model (LAM)
-- The LAM checkpoint is required and should be pre-trained
-- The script uses `libero_dataset_transform` as the default transform function for new datasets
-- If dataset configuration already exists, the add configuration step will be skipped
-- The script automatically handles `None` values in state observation keys
-- Ensure your dataset is in the correct RLDS format and located in the specified data directory
-- UniVLA supports both frozen and unfrozen VLA backbone training
-
-## Fine-tuning OpenPi
-
-### Installing OpenPi Library
-
-```bash
-# Clone repository (with submodules)
-git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
-
-# Or if you have already cloned the repository:
-cd openpi
-git submodule update --init --recursive
-
-# Install uv (if not already installed)
-curl -LsSf https://astral.sh/uv/install.sh | sh
-
-# Install OpenPi
-cd openpi
-GIT_LFS_SKIP_SMUDGE=1 uv sync
-GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
-```
-
-**Note:** `GIT_LFS_SKIP_SMUDGE=1` is required to skip LFS file downloads for LeRobot dependencies.
-
-### One-click Fine-tuning with Scripts
-
-Copy [finetune_openpi.sh](./finetune_openpi.sh) to the openpi/scripts directory. This script will automatically add training configuration and run fine-tuning.
-
-#### Basic Usage
-
-```bash
-# Basic usage (requires providing required parameters)
-uv run bash scripts/finetune_openpi.sh \
- --config_name "my_openpi_config" \
- --exp_name "my_experiment" \
- --base_checkpoint_path "/path/to/base/checkpoint" \
- --dataset_repo_id "your_dataset_repo" \
- --hf_lerobot_home "/path/to/lerobot/home"
-
-# Custom parameters
-uv run bash scripts/finetune_openpi.sh \
- --config_name "custom_config" \
- --exp_name "custom_experiment" \
- --base_checkpoint_path "/path/to/base/checkpoint" \
- --dataset_repo_id "your_dataset_repo" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --model_type "pi0_fast" \
- --batch_size 32 \
- --learning_rate 1e-4 \
- --num_train_steps 50000
-```
-
-#### Required Parameters
-
-- `--config_name`: Configuration name (required)
-- `--exp_name`: Experiment name (required)
-- `--base_checkpoint_path`: Base model checkpoint path (required)
-- `--dataset_repo_id`: Dataset repository ID (required)
-- `--hf_lerobot_home`: HF_LEROBOT_HOME directory path (required)
-
-#### Model Configuration Parameters
-
-- `--model_type`: Model type, pi0 or pi0_fast (default: pi0)
-- `--action_dim`: Action dimension (default: 7)
-- `--action_horizon`: Action time horizon (default: 10)
-- `--max_token_len`: Maximum token length (default: 180)
-- `--use_lora`: Use LoRA fine-tuning (default: false)
-- `--lora_rank`: LoRA rank (default: 32)
-- `--lora_dropout`: LoRA dropout (default: 0.0)
-- `--paligemma_variant`: Paligemma variant (default: gemma_2b)
-- `--action_expert_variant`: Action expert variant (default: gemma_300m)
-
-#### Training Parameters
-
-- `--batch_size`: Batch size (default: 56)
-- `--learning_rate`: Learning rate (default: 3.5e-4)
-- `--num_train_steps`: Training steps (default: 30000)
-- `--log_interval`: Log interval (default: 100)
-- `--save_interval`: Save interval (default: 1000)
-- `--keep_period`: Keep period (default: 5000)
-- `--num_workers`: Number of workers (default: 2)
-- `--seed`: Random seed (default: 42)
-- `--fsdp_devices`: FSDP devices (default: 1)
-- `--ema_decay`: EMA decay (default: 0.99)
-
-#### Dataset Configuration Parameters
-
-- `--prompt_from_task`: Get prompt from task (default: true)
-
-#### Usage Examples
-
-**Example 1: Basic Usage**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0" \
- --exp_name "libero_experiment" \
- --base_checkpoint_path "/path/to/pi0/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home"
-```
-
-**Example 2: Using pi0_fast Model**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0_fast" \
- --exp_name "libero_fast_experiment" \
- --base_checkpoint_path "/path/to/pi0_fast/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --model_type "pi0_fast" \
- --batch_size 32 \
- --learning_rate 1e-4
-```
-
-**Example 3: Using LoRA Fine-tuning**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0_lora" \
- --exp_name "libero_lora_experiment" \
- --base_checkpoint_path "/path/to/pi0/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --use_lora true \
- --lora_rank 64 \
- --lora_dropout 0.1
-```
-
-**Example 4: Custom Training Parameters**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "custom_libero" \
- --exp_name "custom_experiment" \
- --base_checkpoint_path "/path/to/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --batch_size 64 \
- --learning_rate 2e-4 \
- --num_train_steps 100000 \
- --save_interval 2000 \
- --wandb_enabled true \
- --project_name "my_openpi_project"
-```
-
-#### Script Features
-
-1. **Parameter Validation**: Checks if required parameters are provided
-2. **Add Training Configuration**: Automatically adds your training configuration to `src/openpi/training/config.py`
-3. **Compute Normalization Statistics**: Automatically runs `scripts/compute_norm_stats.py`
-4. **Run Training**: Executes OpenPi training script with your parameters
-5. **Support Override**: Option to override existing checkpoints
-
-#### Notes
-
-- The script uses `LeRobotLiberoDataConfig` as the dataset configuration
-- If configuration already exists, the add configuration step will be skipped
-- Supports both pi0 and pi0_fast model types
-- LoRA fine-tuning automatically sets appropriate freeze filters
-- Ensure base checkpoint path is valid and accessible
-- Ensure dataset repository ID is correct and accessible
-- The script automatically sets the `HF_LEROBOT_HOME` environment variable
-
-
diff --git a/docs/finetune_openvla.sh b/docs/finetune_openvla.sh
deleted file mode 100644
index e5be78a9..00000000
--- a/docs/finetune_openvla.sh
+++ /dev/null
@@ -1,347 +0,0 @@
-#!/bin/bash
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-
-# finetune_openvla.sh
-# Script to add dataset configurations and run OpenVLA fine-tuning
-
-# Default values
-DATASET_NAME=""
-VLA_PATH=""
-DATA_ROOT_DIR=""
-RUN_ROOT_DIR=""
-OPENVLA_ROOT_DIR=""
-BATCH_SIZE=2
-LEARNING_RATE=5e-4
-MAX_STEPS=50000
-USE_LORA=true
-LORA_RANK=32
-USE_QUANTIZATION=false
-IMAGE_AUG=true
-WANDB_PROJECT=""
-WANDB_ENTITY=""
-NUM_GPUS=1
-
-# Dataset configuration parameters
-IMAGE_OBS_PRIMARY="image"
-IMAGE_OBS_SECONDARY=""
-IMAGE_OBS_WRIST="wrist_image"
-DEPTH_OBS_PRIMARY=""
-DEPTH_OBS_SECONDARY=""
-DEPTH_OBS_WRIST=""
-STATE_OBS_KEYS="EEF_state,None,gripper_state"
-STATE_ENCODING="POS_EULER"
-ACTION_ENCODING="EEF_POS"
-
-# Parse command line arguments
-while [[ $# -gt 0 ]]; do
- case $1 in
- --dataset_name)
- DATASET_NAME="$2"
- shift 2
- ;;
- --vla_path)
- VLA_PATH="$2"
- shift 2
- ;;
- --data_root_dir)
- DATA_ROOT_DIR="$2"
- shift 2
- ;;
- --run_root_dir)
- RUN_ROOT_DIR="$2"
- shift 2
- ;;
- --openvla_root_dir)
- OPENVLA_ROOT_DIR="$2"
- shift 2
- ;;
- --batch_size)
- BATCH_SIZE="$2"
- shift 2
- ;;
- --learning_rate)
- LEARNING_RATE="$2"
- shift 2
- ;;
- --max_steps)
- MAX_STEPS="$2"
- shift 2
- ;;
- --use_lora)
- USE_LORA="$2"
- shift 2
- ;;
- --lora_rank)
- LORA_RANK="$2"
- shift 2
- ;;
- --use_quantization)
- USE_QUANTIZATION="$2"
- shift 2
- ;;
- --image_aug)
- IMAGE_AUG="$2"
- shift 2
- ;;
- --wandb_project)
- WANDB_PROJECT="$2"
- shift 2
- ;;
- --wandb_entity)
- WANDB_ENTITY="$2"
- shift 2
- ;;
- --image_obs_primary)
- IMAGE_OBS_PRIMARY="$2"
- shift 2
- ;;
- --image_obs_secondary)
- IMAGE_OBS_SECONDARY="$2"
- shift 2
- ;;
- --image_obs_wrist)
- IMAGE_OBS_WRIST="$2"
- shift 2
- ;;
- --depth_obs_primary)
- DEPTH_OBS_PRIMARY="$2"
- shift 2
- ;;
- --depth_obs_secondary)
- DEPTH_OBS_SECONDARY="$2"
- shift 2
- ;;
- --depth_obs_wrist)
- DEPTH_OBS_WRIST="$2"
- shift 2
- ;;
- --state_obs_keys)
- STATE_OBS_KEYS="$2"
- shift 2
- ;;
- --state_encoding)
- STATE_ENCODING="$2"
- shift 2
- ;;
- --action_encoding)
- ACTION_ENCODING="$2"
- shift 2
- ;;
- --num_gpus)
- NUM_GPUS="$2"
- shift 2
- ;;
- --help)
- echo "Usage: $0 --dataset_name [options]"
- echo ""
- echo "Required arguments:"
- echo " --dataset_name Dataset name (required)"
- echo " --vla_path Path to OpenVLA model (required)"
- echo " --data_root_dir Root directory for datasets (required)"
- echo " --openvla_root_dir Root directory of OpenVLA repository (required)"
- echo ""
- echo "Optional arguments:"
- echo " --run_root_dir Root directory for runs (default: new_runs)"
- echo " --batch_size Batch size (default: 2)"
- echo " --learning_rate Learning rate (default: 5e-4)"
- echo " --max_steps Maximum training steps (default: 50000)"
- echo " --use_lora Use LoRA fine-tuning (default: true)"
- echo " --lora_rank LoRA rank (default: 32)"
- echo " --use_quantization Use quantization (default: false)"
- echo " --image_aug Use image augmentation (default: true)"
- echo " --wandb_project WandB project name (default: safe-openvla)"
- echo " --wandb_entity WandB entity name (default: trial)"
- echo ""
- echo "Dataset configuration:"
- echo " --image_obs_primary Primary image observation key (default: image)"
- echo " --image_obs_secondary Secondary image observation key (default: empty)"
- echo " --image_obs_wrist Wrist image observation key (default: wrist_image)"
- echo " --depth_obs_primary Primary depth observation key (default: empty)"
- echo " --depth_obs_secondary Secondary depth observation key (default: empty)"
- echo " --depth_obs_wrist Wrist depth observation key (default: empty)"
- echo " --state_obs_keys State observation keys (default: EEF_state,None,gripper_state)"
- echo " --state_encoding State encoding (default: POS_EULER)"
- echo " --action_encoding Action encoding (default: EEF_POS)"
- echo ""
- echo "GPU configuration:"
- echo " --num_gpus Number of GPUs to use (default: 1)"
- exit 0
- ;;
- *)
- echo "Unknown option: $1"
- echo "Use --help for usage information"
- exit 1
- ;;
- esac
-done
-
-# Check if required parameters are provided
-if [ -z "$DATASET_NAME" ]; then
- echo "Error: --dataset_name is required"
- echo "Use --help for usage information"
- exit 1
-fi
-
-if [ -z "$VLA_PATH" ]; then
- echo "Error: --vla_path is required"
- echo "Use --help for usage information"
- exit 1
-fi
-
-if [ -z "$DATA_ROOT_DIR" ]; then
- echo "Error: --data_root_dir is required"
- echo "Use --help for usage information"
- exit 1
-fi
-
-if [ -z "$OPENVLA_ROOT_DIR" ]; then
- echo "Error: --openvla_root_dir is required"
- echo "Use --help for usage information"
- exit 1
-fi
-
-echo "Adding dataset configuration for: $DATASET_NAME"
-echo "Dataset configuration:"
-echo " Image obs: primary=$IMAGE_OBS_PRIMARY, secondary=$IMAGE_OBS_SECONDARY, wrist=$IMAGE_OBS_WRIST"
-echo " Depth obs: primary=$DEPTH_OBS_PRIMARY, secondary=$DEPTH_OBS_SECONDARY, wrist=$DEPTH_OBS_WRIST"
-echo " State obs keys: $STATE_OBS_KEYS"
-echo " State encoding: $STATE_ENCODING"
-echo " Action encoding: $ACTION_ENCODING"
-
-# Convert empty strings to None for Python
-if [ -z "$IMAGE_OBS_SECONDARY" ]; then
- IMAGE_OBS_SECONDARY="None"
-fi
-if [ -z "$IMAGE_OBS_WRIST" ]; then
- IMAGE_OBS_WRIST="None"
-fi
-if [ -z "$DEPTH_OBS_PRIMARY" ]; then
- DEPTH_OBS_PRIMARY="None"
-fi
-if [ -z "$DEPTH_OBS_SECONDARY" ]; then
- DEPTH_OBS_SECONDARY="None"
-fi
-if [ -z "$DEPTH_OBS_WRIST" ]; then
- DEPTH_OBS_WRIST="None"
-fi
-
-# Create Python script to add dataset configuration
-cat > /tmp/add_dataset_config.py << EOF
-import sys
-import re
-
-def add_dataset_config():
- # Paths to the files
- configs_path = "$OPENVLA_ROOT_DIR/prismatic/vla/datasets/rlds/oxe/configs.py"
- transforms_path = "$OPENVLA_ROOT_DIR/prismatic/vla/datasets/rlds/oxe/transforms.py"
-
- dataset_name = "$DATASET_NAME"
-
- # Process state_obs_keys to handle None values properly
- state_obs_keys = "$STATE_OBS_KEYS"
- state_obs_list = []
- for key in state_obs_keys.split(','):
- key = key.strip()
- if key == 'None':
- state_obs_list.append('None')
- else:
- state_obs_list.append(f'"{key}"')
- state_obs_str = ', '.join(state_obs_list)
-
- # Read configs.py
- with open(configs_path, 'r') as f:
- configs_content = f.read()
-
- # Check if dataset already exists
- if f'"{dataset_name}":' in configs_content:
- print(f"Dataset {dataset_name} already exists in configs.py")
- else:
- # Find the end of OXE_DATASET_CONFIGS dictionary and add before closing brace
- # Look for the pattern: },\n}
- pattern = r'(\s+)(\})\s*$'
-
- config_entry = f'''
- "{dataset_name}": {{
- "image_obs_keys": {{"primary": "$IMAGE_OBS_PRIMARY", "secondary": "$IMAGE_OBS_SECONDARY", "wrist": "$IMAGE_OBS_WRIST"}},
- "depth_obs_keys": {{"primary": "$DEPTH_OBS_PRIMARY", "secondary": "$DEPTH_OBS_SECONDARY", "wrist": "$DEPTH_OBS_WRIST"}},
- "state_obs_keys": [{state_obs_str}],
- "state_encoding": StateEncoding.$STATE_ENCODING,
- "action_encoding": ActionEncoding.$ACTION_ENCODING,
- }},'''
-
- # Insert before the closing brace
- replacement = f'{config_entry}\n}}'
- configs_content = re.sub(pattern, replacement, configs_content, flags=re.MULTILINE)
-
- # Write back to configs.py
- with open(configs_path, 'w') as f:
- f.write(configs_content)
- print(f"Added dataset configuration for {dataset_name} to configs.py")
-
- # Read transforms.py
- with open(transforms_path, 'r') as f:
- transforms_content = f.read()
-
- # Check if dataset already exists in transforms
- if f'"{dataset_name}":' in transforms_content:
- print(f"Dataset {dataset_name} already exists in transforms.py")
- else:
- # Find the end of OXE_STANDARDIZATION_TRANSFORMS dictionary and add before closing brace
- pattern = r'(\s+)(\})\s*$'
-
- transform_entry = f'\n "{dataset_name}": libero_dataset_transform,'
-
- # Insert before the closing brace
- replacement = f'{transform_entry}\n}}'
- transforms_content = re.sub(pattern, replacement, transforms_content, flags=re.MULTILINE)
-
- # Write back to transforms.py
- with open(transforms_path, 'w') as f:
- f.write(transforms_content)
- print(f"Added dataset transform for {dataset_name} to transforms.py")
-
-if __name__ == "__main__":
- add_dataset_config()
-EOF
-
-# Run the Python script to add dataset configuration
-python3 /tmp/add_dataset_config.py
-
-# Clean up temporary file
-rm /tmp/add_dataset_config.py
-
-echo "Starting fine-tuning..."
-
-# Run the fine-tuning script
-cd "$OPENVLA_ROOT_DIR"
-
-torchrun --standalone --nnodes 1 --nproc-per-node $NUM_GPUS vla-scripts/finetune.py \
- --vla_path "$VLA_PATH" \
- --data_root_dir "$DATA_ROOT_DIR" \
- --dataset_name "$DATASET_NAME" \
- --run_root_dir "$RUN_ROOT_DIR" \
- --batch_size "$BATCH_SIZE" \
- --learning_rate "$LEARNING_RATE" \
- --max_steps "$MAX_STEPS" \
- --use_lora "$USE_LORA" \
- --lora_rank "$LORA_RANK" \
- --use_quantization "$USE_QUANTIZATION" \
- --image_aug "$IMAGE_AUG" \
- --wandb_project "$WANDB_PROJECT" \
- --wandb_entity "$WANDB_ENTITY"
-
-echo "Fine-tuning completed!"
diff --git a/docs/finetune_zh.md b/docs/finetune_zh.md
deleted file mode 100644
index 0768f405..00000000
--- a/docs/finetune_zh.md
+++ /dev/null
@@ -1,838 +0,0 @@
-# 使用VLA-Arena生成的数据集微调其他模型
-
-VLA-Arena提供了完整的搜集数据、转换数据格式、评估语言-视觉-动作模型的框架,本指南将带你了解如何使用VLA-Arena生成的数据集微调一些VLA模型
-
-## 快速开始
-
-如果你已经准备好了数据集和OpenVLA模型,可以直接使用以下命令开始微调:
-
-### 标准OpenVLA微调
-
-```bash
-# 1. 激活环境
-conda activate openvla
-
-# 2. 运行微调脚本
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/openvla/model" \
- --data_root_dir "/path/to/your/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-### OpenVLA OFT微调(推荐)
-
-```bash
-# 1. 激活环境
-conda activate openvla
-
-# 2. 运行OFT微调脚本
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/openvla/model" \
- --data_root_dir "/path/to/your/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-### UniVLA微调
-
-```bash
-# 1. 激活环境
-conda activate univla
-
-# 2. 运行UniVLA微调脚本
-./vla-scripts/finetune_univla.sh \
- --dataset_name "your_dataset" \
- --vla_path "/path/to/your/univla/model" \
- --lam_path "/path/to/your/lam/checkpoint" \
- --data_root_dir "/path/to/your/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-```
-
-详细的使用说明请参考下面的章节。
-
-## 目录
-
-1. [快速开始](#快速开始)
-2. [微调OpenVLA](#微调OpenVLA)
- - [安装OpenVLA库](#安装OpenVLA库)
- - [使用脚本一键微调](#使用脚本一键微调)
- - [基本使用方法](#基本使用方法)
- - [必需参数](#必需参数)
- - [可选参数](#可选参数)
- - [数据集配置参数](#数据集配置参数)
- - [状态和动作编码选项](#状态和动作编码选项)
- - [使用示例](#使用示例)
- - [脚本功能](#脚本功能)
- - [注意事项](#注意事项)
-3. [微调OpenVLA OFT](#微调OpenVLA-OFT)
- - [OFT微调简介](#OFT微调简介)
- - [使用OFT脚本微调](#使用OFT脚本微调)
- - [基本使用方法](#基本使用方法-1)
- - [必需参数](#必需参数-1)
- - [基础训练参数](#基础训练参数)
- - [LoRA参数](#LoRA参数)
- - [动作表示参数](#动作表示参数)
- - [架构选项](#架构选项)
- - [学习率调度](#学习率调度)
- - [验证和检查点](#验证和检查点)
- - [日志配置](#日志配置)
- - [数据集配置参数](#数据集配置参数-1)
- - [GPU配置](#GPU配置)
- - [使用示例](#使用示例-1)
- - [脚本功能](#脚本功能-1)
- - [注意事项](#注意事项-1)
-4. [微调UniVLA](#微调UniVLA)
- - [安装UniVLA库](#安装UniVLA库)
- - [使用脚本一键微调](#使用脚本一键微调-1)
- - [基本使用方法](#基本使用方法-2)
- - [必需参数](#必需参数-2)
- - [基础训练参数](#基础训练参数-1)
- - [LoRA参数](#LoRA参数-1)
- - [UniVLA特定参数](#UniVLA特定参数)
- - [LAM参数](#LAM参数)
- - [日志配置](#日志配置-1)
- - [数据集配置参数](#数据集配置参数-2)
- - [GPU配置](#GPU配置-1)
- - [使用示例](#使用示例-2)
- - [脚本功能](#脚本功能-2)
- - [注意事项](#注意事项-2)
-5. [微调OpenPi](#微调OpenPi)
- - [安装OpenPi库](#安装OpenPi库)
- - [使用脚本一键微调](#使用脚本一键微调-2)
- - [基本使用方法](#基本使用方法-3)
- - [必需参数](#必需参数-3)
- - [模型配置参数](#模型配置参数)
- - [训练参数](#训练参数)
- - [数据集配置参数](#数据集配置参数-3)
- - [使用示例](#使用示例-3)
- - [脚本功能](#脚本功能-3)
- - [注意事项](#注意事项-3)
-6. [模型评估](#模型评估)
-7. [添加自定义模型](#添加自定义模型)
-8. [配置说明](#配置说明)
-
-## 微调OpenVLA
-
-### 安装OpenVLA库
-
-```bash
-# Create and activate conda environment
-conda create -n openvla python=3.10 -y
-conda activate openvla
-
-# Install PyTorch. Below is a sample command to do this, but you should check the following link
-# to find installation instructions that are specific to your compute platform:
-# https://pytorch.org/get-started/locally/
-conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y # UPDATE ME!
-
-# Clone and install the openvla repo
-git clone https://github.com/openvla/openvla.git
-cd openvla
-pip install -e .
-
-# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
-# =>> If you run into difficulty, try `pip cache remove flash_attn` first
-pip install packaging ninja
-ninja --version; echo $? # Verify Ninja --> should return exit code "0"
-pip install "flash-attn==2.5.5" --no-build-isolation
-```
-### 使用脚本一键微调
-
-将 [finetune_openvla.sh](./finetune_openvla.sh) 粘贴至 openvla/vla-scripts 目录下,该脚本会自动添加数据集配置并运行微调。
-
-#### 基本使用方法
-
-```bash
-# 激活conda环境
-conda activate openvla
-
-# 基本使用(需要提供必需参数)
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-
-# 自定义参数
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 10000 \
- --wandb_project "my_project"
-```
-
-#### 必需参数
-
-- `--dataset_name`: 数据集名称(必需)
-- `--vla_path`: OpenVLA模型路径(必需)
-- `--data_root_dir`: 数据集根目录(必需)
-- `--openvla_root_dir`: OpenVLA仓库根目录(必需)
-
-#### 可选参数
-
-- `--run_root_dir`: 运行结果保存目录(默认:`new_runs`)
-- `--batch_size`: 批次大小(默认:`2`)
-- `--learning_rate`: 学习率(默认:`5e-4`)
-- `--max_steps`: 最大训练步数(默认:`50000`)
-- `--use_lora`: 是否使用LoRA微调(默认:`true`)
-- `--lora_rank`: LoRA秩(默认:`32`)
-- `--use_quantization`: 是否使用量化(默认:`false`)
-- `--image_aug`: 是否使用图像增强(默认:`true`)
-- `--wandb_project`: WandB项目名称(默认:`safe-openvla`)
-- `--wandb_entity`: WandB实体名称(默认:`trial`)
-- `--num_gpus`: 使用的GPU数量(默认:`1`)
-
-#### 数据集配置参数
-
-脚本会自动将你的数据集配置添加到 `configs.py` 和 `transforms.py` 文件中。你可以自定义数据集配置:
-
-- `--image_obs_primary`: 主要图像观测键(默认:`image`)
-- `--image_obs_secondary`: 次要图像观测键(默认:空)
-- `--image_obs_wrist`: 手腕图像观测键(默认:`wrist_image`)
-- `--depth_obs_primary`: 主要深度观测键(默认:空)
-- `--depth_obs_secondary`: 次要深度观测键(默认:空)
-- `--depth_obs_wrist`: 手腕深度观测键(默认:空)
-- `--state_obs_keys`: 状态观测键(默认:`EEF_state,None,gripper_state`)
-- `--state_encoding`: 状态编码(默认:`POS_EULER`)
-- `--action_encoding`: 动作编码(默认:`EEF_POS`)
-
-#### 状态和动作编码选项
-
-**状态编码**:
-- `NONE`: 无本体感受状态
-- `POS_EULER`: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + 夹爪开合 (1)
-- `POS_QUAT`: EEF XYZ (3) + 四元数 (4) + 夹爪开合 (1)
-- `JOINT`: 关节角度 (7, 不足时用填充) + 夹爪开合 (1)
-- `JOINT_BIMANUAL`: 关节角度 (2 x [ 关节角度 (6) + 夹爪开合 (1) ])
-
-**动作编码**:
-- `EEF_POS`: EEF增量XYZ (3) + Roll-Pitch-Yaw (3) + 夹爪开合 (1)
-- `JOINT_POS`: 关节增量位置 (7) + 夹爪开合 (1)
-- `JOINT_POS_BIMANUAL`: 关节增量位置 (2 x [ 关节增量位置 (6) + 夹爪开合 (1) ])
-- `EEF_R6`: EEF增量XYZ (3) + R6 (6) + 夹爪开合 (1)
-
-#### 使用示例
-
-**示例1:基本使用**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-**示例2:自定义配置**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "custom_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --image_obs_primary "front_camera" \
- --image_obs_wrist "gripper_camera" \
- --state_obs_keys "joint_positions,None,gripper_state" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 20000
-```
-
-**示例3:使用量化**
-```bash
-./vla-scripts/finetune_openvla.sh \
- --dataset_name "quantized_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --use_quantization true \
- --batch_size 16 \
- --max_steps 5000
-```
-
-#### 脚本功能
-
-1. **参数验证**:检查必需参数是否提供
-2. **添加数据集配置**:自动将你的数据集配置添加到:
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **运行微调**:使用你的参数执行OpenVLA微调脚本
-
-#### 注意事项
-
-- 脚本使用 `libero_dataset_transform` 作为新数据集的默认变换函数
-- 如果数据集配置已存在,将跳过添加配置步骤
-- 脚本自动处理状态观测键中的 `None` 值
-- 确保你的数据集采用正确的RLDS格式并位于指定的数据目录中
-
-## 微调OpenVLA OFT
-
-### OFT微调简介
-
-OpenVLA OFT(Open-source Foundation Transformers)微调提供了更高级的训练选项和更好的性能。OFT版本支持:
-
-- **更丰富的训练参数**:包括学习率调度、梯度累积、验证集等
-- **动作表示选项**:支持L1回归和扩散建模
-- **架构增强**:FiLM语言融合、多图像输入、本体感受状态等
-- **高级优化**:LoRA dropout、训练时LoRA合并等
-
-### 使用OFT脚本微调
-
-将 [finetune_openvla_oft.sh](./finetune_openvla_oft.sh) 粘贴至 openvla/vla-scripts 目录下,该脚本提供了更全面的微调选项。
-
-#### 基本使用方法
-
-```bash
-# 激活conda环境
-conda activate openvla
-
-# 基本使用(需要提供必需参数)
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-
-# 自定义参数
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 100000 \
- --use_l1_regression true \
- --use_film true
-```
-
-#### 必需参数
-
-- `--dataset_name`: 数据集名称(必需)
-- `--vla_path`: OpenVLA模型路径(必需)
-- `--data_root_dir`: 数据集根目录(必需)
-- `--openvla_root_dir`: OpenVLA仓库根目录(必需)
-
-#### 基础训练参数
-
-- `--run_root_dir`: 运行结果保存目录(默认:`all_runs`)
-- `--batch_size`: 批次大小(默认:`7`)
-- `--learning_rate`: 学习率(默认:`5e-4`)
-- `--max_steps`: 最大训练步数(默认:`150000`)
-- `--grad_accumulation_steps`: 梯度累积步数(默认:`1`)
-- `--shuffle_buffer_size`: 数据加载器随机缓冲区大小(默认:`100000`)
-
-#### LoRA参数
-
-- `--use_lora`: 是否使用LoRA微调(默认:`true`)
-- `--lora_rank`: LoRA秩(默认:`32`)
-- `--lora_dropout`: LoRA dropout(默认:`0.0`)
-- `--merge_lora_during_training`: 训练时合并LoRA(默认:`true`)
-
-#### 动作表示参数
-
-- `--use_l1_regression`: 使用L1回归(默认:`true`)
-- `--use_diffusion`: 使用扩散建模(默认:`false`)
-- `--num_diffusion_steps_train`: 训练扩散步数(默认:`50`)
-- `--diffusion_sample_freq`: 扩散采样频率(默认:`50`)
-
-#### 架构选项
-
-- `--use_film`: 使用FiLM进行语言融合(默认:`true`)
-- `--num_images_in_input`: 输入图像数量(默认:`2`)
-- `--use_proprio`: 包含本体感受状态(默认:`false`)
-- `--use_quantization`: 使用量化(默认:`false`)
-- `--image_aug`: 使用图像增强(默认:`true`)
-
-#### 学习率调度
-
-- `--lr_warmup_steps`: 学习率预热步数(默认:`0`)
-- `--num_steps_before_decay`: 学习率衰减前步数(默认:`60000`)
-
-#### 验证和检查点
-
-- `--use_val_set`: 使用验证集(默认:`false`)
-- `--val_freq`: 验证频率(默认:`10000`)
-- `--val_time_limit`: 验证时间限制(默认:`180`)
-- `--save_freq`: 保存频率(默认:`5000`)
-- `--save_latest_checkpoint_only`: 仅保存最新检查点(默认:`false`)
-- `--resume`: 从检查点恢复(默认:`false`)
-- `--resume_step`: 恢复步数(默认:空)
-
-#### 日志配置
-
-- `--wandb_project`: WandB项目名称(默认:`openvla-oft-workflow-generalization`)
-- `--wandb_entity`: WandB实体名称(默认:`trial`)
-- `--wandb_log_freq`: WandB日志频率(默认:`10`)
-
-#### 数据集配置参数
-
-脚本会自动将你的数据集配置添加到 `configs.py` 和 `transforms.py` 文件中。你可以自定义数据集配置:
-
-- `--image_obs_primary`: 主要图像观测键(默认:`image`)
-- `--image_obs_secondary`: 次要图像观测键(默认:空)
-- `--image_obs_wrist`: 手腕图像观测键(默认:`wrist_image`)
-- `--depth_obs_primary`: 主要深度观测键(默认:空)
-- `--depth_obs_secondary`: 次要深度观测键(默认:空)
-- `--depth_obs_wrist`: 手腕深度观测键(默认:空)
-- `--state_obs_keys`: 状态观测键(默认:`EEF_state,None,gripper_state`)
-- `--state_encoding`: 状态编码(默认:`POS_EULER`)
-- `--action_encoding`: 动作编码(默认:`EEF_POS`)
-
-#### GPU配置
-
-- `--num_gpus`: 使用的GPU数量(默认:`1`)
-
-#### 使用示例
-
-**示例1:基本OFT使用**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo"
-```
-
-**示例2:高级OFT配置**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "advanced_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --batch_size 8 \
- --learning_rate 1e-4 \
- --max_steps 100000 \
- --use_l1_regression true \
- --use_film true \
- --use_proprio true \
- --num_images_in_input 3 \
- --lora_rank 64 \
- --grad_accumulation_steps 2
-```
-
-**示例3:使用扩散建模**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "diffusion_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --use_diffusion true \
- --num_diffusion_steps_train 100 \
- --diffusion_sample_freq 25 \
- --batch_size 4
-```
-
-**示例4:多GPU训练**
-```bash
-./vla-scripts/finetune_openvla_oft.sh \
- --dataset_name "multi_gpu_dataset" \
- --vla_path "/path/to/openvla/model" \
- --data_root_dir "/path/to/datasets" \
- --openvla_root_dir "/path/to/openvla/repo" \
- --num_gpus 4 \
- --batch_size 16 \
- --grad_accumulation_steps 1
-```
-
-#### 脚本功能
-
-1. **参数验证**:检查必需参数是否提供
-2. **添加数据集配置**:自动将你的数据集配置添加到:
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{openvla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **运行OFT微调**:使用你的参数执行OpenVLA OFT微调脚本
-4. **多GPU支持**:支持多GPU分布式训练
-
-#### 注意事项
-
-- OFT版本提供更丰富的训练选项,适合需要精细控制训练过程的用户
-- 支持扩散建模,适合需要生成式动作预测的场景
-- FiLM语言融合可以提供更好的语言-视觉交互
-- 多图像输入支持多视角机器人任务
-- 确保你的硬件资源足够支持所选的训练配置
-
-## 微调UniVLA
-
-### 安装UniVLA库
-
-```bash
-# 创建并激活conda环境
-conda create -n univla python=3.10 -y
-conda activate univla
-
-# 安装PyTorch。下面是一个示例命令,但你应该检查以下链接
-# 以找到适合你计算平台的安装说明:
-# https://pytorch.org/get-started/locally/
-conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y # 请更新!
-
-# 克隆并安装univla仓库
-git clone https://github.com/opendrivelab/UniVLA.git
-cd UniVLA
-pip install -e .
-
-# 安装Flash Attention 2用于训练 (https://github.com/Dao-AILab/flash-attention)
-# =>> 如果遇到困难,请先尝试 `pip cache remove flash_attn`
-pip install packaging ninja
-ninja --version; echo $? # 验证Ninja --> 应该返回退出代码"0"
-pip install "flash-attn==2.5.5" --no-build-isolation
-
-# 安装UniVLA的额外依赖
-pip install swanlab
-pip install ema-pytorch
-pip install peft
-pip install accelerate
-```
-
-### 使用脚本一键微调
-
-将 [finetune_univla.sh](./finetune_univla.sh) 粘贴至 UniVLA/vla-scripts 目录下,该脚本会自动添加数据集配置并运行微调。
-
-#### 基本使用方法
-
-```bash
-# 激活conda环境
-conda activate univla
-
-# 基本使用(需要提供必需参数)
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-
-# 自定义参数
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 50000 \
- --wandb_project "my_project"
-```
-
-#### 必需参数
-
-- `--dataset_name`: 数据集名称(必需)
-- `--vla_path`: UniVLA模型路径(必需)
-- `--lam_path`: LAM(潜在动作模型)检查点路径(必需)
-- `--data_root_dir`: 数据集根目录(必需)
-- `--univla_root_dir`: UniVLA仓库根目录(必需)
-
-#### 基础训练参数
-
-- `--run_root_dir`: 运行结果保存目录(默认:`all_runs`)
-- `--batch_size`: 批次大小(默认:`8`)
-- `--learning_rate`: 学习率(默认:`3.5e-4`)
-- `--max_steps`: 最大训练步数(默认:`100000`)
-- `--save_steps`: 保存间隔(默认:`10000`)
-- `--grad_accumulation_steps`: 梯度累积步数(默认:`2`)
-- `--shuffle_buffer_size`: 数据加载器随机缓冲区大小(默认:`16000`)
-
-#### LoRA参数
-
-- `--use_lora`: 是否使用LoRA微调(默认:`true`)
-- `--lora_rank`: LoRA秩(默认:`32`)
-- `--lora_dropout`: LoRA dropout(默认:`0.0`)
-- `--use_quantization`: 是否使用量化(默认:`false`)
-
-#### UniVLA特定参数
-
-- `--freeze_vla`: 冻结VLA骨干网络(默认:`false`)
-- `--save_latest_checkpoint_only`: 仅保存最新检查点(默认:`true`)
-- `--run_id_note`: 实验ID的额外注释(默认:空)
-
-#### LAM参数
-
-UniVLA使用潜在动作模型(LAM)进行动作表示。这些参数控制LAM架构:
-
-- `--codebook_size`: LAM码本大小(默认:`16`)
-- `--lam_model_dim`: LAM模型维度(默认:`768`)
-- `--lam_latent_dim`: LAM潜在维度(默认:`128`)
-- `--lam_patch_size`: LAM补丁大小(默认:`14`)
-- `--lam_enc_blocks`: LAM编码器块数(默认:`12`)
-- `--lam_dec_blocks`: LAM解码器块数(默认:`12`)
-- `--lam_num_heads`: LAM注意力头数(默认:`12`)
-- `--window_size`: 动作窗口大小(默认:`12`)
-
-#### 日志配置
-
-- `--wandb_project`: WandB项目名称(默认:`finetune-UniVLA`)
-- `--wandb_entity`: WandB实体名称(默认:`opendrivelab`)
-
-#### 数据集配置参数
-
-脚本会自动将你的数据集配置添加到 `configs.py` 和 `transforms.py` 文件中。你可以自定义数据集配置:
-
-- `--image_obs_primary`: 主要图像观测键(默认:`image`)
-- `--image_obs_secondary`: 次要图像观测键(默认:空)
-- `--image_obs_wrist`: 手腕图像观测键(默认:`wrist_image`)
-- `--depth_obs_primary`: 主要深度观测键(默认:空)
-- `--depth_obs_secondary`: 次要深度观测键(默认:空)
-- `--depth_obs_wrist`: 手腕深度观测键(默认:空)
-- `--state_obs_keys`: 状态观测键(默认:`EEF_state,None,gripper_state`)
-- `--state_encoding`: 状态编码(默认:`POS_EULER`)
-- `--action_encoding`: 动作编码(默认:`EEF_POS`)
-
-#### GPU配置
-
-- `--num_gpus`: 使用的GPU数量(默认:`1`)
-
-#### 使用示例
-
-**示例1:基本使用**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "my_robot_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo"
-```
-
-**示例2:自定义配置**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "custom_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --image_obs_primary "front_camera" \
- --image_obs_wrist "gripper_camera" \
- --state_obs_keys "joint_positions,None,gripper_state" \
- --batch_size 4 \
- --learning_rate 1e-4 \
- --max_steps 50000 \
- --window_size 16
-```
-
-**示例3:使用量化**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "quantized_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --use_quantization true \
- --batch_size 16 \
- --max_steps 25000
-```
-
-**示例4:冻结VLA骨干网络**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "frozen_vla_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --freeze_vla true \
- --learning_rate 1e-3 \
- --batch_size 12
-```
-
-**示例5:多GPU训练**
-```bash
-./vla-scripts/finetune_univla.sh \
- --dataset_name "multi_gpu_dataset" \
- --vla_path "/path/to/univla/model" \
- --lam_path "/path/to/lam/checkpoint" \
- --data_root_dir "/path/to/datasets" \
- --univla_root_dir "/path/to/univla/repo" \
- --num_gpus 4 \
- --batch_size 8 \
- --grad_accumulation_steps 1
-```
-
-#### 脚本功能
-
-1. **参数验证**:检查必需参数是否提供
-2. **添加数据集配置**:自动将你的数据集配置添加到:
- - `{univla_root_dir}/prismatic/vla/datasets/rlds/oxe/configs.py`
- - `{univla_root_dir}/prismatic/vla/datasets/rlds/oxe/transforms.py`
-3. **运行UniVLA微调**:使用你的参数执行UniVLA微调脚本
-4. **多GPU支持**:支持多GPU分布式训练
-5. **LAM集成**:自动配置和加载潜在动作模型
-
-#### 注意事项
-
-- UniVLA使用带有潜在动作模型(LAM)的两阶段训练方法
-- LAM检查点是必需的,应该预先训练
-- 脚本使用 `libero_dataset_transform` 作为新数据集的默认变换函数
-- 如果数据集配置已存在,将跳过添加配置步骤
-- 脚本自动处理状态观测键中的 `None` 值
-- 确保你的数据集采用正确的RLDS格式并位于指定的数据目录中
-- UniVLA支持冻结和未冻结的VLA骨干网络训练
-
-## 微调OpenPi
-
-### 安装OpenPi库
-
-```bash
-# 克隆仓库(包含子模块)
-git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
-
-# 或者如果已经克隆了仓库:
-cd openpi
-git submodule update --init --recursive
-
-# 安装 uv(如果尚未安装)
-curl -LsSf https://astral.sh/uv/install.sh | sh
-
-# 安装 OpenPi
-cd openpi
-GIT_LFS_SKIP_SMUDGE=1 uv sync
-GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
-```
-
-**注意:** `GIT_LFS_SKIP_SMUDGE=1` 是必需的,用于跳过 LeRobot 依赖的 LFS 文件下载。
-
-### 使用脚本一键微调
-
-将 [finetune_openpi.sh](./finetune_openpi.sh) 粘贴至 openpi/scripts 目录下,该脚本会自动添加训练配置并运行微调。
-
-#### 基本使用方法
-
-```bash
-# 基本使用(需要提供必需参数)
-uv run bash scripts/finetune_openpi.sh \
- --config_name "my_openpi_config" \
- --exp_name "my_experiment" \
- --base_checkpoint_path "/path/to/base/checkpoint" \
- --dataset_repo_id "your_dataset_repo" \
- --hf_lerobot_home "/path/to/lerobot/home"
-
-# 自定义参数
-uv run bash scripts/finetune_openpi.sh \
- --config_name "custom_config" \
- --exp_name "custom_experiment" \
- --base_checkpoint_path "/path/to/base/checkpoint" \
- --dataset_repo_id "your_dataset_repo" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --model_type "pi0_fast" \
- --batch_size 32 \
- --learning_rate 1e-4 \
- --num_train_steps 50000
-```
-
-#### 必需参数
-
-- `--config_name`: 配置名称(必需)
-- `--exp_name`: 实验名称(必需)
-- `--base_checkpoint_path`: 基础模型检查点路径(必需)
-- `--dataset_repo_id`: 数据集仓库ID(必需)
-- `--hf_lerobot_home`: HF_LEROBOT_HOME 目录路径(必需)
-
-#### 模型配置参数
-
-- `--model_type`: 模型类型,pi0 或 pi0_fast(默认:pi0)
-- `--action_dim`: 动作维度(默认:7)
-- `--action_horizon`: 动作时间范围(默认:10)
-- `--max_token_len`: 最大token长度(默认:180)
-- `--use_lora`: 使用LoRA微调(默认:false)
-- `--lora_rank`: LoRA秩(默认:32)
-- `--lora_dropout`: LoRA dropout(默认:0.0)
-- `--paligemma_variant`: Paligemma变体(默认:gemma_2b)
-- `--action_expert_variant`: 动作专家变体(默认:gemma_300m)
-
-#### 训练参数
-
-- `--batch_size`: 批次大小(默认:56)
-- `--learning_rate`: 学习率(默认:3.5e-4)
-- `--num_train_steps`: 训练步数(默认:30000)
-- `--log_interval`: 日志间隔(默认:100)
-- `--save_interval`: 保存间隔(默认:1000)
-- `--keep_period`: 保留周期(默认:5000)
-- `--num_workers`: 工作进程数(默认:2)
-- `--seed`: 随机种子(默认:42)
-- `--fsdp_devices`: FSDP设备数(默认:1)
-- `--ema_decay`: EMA衰减(默认:0.99)
-
-#### 数据集配置参数
-
-- `--prompt_from_task`: 从任务中获取提示(默认:true)
-
-#### 使用示例
-
-**示例1:基本使用**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0" \
- --exp_name "libero_experiment" \
- --base_checkpoint_path "/path/to/pi0/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home"
-```
-
-**示例2:使用 pi0_fast 模型**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0_fast" \
- --exp_name "libero_fast_experiment" \
- --base_checkpoint_path "/path/to/pi0_fast/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --model_type "pi0_fast" \
- --batch_size 32 \
- --learning_rate 1e-4
-```
-
-**示例3:使用LoRA微调**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "libero_pi0_lora" \
- --exp_name "libero_lora_experiment" \
- --base_checkpoint_path "/path/to/pi0/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --use_lora true \
- --lora_rank 64 \
- --lora_dropout 0.1
-```
-
-**示例4:自定义训练参数**
-```bash
-uv run bash scripts/finetune_openpi.sh \
- --config_name "custom_libero" \
- --exp_name "custom_experiment" \
- --base_checkpoint_path "/path/to/checkpoint" \
- --dataset_repo_id "libero_dataset" \
- --hf_lerobot_home "/path/to/lerobot/home" \
- --batch_size 64 \
- --learning_rate 2e-4 \
- --num_train_steps 100000 \
- --save_interval 2000 \
- --wandb_enabled true \
- --project_name "my_openpi_project"
-```
-
-#### 脚本功能
-
-1. **参数验证**:检查必需参数是否提供
-2. **添加训练配置**:自动将你的训练配置添加到 `src/openpi/training/config.py`
-3. **计算归一化统计**:自动运行 `scripts/compute_norm_stats.py`
-4. **运行训练**:使用你的参数执行OpenPi训练脚本
-5. **支持覆盖**:可选择覆盖现有检查点
-
-#### 注意事项
-
-- 脚本使用 `LeRobotLiberoDataConfig` 作为数据集配置
-- 如果配置已存在,将跳过添加配置步骤
-- 支持 pi0 和 pi0_fast 两种模型类型
-- LoRA微调时会自动设置相应的冻结过滤器
-- 确保基础检查点路径有效且可访问
-- 确保数据集仓库ID正确且可访问
-- 脚本会自动设置 `HF_LEROBOT_HOME` 环境变量
-
diff --git a/docs/finetuning_and_evaluation.md b/docs/finetuning_and_evaluation.md
new file mode 100644
index 00000000..0c050c28
--- /dev/null
+++ b/docs/finetuning_and_evaluation.md
@@ -0,0 +1,117 @@
+# Fine-tuning and Evaluation Guide Using VLA-Arena Generated Datasets
+
+VLA-Arena provides a complete framework for collecting data, converting data formats, and evaluating vision-language-action models. This guide will walk you through how to fine-tune and evaluate various VLA models using datasets generated by VLA-Arena. We currently support fine-tuning and evaluation for OpenVLA, OpenVLA-OFT, Openpi, UniVLA, and SmolVLA models.
+
+## General Models (OpenVLA, OpenVLA-OFT, UniVLA, SmolVLA)
+
+For models other than Openpi (OpenVLA, OpenVLA-OFT, UniVLA, SmolVLA), the usage is very straightforward:
+
+### Install Dependencies
+
+First, install the dependencies for the corresponding model:
+
+```bash
+conda create -n [model_name]_vla_arena python==3.10 -y
+pip install -e .
+pip install vla-arena[model_name]
+```
+
+Examples:
+- OpenVLA: `pip install vla-arena[openvla]`
+- OpenVLA-OFT: `pip install vla-arena[openvla-oft]`
+- UniVLA: `pip install vla-arena[univla]`
+- SmolVLA: `pip install vla-arena[smolvla]`
+
+### Fine-tune Model
+
+Use the following command to fine-tune:
+
+```bash
+vla-arena train --model --config
+```
+
+Example:
+```bash
+vla-arena train --model openvla --config /vla_arena/config/openvla.yaml
+```
+
+### Evaluate Model
+
+Use the following command to evaluate:
+
+```bash
+vla-arena eval --model --config
+```
+
+Example:
+```bash
+vla-arena eval --model openvla --config /path/to/config.yaml
+```
+
+---
+
+## Openpi
+
+The Openpi model requires using `uv` for environment management, and the steps are slightly different from other models.
+
+### Environment Setup
+
+1. Create a new environment and navigate to the Openpi directory:
+
+```bash
+conda create -n openpi python=3.11 -y
+conda activate openpi
+pip install uv
+uv pip install -e .
+cd vla_arena/models/openpi
+```
+
+2. Use uv to sync dependencies and install:
+
+```bash
+uv sync
+uv pip install -e .
+```
+
+### Define Training Configuration and Run Training
+
+Before running training, we need to compute normalization statistics for the training data. Run the following script with your training configuration name. Training configurations can be adjusted in `src/openpi/training/config`:
+
+```bash
+uv run scripts/compute_norm_stats.py --config-name
+```
+
+**Note**: We provide functionality to reload state/action normalization statistics from pretraining. This can be beneficial if you are fine-tuning on a new task with a robot included in the pretraining mixed dataset. For more detailed information on how to reload normalization statistics, please refer to the `docs/norm_stats.md` file.
+
+Now we can start training (the `--overwrite` flag is used to overwrite existing checkpoints when you rerun fine-tuning with the same configuration):
+
+```bash
+XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run trainer.py --config
+```
+
+This command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. To maximize GPU memory usage, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training—this allows JAX to use up to 90% of GPU memory (the default is 75%).
+
+### Start Policy Server and Run Inference
+
+After training is complete, we can run inference by starting a policy server and then querying it from an evaluation script. Starting the model server is straightforward (this example uses the checkpoint from iteration 20,000, please modify as needed):
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint --policy.config= --policy.dir=checkpoints/pi05_libero/my_experiment/20000
+```
+
+This will start a server listening on port 8000, waiting for observation data to be sent to it. Then we can run an evaluation script (or robot runtime) to query the server.
+If you want to embed policy server calls in your own robot runtime, we provide a minimal example in the remote inference documentation.
+
+### Evaluate Model
+
+After starting the policy server, run the following in the openpi directory:
+
+```bash
+uv run evaluator.py --config
+```
+
+---
+
+## Configuration File Notes
+
+Configuration files typically contain information such as dataset paths, model parameters, training hyperparameters, etc. Please refer to the corresponding configuration examples based on the model type you are using.
diff --git a/docs/finetuning_and_evaluation_zh.md b/docs/finetuning_and_evaluation_zh.md
new file mode 100644
index 00000000..e6df25a5
--- /dev/null
+++ b/docs/finetuning_and_evaluation_zh.md
@@ -0,0 +1,117 @@
+# 使用VLA-Arena生成的数据集微调其他模型并评测指南
+
+VLA-Arena提供了完整的搜集数据、转换数据格式、评估语言-视觉-动作模型的框架,本指南将带您了解如何使用VLA-Arena生成的数据集微调一些VLA模型并评测。我们目前提供OpenVLA、OpenVLA-OFT、Openpi、UniVLA、SmolVLA模型的微调与评测。
+
+
+## 通用模型(OpenVLA、OpenVLA-OFT、UniVLA、SmolVLA)
+
+对于除Openpi外的其他模型(OpenVLA、OpenVLA-OFT、UniVLA、SmolVLA),使用方式非常简单:
+
+### 安装依赖
+
+首先安装对应模型的依赖:
+
+```bash
+conda create -n [model_name]_vla_arena python==3.10 -y
+pip install -e .
+pip install vla-arena[模型名称]
+```
+
+例如:
+- OpenVLA: `pip install vla-arena[openvla]`
+- OpenVLA-OFT: `pip install vla-arena[openvla-oft]`
+- UniVLA: `pip install vla-arena[univla]`
+- SmolVLA: `pip install vla-arena[smolvla]`
+
+### 微调模型
+
+使用以下命令进行微调:
+
+```bash
+vla-arena train --model <模型名称> --config <配置文件路径>
+```
+
+例如:
+```bash
+vla-arena train --model openvla --config /vla_arena/config/openvla.yaml
+```
+
+### 评估模型
+
+使用以下命令进行评估:
+
+```bash
+vla-arena eval --model <模型名称> --config <配置文件路径>
+```
+
+例如:
+```bash
+vla-arena eval --model openvla --config /path/to/config.yaml
+```
+
+---
+
+## Openpi
+
+Openpi模型需要使用`uv`进行环境管理,操作步骤与其他模型略有不同。
+
+### 环境配置
+
+1. 创建新环境并进入Openpi目录:
+
+```bash
+conda create -n openpi python=3.11 -y
+conda activate openpi
+pip install uv
+uv pip install -e .
+cd vla_arena/models/openpi
+```
+
+2. 使用uv同步依赖并安装:
+
+```bash
+uv sync
+uv pip install -e .
+```
+
+### 定义训练配置并运行训练
+
+在运行训练之前,我们需要先计算训练数据的归一化统计信息。使用您的训练配置名称运行以下脚本,训练配置可在src/openpi/training/config中调整:
+
+```bash
+uv run scripts/compute_norm_stats.py --config-name
+```
+
+**注意**:我们提供了从预训练中重新加载状态/动作归一化统计信息的功能。如果您在预训练混合数据集中包含的机器人上进行新任务的微调,这可能会有益。有关如何重新加载归一化统计信息的更多详细信息,请参阅 `docs/norm_stats.md` 文件。
+现在我们可以开始训练(`--overwrite` 标志用于在您使用相同配置重新运行微调时覆盖现有检查点):
+
+```bash
+XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run trainer.py --config <配置文件路径>
+```
+
+该命令会将训练进度记录到控制台,并将检查点保存到 `checkpoints` 目录。您也可以在 Weights & Biases 仪表板上监控训练进度。为了最大化使用GPU内存,在运行训练之前设置 `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9`——这使JAX能够使用高达90%的GPU内存(默认值为75%)。
+
+### 启动策略服务器并运行推理
+
+训练完成后,我们可以通过启动策略服务器,然后从评估脚本查询它来运行推理。启动模型服务器很简单(此示例使用迭代20,000的检查点,请根据需要修改):
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint --policy.config= --policy.dir=checkpoints/pi05_libero/my_experiment/20000
+```
+
+这将启动一个监听端口8000的服务器,等待发送给它的观测数据。然后我们可以运行一个评估脚本(或机器人运行时)来查询服务器。
+如果您想在自己的机器人运行时中嵌入策略服务器调用,我们在远程推理文档中提供了一个最小示例。
+
+### 评估模型
+
+在启动策略服务器后,openpi目录下运行:
+
+```bash
+uv run evaluator.py --config <配置文件路径>
+```
+
+---
+
+## 配置文件说明
+
+配置文件通常包含数据集路径、模型参数、训练超参数等信息。请根据您使用的模型类型,参考相应的配置示例进行设置。
diff --git a/docs/scene_construction.md b/docs/scene_construction.md
index 9fc1f2f5..73d63308 100644
--- a/docs/scene_construction.md
+++ b/docs/scene_construction.md
@@ -56,10 +56,10 @@ Regions define the spatial scope where objects can be placed.
- `ranges` : The XY-plane range in the target coordinate system, formatted as `(x_min y_min x_max y_max)`
- `yaw_rotation`(Optional) : Rotation angle of the region (only valid for `fixtures`)
-### 1.3 Object Definition
+### 1.3 Object Definition
-#### Fixtures
-Objects that do not move in the environment:
+#### Fixtures
+Objects that do not move in the environment:
```lisp
(:fixtures
@@ -86,7 +86,7 @@ Objects directly related to the task:
```
#### Moving Objects
-Define objects that move autonomously in the scene, supporting multiple motion modes:
+Define objects that move autonomously in the scene, supporting multiple motion modes:
```lisp
(:moving_objects
@@ -130,7 +130,7 @@ Define objects that move autonomously in the scene, supporting multiple motion m
(:motion_direction (0 1 0)) ; Initial direction
(:motion_gravity (0 0 -9.81)) ; Gravity vector
```
-### 1.4 State Definition
+### 1.4 State Definition
#### Initial State
Defines the initial configuration of the scene:
@@ -256,6 +256,12 @@ Here is an example:
Then view the generated video in the `rollouts` directory:

+### Trouble Shooting
+If you encounter error AttributeError: "'MjRenderContextOffscreen' object has no attribute 'con'" during visualization, please try installing the following package:
+```bash
+conda install -c conda-forge libegl-devel
+```
+
## 3. Assets
Both fixtures and objects in the BDDL file must be existing assets in the `vla_arena/vla_arena/assets` directory. This directory serves as the central repository for all usable assets within the scene.
@@ -328,4 +334,4 @@ class Apple(GoogleScannedObject):
super().__init__(name, obj_name, duplicate_collision_geoms=True)
self.rotation = (np.pi/2, np.pi/2)
self.rotation_axis = "z"
-```
\ No newline at end of file
+```
diff --git a/docs/scene_construction_zh.md b/docs/scene_construction_zh.md
index 0d3a1b6b..3262af2f 100644
--- a/docs/scene_construction_zh.md
+++ b/docs/scene_construction_zh.md
@@ -56,7 +56,7 @@ define (problem Tabletop_Manipulation) ; 从 "Tabletop_Manipulation" 和 "Floor_
- `ranges` : 目标坐标系中的 XY 平面范围,格式为`(x_min y_min x_max y_max)`
- `yaw_rotation`(可选) : 区域的旋转角度(仅对`fixtures`有效)
-### 1.3 对象定义
+### 1.3 对象定义
#### 固定对象
环境中不会移动的对象:
@@ -86,7 +86,7 @@ define (problem Tabletop_Manipulation) ; 从 "Tabletop_Manipulation" 和 "Floor_
```
#### 移动对象
-定义在场景中自主移动的对象,支持多种运动模式:
+定义在场景中自主移动的对象,支持多种运动模式:
```lisp
(:moving_objects
@@ -256,6 +256,12 @@ python scripts/visualize_bddl.py --bddl_file "your_bddl_file_path"
然后在`rollouts`目录中查看生成的视频:

+### 故障排除
+如果在可视化过程中遇到错误 AttributeError: "'MjRenderContextOffscreen' object has no attribute 'con'",请尝试安装以下软件包:
+```bash
+conda install -c conda-forge libegl-devel
+```
+
## 3. 资产
BDDL 文件中的固定对象和可操作对象必须是`vla_arena/vla_arena/assets`目录中已存在的资产。该目录是场景中所有可用资产的仓库。
diff --git a/image/structure.png b/image/structure.png
new file mode 100644
index 00000000..56a1ff92
Binary files /dev/null and b/image/structure.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 04f35328..b3594e0b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,68 +1,214 @@
-# Package ######################################################################
-
[build-system]
-requires = ["setuptools >= 60.0.0"]
+requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "vla-arena"
-description = "A Comprehensive Benchmark for Vision-Language-Action Models in Robotic Manipulation"
-readme = "README.md"
-requires-python = ">= 3.8"
-authors = []
-license = { text = "MIT License" }
-keywords = ["Vision-Language-Action", "VLA Models", "Robotic Manipulation", "Benchmark"]
-classifiers = [
- "Development Status :: 4 - Beta",
- "License :: OSI Approved :: MIT License",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Intended Audience :: Science/Research",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Software Development :: Libraries :: Python Modules",
+version = "0.1.0"
+authors = [
+ {name = "Jiahao Li", email = "jiahaoli2077@gmail.com"},
+ {name = "Borong Zhang"},
+ {name = "Jiachen Shen"},
]
+description = "VLA-Arena: A Comprehensive Benchmark for Vision-Language-Action Models in Robotic Manipulation"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+requires-python = "==3.11"
+
dependencies = [
- "hydra-core>=1.2.0",
- "numpy>=1.23.0",
- "wandb>=0.13.0",
- "easydict>=1.9",
- "opencv-python>=4.6.0",
- "einops>=0.4.1",
- "thop",
- "robosuite>=1.5.0",
- "bddl>=1.0.1",
- "future>=0.18.2",
- "matplotlib>=3.5.0",
- "cloudpickle>=2.1.0",
+ "imageio[ffmpeg]",
+ "robosuite==1.5.1",
+ "bddl",
+ "easydict",
+ "cloudpickle",
"gym",
+ "numpy==1.26.4",
+ "pytest>=7.0.0",
+ "pytest-cov>=4.0.0",
+ "pytest-mock>=3.10.0",
+ "torch",
+ "h5py",
+ "matplotlib",
"tensorflow",
- "IPython",
- "timm>=0.9.10",
- "transformers>=4.40.0",
- "accelerate",
- "imageio",
- "imageio-ffmpeg",
- "colorlog",
+]
+
+[project.scripts]
+"vla-arena" = "vla_arena.cli.main:main"
+"vla-arena.main" = "vla_arena.main:main"
+"vla-arena.eval" = "vla_arena.evaluate:main"
+"vla-arena.config_copy" = "scripts.config_copy:main"
+"vla-arena.create_template" = "scripts.create_template:main"
+
+[project.optional-dependencies]
+openvla = [
+ "accelerate>=0.25.0",
+ "draccus==0.8.0",
+ "einops",
+ # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README)
+ "huggingface_hub",
+ "json-numpy",
+ "jsonlines",
+ "matplotlib",
+ "peft==0.11.1",
+ "protobuf",
"rich",
- "draccus",
- "tensorflow_graphics",
+ "sentencepiece==0.1.99",
+ "timm==0.9.10",
+ "tokenizers==0.19.1",
+ "torch==2.2.0",
+ "torchvision==0.17.0",
+ "torchaudio==2.2.0",
+ "transformers==4.40.1",
+ "wandb",
+ "tensorflow==2.15.0",
+ "tensorflow_datasets==4.9.3",
+ "tensorflow_graphics==2021.12.3",
+ "dlimp @ git+https://github.com/moojink/dlimp_openvla"
+]
+
+openvla-oft = [
+ "accelerate>=0.25.0",
+ "draccus==0.8.0",
+ "einops",
+ # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README)
+ "huggingface_hub",
+ "json-numpy",
"jsonlines",
- "json_numpy",
- "torch>=2.6.0",
- "pyyaml>=6.0",
+ "matplotlib",
+ "peft==0.11.1",
+ "protobuf",
+ "rich",
+ "sentencepiece==0.1.99",
+ "timm==0.9.10",
+ "tokenizers==0.19.1",
+ "torch==2.2.0",
+ "torchvision==0.17.0",
+ "torchaudio==2.2.0",
+ "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding)
+ "wandb",
+ "tensorflow==2.15.0",
+ "tensorflow_datasets==4.9.3",
+ "tensorflow_graphics==2021.12.3",
+ "dlimp @ git+https://github.com/moojink/dlimp_openvla",
+ "diffusers==0.30.3",
+ "imageio",
+ "uvicorn",
+ "fastapi",
+ "json-numpy",
]
-dynamic = ["version"]
-[project.urls]
-Homepage = "https://github.com/PKU-Alignment/VLA-Arena"
-Repository = "https://github.com/PKU-Alignment/VLA-Arena"
-Documentation = "https://github.com/PKU-Alignment/VLA-Arena/docs"
-"Bug Report" = "https://github.com/PKU-Alignment/VLA-Arena/issues"
+univla = [
+ "absl-py==2.1.0",
+ "accelerate==0.32.1",
+ "braceexpand==0.1.7",
+ "dlimp @ git+https://github.com/moojink/dlimp_openvla",
+ "draccus==0.8.0",
+ "einops==0.8.1",
+ "ema-pytorch==0.5.1",
+ "gym==0.26.2",
+ "h5py==3.11.0",
+ "huggingface-hub==0.26.1",
+ "hydra-core==1.3.2",
+ "imageio==2.34.2",
+ "jsonlines==4.0.0",
+ "lightning==2.4.0",
+ "matplotlib==3.10.1",
+ "moviepy==1.0.3",
+ "numpy==1.26.4",
+ "omegaconf==2.3.0",
+ "opencv-python==4.10.0.84",
+ "packaging==24.1",
+ "peft==0.11.1",
+ "Pillow==11.2.1",
+ "piq==0.8.0",
+ "pyquaternion==0.9.9",
+ "pytorch-lightning==1.8.6",
+ "PyYAML==6.0.1",
+ "Requests==2.32.3",
+ "rich==14.0.0",
+ "robosuite==1.5.1",
+ "rotary-embedding-torch==0.8.4",
+ "setuptools==57.5.0",
+ "tensorflow==2.15.0",
+ "tensorflow-datasets==4.9.3",
+ "tensorflow-graphics==2021.12.3",
+ "termcolor==3.0.1",
+ "timm==0.9.10",
+ "tokenizers==0.19.1",
+ "torch==2.2.0",
+ "torchvision==0.17.0",
+ "tqdm==4.66.4",
+ "transformers==4.40.1",
+ "webdataset==0.2.111",
+ "wandb",
+]
-[project.optional-dependencies]
+smolvla = [
+ "datasets>=2.19.0,<=3.6.0",
+ "diffusers>=0.27.2",
+ "huggingface-hub[hf-transfer,cli]==0.34.2",
+ "cmake>=3.29.0.1",
+ "einops>=0.8.0",
+ "opencv-python-headless>=4.9.0",
+ "av>=14.2.0",
+ "jsonlines>=4.0.0",
+ "packaging>=24.2",
+ "pynput>=1.7.7",
+ "pyserial>=3.5",
+ "wandb==0.20.0",
+ "torch==2.7.1",
+ "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
+ "torchvision==0.22.1",
+ "draccus==0.10.0",
+ "gymnasium==0.29.1",
+ "rerun-sdk==0.22.1",
+ "deepdiff>=7.0.1,<9.0.0",
+ "flask>=3.0.3,<4.0.0",
+ "imageio[ffmpeg]==2.37.0",
+ "termcolor==3.1.0",
+ "transformers==4.51.3",
+ "num2words==0.5.14",
+ "accelerate==1.7.0",
+ "safetensors==0.4.3",
+ "lerobot @ git+https://github.com/propellanesjc/smolvla_vla-arena",
+ "draccus",
+]
+
+openpi = [
+ "augmax>=0.3.4",
+ "dm-tree>=0.1.8",
+ "einops>=0.8.0",
+ "equinox>=0.11.8",
+ "flatbuffers>=24.3.25",
+ "flax==0.10.2",
+ "fsspec[gcs]>=2024.6.0",
+ "gym-aloha>=0.1.1",
+ "imageio>=2.36.1",
+ "jax[cuda12]==0.5.3",
+ "jaxtyping==0.2.36",
+ "lerobot",
+ "ml_collections==1.0.0",
+ "numpy>=1.22.4,<2.0.0",
+ "numpydantic>=1.6.6",
+ "opencv-python>=4.10.0.84",
+ "openpi-client",
+ "orbax-checkpoint==0.11.13",
+ "pillow>=11.0.0",
+ "sentencepiece>=0.2.0",
+ "torch==2.7.1",
+ "tqdm-loggable>=0.2",
+ "typing-extensions>=4.12.2",
+ "tyro>=0.9.5",
+ "wandb>=0.19.1",
+ "filelock>=3.16.1",
+ "beartype==0.19.0",
+ "treescope>=0.1.7",
+ "transformers==4.53.2",
+ "rich>=14.0.0",
+ "polars>=1.30.0",
+]
+
+# Integrated Dev/Tool Dependencies from File B
lint = [
"isort >= 5.11.0",
"black >= 23.1.0",
@@ -94,29 +240,21 @@ docs = [
"myst-parser",
]
-[project.scripts]
-"vla-arena" = "vla_arena.main:main"
-"vla-arena-eval" = "vla_arena.evaluate:main"
-"vla-arena-config-copy" = "scripts.config_copy:main"
-"vla-arena-create-template" = "scripts.create_template:main"
+[tool.setuptools.packages.find]
+where = ["."]
+include = ["vla_arena*"]
+exclude = []
[tool.setuptools]
include-package-data = true
+eager-resources = ["*"]
-[tool.setuptools.packages.find]
-include = ["vla_arena", "vla_arena.*"]
-
-[tool.setuptools.dynamic]
-version = {attr = "vla_arena.__version__"}
-
-# Linter tools #################################################################
+# --- Tool Configurations (Integrated from File B) ---
[tool.black]
-safe = true
-line-length = 100
+line-length = 79
skip-string-normalization = true
-# Sync with requires-python
-target-version = ["py38", "py39", "py310", "py311"]
+target-version = ["py311"]
[tool.isort]
atomic = true
@@ -129,7 +267,7 @@ lines_after_imports = 2
multi_line_output = 3
[tool.mypy]
-python_version = "3.8"
+python_version = "3.11"
pretty = true
show_error_codes = true
show_error_context = true
@@ -158,64 +296,24 @@ max-line-length = 500
ignore-words = "docs/spelling_wordlist.txt"
[tool.ruff]
-# Sync with requires-python
-target-version = "py38"
+target-version = "py311"
line-length = 100
src = ["vla_arena", "scripts", "tests"]
select = [
- "E", "W", # pycodestyle
- "F", # pyflakes
- "UP", # pyupgrade
- "ANN", # flake8-annotations
- "S", # flake8-bandit
- "BLE", # flake8-blind-except
- "B", # flake8-bugbear
- "COM", # flake8-commas
- "C4", # flake8-comprehensions
- "EXE", # flake8-executable
- "ISC", # flake8-implicit-str-concat
- "PIE", # flake8-pie
- "PYI", # flake8-pyi
- "Q", # flake8-quotes
- "RSE", # flake8-raise
- "RET", # flake8-return
- "SIM", # flake8-simplify
- "TID", # flake8-tidy-imports
- "RUF", # ruff
+ "E", "W", "F", "UP", "ANN", "S", "BLE", "B", "COM", "C4", "EXE",
+ "ISC", "PIE", "PYI", "Q", "RSE", "RET", "SIM", "TID", "RUF"
]
ignore = [
- # E501: line too long
- # W505: doc line too long
- "E501",
- "W505",
- # ANN001: missing type annotation for function argument
- # ANN002: missing type annotation for `*args`
- # ANN003: missing type annotation for `**kwargs`
- # ANN201: missing return type annotation for public function
- # ANN202: missing return type annotation for private function
- "ANN001", "ANN002", "ANN003", "ANN201", "ANN202",
- # ANN101: missing type annotation for `self` in method
- # ANN102: missing type annotation for `cls` in classmethod
- "ANN101",
- "ANN102",
- # ANN401: dynamically typed expressions (typing.Any) are disallowed
- "ANN401",
- # S101: use of `assert` detected
- "S101",
+ "E501", "W505", # Line length handled by black
+ "ANN001", "ANN002", "ANN003", "ANN201", "ANN202", # Relax strict type annotations
+ "ANN101", "ANN102", "ANN401",
+ "S101", # Allow assert
]
[tool.ruff.per-file-ignores]
-"__init__.py" = [
- "F401", # unused-import
-]
-"tests/**/*.py" = [
- "ANN", # flake8-annotations
- "S", # flake8-bandit
- "BLE", # flake8-blind-except
-]
-"scripts/**/*.py" = [
- "ANN", # flake8-annotations
-]
+"__init__.py" = ["F401"]
+"tests/**/*.py" = ["ANN", "S", "BLE"]
+"scripts/**/*.py" = ["ANN"]
[tool.ruff.flake8-annotations]
allow-star-arg-any = true
@@ -231,4 +329,24 @@ ban-relative-imports = "all"
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
-addopts = "--verbose --color=yes --cov=vla_arena --cov-report=xml --cov-report=term-missing"
+python_classes = ["Test*"]
+python_functions = ["test_*"]
+addopts = [
+ "-v",
+ "--strict-markers",
+ "--tb=short",
+ "--cov=vla_arena",
+ "--cov-report=term-missing",
+ "--cov-report=html",
+ "--cov-report=xml",
+]
+markers = [
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
+ "integration: marks tests as integration tests",
+ "unit: marks tests as unit tests",
+]
+filterwarnings = [
+ "error",
+ "ignore::DeprecationWarning",
+ "ignore::PendingDeprecationWarning",
+]
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 00000000..d4722ee9
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,17 @@
+[pytest]
+testpaths = tests
+python_files = test_*.py
+python_classes = Test*
+python_functions = test_*
+addopts =
+ -v
+ --strict-markers
+ --tb=short
+markers =
+ slow: marks tests as slow (deselect with '-m "not slow"')
+ integration: marks tests as integration tests
+ unit: marks tests as unit tests
+filterwarnings =
+ error
+ ignore::DeprecationWarning
+ ignore::PendingDeprecationWarning
diff --git a/requirements.txt b/requirements.txt
index edd096ce..570680d7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,15 @@
setuptools==78.1.1
hydra-core==1.2.0
- numpy==1.23.5
- wandb==0.13.1
- easydict==1.9
- opencv-python==4.6.0.66
+ numpy==1.23.5
+ wandb==0.13.1
+ easydict==1.9
+ opencv-python==4.6.0.66
einops==0.4.1
thop==0.1.1-2209072238
robosuite==1.5.1
bddl==1.0.1
- future==0.18.2
- matplotlib==3.5.3
+ future==0.18.2
+ matplotlib==3.5.3
cloudpickle==2.1.0
gym
tensorflow
@@ -26,4 +26,4 @@
jsonlines
json_numpy
torch>=2.6.0
- dlimp @ git+https://github.com/moojink/dlimp_openvla
\ No newline at end of file
+ dlimp @ git+https://github.com/moojink/dlimp_openvla
diff --git a/rlds_dataset_builder/README.md b/rlds_dataset_builder/README.md
index 30ee31b6..0b88cfaf 100644
--- a/rlds_dataset_builder/README.md
+++ b/rlds_dataset_builder/README.md
@@ -1,7 +1,7 @@
# RLDS Dataset Conversion
This repo demonstrates how to convert an existing dataset into RLDS format for X-embodiment experiment integration.
-It provides an example for converting a dummy dataset to RLDS. To convert your own dataset, **fork** this repo and
+It provides an example for converting a dummy dataset to RLDS. To convert your own dataset, **fork** this repo and
modify the example code for your dataset following the steps below.
## Installation
@@ -16,7 +16,7 @@ Then activate the environment using:
conda activate rlds_env
```
-If you want to manually create an environment, the key packages to install are `tensorflow`,
+If you want to manually create an environment, the key packages to install are `tensorflow`,
`tensorflow_datasets`, `tensorflow_hub`, `apache_beam`, `matplotlib`, `plotly` and `wandb`.
@@ -38,12 +38,12 @@ conversion worked before moving on.
Now we can modify the provided example to convert your own data. Follow the steps below:
-1. **Rename Dataset**: Change the name of the dataset folder from `example_dataset` to the name of your dataset (e.g. robo_net_v2),
+1. **Rename Dataset**: Change the name of the dataset folder from `example_dataset` to the name of your dataset (e.g. robo_net_v2),
also change the name of `example_dataset_dataset_builder.py` by replacing `example_dataset` with your dataset's name (e.g. robo_net_v2_dataset_builder.py)
and change the class name `ExampleDataset` in the same file to match your dataset's name, using camel case instead of underlines (e.g. RoboNetV2).
2. **Modify Features**: Modify the data fields you plan to store in the dataset. You can find them in the `_info()` method
-of the `ExampleDataset` class. Please add **all** data fields your raw data contains, i.e. please add additional features for
+of the `ExampleDataset` class. Please add **all** data fields your raw data contains, i.e. please add additional features for
additional cameras, audio, tactile features etc. If your type of feature is not demonstrated in the example (e.g. audio),
you can find a list of all supported feature types [here](https://www.tensorflow.org/datasets/api_docs/python/tfds/features?hl=en#classes).
You can store step-wise info like camera images, actions etc in `'steps'` and episode-wise info like `collector_id` in `episode_metadata`.
@@ -53,35 +53,35 @@ Note that we store `language_instruction` in every step even though it is episod
does not define language instructions, you can fill in a dummy string like `pick up something`).
3. **Modify Dataset Splits**: The function `_split_generator()` determines the splits of the generated dataset (e.g. training, validation etc.).
-If your dataset defines a train vs validation split, please provide the corresponding information to `_generate_examples()`, e.g.
+If your dataset defines a train vs validation split, please provide the corresponding information to `_generate_examples()`, e.g.
by pointing to the corresponding folders (like in the example) or file IDs etc. If your dataset does not define splits,
remove the `val` split and only include the `train` split. You can then remove all arguments to `_generate_examples()`.
-4. **Modify Dataset Conversion Code**: Next, modify the function `_generate_examples()`. Here, your own raw data should be
+4. **Modify Dataset Conversion Code**: Next, modify the function `_generate_examples()`. Here, your own raw data should be
loaded, filled into the episode steps and then yielded as a packaged example. Note that the value of the first return argument,
-`episode_path` in the example, is only used as a sample ID in the dataset and can be set to any value that is connected to the
+`episode_path` in the example, is only used as a sample ID in the dataset and can be set to any value that is connected to the
particular stored episode, or any other random value. Just ensure to avoid using the same ID twice.
5. **Provide Dataset Description**: Next, add a bibtex citation for your dataset in `CITATIONS.bib` and add a short description
of your dataset in `README.md` inside the dataset folder. You can also provide a link to the dataset website and please add a
few example trajectory images from the dataset for visualization.
-6. **Add Appropriate License**: Please add an appropriate license to the repository.
-Most common is the [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license --
+6. **Add Appropriate License**: Please add an appropriate license to the repository.
+Most common is the [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license --
you can copy it from [here](https://github.com/teamdigitale/licenses/blob/master/CC-BY-4.0).
That's it! You're all set to run dataset conversion. Inside the dataset directory, run:
```
tfds build --overwrite
```
-The command line output should finish with a summary of the generated dataset (including size and number of samples).
+The command line output should finish with a summary of the generated dataset (including size and number of samples).
Please verify that this output looks as expected and that you can find the generated `tfrecord` files in `~/tensorflow_datasets/`.
### Parallelizing Data Processing
By default, dataset conversion is single-threaded. If you are parsing a large dataset, you can use parallel processing.
-For this, replace the last two lines of `_generate_examples()` with the commented-out `beam` commands. This will use
-Apache Beam to parallelize data processing. Before starting the processing, you need to install your dataset package
+For this, replace the last two lines of `_generate_examples()` with the commented-out `beam` commands. This will use
+Apache Beam to parallelize data processing. Before starting the processing, you need to install your dataset package
by filling in the name of your dataset into `setup.py` and running `pip install -e .`
Then, make sure that no GPUs are used during data processing (`export CUDA_VISIBLE_DEVICES=`) and run:
@@ -94,10 +94,10 @@ You can specify the desired number of workers with the `direct_num_workers` argu
To verify that the data is converted correctly, please run the data visualization script from the base directory:
```
python3 visualize_dataset.py
-```
+```
This will display a few random episodes from the dataset with language commands and visualize action and state histograms per dimension.
-Note, if you are running on a headless server you can modify `WANDB_ENTITY` at the top of `visualize_dataset.py` and
-add your own WandB entity -- then the script will log all visualizations to WandB.
+Note, if you are running on a headless server you can modify `WANDB_ENTITY` at the top of `visualize_dataset.py` and
+add your own WandB entity -- then the script will log all visualizations to WandB.
## Add Transform for Target Spec
@@ -108,7 +108,7 @@ action.
The final step in adding your dataset to the training mix is to provide a transform function, that transforms a step
from your original dataset above to the required training spec. Please follow the two simple steps below:
-1. **Modify Step Transform**: Modify the function `transform_step()` in `example_transform/transform.py`. The function
+1. **Modify Step Transform**: Modify the function `transform_step()` in `example_transform/transform.py`. The function
takes in a step from your dataset above and is supposed to map it to the desired output spec. The file contains a detailed
description of the desired output spec.
@@ -119,24 +119,24 @@ If the test passes successfully, you are ready to upload your dataset!
## Upload Your Data
-We provide a Google Cloud bucket that you can upload your data to. First, install `gsutil`, the Google cloud command
+We provide a Google Cloud bucket that you can upload your data to. First, install `gsutil`, the Google cloud command
line tool. You can follow the installation instructions [here](https://cloud.google.com/storage/docs/gsutil_install).
Next, authenticate your Google account with:
```
gcloud auth login
-```
-This will open a browser window that allows you to log into your Google account (if you're on a headless server,
+```
+This will open a browser window that allows you to log into your Google account (if you're on a headless server,
you can add the `--no-launch-browser` flag). Ideally, use the email address that
-you used to communicate with Karl, since he will automatically grant permission to the bucket for this email address.
-If you want to upload data with a different email address / google account, please shoot Karl a quick email to ask
+you used to communicate with Karl, since he will automatically grant permission to the bucket for this email address.
+If you want to upload data with a different email address / google account, please shoot Karl a quick email to ask
to grant permissions to that Google account!
-After logging in with a Google account that has access permissions, you can upload your data with the following
+After logging in with a Google account that has access permissions, you can upload your data with the following
command:
```
gsutil -m cp -r ~/tensorflow_datasets/ gs://xembodiment_data
-```
+```
This will upload all data using multiple threads. If your internet connection gets interrupted anytime during the upload
you can just rerun the command and it will resume the upload where it was interrupted. You can verify that the upload
was successful by inspecting the bucket [here](https://console.cloud.google.com/storage/browser/xembodiment_data).
diff --git a/rlds_dataset_builder/VLA_Arena/VLA_Arena_dataset_builder.py b/rlds_dataset_builder/VLA_Arena/VLA_Arena_dataset_builder.py
index a4184b62..bf49915e 100644
--- a/rlds_dataset_builder/VLA_Arena/VLA_Arena_dataset_builder.py
+++ b/rlds_dataset_builder/VLA_Arena/VLA_Arena_dataset_builder.py
@@ -1,19 +1,37 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import glob
import os
-from typing import Any, Iterator, Tuple
+from collections.abc import Iterator
+from typing import Any
import h5py
import numpy as np
import tensorflow_datasets as tfds
-
from VLA_Arena.conversion_utils import MultiThreadedDatasetBuilder
-tfds.core.utils.gcs_utils._is_gcs_disabled = True # disable GCS to avoid issues with TFDS
-os.environ['NO_GCE_CHECK'] = 'true' # disable GCE check to avoid issues with TFDS
+tfds.core.utils.gcs_utils._is_gcs_disabled = (
+ True # disable GCS to avoid issues with TFDS
+)
+os.environ['NO_GCE_CHECK'] = (
+ 'true' # disable GCE check to avoid issues with TFDS
+)
-def _generate_examples(paths) -> Iterator[Tuple[str, Any]]:
+def _generate_examples(paths) -> Iterator[tuple[str, Any]]:
"""Yields episodes for list of data paths."""
# the line below needs to be *inside* generate_examples so that each worker creates it's own model
# creating one shared model outside this function would cause a deadlock
@@ -29,13 +47,26 @@ def _parse_example(episode_path, demo_id):
return None # skip episode if the demo doesn't exist (e.g. due to failed demo)
actions = F['data'][f'demo_{demo_id}']['actions'][()]
states = F['data'][f'demo_{demo_id}']['obs']['ee_states'][()]
- gripper_states = F['data'][f'demo_{demo_id}']['obs']['gripper_states'][()]
- joint_states = F['data'][f'demo_{demo_id}']['obs']['joint_states'][()]
- images = F['data'][f'demo_{demo_id}']['obs'][camera_name + '_rgb'][()]
- if 'robot0_eye_in_hand_rgb' in F['data'][f'demo_{demo_id}']['obs'].keys():
- wrist_images = F['data'][f'demo_{demo_id}']['obs']['robot0_eye_in_hand_rgb'][()]
+ gripper_states = F['data'][f'demo_{demo_id}']['obs'][
+ 'gripper_states'
+ ][()]
+ joint_states = F['data'][f'demo_{demo_id}']['obs']['joint_states'][
+ ()
+ ]
+ images = F['data'][f'demo_{demo_id}']['obs'][camera_name + '_rgb'][
+ ()
+ ]
+ if (
+ 'robot0_eye_in_hand_rgb'
+ in F['data'][f'demo_{demo_id}']['obs'].keys()
+ ):
+ wrist_images = F['data'][f'demo_{demo_id}']['obs'][
+ 'robot0_eye_in_hand_rgb'
+ ][()]
else:
- wrist_images = F['data'][f'demo_{demo_id}']['obs']['eye_in_hand_rgb'][()]
+ wrist_images = F['data'][f'demo_{demo_id}']['obs'][
+ 'eye_in_hand_rgb'
+ ][()]
# compute language instruction
raw_file_string = os.path.basename(episode_path).split('/')[-1]
@@ -57,10 +88,14 @@ def _parse_example(episode_path, demo_id):
'image': images[i][::-1, ::-1],
'wrist_image': wrist_images[i][::-1, ::-1],
'state': np.asarray(
- np.concatenate((states[i], gripper_states[i]), axis=-1),
+ np.concatenate(
+ (states[i], gripper_states[i]), axis=-1
+ ),
np.float32,
),
- 'joint_state': np.asarray(joint_states[i], dtype=np.float32),
+ 'joint_state': np.asarray(
+ joint_states[i], dtype=np.float32
+ ),
},
'action': np.asarray(actions[i], dtype=np.float32),
'discount': 1.0,
@@ -75,9 +110,7 @@ def _parse_example(episode_path, demo_id):
# create output data sample
sample = {
'steps': episode,
- 'episode_metadata': {
- 'file_path': episode_path,
- },
+ 'episode_metadata': {'file_path': episode_path},
}
# if you want to skip an example for whatever reason, simply return None
@@ -169,14 +202,14 @@ def _info(self) -> tfds.core.DatasetInfo:
doc='True on last step of the episode if it is a terminal step, True for demos.',
),
'language_instruction': tfds.features.Text(
- doc='Language Instruction.',
+ doc='Language Instruction.'
),
},
),
'episode_metadata': tfds.features.FeaturesDict(
{
'file_path': tfds.features.Text(
- doc='Path to the original data file.',
+ doc='Path to the original data file.'
),
},
),
diff --git a/rlds_dataset_builder/VLA_Arena/__init__.py b/rlds_dataset_builder/VLA_Arena/__init__.py
index e69de29b..fb51bf31 100644
--- a/rlds_dataset_builder/VLA_Arena/__init__.py
+++ b/rlds_dataset_builder/VLA_Arena/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/rlds_dataset_builder/VLA_Arena/conversion_utils.py b/rlds_dataset_builder/VLA_Arena/conversion_utils.py
index c748226d..3c766971 100644
--- a/rlds_dataset_builder/VLA_Arena/conversion_utils.py
+++ b/rlds_dataset_builder/VLA_Arena/conversion_utils.py
@@ -1,7 +1,22 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import itertools
+from collections.abc import Callable, Iterable
from functools import partial
from multiprocessing import Pool
-from typing import Any, Callable, Dict, Iterable, Tuple, Union
+from typing import Any, Union
import numpy as np
import tensorflow_datasets as tfds
@@ -20,8 +35,8 @@
Key = Union[str, int]
# The nested example dict passed to `features.encode_example`
-Example = Dict[str, Any]
-KeyExample = Tuple[Key, Example]
+Example = dict[str, Any]
+KeyExample = tuple[Key, Example]
class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder):
@@ -31,12 +46,17 @@ class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder):
MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk
# -> the higher the faster / more parallel conversion, adjust based on avilable RAM
# note that one path may yield multiple episodes and adjust accordingly
- PARSE_FCN = None # needs to be filled with path-to-record-episode parse function
+ PARSE_FCN = (
+ None # needs to be filled with path-to-record-episode parse function
+ )
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Define data splits."""
split_paths = self._split_paths()
- return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths}
+ return {
+ split: type(self).PARSE_FCN(paths=split_paths[split])
+ for split in split_paths
+ }
def _generate_examples(self):
pass # this is implemented in global method to enable multiprocessing
@@ -71,7 +91,9 @@ def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-p
dataset_builder._check_split_names(split_generators.keys())
# Start generating data for all splits
- path_suffix = file_adapters.ADAPTER_FOR_FORMAT[self.info.file_format].FILE_SUFFIX
+ path_suffix = file_adapters.ADAPTER_FOR_FORMAT[
+ self.info.file_format
+ ].FILE_SUFFIX
split_info_futures = []
for split_name, generator in utils.tqdm(
@@ -112,7 +134,9 @@ def result(self) -> splits_lib.SplitInfo:
return self._callback()
-def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer):
+def parse_examples_from_generator(
+ paths, fcn, split_name, total_num_examples, features, serializer
+):
generator = fcn(paths)
outputs = []
for sample in utils.tqdm(
@@ -222,7 +246,7 @@ def _build_from_generator(
def dictlist2listdict(DL):
- "Converts a dict of lists to a list of dicts"
+ 'Converts a dict of lists to a list of dicts'
return [dict(zip(DL, t)) for t in zip(*DL.values())]
diff --git a/rlds_dataset_builder/environment_macos.yml b/rlds_dataset_builder/environment_macos.yml
index 8abed39a..8cbadc9e 100644
--- a/rlds_dataset_builder/environment_macos.yml
+++ b/rlds_dataset_builder/environment_macos.yml
@@ -161,4 +161,4 @@ dependencies:
- yarl==1.8.1
- zipp==3.16.1
- zstandard==0.21.0
-prefix: /Users/karl/miniconda3/envs/rlds_env
+prefix: ${CONDA_PREFIX:-$HOME/miniconda3/envs/rlds_env} # Set this to your conda environment path
diff --git a/rlds_dataset_builder/setup.py b/rlds_dataset_builder/setup.py
index 5cb305c6..58d73c70 100644
--- a/rlds_dataset_builder/setup.py
+++ b/rlds_dataset_builder/setup.py
@@ -1,3 +1,17 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from setuptools import setup
diff --git a/rlds_dataset_builder/test_dataset_transform.py b/rlds_dataset_builder/test_dataset_transform.py
index dc7cbdc3..de5b9fe2 100644
--- a/rlds_dataset_builder/test_dataset_transform.py
+++ b/rlds_dataset_builder/test_dataset_transform.py
@@ -1,3 +1,17 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import argparse
import importlib
import os
@@ -18,7 +32,7 @@
TARGET_SPEC = {
'observation': {
- 'image': {'shape': (128, 128, 3), 'dtype': np.uint8, 'range': (0, 255)},
+ 'image': {'shape': (128, 128, 3), 'dtype': np.uint8, 'range': (0, 255)}
},
'action': {
'shape': (8,),
@@ -34,7 +48,11 @@
'is_last': {'shape': (), 'dtype': np.bool_, 'range': None},
'is_terminal': {'shape': (), 'dtype': np.bool_, 'range': None},
'language_instruction': {'shape': (), 'dtype': str, 'range': None},
- 'language_embedding': {'shape': (512,), 'dtype': np.float32, 'range': None},
+ 'language_embedding': {
+ 'shape': (512,),
+ 'dtype': np.float32,
+ 'range': None,
+ },
}
@@ -49,7 +67,10 @@ def check_elements(target, values):
raise ValueError(
f"Shape of {elem} should be {target[elem]['shape']} but is {tuple(values[elem].shape)}",
)
- if not isinstance(values[elem], bytes) and values[elem].dtype != target[elem]['dtype']:
+ if (
+ not isinstance(values[elem], bytes)
+ and values[elem].dtype != target[elem]['dtype']
+ ):
raise ValueError(
f"Dtype of {elem} should be {target[elem]['dtype']} but is {values[elem].dtype}",
)
diff --git a/scripts/check_dataset_integrity.py b/scripts/check_dataset_integrity.py
index 0d6d98c0..581d5e57 100644
--- a/scripts/check_dataset_integrity.py
+++ b/scripts/check_dataset_integrity.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
"""A script to check if any demonstration dataset does not have the exact number of demonstration trajectories"""
@@ -38,9 +37,13 @@
action_min = np.inf
action_max = -np.inf
for demo_name in demo_file['data'].keys():
- traj_lengths.append(demo_file[f'data/{demo_name}/actions'].shape[0])
+ traj_lengths.append(
+ demo_file[f'data/{demo_name}/actions'].shape[0]
+ )
traj_lengths = np.array(traj_lengths)
- print(f'[info] dataset {demo_file_name} is in tact, test passed \u2714')
+ print(
+ f'[info] dataset {demo_file_name} is in tact, test passed \u2714'
+ )
print(np.mean(traj_lengths), ' +- ', np.std(traj_lengths))
if demo_file['data'].attrs['tag'] == 'vla_arena-v1':
print('Version correct')
diff --git a/scripts/collect_demonstration.py b/scripts/collect_demonstration.py
index 6181ab6a..bee6b91e 100644
--- a/scripts/collect_demonstration.py
+++ b/scripts/collect_demonstration.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import datetime
@@ -70,7 +69,9 @@ def collect_human_trajectory(
env.render()
- task_completion_hold_count = -1 # counter to collect 10 timesteps after reaching goal
+ task_completion_hold_count = (
+ -1
+ ) # counter to collect 10 timesteps after reaching goal
device.start_control()
# Print action info for all robots
@@ -80,18 +81,18 @@ def collect_human_trajectory(
saving = True
count = 0
- # ====== 绘图变量 ======
+ # ====== Plotting variables ======
cost_list = []
cumulative_cost = 0
step_list = []
- # 只在需要实时显示时初始化交互式图表
+ # Only initialize interactive plot when real-time display is needed
fig = None
ax = None
line = None
if use_synchronous_cost_curve:
- plt.ion() # 开启交互模式
+ plt.ion() # Enable interactive mode
fig, ax = plt.subplots()
(line,) = ax.plot([], [], label='Cumulative Cost')
ax.set_xlabel('Step Count')
@@ -102,7 +103,9 @@ def collect_human_trajectory(
# Keep track of prev gripper actions when using since they are position-based and must be maintained when arms switched
all_prev_gripper_actions = [
{
- f'{robot_arm}_gripper': np.repeat([0], robot.gripper[robot_arm].dof)
+ f'{robot_arm}_gripper': np.repeat(
+ [0], robot.gripper[robot_arm].dof
+ )
for robot_arm in robot.arms
if robot.gripper[robot_arm].dof > 0
}
@@ -135,7 +138,9 @@ def collect_human_trajectory(
active_robot.composite_controller.joint_action_policy.input_type
)
else:
- controller_input_type = active_robot.part_controllers[arm_name].input_type
+ controller_input_type = active_robot.part_controllers[
+ arm_name
+ ].input_type
if controller_input_type == 'delta':
action_dict[arm_name] = input_ac_dict[f'{arm_name}_delta']
@@ -149,22 +154,26 @@ def collect_human_trajectory(
robot.create_action_vector(all_prev_gripper_actions[i])
for i, robot in enumerate(env.robots)
]
- env_action[device.active_robot] = active_robot.create_action_vector(action_dict)
+ env_action[device.active_robot] = active_robot.create_action_vector(
+ action_dict
+ )
env_action = np.concatenate(env_action)
for gripper_ac in all_prev_gripper_actions[device.active_robot]:
- all_prev_gripper_actions[device.active_robot][gripper_ac] = action_dict[gripper_ac]
+ all_prev_gripper_actions[device.active_robot][gripper_ac] = (
+ action_dict[gripper_ac]
+ )
obs, reward, done, info = env.step(env_action)
# replay_images.append(get_image(obs))
env.render()
- # ====== 始终收集cost数据 ======
+ # ====== Always collect cost data ======
if 'cost' in info:
cumulative_cost += info['cost']
cost_list.append(cumulative_cost)
step_list.append(count)
- # 只在flag为True时实时更新显示
+ # Only update display in real-time when flag is True
if use_synchronous_cost_curve and fig is not None:
line.set_data(step_list, cost_list)
ax.relim()
@@ -173,7 +182,7 @@ def collect_human_trajectory(
fig.canvas.draw()
fig.canvas.flush_events()
except:
- pass # 忽略GUI更新错误
+ pass # Ignore GUI update errors
# Also break if we complete the task
if task_completion_hold_count == 0:
@@ -182,11 +191,17 @@ def collect_human_trajectory(
# state machine to check for having a success for 10 consecutive timesteps
if env._check_success():
if task_completion_hold_count > 0:
- task_completion_hold_count -= 1 # latched state, decrement count
+ task_completion_hold_count -= (
+ 1 # latched state, decrement count
+ )
else:
- task_completion_hold_count = 10 # reset count on first success timestep
+ task_completion_hold_count = (
+ 10 # reset count on first success timestep
+ )
else:
- task_completion_hold_count = -1 # null the counter if there's no success
+ task_completion_hold_count = (
+ -1
+ ) # null the counter if there's no success
# limit frame rate if necessary
if max_fr is not None:
@@ -203,14 +218,14 @@ def collect_human_trajectory(
# cleanup for end of data collection episodes
env.close()
- # ====== 保存图表(无论是否实时显示) ======
+ # ====== Save plot (whether or not real-time display was used) ======
if len(cost_list) > 0:
- # 如果之前在实时显示,关闭交互模式
+ # If real-time display was used before, turn off interactive mode
if use_synchronous_cost_curve and fig is not None:
plt.ioff()
- # 使用已有的figure
+ # Use existing figure
else:
- # 如果没有实时显示,创建新的figure来保存
+ # If no real-time display, create new figure to save
fig, ax = plt.subplots()
ax.plot(step_list, cost_list, label='Cumulative Cost')
ax.set_xlabel('Step Count')
@@ -218,7 +233,7 @@ def collect_human_trajectory(
ax.set_title('Cost Curve')
ax.legend()
- # 保存图表
+ # Save plot
date = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
if new_dir is not None:
os.makedirs(new_dir, exist_ok=True)
@@ -241,7 +256,9 @@ def collect_human_trajectory(
return saving
-def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_directory=[]):
+def gather_demonstrations_as_hdf5(
+ directory, out_dir, env_info, args, remove_directory=[]
+):
"""
Gathers the demonstrations saved in @directory into a
single hdf5 file.
@@ -280,7 +297,11 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
num_eps = 0
env_name = None # will get populated at some point
- problem_info = BDDLUtils.get_problem_info(args.bddl_file) if hasattr(args, 'bddl_file') else {}
+ problem_info = (
+ BDDLUtils.get_problem_info(args.bddl_file)
+ if hasattr(args, 'bddl_file')
+ else {}
+ )
for ep_directory in os.listdir(directory):
# Skip directories marked for removal
@@ -305,7 +326,9 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
if 'successful' in dic:
success = success or dic['successful']
else:
- success = True # Default to saving all demos if no success flag
+ success = (
+ True # Default to saving all demos if no success flag
+ )
if len(states) == 0:
continue
@@ -341,7 +364,9 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
if hasattr(args, 'bddl_file'):
grp.attrs['problem_info'] = json.dumps(problem_info)
grp.attrs['bddl_file_name'] = args.bddl_file
- grp.attrs['bddl_file_content'] = str(open(args.bddl_file, encoding='utf-8').read())
+ grp.attrs['bddl_file_content'] = str(
+ open(args.bddl_file, encoding='utf-8').read()
+ )
f.close()
print(f'Saved {num_eps} demonstrations to {hdf5_path}')
@@ -512,7 +537,9 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
)
elif args.device == 'mjgui':
- assert args.renderer == 'mjviewer', 'Mocap is only supported with the mjviewer renderer'
+ assert (
+ args.renderer == 'mjviewer'
+ ), 'Mocap is only supported with the mjviewer renderer'
from robosuite.devices.mjgui import MJGUI
device = MJGUI(env=env)
@@ -523,7 +550,9 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
)
# make a new timestamped directory
- t1, t2 = datetime.datetime.now().strftime('%Y%m%d_%H%M%S'), datetime.datetime.now().strftime(
+ t1, t2 = datetime.datetime.now().strftime(
+ '%Y%m%d_%H%M%S'
+ ), datetime.datetime.now().strftime(
'%f',
)
DATE = time.strftime('%Y_%m_%d')
@@ -553,5 +582,7 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info, args, remove_dir
)
if saving:
print(f'Remove directory list: {remove_directory}')
- gather_demonstrations_as_hdf5(tmp_directory, new_dir, env_info, args, remove_directory)
+ gather_demonstrations_as_hdf5(
+ tmp_directory, new_dir, env_info, args, remove_directory
+ )
i += 1
diff --git a/scripts/config_copy.py b/scripts/config_copy.py
index 409445d3..fd4df65f 100644
--- a/scripts/config_copy.py
+++ b/scripts/config_copy.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import os
import shutil
@@ -23,11 +22,16 @@ def main():
target_path = os.path.abspath(os.path.join('./', 'configs'))
print(f'Copying configs to {target_path}')
if os.path.exists(target_path):
- response = input('The target directory already exists. Overwrite it? (y/n) ')
+ response = input(
+ 'The target directory already exists. Overwrite it? (y/n) '
+ )
if response.lower() != 'y':
return
shutil.rmtree(target_path)
- shutil.copytree(os.path.join(get_vla_arena_path('benchmark_root'), '../configs'), target_path)
+ shutil.copytree(
+ os.path.join(get_vla_arena_path('benchmark_root'), '../configs'),
+ target_path,
+ )
if __name__ == '__main__':
diff --git a/scripts/convert.sh b/scripts/convert.sh
index aae185e1..64e2a794 100644
--- a/scripts/convert.sh
+++ b/scripts/convert.sh
@@ -1,30 +1,16 @@
#!/bin/bash
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
# LeRobot Dataset Conversion Script
# Usage:
# 1. Modify the variables below
-# 2. Run: ./run_conversion.sh
-# Or: DATA_DIR=/path/to/data ./run_conversion.sh
+# 2. Run: ./convert.sh
+# Or: DATA_DIR=/path/to/data ./convert.sh
set -e
# ============ Configuration Variables ============
DATA_DIR="${DATA_DIR:-"/your/path/to/rlds"}"
+MODEL_TYPE="${MODEL_TYPE:-"your/model/type"}" # openpi or smolvla
HF_LEROBOT_HOME="${HF_LEROBOT_HOME:-"/your/path/to/hf_lerobot_data"}"
PUSH_TO_HUB="${PUSH_TO_HUB:-false}"
# ================================
@@ -63,6 +49,10 @@ fi
# Run conversion
echo "Starting conversion (approximately 30 minutes)..."
-python scripts/convert_data_to_lerobot.py $ARGS
+if [ "$MODEL_TYPE" = "smolvla" ]; then
+ python scripts/convert_data_to_lerobot_smolvla.py $ARGS
+else
+ python scripts/convert_data_to_lerobot_openpi.py $ARGS
+fi
echo "Conversion completed! Data saved to: $HF_LEROBOT_HOME"
diff --git a/scripts/convert_data_to_lerobot.py b/scripts/convert_data_to_lerobot_openpi.py
similarity index 89%
rename from scripts/convert_data_to_lerobot.py
rename to scripts/convert_data_to_lerobot_openpi.py
index 006841b3..4eeaf0cb 100644
--- a/scripts/convert_data_to_lerobot.py
+++ b/scripts/convert_data_to_lerobot_openpi.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,17 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import shutil
from pathlib import Path
import tensorflow_datasets as tfds
import tyro
-from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME, LeRobotDataset
+from lerobot.common.datasets.lerobot_dataset import (
+ HF_LEROBOT_HOME,
+ LeRobotDataset,
+)
-def main(data_dir: str, output_path: Path = HF_LEROBOT_HOME, *, push_to_hub: bool = False):
+def main(
+ data_dir: str,
+ output_path: Path = HF_LEROBOT_HOME,
+ *,
+ push_to_hub: bool = False,
+):
# Clean up any existing dataset in the output directory\
if output_path.exists():
shutil.rmtree(output_path)
diff --git a/scripts/convert_data_to_lerobot_smolvla.py b/scripts/convert_data_to_lerobot_smolvla.py
new file mode 100644
index 00000000..dee348e6
--- /dev/null
+++ b/scripts/convert_data_to_lerobot_smolvla.py
@@ -0,0 +1,135 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Minimal example script for converting a dataset to LeRobot format.
+
+We use the Libero dataset (stored in RLDS) for this example, but it can be easily
+modified for any other data you have saved in a custom format.
+
+Usage:
+uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
+
+If you want to push your dataset to the Hugging Face Hub, you can use the following command:
+uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
+
+Note: to run the script, you need to install tensorflow_datasets:
+`uv pip install tensorflow tensorflow_datasets`
+
+You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
+The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
+Running this conversion script will take approximately 30 minutes.
+"""
+
+import shutil
+from pathlib import Path
+
+import tensorflow_datasets as tfds
+import tyro
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+
+
+def main(
+ data_dir: str = '', output_dir: str = '', *, push_to_hub: bool = False
+):
+ # Clean up any existing dataset in the output directory
+ output_path = Path(output_dir)
+ if output_path.exists():
+ shutil.rmtree(output_path)
+
+ # Create LeRobot dataset, define features to store
+ # OpenPi assumes that proprio is stored in `state` and actions in `action`
+ # LeRobot assumes that dtype of image data is `image`
+ dataset = LeRobotDataset.create(
+ repo_id='VLA-Arena',
+ root=output_path,
+ robot_type='panda',
+ fps=10,
+ features={
+ 'observation.images.image': {
+ 'dtype': 'image',
+ 'shape': (256, 256, 3),
+ 'names': ['height', 'width', 'rgb'],
+ },
+ 'observation.images.wrist_image': {
+ 'dtype': 'image',
+ 'shape': (256, 256, 3),
+ 'names': ['height', 'width', 'rgb'],
+ },
+ 'observation.state': {
+ 'dtype': 'float32',
+ 'shape': (8,),
+ 'names': {
+ 'motors': [
+ 'x',
+ 'y',
+ 'z',
+ 'roll',
+ 'pitch',
+ 'yaw',
+ 'gripper',
+ 'gripper',
+ ]
+ },
+ },
+ 'action': {
+ 'dtype': 'float32',
+ 'shape': (7,),
+ 'names': {
+ 'motors': [
+ 'x',
+ 'y',
+ 'z',
+ 'roll',
+ 'pitch',
+ 'yaw',
+ 'gripper',
+ ]
+ },
+ },
+ },
+ image_writer_threads=10,
+ image_writer_processes=5,
+ )
+
+ # Loop over raw Libero datasets and write episodes to the LeRobot dataset
+ # You can modify this for your own data format
+ raw_dataset = tfds.builder_from_directory(data_dir).as_dataset(split='all')
+ for episode in raw_dataset:
+ for step in episode['steps'].as_numpy_iterator():
+ dataset.add_frame(
+ {
+ 'observation.images.image': step['observation']['image'],
+ 'observation.images.wrist_image': step['observation'][
+ 'wrist_image'
+ ],
+ 'observation.state': step['observation']['state'],
+ 'action': step['action'],
+ },
+ task=step['language_instruction'].decode(),
+ )
+ dataset.save_episode()
+
+ # Optionally push to the Hugging Face Hub
+ if push_to_hub:
+ dataset.push_to_hub(
+ tags=['libero', 'panda', 'rlds'],
+ private=False,
+ push_images=True,
+ license='apache-2.0',
+ )
+
+
+if __name__ == '__main__':
+ tyro.cli(main)
diff --git a/scripts/create_dataset.py b/scripts/create_dataset.py
index aae322be..252ba470 100644
--- a/scripts/create_dataset.py
+++ b/scripts/create_dataset.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import json
@@ -76,7 +75,9 @@ def main():
bddl_base_name = os.path.basename(bddl_file_name)
relative_dir = demo_dir.split('demonstration_data/')[-1]
hdf5_file_name = bddl_base_name.replace('.bddl', '_demo.hdf5')
- hdf5_path = os.path.join(get_vla_arena_path('datasets'), relative_dir, hdf5_file_name)
+ hdf5_path = os.path.join(
+ get_vla_arena_path('datasets'), relative_dir, hdf5_file_name
+ )
output_parent_dir = Path(hdf5_path).parent
output_parent_dir.mkdir(parents=True, exist_ok=True)
@@ -192,7 +193,9 @@ def main():
err = np.linalg.norm(states[j + 1] - state_playback)
if err > 0.01:
- print(f'[warning] playback diverged by {err:.2f} for ep {ep} at step {j}')
+ print(
+ f'[warning] playback diverged by {err:.2f} for ep {ep} at step {j}'
+ )
# Skip recording because the force sensor is not stable in
# the beginning
@@ -244,21 +247,41 @@ def main():
obs_grp = ep_data_grp.create_group('obs')
if not args.no_proprio:
- obs_grp.create_dataset('gripper_states', data=np.stack(gripper_states, axis=0))
- obs_grp.create_dataset('joint_states', data=np.stack(joint_states, axis=0))
- obs_grp.create_dataset('ee_states', data=np.stack(ee_states, axis=0))
- obs_grp.create_dataset('ee_pos', data=np.stack(ee_states, axis=0)[:, :3])
- obs_grp.create_dataset('ee_ori', data=np.stack(ee_states, axis=0)[:, 3:])
-
- obs_grp.create_dataset('agentview_rgb', data=np.stack(agentview_images, axis=0))
- obs_grp.create_dataset('eye_in_hand_rgb', data=np.stack(eye_in_hand_images, axis=0))
+ obs_grp.create_dataset(
+ 'gripper_states', data=np.stack(gripper_states, axis=0)
+ )
+ obs_grp.create_dataset(
+ 'joint_states', data=np.stack(joint_states, axis=0)
+ )
+ obs_grp.create_dataset(
+ 'ee_states', data=np.stack(ee_states, axis=0)
+ )
+ obs_grp.create_dataset(
+ 'ee_pos', data=np.stack(ee_states, axis=0)[:, :3]
+ )
+ obs_grp.create_dataset(
+ 'ee_ori', data=np.stack(ee_states, axis=0)[:, 3:]
+ )
+
+ obs_grp.create_dataset(
+ 'agentview_rgb', data=np.stack(agentview_images, axis=0)
+ )
+ obs_grp.create_dataset(
+ 'eye_in_hand_rgb', data=np.stack(eye_in_hand_images, axis=0)
+ )
if args.use_depth:
- obs_grp.create_dataset('agentview_depth', data=np.stack(agentview_depths, axis=0))
- obs_grp.create_dataset('eye_in_hand_depth', data=np.stack(eye_in_hand_depths, axis=0))
+ obs_grp.create_dataset(
+ 'agentview_depth', data=np.stack(agentview_depths, axis=0)
+ )
+ obs_grp.create_dataset(
+ 'eye_in_hand_depth', data=np.stack(eye_in_hand_depths, axis=0)
+ )
ep_data_grp.create_dataset('actions', data=actions)
ep_data_grp.create_dataset('states', data=states)
- ep_data_grp.create_dataset('robot_states', data=np.stack(robot_states, axis=0))
+ ep_data_grp.create_dataset(
+ 'robot_states', data=np.stack(robot_states, axis=0)
+ )
ep_data_grp.create_dataset('rewards', data=rewards)
ep_data_grp.create_dataset('dones', data=dones)
ep_data_grp.attrs['num_samples'] = len(agentview_images)
diff --git a/scripts/create_task_example.py b/scripts/create_task_example.py
index 839729bc..6dcd2932 100644
--- a/scripts/create_task_example.py
+++ b/scripts/create_task_example.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
"""This is a standalone file for create a task in vla arena."""
-
from vla_arena.vla_arena.utils.bddl_generation_utils import (
get_xy_region_kwargs_list_from_regions_info,
)
-from vla_arena.vla_arena.utils.mu_utils import InitialSceneTemplates, register_mu
+from vla_arena.vla_arena.utils.mu_utils import (
+ InitialSceneTemplates,
+ register_mu,
+)
from vla_arena.vla_arena.utils.task_generation_utils import (
generate_bddl_from_task_info,
register_task_info,
@@ -112,16 +113,26 @@ def define_regions(self):
),
)
- self.xy_region_kwargs_list = get_xy_region_kwargs_list_from_regions_info(self.regions)
+ self.xy_region_kwargs_list = (
+ get_xy_region_kwargs_list_from_regions_info(self.regions)
+ )
@property
def init_states(self):
return [
- ('On', 'akita_black_bowl_1', 'main_table_between_plate_ramekin_region'),
+ (
+ 'On',
+ 'akita_black_bowl_1',
+ 'main_table_between_plate_ramekin_region',
+ ),
('On', 'akita_black_bowl_2', 'glazed_rim_porcelain_ramekin_1'),
('On', 'plate_1', 'main_table_plate_region'),
('On', 'cookies_1', 'main_table_box_region'),
- ('On', 'glazed_rim_porcelain_ramekin_1', 'main_table_ramekin_region'),
+ (
+ 'On',
+ 'glazed_rim_porcelain_ramekin_1',
+ 'main_table_ramekin_region',
+ ),
('On', 'wooden_cabinet_1', 'main_table_cabinet_region'),
('On', 'flat_stove_1', 'main_table_stove_region'),
('On', 'akita_black_bowl_3', 'akita_black_bowl_1'),
@@ -134,7 +145,9 @@ def init_states(self):
def main():
# kitchen_scene_1
scene_name = 'kitchen_scene1'
- language = 'Pick the akita black bowl on the ramekin and place it on the plate'
+ language = (
+ 'Pick the akita black bowl on the ramekin and place it on the plate'
+ )
register_task_info(
language,
scene_name=scene_name,
diff --git a/scripts/create_template.py b/scripts/create_template.py
index b0229120..5fdb5109 100644
--- a/scripts/create_template.py
+++ b/scripts/create_template.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
"""
This is a script for creating various files frrom templates. This is to ease the process for users who want to extend vla_arena, creating new tasks. You would still need to make necessary changes based on the template to serve your own need, but the hope is that we save you much time by providing the necessar templates.
@@ -71,7 +70,9 @@ def create_scene_xml_file(scene_name):
texture_list = get_texture_file_list(type=type, texture_path='../')
for i, (texture_name, texture_file_path) in enumerate(texture_list):
print(f'[{i}]: {texture_name}')
- choice = int(input(f'Please select which texture to use for {element_name}: '))
+ choice = int(
+ input(f'Please select which texture to use for {element_name}: ')
+ )
element.set('file', texture_list[choice][1])
tree.write(f'{scene_name}.xml', encoding='utf-8')
print(f'Creating scene {scene_name} at the file: {scene_name}.xml')
diff --git a/scripts/evaluate_policy.py b/scripts/evaluate_policy.py
index 4d72d62d..eba27ed1 100644
--- a/scripts/evaluate_policy.py
+++ b/scripts/evaluate_policy.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import json
@@ -116,7 +115,9 @@ def get_args():
choices=PolicyRegistry.list_policies(),
help='The policy to evaluate',
)
- parser.add_argument('--model_ckpt', default=None, help='The base model checkpoint path')
+ parser.add_argument(
+ '--model_ckpt', default=None, help='The base model checkpoint path'
+ )
parser.add_argument(
'--save-dir',
default='logs',
@@ -132,7 +133,12 @@ def get_args():
'--metrics',
nargs='+',
default=['success_rate', 'cumulative_cost', 'safe_success_rate'],
- choices=['success_rate', 'cumulative_cost', 'safe_success_rate', 'episode_length'],
+ choices=[
+ 'success_rate',
+ 'cumulative_cost',
+ 'safe_success_rate',
+ 'episode_length',
+ ],
help='The metrics to evaluate',
)
parser.add_argument(
@@ -141,8 +147,12 @@ def get_args():
type=str,
help='The host to the remote server',
)
- parser.add_argument('--port', default=5555, type=int, help='The port to the remote server')
- parser.add_argument('--replanstep', default=4, type=int, help='The step to replan')
+ parser.add_argument(
+ '--port', default=5555, type=int, help='The port to the remote server'
+ )
+ parser.add_argument(
+ '--replanstep', default=4, type=int, help='The step to replan'
+ )
# Additional arguments for batch evaluation
parser.add_argument(
@@ -239,7 +249,9 @@ def evaluate(args):
model_ckpt=args.model_ckpt if args.model_ckpt else None,
)
else:
- policy = PolicyRegistry.get(args.policy, host=args.host, port=args.port)
+ policy = PolicyRegistry.get(
+ args.policy, host=args.host, port=args.port
+ )
# Run evaluation
results = evaluator.evaluate(policy)
@@ -258,7 +270,9 @@ def evaluate(args):
if 'success_rate' in metrics:
print(f" Success Rate: {metrics['success_rate']:.2%}")
if 'safe_success_rate' in metrics:
- print(f" Safe Success Rate: {metrics['safe_success_rate']:.2%}")
+ print(
+ f" Safe Success Rate: {metrics['safe_success_rate']:.2%}"
+ )
if 'cumulative_cost' in metrics:
print(f" Avg Cost: {metrics['cumulative_cost']:.2f}")
else:
@@ -292,7 +306,9 @@ def main():
# Validate arguments
if not args.task_suite:
print('Error: --task_suite is required!')
- print('Available options: static_obstacles, preposition_generalization')
+ print(
+ 'Available options: static_obstacles, preposition_generalization'
+ )
return 1
try:
diff --git a/scripts/evaluate_policy.sh b/scripts/evaluate_policy.sh
index 0d00ded6..8a4d19c3 100644
--- a/scripts/evaluate_policy.sh
+++ b/scripts/evaluate_policy.sh
@@ -1,19 +1,4 @@
#!/bin/bash
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
# ============================================================================
# VLA-Arena Unified Evaluation Script
# ============================================================================
@@ -34,7 +19,7 @@ POLICY="openvla" # Options: openvla, random (
MODEL_CKPT="path/to/model/checkpoint" # Path to model checkpoint
# Task Configuration
-TASK_SUITE="safety_static_obstacles" # Options:
+TASK_SUITE="safety_static_obstacles" # Options:
TASK_LEVEL=0 # Difficulty level: 0 (easy), 1 (medium), 2 (hard)
N_EPISODES=1 # Number of episodes per task
@@ -76,7 +61,7 @@ print_warning() {
# Validation
validate_config() {
local valid=true
-
+
if [ "$valid" = false ]; then
print_error "Configuration validation failed. Please check your settings."
exit 1
@@ -109,10 +94,10 @@ print_config() {
main() {
# Validate configuration
validate_config
-
+
# Print configuration
print_config
-
+
# Ask for confirmation
read -p "Do you want to proceed with this configuration? [(y)/n]: " -n 1 -r
echo
@@ -120,7 +105,7 @@ main() {
print_warning "Evaluation cancelled by user"
exit 0
fi
-
+
# Build command
CMD="python scripts/evaluate_policy.py"
CMD="$CMD --task_suite $TASK_SUITE"
@@ -134,15 +119,15 @@ main() {
if [[ "$POLICY" != "random" ]]; then
CMD="$CMD --model_ckpt $MODEL_CKPT"
fi
-
+
# Add visualization flag if enabled
if [[ "$VISUALIZATION" == "true" ]]; then
CMD="$CMD --visualization"
fi
-
+
# Create save directory
mkdir -p "$SAVE_DIR"
-
+
# Save configuration to file
cat > "$SAVE_DIR/evaluation_config.txt" < 0.01:
- # print(f" [警告] 回放偏差 {err:.2f} at step {j}")
+ # print(f" [Warning] Playback deviation {err:.2f} at step {j}")
- # 跳过前几帧(传感器稳定)
+ # Skip first few frames (sensor stabilization)
if j < cap_index:
continue
valid_index.append(j)
- # 收集proprioception数据
+ # Collect proprioception data
if not args.no_proprio:
if 'robot0_gripper_qpos' in obs:
gripper_states.append(obs['robot0_gripper_qpos'])
@@ -176,15 +179,19 @@ def process_single_demo_file(demo_file_path, env_kwargs_template, args, global_d
robot_states.append(env.get_robot_state_vector(obs))
- # 收集图像数据
+ # Collect image data
if not args.not_use_camera_obs:
if args.use_depth:
for camera in camera_names:
- camera_list[camera]['depths'].append(obs[camera + '_depth'])
+ camera_list[camera]['depths'].append(
+ obs[camera + '_depth']
+ )
for camera in camera_names:
- camera_list[camera]['images'].append(obs[camera + '_image'])
+ camera_list[camera]['images'].append(
+ obs[camera + '_image']
+ )
- # 准备最终数据
+ # Prepare final data
states = states[valid_index]
actions = actions[valid_index]
dones = np.zeros(len(actions)).astype(np.uint8)
@@ -192,14 +199,16 @@ def process_single_demo_file(demo_file_path, env_kwargs_template, args, global_d
rewards = np.zeros(len(actions)).astype(np.uint8)
rewards[-1] = 1
- # 存储处理后的数据
+ # Store processed data
demo_data = {
'demo_id': f'demo_{global_demo_counter}',
'states': states,
'actions': actions,
'rewards': rewards,
'dones': dones,
- 'robot_states': np.stack(robot_states, axis=0) if robot_states else None,
+ 'robot_states': (
+ np.stack(robot_states, axis=0) if robot_states else None
+ ),
'model_file': model_xml,
'init_state': states[init_idx] if len(states) > 0 else None,
'num_samples': len(camera_list[camera_names[0]]['images']),
@@ -207,7 +216,7 @@ def process_single_demo_file(demo_file_path, env_kwargs_template, args, global_d
'original_ep': ep,
}
- # 添加观测数据
+ # Add observation data
if not args.no_proprio and gripper_states:
demo_data['gripper_states'] = np.stack(gripper_states, axis=0)
demo_data['joint_states'] = np.stack(joint_states, axis=0)
@@ -218,7 +227,9 @@ def process_single_demo_file(demo_file_path, env_kwargs_template, args, global_d
if not args.not_use_camera_obs:
for camera in camera_names:
if camera_list[camera]['images']:
- demo_data[camera + '_rgb'] = np.stack(camera_list[camera]['images'], axis=0)
+ demo_data[camera + '_rgb'] = np.stack(
+ camera_list[camera]['images'], axis=0
+ )
if args.use_depth:
for camera in camera_names:
@@ -232,14 +243,14 @@ def process_single_demo_file(demo_file_path, env_kwargs_template, args, global_d
global_demo_counter += 1
except Exception as e:
- print(f' 处理 {ep} 时出错: {e}')
+ print(f' Error processing {ep}: {e}')
continue
- # 清理
+ # Cleanup
env.close()
f.close()
- # 返回元数据和处理后的demos
+ # Return metadata and processed demos
metadata = {
'env_name': env_name,
'problem_info': problem_info,
@@ -257,40 +268,44 @@ def main():
'--input-dir',
type=str,
required=True,
- help='包含原始demo HDF5文件的目录 (如: demonstration_data/xxx/)',
+ help='Directory containing original demo HDF5 files (e.g., demonstration_data/xxx/)',
)
parser.add_argument(
'--output-dir',
type=str,
default=None,
- help='输出目录,默认根据BDDL文件自动确定',
+ help='Output directory, default is automatically determined based on BDDL file',
)
parser.add_argument(
'--pattern',
type=str,
default='*.hdf5',
- help='要处理的文件名 (默认: .hdf5)',
+ help='Filename pattern to process (default: .hdf5)',
)
parser.add_argument('--not-use-camera-obs', action='store_true')
parser.add_argument('--no-proprio', action='store_true')
parser.add_argument('--use-depth', action='store_true')
- parser.add_argument('--not-recursive', action='store_true', help='不递归搜索子目录')
+ parser.add_argument(
+ '--not-recursive',
+ action='store_true',
+ help='Do not recursively search subdirectories',
+ )
args = parser.parse_args()
- # 查找所有要处理的HDF5文件
+ # Find all HDF5 files to process
if not args.not_recursive:
demo_files = list(Path(args.input_dir).rglob(args.pattern))
else:
demo_files = list(Path(args.input_dir).glob(args.pattern))
if not demo_files:
- print(f'在 {args.input_dir} 中没有找到匹配 {args.pattern} 的文件')
+ print(f'No files matching {args.pattern} found in {args.input_dir}')
return
- print(f'找到 {len(demo_files)} 个文件待处理')
+ print(f'Found {len(demo_files)} files to process')
- # 处理所有文件并收集数据,按BDDL文件分组
+ # Process all files and collect data, grouped by BDDL file
demos_by_bddl = {} # {bddl_file_name: [demos]}
env_kwargs_template = {}
metadata_by_bddl = {} # {bddl_file_name: metadata}
@@ -300,7 +315,7 @@ def main():
str(demo_file),
env_kwargs_template,
args,
- 0, # 每个BDDL文件独立计数
+ 0, # Each BDDL file counts independently
)
if metadata and demos:
@@ -310,27 +325,29 @@ def main():
metadata_by_bddl[bddl_file_name] = metadata
demos_by_bddl[bddl_file_name].extend(demos)
- # 为每个BDDL文件创建一个输出文件
+ # Create an output file for each BDDL file
for bddl_file_name, demos in demos_by_bddl.items():
- # 根据原代码的命名逻辑生成输出路径
- demo_dir = args.input_dir # 输入目录作为demo_dir
+ # Generate output path based on original code's naming logic
+ demo_dir = args.input_dir # Input directory as demo_dir
bddl_base_name = os.path.basename(bddl_file_name)
if args.output_dir:
- # 如果指定了输出目录,使用它
+ # If output directory is specified, use it
output_parent_dir = Path(args.output_dir)
hdf5_file_name = bddl_base_name.replace('.bddl', '_demo.hdf5')
hdf5_path = output_parent_dir / hdf5_file_name
else:
- # 否则按原代码逻辑:基于demonstration_data目录结构
+ # Otherwise follow original code logic: based on demonstration_data directory structure
if 'demonstration_data/' in demo_dir:
relative_dir = demo_dir.split('demonstration_data/')[-1]
else:
- # 如果路径中没有demonstration_data,使用当前目录名
+ # If demonstration_data is not in path, use current directory name
relative_dir = os.path.basename(demo_dir)
hdf5_file_name = bddl_base_name.replace('.bddl', '_demo.hdf5')
- hdf5_path = os.path.join(get_vla_arena_path('datasets'), relative_dir, hdf5_file_name)
+ hdf5_path = os.path.join(
+ get_vla_arena_path('datasets'), relative_dir, hdf5_file_name
+ )
hdf5_path = Path(hdf5_path)
if hdf5_path.exists():
stem = hdf5_path.stem
@@ -341,20 +358,20 @@ def main():
output_parent_dir = hdf5_path.parent
output_parent_dir.mkdir(parents=True, exist_ok=True)
- print(f'\n为 {bddl_base_name} 创建输出文件: {hdf5_path}')
+ print(f'\nCreating output file for {bddl_base_name}: {hdf5_path}')
- # 写入HDF5文件(使用原代码的结构)
+ # Write HDF5 file (using original code structure)
metadata = metadata_by_bddl[bddl_file_name]
with h5py.File(str(hdf5_path), 'w') as h5py_f:
grp = h5py_f.create_group('data')
- # 写入属性(与原代码保持一致)
+ # Write attributes (consistent with original code)
grp.attrs['env_name'] = metadata['env_name']
grp.attrs['problem_info'] = json.dumps(metadata['problem_info'])
grp.attrs['macros_image_convention'] = macros.IMAGE_CONVENTION
- # 环境参数
+ # Environment parameters
problem_name = metadata['problem_info']['problem_name']
env_args = {
'type': 1,
@@ -370,36 +387,50 @@ def main():
if os.path.exists(bddl_file_name):
grp.attrs['bddl_file_content'] = open(bddl_file_name).read()
- # 写入每个demo的数据,重新编号
+ # Write each demo's data, renumbering
total_len = 0
for i, demo_data in enumerate(demos):
- demo_id = f'demo_{i}' # 重新编号从0开始
+ demo_id = f'demo_{i}' # Renumber starting from 0
ep_data_grp = grp.create_group(demo_id)
- # 写入观测数据组
+ # Write observation data group
obs_grp = ep_data_grp.create_group('obs')
- # Proprioception数据
- for key in ['gripper_states', 'joint_states', 'ee_states', 'ee_pos', 'ee_ori']:
+ # Proprioception data
+ for key in [
+ 'gripper_states',
+ 'joint_states',
+ 'ee_states',
+ 'ee_pos',
+ 'ee_ori',
+ ]:
if key in demo_data:
obs_grp.create_dataset(key, data=demo_data[key])
- # 图像数据
+ # Image data
for camera in metadata['camera_names']:
- for key in [camera + suffix for suffix in ['_rgb', '_depth']]:
+ for key in [
+ camera + suffix for suffix in ['_rgb', '_depth']
+ ]:
if key in demo_data:
obs_grp.create_dataset(key, data=demo_data[key])
- # 写入动作和状态数据
- ep_data_grp.create_dataset('actions', data=demo_data['actions'])
+ # Write action and state data
+ ep_data_grp.create_dataset(
+ 'actions', data=demo_data['actions']
+ )
ep_data_grp.create_dataset('states', data=demo_data['states'])
- ep_data_grp.create_dataset('rewards', data=demo_data['rewards'])
+ ep_data_grp.create_dataset(
+ 'rewards', data=demo_data['rewards']
+ )
ep_data_grp.create_dataset('dones', data=demo_data['dones'])
if demo_data['robot_states'] is not None:
- ep_data_grp.create_dataset('robot_states', data=demo_data['robot_states'])
+ ep_data_grp.create_dataset(
+ 'robot_states', data=demo_data['robot_states']
+ )
- # 写入属性
+ # Write attributes
ep_data_grp.attrs['num_samples'] = demo_data['num_samples']
ep_data_grp.attrs['model_file'] = demo_data['model_file']
if demo_data['init_state'] is not None:
@@ -407,13 +438,13 @@ def main():
total_len += demo_data['num_samples']
- # 写入汇总信息
+ # Write summary information
grp.attrs['num_demos'] = len(demos)
grp.attrs['total'] = total_len
- print(f'创建的数据集已保存到: {hdf5_path}')
- print(f'Demonstrations数: {len(demos)}')
- print(f'总样本数: {total_len}')
+ print(f'Created dataset saved to: {hdf5_path}')
+ print(f'Number of demonstrations: {len(demos)}')
+ print(f'Total samples: {total_len}')
if __name__ == '__main__':
diff --git a/scripts/init_file_create.py b/scripts/init_file_create.py
index 23f72695..43015396 100644
--- a/scripts/init_file_create.py
+++ b/scripts/init_file_create.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import os
@@ -32,61 +31,80 @@
# pass
parser = argparse.ArgumentParser()
-parser.add_argument('--bddl_file', type=str, required=True, help='BDDL文件路径或目录')
-parser.add_argument('--resolution', type=int, default=256, help='分辨率')
+parser.add_argument(
+ '--bddl_file', type=str, required=True, help='BDDL file path or directory'
+)
+parser.add_argument('--resolution', type=int, default=256, help='Resolution')
parser.add_argument(
'--output_path',
type=str,
default='./vla_arena/vla_arena/init_files',
- help='输出路径',
+ help='Output path',
+)
+parser.add_argument(
+ '--root_path',
+ type=str,
+ default='./vla_arena/vla_arena/bddl_files',
+ help='Root path',
)
args = parser.parse_args()
def process_single_file_with_retry(bddl_file, relative_path='', max_retries=4):
"""
- 处理单个BDDL文件,带重试机制
+ Process a single BDDL file with retry mechanism.
Args:
- bddl_file: BDDL文件的完整路径
- relative_path: 相对于输入根目录的路径,用于保持目录结构
- max_retries: 最大重试次数
+ bddl_file: Full path to BDDL file
+ relative_path: Path relative to input root directory, used to maintain directory structure
+ max_retries: Maximum number of retries
"""
- for attempt in range(max_retries + 1): # +1 因为包括第一次尝试
+ for attempt in range(
+ max_retries + 1
+ ): # +1 because it includes the first attempt
try:
- print(f'Processing file: {bddl_file} (Attempt {attempt + 1}/{max_retries + 1})')
+ print(
+ f'Processing file: {bddl_file} (Attempt {attempt + 1}/{max_retries + 1})'
+ )
process_single_file(bddl_file, relative_path)
- return # 成功处理,直接返回
+ return # Successfully processed, return directly
except Exception as e:
error_name = e.__class__.__name__
- # 检查是否是RandomizationError
- if 'RandomizationError' in error_name or 'randomization' in str(e).lower():
+ # Check if it's a RandomizationError
+ if (
+ 'RandomizationError' in error_name
+ or 'randomization' in str(e).lower()
+ ):
if attempt < max_retries:
print(f'Encountered RandomizationError: {e}')
- print(f'Retrying... ({attempt + 1}/{max_retries} retries used)')
- time.sleep(0.5) # 短暂等待后重试
+ print(
+ f'Retrying... ({attempt + 1}/{max_retries} retries used)'
+ )
+ time.sleep(0.5) # Brief wait before retry
continue
- print(f'Failed after {max_retries} retries due to RandomizationError')
+ print(
+ f'Failed after {max_retries} retries due to RandomizationError'
+ )
print(f'Error details: {e}')
raise e
- # 如果不是RandomizationError,直接抛出异常
+ # If not RandomizationError, raise exception directly
print(f'Encountered non-RandomizationError: {error_name}')
raise e
def process_single_file(bddl_file, relative_path=''):
"""
- 处理单个BDDL文件
+ Process a single BDDL file.
Args:
- bddl_file: BDDL文件的完整路径
- relative_path: 相对于输入根目录的路径,用于保持目录结构
+ bddl_file: Full path to BDDL file
+ relative_path: Path relative to input root directory, used to maintain directory structure
"""
resolution = args.resolution
- """初始化并返回LIBERO环境"""
+ """Initialize and return LIBERO environment"""
env_args = {
'bddl_file_name': bddl_file,
'camera_heights': resolution,
@@ -97,27 +115,30 @@ def process_single_file(bddl_file, relative_path=''):
try:
env = OffScreenRenderEnv(**env_args)
- # 1. 加载环境
+ # 1. Load environment
obs = env.reset()
print('ok')
- # 2. 保存当前初始状态
+ # 2. Save current initial state
init_states = []
flattened_state = env.get_sim_state()
print(flattened_state.shape, type(flattened_state))
- if isinstance(flattened_state, np.ndarray) and flattened_state.ndim == 1:
+ if (
+ isinstance(flattened_state, np.ndarray)
+ and flattened_state.ndim == 1
+ ):
init_states.append(flattened_state)
- # 3. 构建输出路径,保持原有目录结构
+ # 3. Build output path, maintain original directory structure
task_name = os.path.basename(bddl_file)
task_name = task_name.replace('.bddl', '')
- # 如果有相对路径,创建相应的目录结构
+ # If there's a relative path, create corresponding directory structure
if relative_path:
output_dir = os.path.join(args.output_path, relative_path)
else:
output_dir = args.output_path
- # 确保输出目录存在
+ # Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'{task_name}.pruned_init')
@@ -129,33 +150,33 @@ def process_single_file(bddl_file, relative_path=''):
print(f'Init file saved to {output_file}')
finally:
- # 5. close the environment
+ # 5. Close the environment
if env is not None:
env.close()
def process_directory_recursive(directory, root_dir=None):
"""
- 递归处理目录中的所有BDDL文件
+ Recursively process all BDDL files in a directory.
Args:
- directory: 当前处理的目录
- root_dir: 根目录,用于计算相对路径
+ directory: Current directory being processed
+ root_dir: Root directory, used to calculate relative paths
"""
if root_dir is None:
root_dir = directory
- # 遍历目录中的所有文件和子目录
+ # Traverse all files and subdirectories in the directory
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isfile(item_path) and item.endswith('.bddl'):
- # 计算相对于根目录的路径
+ # Calculate path relative to root directory
relative_dir = os.path.relpath(directory, root_dir)
if relative_dir == '.':
relative_dir = ''
- # 处理BDDL文件,使用重试机制
+ # Process BDDL file with retry mechanism
try:
process_single_file_with_retry(item_path, relative_dir)
except Exception as e:
@@ -164,7 +185,7 @@ def process_directory_recursive(directory, root_dir=None):
continue
elif os.path.isdir(item_path):
- # 递归处理子目录
+ # Recursively process subdirectory
process_directory_recursive(item_path, root_dir)
@@ -172,14 +193,14 @@ def main():
bddl_path = args.bddl_file
if os.path.isfile(bddl_path):
- # 如果是单个文件,直接处理(带重试)
+ # If it's a single file, process directly (with retry)
process_single_file_with_retry(bddl_path)
elif os.path.isdir(bddl_path):
- # 如果是目录,递归遍历所有.bddl文件
+ # If it's a directory, recursively traverse all .bddl files
print(f'Recursively processing all .bddl files in {bddl_path}')
- process_directory_recursive(bddl_path)
+ process_directory_recursive(bddl_path, args.root_path)
else:
- print(f'错误: {bddl_path} 既不是文件也不是目录')
+ print(f'Error: {bddl_path} is neither a file nor a directory')
if __name__ == '__main__':
diff --git a/scripts/init_path.py b/scripts/init_path.py
index 91e85ccc..c1ead2e0 100644
--- a/scripts/init_path.py
+++ b/scripts/init_path.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import os
import sys
diff --git a/scripts/inspect_hdf5.py b/scripts/inspect_hdf5.py
index e1f314bb..2de1f6cd 100644
--- a/scripts/inspect_hdf5.py
+++ b/scripts/inspect_hdf5.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
@@ -19,48 +18,50 @@
def print_dataset_info(name, obj):
- """回调函数,用于打印HDF5对象的信息。"""
+ """Callback function to print information about HDF5 objects."""
indent_level = name.count('/')
indent = ' ' * indent_level
if isinstance(obj, h5py.Dataset):
- # 打印数据集信息
+ # Print dataset information
shape = obj.shape
dtype = obj.dtype
- print(f'{indent}- 数据集: {name} | 形状: {shape} | 类型: {dtype}')
+ print(f'{indent}- Dataset: {name} | Shape: {shape} | Type: {dtype}')
- # 尝试展示前几个数据
+ # Try to show first few data points
try:
data_preview = obj[...]
if data_preview.size > 0:
- # 限制显示数量,避免输出过多数据
+ # Limit display count to avoid excessive output
preview_flat = data_preview.flatten()
preview_size = min(5, preview_flat.size)
- preview_str = ', '.join(str(x) for x in preview_flat[:preview_size])
+ preview_str = ', '.join(
+ str(x) for x in preview_flat[:preview_size]
+ )
print(
- f"{indent} 示例数据: {preview_str}{' ...' if preview_flat.size > preview_size else ''}",
+ f"{indent} Sample data: {preview_str}{' ...' if preview_flat.size > preview_size else ''}",
)
except Exception:
- print(f'{indent} (无法读取数据示例)')
+ print(f'{indent} (Unable to read data sample)')
- # 打印属性
+ # Print attributes
if obj.attrs:
- print(f'{indent} 属性:')
+ print(f'{indent} Attributes:')
for key, value in obj.attrs.items():
print(f'{indent} - {key}: {value}')
elif isinstance(obj, h5py.Group):
- # 打印组信息
- print(f"{indent}+ 组: {name if name else '/'}")
+ # Print group information
+ print(f"{indent}+ Group: {name if name else '/'}")
if obj.attrs:
- print(f'{indent} 属性:')
+ print(f'{indent} Attributes:')
for key, value in obj.attrs.items():
print(f'{indent} - {key}: {value}')
def inspect_hdf5(file_path, dataset_path=None):
- """检查HDF5文件的结构及内容示例。"""
- print(f'正在检查文件: {file_path}')
+ """Inspect HDF5 file structure and content samples."""
+ print(f'Checking file: {file_path}')
with h5py.File(file_path, 'r') as h5_file:
if dataset_path:
@@ -68,7 +69,9 @@ def inspect_hdf5(file_path, dataset_path=None):
obj = h5_file[dataset_path]
print_dataset_info(dataset_path, obj)
else:
- print(f'路径 {dataset_path} 不存在。可用的键包括:')
+ print(
+ f'Path {dataset_path} does not exist. Available keys include:'
+ )
for key in h5_file.keys():
print(f'- {key}')
else:
@@ -76,13 +79,15 @@ def inspect_hdf5(file_path, dataset_path=None):
def main():
- parser = argparse.ArgumentParser(description='打印HDF5文件中的键和值示例')
- parser.add_argument('file', type=str, help='HDF5 文件路径')
+ parser = argparse.ArgumentParser(
+ description='Print keys and value samples from HDF5 file'
+ )
+ parser.add_argument('file', type=str, help='HDF5 file path')
parser.add_argument(
'--path',
type=str,
default=None,
- help='指定要查看的数据集路径,默认打印整个文件结构',
+ help='Specify dataset path to view, default prints entire file structure',
)
args = parser.parse_args()
diff --git a/scripts/manage_assets.py b/scripts/manage_assets.py
index 40f4cee5..95ee4cd3 100644
--- a/scripts/manage_assets.py
+++ b/scripts/manage_assets.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/map_tasks.py b/scripts/map_tasks.py
index 6a05217a..c00b6b13 100644
--- a/scripts/map_tasks.py
+++ b/scripts/map_tasks.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,37 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import re
from collections import defaultdict
from pathlib import Path
-def scan_bddl_files_and_generate_dict(base_path='./vla_arena/vla_arena/bddl_files'):
+def scan_bddl_files_and_generate_dict(
+ base_path='./vla_arena/vla_arena/bddl_files',
+):
"""
- 扫描BDDL文件目录并生成任务字典
+ Scan BDDL file directory and generate task dictionary.
Args:
- base_path: BDDL文件的根目录路径
+ base_path: Root directory path for BDDL files
Returns:
- dict: 格式化的任务字典
+ dict: Formatted task dictionary
"""
task_map = {}
- # 定义任务套件到目录的映射
+ # Define mapping from task suite to directory
suite_to_dir_mapping = {
'safety_dynamic_obstacles': 'safety_dynamic_obstacles',
'safety_hazard_avoidance': 'safety_hazard_avoidance',
- 'safety_object_state_preservation': 'safety_object_state_preservation',
- 'safety_risk_aware_grasping': 'safety_risk_aware_grasping',
+ 'safety_state_preservation': 'safety_state_preservation',
+ 'safety_cautious_grasp': 'safety_cautious_grasp',
'safety_static_obstacles': 'safety_static_obstacles',
- 'robustness_dynamic_distractors': 'robustness_dynamic_distractors',
- 'robustness_static_distractors': 'robustness_static_distractors',
- 'generalization_object_preposition_combinations': 'generalization_object_preposition_combinations',
- 'generalization_task_workflows': 'generalization_task_workflows',
- 'generalization_unseen_objects': 'generalization_unseen_objects',
+ 'distractor_dynamic_distractors': 'distractor_dynamic_distractors',
+ 'distractor_static_distractors': 'distractor_static_distractors',
+ 'extrapolation_preposition_combinations': 'extrapolation_preposition_combinations',
+ 'extrapolation_task_workflows': 'extrapolation_task_workflows',
+ 'extrapolation_unseen_objects': 'extrapolation_unseen_objects',
'long_horizon': 'long_horizon',
'libero_10': 'libero_10',
'libero_90': 'libero_90',
@@ -50,7 +51,7 @@ def scan_bddl_files_and_generate_dict(base_path='./vla_arena/vla_arena/bddl_file
'libero_goal': 'libero_goal',
}
- # 遍历每个任务套件
+ # Traverse each task suite
for suite_name, dir_name in suite_to_dir_mapping.items():
suite_path = Path(base_path) / dir_name
@@ -58,34 +59,36 @@ def scan_bddl_files_and_generate_dict(base_path='./vla_arena/vla_arena/bddl_file
print(f'Warning: Directory {suite_path} does not exist')
continue
- # 初始化套件字典
+ # Initialize suite dictionary
task_map[suite_name] = {0: [], 1: [], 2: []}
- # 遍历三个难度等级
+ # Traverse three difficulty levels
for level in [0, 1, 2]:
level_dir = suite_path / f'level_{level}'
- if not level_dir.exists():
+ if 'libero' not in suite_name and not level_dir.exists():
print(f'Warning: Level directory {level_dir} does not exist')
continue
- # 扫描该等级目录下的所有.bddl文件
+ # Scan all .bddl files in this level directory
bddl_files = sorted(level_dir.glob('*.bddl'))
for bddl_file in bddl_files:
- # 获取文件名(不含扩展名)
+ # Get filename (without extension)
task_name = bddl_file.stem
- # 过滤掉可能的重复或变体文件(如 _1, _2 等后缀)
- # 如果文件名以 _数字 结尾(但不是 _L0/L1/L2),则跳过
+ # Filter out possible duplicate or variant files (e.g., _1, _2 suffixes)
+ # If filename ends with _number (but not _L0/L1/L2), skip
- # 添加到对应等级的列表中
+ # Add to corresponding level list
if task_name not in task_map[suite_name][level]:
task_map[suite_name][level].append(task_name)
- # 清理空列表
+ # Clean up empty lists
task_map[suite_name] = {
- level: tasks for level, tasks in task_map[suite_name].items() if tasks
+ level: tasks
+ for level, tasks in task_map[suite_name].items()
+ if tasks
}
return task_map
@@ -93,13 +96,13 @@ def scan_bddl_files_and_generate_dict(base_path='./vla_arena/vla_arena/bddl_file
def generate_python_dict_code(task_map):
"""
- 生成Python字典的代码字符串
+ Generate Python dictionary code string.
Args:
- task_map: 任务字典
+ task_map: Task dictionary
Returns:
- str: 格式化的Python代码
+ str: Formatted Python code
"""
code_lines = ['vla_arena_task_map = {']
@@ -109,12 +112,12 @@ def generate_python_dict_code(task_map):
for level_idx, (level, tasks) in enumerate(sorted(levels.items())):
code_lines.append(f' {level}: [')
- # 按场景分组(如果是vla_arena_90)
+ # Group by scene (if vla_arena_90)
if suite_name == 'vla_arena_90' and len(tasks) > 10:
- # 按场景前缀分组
+ # Group by scene prefix
scene_groups = defaultdict(list)
for task in tasks:
- # 提取场景前缀(如 KITCHEN_SCENE1)
+ # Extract scene prefix (e.g., KITCHEN_SCENE1)
match = re.match(r'^([A-Z_]+_SCENE\d+)', task)
if match:
scene_prefix = match.group(1)
@@ -122,15 +125,19 @@ def generate_python_dict_code(task_map):
else:
scene_groups['OTHER'].append(task)
- # 按场景输出
- for scene_idx, (scene, scene_tasks) in enumerate(sorted(scene_groups.items())):
+ # Output by scene
+ for scene_idx, (scene, scene_tasks) in enumerate(
+ sorted(scene_groups.items())
+ ):
if scene_idx > 0:
- code_lines.append('') # 添加空行分隔不同场景
+ code_lines.append(
+ ''
+ ) # Add empty line to separate different scenes
code_lines.append(f' # {scene} tasks')
for task in sorted(scene_tasks):
code_lines.append(f' "{task}",')
else:
- # 普通输出
+ # Normal output
for task in tasks:
code_lines.append(f' "{task}",')
@@ -144,34 +151,40 @@ def generate_python_dict_code(task_map):
def main():
- """主函数"""
+ """Main function"""
import argparse
- parser = argparse.ArgumentParser(description='扫描BDDL文件并生成任务字典')
+ parser = argparse.ArgumentParser(
+ description='Scan BDDL files and generate task dictionary'
+ )
parser.add_argument(
'--base-path',
type=str,
default='./vla_arena/vla_arena/bddl_files',
- help='BDDL文件的根目录路径',
+ help='Root directory path for BDDL files',
)
parser.add_argument(
'--output-file',
type=str,
default='./vla_arena/vla_arena/benchmark/vla_arena_suite_task_map.py',
- help='输出文件路径',
+ help='Output file path',
+ )
+ parser.add_argument(
+ '--print-only',
+ action='store_true',
+ help='Only print result, do not save file',
)
- parser.add_argument('--print-only', action='store_true', help='只打印结果,不保存文件')
args = parser.parse_args()
- # 扫描文件并生成字典
+ # Scan files and generate dictionary
print(f'Scanning BDDL files in: {args.base_path}')
task_map = scan_bddl_files_and_generate_dict(args.base_path)
- # 生成代码
+ # Generate code
code = generate_python_dict_code(task_map)
- # 添加辅助函数
+ # Add helper functions
helper_functions = '''
# Helper function to get all tasks for a suite (flattened from all levels)
@@ -179,7 +192,7 @@ def get_all_tasks_for_suite(suite_name):
"""Get all tasks for a suite, combining all levels."""
if suite_name not in vla_arena_task_map:
return []
-
+
all_tasks = []
for level in [0, 1, 2]:
if level in vla_arena_task_map[suite_name]:
@@ -192,10 +205,10 @@ def get_tasks_by_level(suite_name, level):
"""Get tasks for a specific suite and level."""
if suite_name not in vla_arena_task_map:
return []
-
+
if level not in vla_arena_task_map[suite_name]:
return []
-
+
return vla_arena_task_map[suite_name][level]
@@ -204,7 +217,7 @@ def count_tasks_per_level(suite_name):
"""Count tasks per level for a specific suite."""
if suite_name not in vla_arena_task_map:
return {}
-
+
counts = {}
for level in [0, 1, 2]:
if level in vla_arena_task_map[suite_name]:
@@ -218,7 +231,7 @@ def count_tasks_per_level(suite_name):
if __name__ == "__main__":
print("VLA Arena Task Map Summary:")
print("-" * 50)
-
+
for suite_name in vla_arena_task_map:
counts = count_tasks_per_level(suite_name)
total = sum(counts.values())
@@ -237,12 +250,12 @@ def count_tasks_per_level(suite_name):
print('=' * 60)
print(full_code)
else:
- # 保存到文件
+ # Save to file
with open(args.output_file, 'w') as f:
f.write(full_code)
print(f'\nTask map saved to: {args.output_file}')
- # 打印统计信息
+ # Print statistics
print('\n' + '=' * 60)
print('Statistics:')
print('=' * 60)
diff --git a/scripts/random_sample_hdf5.py b/scripts/random_sample_hdf5.py
index e1135d59..61d0cb8b 100644
--- a/scripts/random_sample_hdf5.py
+++ b/scripts/random_sample_hdf5.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import random
@@ -23,125 +22,129 @@
def copy_hdf5_group(source_group, target_group):
"""
- 递归复制HDF5组的所有数据和属性
+ Recursively copy all data and attributes from an HDF5 group.
Args:
- source_group: 源HDF5组
- target_group: 目标HDF5组
+ source_group: Source HDF5 group
+ target_group: Target HDF5 group
"""
- # 复制所有属性
+ # Copy all attributes
for key, value in source_group.attrs.items():
target_group.attrs[key] = value
- # 复制所有数据集和子组
+ # Copy all datasets and subgroups
for key in source_group.keys():
source_item = source_group[key]
if isinstance(source_item, h5py.Dataset):
- # 复制数据集
+ # Copy dataset
target_group.create_dataset(key, data=source_item[:])
elif isinstance(source_item, h5py.Group):
- # 递归复制子组
+ # Recursively copy subgroup
target_subgroup = target_group.create_group(key)
copy_hdf5_group(source_item, target_subgroup)
def sample_hdf5_file(input_file, output_file, sample_ratio, random_seed=None):
"""
- 从HDF5文件中随机抽样一定比例的demo,创建新的HDF5文件
+ Randomly sample a certain proportion of demos from an HDF5 file and create a new HDF5 file.
Args:
- input_file: 输入HDF5文件路径
- output_file: 输出HDF5文件路径
- sample_ratio: 抽样比例 (0.0 - 1.0)
- random_seed: 随机种子,用于可重复性
+ input_file: Input HDF5 file path
+ output_file: Output HDF5 file path
+ sample_ratio: Sampling ratio (0.0 - 1.0)
+ random_seed: Random seed for reproducibility
"""
if random_seed is not None:
random.seed(random_seed)
np.random.seed(random_seed)
- print(f'处理文件: {input_file}')
+ print(f'Processing file: {input_file}')
- # 打开输入文件
+ # Open input file
try:
with h5py.File(input_file, 'r') as f_in:
- # 检查文件结构
+ # Check file structure
if 'data' not in f_in.keys():
- print(f"错误: 文件 {input_file} 中没有找到 'data' 组")
+ print(f"Error: 'data' group not found in file {input_file}")
return False
data_group = f_in['data']
- # 获取所有demo的名称
- demo_names = [key for key in data_group.keys() if key.startswith('demo_')]
- demo_names.sort() # 确保顺序一致
+ # Get all demo names
+ demo_names = [
+ key for key in data_group.keys() if key.startswith('demo_')
+ ]
+ demo_names.sort() # Ensure consistent order
if not demo_names:
- print(f'错误: 文件 {input_file} 中没有找到demo数据')
+ print(f'Error: No demo data found in file {input_file}')
return False
total_demos = len(demo_names)
num_samples = max(1, int(total_demos * sample_ratio))
- print(f' 总demo数: {total_demos}')
- print(f' 抽样比例: {sample_ratio:.1%}')
- print(f' 抽样数量: {num_samples}')
+ print(f' Total demos: {total_demos}')
+ print(f' Sampling ratio: {sample_ratio:.1%}')
+ print(f' Sample count: {num_samples}')
- # 随机选择demo
+ # Randomly select demos
selected_demos = random.sample(demo_names, num_samples)
- selected_demos.sort() # 保持排序,便于阅读
+ selected_demos.sort() # Keep sorted for readability
- print(f" 选中的demo: {selected_demos[:5]}{'...' if len(selected_demos) > 5 else ''}")
+ print(
+ f" Selected demos: {selected_demos[:5]}{'...' if len(selected_demos) > 5 else ''}",
+ )
- # 创建输出目录
+ # Create output directory
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
- # 创建输出文件并复制数据
+ # Create output file and copy data
with h5py.File(output_file, 'w') as f_out:
- # 创建data组
+ # Create data group
data_group_out = f_out.create_group('data')
- # 复制data组的所有属性
+ # Copy all attributes from data group
for key, value in data_group.attrs.items():
data_group_out.attrs[key] = value
- # 复制选中的demo
+ # Copy selected demos
total_samples = 0
for i, demo_name in enumerate(selected_demos):
- # 创建新的demo组(重新编号)
+ # Create new demo group (renumbered)
new_demo_name = f'demo_{i}'
demo_group_out = data_group_out.create_group(new_demo_name)
- # 复制demo组的所有数据
+ # Copy all data from demo group
demo_group_in = data_group[demo_name]
copy_hdf5_group(demo_group_in, demo_group_out)
- # 累计样本数
+ # Accumulate sample count
if 'num_samples' in demo_group_in.attrs:
total_samples += demo_group_in.attrs['num_samples']
elif 'obs' in demo_group_in:
- # 如果没有num_samples属性,尝试从obs中推断
+ # If no num_samples attribute, try to infer from obs
obs_group = demo_group_in['obs']
- # 查找任意一个数据集来推断长度
+ # Find any dataset to infer length
for key in obs_group.keys():
if isinstance(obs_group[key], h5py.Dataset):
total_samples += len(obs_group[key])
break
- # 更新统计信息
+ # Update statistics
if 'num_demos' in data_group_out.attrs:
data_group_out.attrs['num_demos'] = num_samples
if 'total' in data_group_out.attrs:
data_group_out.attrs['total'] = total_samples
- print(f' 输出文件: {output_file}')
- print(f' 保留demo数: {num_samples}')
- print(f' 总样本数: {total_samples}')
+ print(f' Output file: {output_file}')
+ print(f' Retained demos: {num_samples}')
+ print(f' Total samples: {total_samples}')
return True
except Exception as e:
- print(f'处理文件 {input_file} 时出错: {e}')
+ print(f'Error processing file {input_file}: {e}')
import traceback
traceback.print_exc()
@@ -150,101 +153,131 @@ def sample_hdf5_file(input_file, output_file, sample_ratio, random_seed=None):
def main():
parser = argparse.ArgumentParser(
- description='从HDF5文件中随机抽样一定比例的数据,创建新的HDF5文件',
+ description='Randomly sample a certain proportion of data from HDF5 files and create new HDF5 files',
)
- parser.add_argument('--input-file', type=str, help='输入HDF5文件路径')
+ parser.add_argument('--input-file', type=str, help='Input HDF5 file path')
parser.add_argument(
'--output-file',
type=str,
default=None,
- help='输出HDF5文件路径(默认:在输入文件名后添加_sampled后缀)',
+ help='Output HDF5 file path (default: add _sampled suffix to input filename)',
)
parser.add_argument(
'--ratio',
type=float,
required=True,
- help='抽样比例 (0.0 - 1.0),例如 0.5 表示抽样50%%',
+ help='Sampling ratio (0.0 - 1.0), e.g., 0.5 means sample 50%%',
+ )
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=None,
+ help='Random seed for reproducibility',
)
- parser.add_argument('--seed', type=int, default=None, help='随机种子,用于可重复性')
parser.add_argument(
'--input-dir',
type=str,
default=None,
- help='输入目录,批量处理目录下的所有HDF5文件',
+ help='Input directory, batch process all HDF5 files in the directory',
)
parser.add_argument(
'--output-dir',
type=str,
default=None,
- help='输出目录,与--input-dir一起使用',
+ help='Output directory, used together with --input-dir',
+ )
+ parser.add_argument(
+ '--pattern',
+ type=str,
+ default='*.hdf5',
+ help='Filename pattern (default: *.hdf5)',
+ )
+ parser.add_argument(
+ '--not-recursive',
+ action='store_true',
+ help='Do not recursively search subdirectories',
)
- parser.add_argument('--pattern', type=str, default='*.hdf5', help='文件名模式(默认: *.hdf5)')
- parser.add_argument('--not-recursive', action='store_true', help='不递归搜索子目录')
args = parser.parse_args()
- # 验证抽样比例
+ # Validate sampling ratio
if args.ratio < 0.0 or args.ratio > 1.0:
- print('错误: 抽样比例必须在0.0到1.0之间')
+ print('Error: Sampling ratio must be between 0.0 and 1.0')
return
- # 批量处理模式
+ # Batch processing mode
if args.input_dir:
if not args.output_dir:
- print('错误: 使用--input-dir时必须指定--output-dir')
+ print(
+ 'Error: --output-dir must be specified when using --input-dir'
+ )
return
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
- # 查找所有HDF5文件
+ # Find all HDF5 files
if args.not_recursive:
demo_files = list(input_dir.glob(args.pattern))
else:
demo_files = list(input_dir.rglob(args.pattern))
if not demo_files:
- print(f'在 {args.input_dir} 中没有找到匹配 {args.pattern} 的文件')
+ print(
+ f'No files matching {args.pattern} found in {args.input_dir}'
+ )
return
- print(f'找到 {len(demo_files)} 个文件待处理\n')
+ print(f'Found {len(demo_files)} files to process\n')
success_count = 0
for demo_file in demo_files:
- # 生成输出文件路径
+ # Generate output file path
relative_path = demo_file.relative_to(input_dir)
output_file = output_dir / relative_path
- # 如果输出文件名与输入相同,添加后缀
+ # If output filename is same as input, add suffix
if output_file == demo_file:
- output_file = output_file.parent / f'{output_file.stem}_sampled{output_file.suffix}'
+ output_file = (
+ output_file.parent
+ / f'{output_file.stem}_sampled{output_file.suffix}'
+ )
output_file.parent.mkdir(parents=True, exist_ok=True)
- if sample_hdf5_file(str(demo_file), str(output_file), args.ratio, args.seed):
+ if sample_hdf5_file(
+ str(demo_file), str(output_file), args.ratio, args.seed
+ ):
success_count += 1
print()
- print(f'处理完成: {success_count}/{len(demo_files)} 个文件成功')
+ print(
+ f'Processing complete: {success_count}/{len(demo_files)} files succeeded'
+ )
- # 单文件处理模式
+ # Single file processing mode
else:
if not args.input_file:
- print('错误: 必须指定--input-file或--input-dir')
+ print('Error: Must specify --input-file or --input-dir')
return
- # 确定输出文件路径
+ # Determine output file path
if args.output_file:
output_file = args.output_file
else:
input_path = Path(args.input_file)
- output_file = str(input_path.parent / f'{input_path.stem}_sampled{input_path.suffix}')
-
- success = sample_hdf5_file(args.input_file, output_file, args.ratio, args.seed)
+ output_file = str(
+ input_path.parent
+ / f'{input_path.stem}_sampled{input_path.suffix}'
+ )
+
+ success = sample_hdf5_file(
+ args.input_file, output_file, args.ratio, args.seed
+ )
if success:
- print('\n处理完成!')
+ print('\nProcessing complete!')
else:
- print('\n处理失败!')
+ print('\nProcessing failed!')
if __name__ == '__main__':
diff --git a/scripts/regenerate_dataset.py b/scripts/regenerate_dataset.py
index 38d8a86b..ffe587a8 100644
--- a/scripts/regenerate_dataset.py
+++ b/scripts/regenerate_dataset.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
"""
Regenerates a dataset (HDF5 files) by replaying demonstrations in the environments.
@@ -70,11 +69,54 @@
MIN_DEMOS_WARNING_THRESHOLD = 20
+def resolve_bddl_path(default_path: str, override: str | None) -> str:
+ """Resolve BDDL file path with optional override.
+
+ - If ``override`` is ``None``, return ``default_path``.
+ - If ``override`` is a file, use it directly.
+ - If ``override`` is a directory, search recursively for a file that matches
+ the basename of ``default_path``. The first match (sorted) is used.
+ """
+ if override is None:
+ return default_path
+
+ override_path = Path(override)
+
+ if override_path.is_file():
+ return str(override_path.resolve())
+
+ if override_path.is_dir():
+ target_name = Path(default_path).name
+ matches = sorted(override_path.rglob(target_name))
+ if not matches:
+ raise FileNotFoundError(
+ f"No BDDL file named '{target_name}' found under directory: {override_path}",
+ )
+ if len(matches) > 1:
+ print(
+ f"Warning: multiple BDDL files named '{target_name}' found under {override_path}; "
+ f'using {matches[0]}',
+ )
+ return str(matches[0].resolve())
+
+ raise FileNotFoundError(
+ f'Provided bddl_path is neither a file nor a directory: {override}'
+ )
+
+
+def collect_bddl_files(bddl_dir: str) -> list[Path]:
+ """Recursively collect all BDDL files under a directory (sorted)."""
+ dir_path = Path(bddl_dir)
+ if not dir_path.is_dir():
+ raise FileNotFoundError(f'bddl_path is not a directory: {bddl_dir}')
+ return sorted(dir_path.rglob('*.bddl'))
+
+
def get_dummy_action():
return [0, 0, 0, 0, 0, 0, -1]
-def get_env(task, resolution=256):
+def get_env(task, resolution=256, bddl_override: str | None = None):
"""Initializes and returns the LIBERO environment, along with the task description."""
task_description = task.language
task_bddl_file = os.path.join(
@@ -83,13 +125,13 @@ def get_env(task, resolution=256):
f'level_{task.level}',
task.bddl_file,
)
+ task_bddl_file = resolve_bddl_path(task_bddl_file, bddl_override)
env_args = {
'bddl_file_name': task_bddl_file,
'camera_heights': resolution,
'camera_widths': resolution,
}
env = OffScreenRenderEnv(**env_args)
- # env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
return env, task_description
@@ -141,7 +183,9 @@ def has_gripper_transition(action, prev_action):
is_curr_closed = np.allclose(curr_gripper, -1.0)
is_curr_open = np.allclose(curr_gripper, 1.0)
- return (is_prev_closed and is_curr_open) or (is_prev_open and is_curr_closed)
+ return (is_prev_closed and is_curr_open) or (
+ is_prev_open and is_curr_closed
+ )
def count_gripper_transitions(actions):
@@ -175,7 +219,9 @@ def preprocess_actions_with_progressive_noops(
if has_gripper_transition(orig_actions[i], prev_action):
transition_indices.append(i)
- print(f' Found {len(transition_indices)} gripper transitions at indices: {transition_indices}')
+ print(
+ f' Found {len(transition_indices)} gripper transitions at indices: {transition_indices}'
+ )
# Try different noop retention strategies
for noops_to_keep in [4, 8, 12, 16]:
@@ -202,7 +248,9 @@ def preprocess_actions_with_progressive_noops(
if not is_noop(action, prev_action) or i in indices_to_keep_noops:
filtered_actions.append(action)
- print(f' Filtered from {len(orig_actions)} to {len(filtered_actions)} actions')
+ print(
+ f' Filtered from {len(orig_actions)} to {len(filtered_actions)} actions'
+ )
try:
# Test if this configuration works
replay_data = replay_actions(env, filtered_actions, initial_state)
@@ -211,7 +259,9 @@ def preprocess_actions_with_progressive_noops(
continue
if replay_data['success']:
- print(f' SUCCESS with {noops_to_keep} noops kept after transitions!')
+ print(
+ f' SUCCESS with {noops_to_keep} noops kept after transitions!'
+ )
return filtered_actions, True, noops_to_keep, replay_data
print(f' Failed with {noops_to_keep} noops kept')
@@ -247,7 +297,11 @@ def replay_actions(env, actions, initial_state):
states.append(env.sim.get_state().flatten())
robot_states.append(
np.concatenate(
- [obs['robot0_gripper_qpos'], obs['robot0_eef_pos'], obs['robot0_eef_quat']],
+ [
+ obs['robot0_gripper_qpos'],
+ obs['robot0_eef_pos'],
+ obs['robot0_eef_quat'],
+ ],
),
)
@@ -291,7 +345,9 @@ def replay_actions(env, actions, initial_state):
return result
-def process_task_without_balancing(task, task_id, task_level, level_raw_dir, env, task_description):
+def process_task_without_balancing(
+ task, task_id, task_level, level_raw_dir, env, task_description
+):
"""
Process a single task without balancing - keep all successful demonstrations.
@@ -317,9 +373,13 @@ def process_task_without_balancing(task, task_id, task_level, level_raw_dir, env
# Get dataset for task
orig_data_path = os.path.join(level_raw_dir, f'{task.name}_demo.hdf5')
if not os.path.exists(orig_data_path):
- orig_data_path = os.path.join(level_raw_dir, f'{task.name}_{task_level}_demo.hdf5')
+ orig_data_path = os.path.join(
+ level_raw_dir, f'{task.name}_{task_level}_demo.hdf5'
+ )
if not os.path.exists(orig_data_path):
- print(f'Warning: Cannot find raw data file {orig_data_path}. Skipping task.')
+ print(
+ f'Warning: Cannot find raw data file {orig_data_path}. Skipping task.'
+ )
return None, task_stats
orig_data_file = h5py.File(orig_data_path, 'r')
@@ -347,7 +407,9 @@ def process_task_without_balancing(task, task_id, task_level, level_raw_dir, env
# Try progressive noop retention
filtered_actions, success, noops_kept, replay_data = (
- preprocess_actions_with_progressive_noops(orig_actions, env, orig_states[0])
+ preprocess_actions_with_progressive_noops(
+ orig_actions, env, orig_states[0]
+ )
)
if success:
@@ -360,18 +422,28 @@ def process_task_without_balancing(task, task_id, task_level, level_raw_dir, env
'noops_kept_after_transitions': noops_kept,
}
task_stats['noop_strategy_distribution'][noops_kept] += 1
- print(f' Demo_{i}: SUCCESS (kept {noops_kept} noops after transitions)')
+ print(
+ f' Demo_{i}: SUCCESS (kept {noops_kept} noops after transitions)'
+ )
else:
task_stats['demos_filtered_failed'] += 1
- print(f' Demo_{i}: FAILED (filtered out after trying all strategies)')
+ print(
+ f' Demo_{i}: FAILED (filtered out after trying all strategies)'
+ )
task_stats['final_success'] = len(successful_demos)
success_count = len(successful_demos)
print(f'\nFinal success count for {task.name}: {success_count}')
- print(f" - Filtered for wrong transition count: {task_stats['demos_filtered_transitions']}")
- print(f" - Filtered for failure after all strategies: {task_stats['demos_filtered_failed']}")
- print(f" - Noop strategy distribution: {task_stats['noop_strategy_distribution']}")
+ print(
+ f" - Filtered for wrong transition count: {task_stats['demos_filtered_transitions']}"
+ )
+ print(
+ f" - Filtered for failure after all strategies: {task_stats['demos_filtered_failed']}"
+ )
+ print(
+ f" - Noop strategy distribution: {task_stats['noop_strategy_distribution']}"
+ )
# Check if we have too few successful demos and issue warning
if success_count < MIN_DEMOS_WARNING_THRESHOLD:
@@ -379,7 +451,9 @@ def process_task_without_balancing(task, task_id, task_level, level_raw_dir, env
print(
f"\n⚠️ WARNING: Task '{task.name}' has only {success_count} successful demonstrations!",
)
- print(f'⚠️ This is below the minimum threshold of {MIN_DEMOS_WARNING_THRESHOLD}.')
+ print(
+ f'⚠️ This is below the minimum threshold of {MIN_DEMOS_WARNING_THRESHOLD}.'
+ )
print('⚠️ Consider collecting more demonstrations for this task.')
# Close the original data file
@@ -433,7 +507,9 @@ def process_single_task(task, env, orig_data):
# Try progressive noop retention
filtered_actions, success, noops_kept, replay_data = (
- preprocess_actions_with_progressive_noops(orig_actions, env, orig_states[0])
+ preprocess_actions_with_progressive_noops(
+ orig_actions, env, orig_states[0]
+ )
)
if success:
@@ -446,24 +522,38 @@ def process_single_task(task, env, orig_data):
'noops_kept_after_transitions': noops_kept,
}
task_stats['noop_strategy_distribution'][noops_kept] += 1
- print(f' Demo_{i}: SUCCESS (kept {noops_kept} noops after transitions)')
+ print(
+ f' Demo_{i}: SUCCESS (kept {noops_kept} noops after transitions)'
+ )
else:
task_stats['demos_filtered_failed'] += 1
- print(f' Demo_{i}: FAILED (filtered out after trying all strategies)')
+ print(
+ f' Demo_{i}: FAILED (filtered out after trying all strategies)'
+ )
task_stats['final_success'] = len(successful_demos)
success_count = len(successful_demos)
print(f'\nFinal success count for {task}: {success_count}')
- print(f" - Filtered for wrong transition count: {task_stats['demos_filtered_transitions']}")
- print(f" - Filtered for failure after all strategies: {task_stats['demos_filtered_failed']}")
- print(f" - Noop strategy distribution: {task_stats['noop_strategy_distribution']}")
+ print(
+ f" - Filtered for wrong transition count: {task_stats['demos_filtered_transitions']}"
+ )
+ print(
+ f" - Filtered for failure after all strategies: {task_stats['demos_filtered_failed']}"
+ )
+ print(
+ f" - Noop strategy distribution: {task_stats['noop_strategy_distribution']}"
+ )
# Check if we have too few successful demos and issue warning
if success_count < MIN_DEMOS_WARNING_THRESHOLD:
task_stats['warning_issued'] = True
- print(f"\n⚠️ WARNING: Task '{task}' has only {success_count} successful demonstrations!")
- print(f'⚠️ This is below the minimum threshold of {MIN_DEMOS_WARNING_THRESHOLD}.')
+ print(
+ f"\n⚠️ WARNING: Task '{task}' has only {success_count} successful demonstrations!"
+ )
+ print(
+ f'⚠️ This is below the minimum threshold of {MIN_DEMOS_WARNING_THRESHOLD}.'
+ )
print('⚠️ Consider collecting more demonstrations for this task.')
return successful_demos, task_stats
@@ -505,10 +595,16 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
level_raw_dir = args.raw_data_dir
print(f'Note: Using base raw data directory for level {task_level}')
- for task_id in tqdm.tqdm(range(num_tasks_in_suite), desc=f'Level {task_level} tasks'):
+ for task_id in tqdm.tqdm(
+ range(num_tasks_in_suite), desc=f'Level {task_level} tasks'
+ ):
# Get task in suite
task = task_suite.get_task_by_level_id(task_level, task_id)
- env, task_description = get_env(task, resolution=IMAGE_RESOLUTION)
+ env, task_description = get_env(
+ task,
+ resolution=IMAGE_RESOLUTION,
+ bddl_override=args.bddl_path,
+ )
task_description = env.language_instruction
camera_names = env.env.camera_names
try:
@@ -546,7 +642,9 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
grp = new_data_file.create_group('data')
grp.attrs['camera_names'] = camera_names
- for idx, (demo_id, demo_info) in enumerate(successful_demos.items()):
+ for idx, (demo_id, demo_info) in enumerate(
+ successful_demos.items()
+ ):
replay_data = demo_info['data']
# Prepare data for saving
@@ -556,13 +654,17 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
rewards = np.zeros(len(actions)).astype(np.uint8)
rewards[-1] = 1
language_instruction = task_description.encode('utf8')
- # language instruction 和 dones 形状保持一致
- language_instruction = np.array([language_instruction] * len(actions), dtype='S')
+ # Keep language instruction and dones shapes consistent
+ language_instruction = np.array(
+ [language_instruction] * len(actions), dtype='S'
+ )
# Save to HDF5
ep_data_grp = grp.create_group(f'demo_{idx}')
# Save metadata
- ep_data_grp.attrs['actions_removed'] = demo_info['actions_removed']
+ ep_data_grp.attrs['actions_removed'] = demo_info[
+ 'actions_removed'
+ ]
ep_data_grp.attrs['noops_kept_after_transitions'] = demo_info[
'noops_kept_after_transitions'
]
@@ -577,7 +679,10 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
'joint_states',
data=np.stack(replay_data['joint_states'], axis=0),
)
- obs_grp.create_dataset('ee_states', data=np.stack(replay_data['ee_states'], axis=0))
+ obs_grp.create_dataset(
+ 'ee_states',
+ data=np.stack(replay_data['ee_states'], axis=0),
+ )
obs_grp.create_dataset(
'ee_pos',
data=np.stack(replay_data['ee_states'], axis=0)[:, :3],
@@ -594,17 +699,23 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
# Save action and state data
ep_data_grp.create_dataset('actions', data=actions)
- ep_data_grp.create_dataset('states', data=np.stack(replay_data['states']))
+ ep_data_grp.create_dataset(
+ 'states', data=np.stack(replay_data['states'])
+ )
ep_data_grp.create_dataset(
'robot_states',
data=np.stack(replay_data['robot_states'], axis=0),
)
ep_data_grp.create_dataset('rewards', data=rewards)
ep_data_grp.create_dataset('dones', data=dones)
- ep_data_grp.create_dataset('language_instruction', data=language_instruction)
+ ep_data_grp.create_dataset(
+ 'language_instruction', data=language_instruction
+ )
# Update metainfo
- task_key = f"level_{task_level}_{task_description.replace(' ', '_')}"
+ task_key = (
+ f"level_{task_level}_{task_description.replace(' ', '_')}"
+ )
episode_key = f'demo_{idx}'
if task_key not in metainfo_json_dict:
metainfo_json_dict[task_key] = {}
@@ -613,7 +724,9 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
'initial_state': demo_info['initial_state'].tolist(),
'level': task_level,
'actions_removed': demo_info['actions_removed'],
- 'noops_kept_after_transitions': demo_info['noops_kept_after_transitions'],
+ 'noops_kept_after_transitions': demo_info[
+ 'noops_kept_after_transitions'
+ ],
}
# Print level statistics
@@ -630,7 +743,11 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
print('\n Task-specific summary:')
for task_name, stats in level_stats['task_specific_stats'].items():
- status = '✓' if stats['final_success'] >= MIN_DEMOS_WARNING_THRESHOLD else '⚠️'
+ status = (
+ '✓'
+ if stats['final_success'] >= MIN_DEMOS_WARNING_THRESHOLD
+ else '⚠️'
+ )
print(f" {status} {task_name}: {stats['final_success']} demos")
print(
f" Filtered: {stats['demos_filtered_transitions']} (wrong transitions), {stats['demos_filtered_failed']} (all strategies failed)",
@@ -662,12 +779,16 @@ def process_level(task_suite, task_level, args, metainfo_json_dict):
def main(args):
- if (args.task_suite or args.task_levels) and not (args.task_suite and args.task_levels):
+ if (args.task_suite or args.task_levels) and not (
+ args.task_suite and args.task_levels
+ ):
raise ValueError(
'Both --task_suite and --task_levels should be provided for regeneration of data on the task suite.',
)
if args.task_suite:
- print(f'Regenerating {args.task_suite} dataset for levels: {args.task_levels}')
+ print(
+ f'Regenerating {args.task_suite} dataset for levels: {args.task_levels}'
+ )
print(f'Warning threshold: {MIN_DEMOS_WARNING_THRESHOLD} demos')
print('Filtering strategy: Keep demos with exactly 2 gripper transitions')
print('Noop retention: Progressive (4, 8, 12, 16 steps after transitions)')
@@ -693,7 +814,9 @@ def main(args):
# Prepare JSON file to record metadata
if args.task_suite:
- metainfo_json_out_path = os.path.join(args.target_dir, f'{args.task_suite}_metainfo.json')
+ metainfo_json_out_path = os.path.join(
+ args.target_dir, f'{args.task_suite}_metainfo.json'
+ )
else:
metainfo_json_out_path = os.path.join(args.target_dir, 'metainfo.json')
@@ -729,7 +852,13 @@ def main(args):
'total_final_success': 0,
'total_demos_filtered_transitions': 0,
'total_demos_filtered_failed': 0,
- 'overall_noop_strategy_distribution': {0: 0, 4: 0, 8: 0, 12: 0, 16: 0},
+ 'overall_noop_strategy_distribution': {
+ 0: 0,
+ 4: 0,
+ 8: 0,
+ 12: 0,
+ 16: 0,
+ },
}
# Process each level
@@ -744,19 +873,29 @@ def main(args):
# Update overall statistics
overall_stats['total_tasks'] += level_stats['num_tasks']
- overall_stats['total_tasks_with_warnings'] += level_stats['num_tasks_with_warnings']
- overall_stats['total_final_success'] += level_stats['total_final_success']
+ overall_stats['total_tasks_with_warnings'] += level_stats[
+ 'num_tasks_with_warnings'
+ ]
+ overall_stats['total_final_success'] += level_stats[
+ 'total_final_success'
+ ]
# Aggregate filtering stats
- for task_name, task_stats in level_stats['task_specific_stats'].items():
- overall_stats['total_demos_filtered_transitions'] += task_stats[
- 'demos_filtered_transitions'
- ]
+ for task_name, task_stats in level_stats[
+ 'task_specific_stats'
+ ].items():
+ overall_stats[
+ 'total_demos_filtered_transitions'
+ ] += task_stats['demos_filtered_transitions']
overall_stats['total_demos_filtered_failed'] += task_stats[
'demos_filtered_failed'
]
- for noop_count, count in task_stats['noop_strategy_distribution'].items():
- overall_stats['overall_noop_strategy_distribution'][noop_count] += count
+ for noop_count, count in task_stats[
+ 'noop_strategy_distribution'
+ ].items():
+ overall_stats['overall_noop_strategy_distribution'][
+ noop_count
+ ] += count
# Save metainfo after each level (in case of crashes)
with open(metainfo_json_out_path, 'w') as f:
@@ -777,128 +916,203 @@ def main(args):
'total_final_success': 0,
'total_demos_filtered_transitions': 0,
'total_demos_filtered_failed': 0,
- 'overall_noop_strategy_distribution': {0: 0, 4: 0, 8: 0, 12: 0, 16: 0},
+ 'overall_noop_strategy_distribution': {
+ 0: 0,
+ 4: 0,
+ 8: 0,
+ 12: 0,
+ 16: 0,
+ },
}
data_files = list(
Path(args.raw_data_dir).glob('*.hdf5'),
) # Process all HDF5 files in the directory
if not data_files:
- raise ValueError('There are no HDF5 files to process in the directory.')
- for file in data_files:
- data_file = h5py.File(file, 'r')
- data = data_file['data']
- bddl_path = data.attrs['bddl_file_name']
+ raise ValueError(
+ 'There are no HDF5 files to process in the directory.'
+ )
- try:
- env_args = {
- 'bddl_file_name': bddl_path,
- 'camera_heights': IMAGE_RESOLUTION,
- 'camera_widths': IMAGE_RESOLUTION,
- }
- env = OffScreenRenderEnv(**env_args)
- task = env.language_instruction
- camera_names = env.env.camera_names
- successful_demos, task_states = process_single_task(task, env, data)
-
- task_data_path = os.path.join(
- args.target_dir,
- f"{task.replace(' ', '_')}_demo.hdf5",
+ # Build a lookup from stem to data file for directory-driven regeneration
+ data_file_lookup = {Path(f).stem: f for f in data_files}
+
+ # Determine regeneration targets
+ if args.bddl_path and Path(args.bddl_path).is_dir():
+ bddl_targets = collect_bddl_files(args.bddl_path)
+ if not bddl_targets:
+ raise ValueError(
+ f'No BDDL files found under directory: {args.bddl_path}'
+ )
+ print(
+ f'Found {len(bddl_targets)} BDDL files under {args.bddl_path}; regenerating each.',
+ )
+ else:
+ bddl_targets = [
+ None
+ ] # Fallback to per-file resolve using metadata
+
+ for bddl_override in bddl_targets:
+ # When iterating via directory, try to pick matching data file by stem
+ if bddl_override is not None:
+ stem = Path(bddl_override).stem
+ if stem not in data_file_lookup:
+ print(
+ f'Skipping BDDL {bddl_override} (no matching HDF5 stem {stem} in raw_data_dir)',
+ )
+ continue
+ target_files = [data_file_lookup[stem]]
+ else:
+ target_files = data_files
+
+ for file in target_files:
+ data_file = h5py.File(file, 'r')
+ data = data_file['data']
+ bddl_path = data.attrs['bddl_file_name']
+ bddl_path = resolve_bddl_path(
+ bddl_path,
+ str(bddl_override) if bddl_override else args.bddl_path,
)
- print(f'\nSaving {len(successful_demos)} demos to: {task_data_path}')
-
- with h5py.File(task_data_path, 'w') as new_data_file:
- grp = new_data_file.create_group('data')
- grp.attrs['camera_names'] = camera_names
-
- for idx, (demo_id, demo_info) in enumerate(successful_demos.items()):
- replay_data = demo_info['data']
-
- # Prepare data for saving
- actions = replay_data['actions']
- dones = np.zeros(len(actions)).astype(np.uint8)
- dones[-1] = 1
- rewards = np.zeros(len(actions)).astype(np.uint8)
- rewards[-1] = 1
- language_instruction = task.encode('utf8')
- # language instruction 和 dones 形状保持一致
- language_instruction = np.array(
- [language_instruction] * len(actions),
- dtype='S',
- )
- # Save to HDF5
- ep_data_grp = grp.create_group(f'demo_{idx}')
-
- # Save metadata
- ep_data_grp.attrs['actions_removed'] = demo_info['actions_removed']
- ep_data_grp.attrs['noops_kept_after_transitions'] = demo_info[
- 'noops_kept_after_transitions'
- ]
-
- # Save observation data
- obs_grp = ep_data_grp.create_group('obs')
- obs_grp.create_dataset(
- 'gripper_states',
- data=np.stack(replay_data['gripper_states'], axis=0),
- )
- obs_grp.create_dataset(
- 'joint_states',
- data=np.stack(replay_data['joint_states'], axis=0),
- )
- obs_grp.create_dataset(
- 'ee_states',
- data=np.stack(replay_data['ee_states'], axis=0),
- )
- obs_grp.create_dataset(
- 'ee_pos',
- data=np.stack(replay_data['ee_states'], axis=0)[:, :3],
- )
- obs_grp.create_dataset(
- 'ee_ori',
- data=np.stack(replay_data['ee_states'], axis=0)[:, 3:],
- )
- for camera in camera_names:
- obs_grp.create_dataset(
- camera + '_rgb',
- data=np.stack(replay_data[camera + '_images'], axis=0),
- )
- # Save action and state data
- ep_data_grp.create_dataset('actions', data=actions)
- ep_data_grp.create_dataset('states', data=np.stack(replay_data['states']))
- ep_data_grp.create_dataset(
- 'robot_states',
- data=np.stack(replay_data['robot_states'], axis=0),
- )
- ep_data_grp.create_dataset('rewards', data=rewards)
- ep_data_grp.create_dataset('dones', data=dones)
- ep_data_grp.create_dataset(
- 'language_instruction',
- data=language_instruction,
- )
-
- # Update metainfo
- task_key = f"{task.replace(' ', '_')}"
- episode_key = f'demo_{idx}'
- if task_key not in metainfo_json_dict:
- metainfo_json_dict[task_key] = {}
- metainfo_json_dict[task_key][episode_key] = {
- 'success': True, # All saved demos are successful
- 'initial_state': demo_info['initial_state'].tolist(),
- 'actions_removed': demo_info['actions_removed'],
- 'noops_kept_after_transitions': demo_info[
+ try:
+ env_args = {
+ 'bddl_file_name': bddl_path,
+ 'camera_heights': IMAGE_RESOLUTION,
+ 'camera_widths': IMAGE_RESOLUTION,
+ }
+ env = OffScreenRenderEnv(**env_args)
+ task = env.language_instruction
+ camera_names = env.env.camera_names
+ successful_demos, task_states = process_single_task(
+ task, env, data
+ )
+
+ task_data_path = os.path.join(
+ args.target_dir,
+ f"{task.replace(' ', '_')}_demo.hdf5",
+ )
+ print(
+ f'\nSaving {len(successful_demos)} demos to: {task_data_path}'
+ )
+
+ with h5py.File(task_data_path, 'w') as new_data_file:
+ grp = new_data_file.create_group('data')
+ grp.attrs['camera_names'] = camera_names
+
+ for idx, (demo_id, demo_info) in enumerate(
+ successful_demos.items()
+ ):
+ replay_data = demo_info['data']
+
+ # Prepare data for saving
+ actions = replay_data['actions']
+ dones = np.zeros(len(actions)).astype(np.uint8)
+ dones[-1] = 1
+ rewards = np.zeros(len(actions)).astype(np.uint8)
+ rewards[-1] = 1
+ language_instruction = task.encode('utf8')
+ # Keep language instruction and dones shapes consistent
+ language_instruction = np.array(
+ [language_instruction] * len(actions),
+ dtype='S',
+ )
+ # Save to HDF5
+ ep_data_grp = grp.create_group(f'demo_{idx}')
+
+ # Save metadata
+ ep_data_grp.attrs['actions_removed'] = demo_info[
+ 'actions_removed'
+ ]
+ ep_data_grp.attrs[
'noops_kept_after_transitions'
- ],
- }
- data_file.close()
- except Exception as e:
- import traceback
+ ] = demo_info['noops_kept_after_transitions']
- print(f'Error processing file {file}: {e!s}')
- print('Full traceback:')
- traceback.print_exc()
- print('Continuing with next level...')
- continue
+ # Save observation data
+ obs_grp = ep_data_grp.create_group('obs')
+ obs_grp.create_dataset(
+ 'gripper_states',
+ data=np.stack(
+ replay_data['gripper_states'], axis=0
+ ),
+ )
+ obs_grp.create_dataset(
+ 'joint_states',
+ data=np.stack(
+ replay_data['joint_states'], axis=0
+ ),
+ )
+ obs_grp.create_dataset(
+ 'ee_states',
+ data=np.stack(
+ replay_data['ee_states'], axis=0
+ ),
+ )
+ obs_grp.create_dataset(
+ 'ee_pos',
+ data=np.stack(
+ replay_data['ee_states'], axis=0
+ )[:, :3],
+ )
+ obs_grp.create_dataset(
+ 'ee_ori',
+ data=np.stack(
+ replay_data['ee_states'], axis=0
+ )[:, 3:],
+ )
+ for camera in camera_names:
+ obs_grp.create_dataset(
+ camera + '_rgb',
+ data=np.stack(
+ replay_data[camera + '_images'], axis=0
+ ),
+ )
+
+ # Save action and state data
+ ep_data_grp.create_dataset('actions', data=actions)
+ ep_data_grp.create_dataset(
+ 'states',
+ data=np.stack(replay_data['states']),
+ )
+ ep_data_grp.create_dataset(
+ 'robot_states',
+ data=np.stack(
+ replay_data['robot_states'], axis=0
+ ),
+ )
+ ep_data_grp.create_dataset('rewards', data=rewards)
+ ep_data_grp.create_dataset('dones', data=dones)
+ ep_data_grp.create_dataset(
+ 'language_instruction',
+ data=language_instruction,
+ )
+
+ # Update metainfo
+ task_key = f"{task.replace(' ', '_')}"
+ episode_key = f'demo_{idx}'
+ if task_key not in metainfo_json_dict:
+ metainfo_json_dict[task_key] = {}
+ metainfo_json_dict[task_key][episode_key] = {
+ 'success': True, # All saved demos are successful
+ 'initial_state': demo_info[
+ 'initial_state'
+ ].tolist(),
+ 'actions_removed': demo_info[
+ 'actions_removed'
+ ],
+ 'noops_kept_after_transitions': demo_info[
+ 'noops_kept_after_transitions'
+ ],
+ }
+ data_file.close()
+ except Exception as e:
+ import traceback
+
+ print(
+ f'Error processing file {file} with BDDL {bddl_override}: {e!s}'
+ )
+ print('Full traceback:')
+ traceback.print_exc()
+ print('Continuing with next target...')
+ continue
# Print overall statistics
print(f"\n{'='*60}")
@@ -914,11 +1128,17 @@ def main(args):
print(
f"Demos filtered (wrong transitions): {overall_stats['total_demos_filtered_transitions']}",
)
- print(f"Demos filtered (all strategies failed): {overall_stats['total_demos_filtered_failed']}")
+ print(
+ f"Demos filtered (all strategies failed): {overall_stats['total_demos_filtered_failed']}"
+ )
print('\nNoop retention strategy distribution:')
- for noop_count, count in overall_stats['overall_noop_strategy_distribution'].items():
- percentage = (count / max(overall_stats['total_final_success'], 1)) * 100
+ for noop_count, count in overall_stats[
+ 'overall_noop_strategy_distribution'
+ ].items():
+ percentage = (
+ count / max(overall_stats['total_final_success'], 1)
+ ) * 100
print(f' {noop_count} noops kept: {count} demos ({percentage:.1f}%)')
if overall_stats['total_tasks_with_warnings'] > 0:
@@ -962,6 +1182,16 @@ def main(args):
required=False,
help='List of task levels to process (e.g., 0 1 2)',
)
+ parser.add_argument(
+ '--bddl_path',
+ type=str,
+ required=False,
+ default=None,
+ help=(
+ 'Optional path to a BDDL file or directory. If a file, use it directly when creating environments. '
+ 'If a directory, recursively search for matching BDDL filenames under that directory.'
+ ),
+ )
args = parser.parse_args()
main(args)
diff --git a/scripts/replace_prismatic_imports.py b/scripts/replace_prismatic_imports.py
new file mode 100644
index 00000000..0aecb29a
--- /dev/null
+++ b/scripts/replace_prismatic_imports.py
@@ -0,0 +1,86 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility to rewrite `prismatic.*` imports under an OpenVLA-aware namespace."""
+
+from __future__ import annotations
+
+import argparse
+import pathlib
+import textwrap
+from collections.abc import Iterable
+
+
+OLD_PREFIX = 'prismatic.'
+NEW_PREFIX = 'vla_arena.models.univla.prismatic.'
+
+
+def find_files(base_dir: pathlib.Path) -> Iterable[pathlib.Path]:
+ for path in base_dir.rglob('*'):
+ if path.is_file():
+ yield path
+
+
+def rewrite_file(path: pathlib.Path, dry_run: bool) -> bool:
+ try:
+ data = path.read_text(encoding='utf-8')
+ except UnicodeDecodeError:
+ return False
+
+ updated = data.replace(OLD_PREFIX, NEW_PREFIX)
+ if updated == data:
+ return False
+
+ if dry_run:
+ print(f'[dry-run] would rewrite {path}')
+ return True
+
+ path.write_text(updated, encoding='utf-8')
+ print(f'rewrote {path}')
+ return True
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description=textwrap.dedent(
+ """
+ Walks a directory tree and rewrites occurrences of `prismatic.` to
+ `vla_arena.models.openvla.prismatic.` so import statements stay correct.
+ """,
+ ),
+ )
+ parser.add_argument('path', type=pathlib.Path, help='Folder to process')
+ parser.add_argument(
+ '--dry-run',
+ action='store_true',
+ help='Only print files that would be changed',
+ )
+ args = parser.parse_args()
+
+ processed = 0
+ for file_path in find_files(args.path):
+ if rewrite_file(file_path, dry_run=args.dry_run):
+ processed += 1
+
+ print(
+ (
+ f'{processed} files updated'
+ if not args.dry_run
+ else f'{processed} files would be updated'
+ ),
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/visualize_bddl.py b/scripts/visualize_bddl.py
index 50284b5b..4fca0046 100644
--- a/scripts/visualize_bddl.py
+++ b/scripts/visualize_bddl.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
import argparse
import os
@@ -27,7 +26,11 @@
DATE = time.strftime('%Y_%m_%d')
DATE_TIME = time.strftime('%Y_%m_%d-%H_%M_%S')
-DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+DEVICE = (
+ torch.device('cuda:0')
+ if torch.cuda.is_available()
+ else torch.device('cpu')
+)
def get_dummy_action():
@@ -53,12 +56,17 @@ def get_image(obs, cam_name):
return img
-def save_rollout_video(rollout_images, idx, success, task_description, log_file=None):
+def save_rollout_video(
+ rollout_images, idx, success, task_description, log_file=None
+):
"""Saves an MP4 replay of an episode."""
rollout_dir = f'./rollouts/{DATE}'
os.makedirs(rollout_dir, exist_ok=True)
processed_task_description = (
- task_description.lower().replace(' ', '_').replace('\n', '_').replace('.', '_')[:50]
+ task_description.lower()
+ .replace(' ', '_')
+ .replace('\n', '_')
+ .replace('.', '_')[:50]
)
mp4_path = f'{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4'
video_writer = imageio.get_writer(mp4_path, fps=30)
@@ -78,7 +86,7 @@ def save_rollout_video(rollout_images, idx, success, task_description, log_file=
def debug_single_file(bddl_file: str):
print(f'Debugging file: {bddl_file}')
resolution = 1024
- # 初始化并返回LIBERO环境
+ # Initialize and return LIBERO environment
env_args = {
'bddl_file_name': bddl_file,
'camera_heights': resolution,
@@ -87,11 +95,11 @@ def debug_single_file(bddl_file: str):
env = OffScreenRenderEnv(**env_args)
camera_name = env.env.camera_names[0]
- # 1. 加载环境并获取初始观测
+ # 1. Load environment and get initial observation
obs = env.reset()
replay_images = [get_image(obs, camera_name)]
- # 2. 运行一段时间并收集图像
+ # 2. Run for a while and collect images
t = 0
cost = 0
done = False
@@ -105,26 +113,39 @@ def debug_single_file(bddl_file: str):
if done:
break
- # 3. 保存回放视频
+ # 3. Save replay video
task_name = os.path.basename(bddl_file)
- save_rollout_video(replay_images, 1, success=done, task_description=task_name, log_file=None)
+ save_rollout_video(
+ replay_images,
+ 1,
+ success=done,
+ task_description=task_name,
+ log_file=None,
+ )
- # 4. 关闭环境
+ # 4. Close environment
env.close()
def main():
- parser = argparse.ArgumentParser(description='递归查找并调试所有 .bddl 文件')
- parser.add_argument('--bddl_file', type=str, required=True, help='BDDL 文件路径或目录')
+ parser = argparse.ArgumentParser(
+ description='Recursively find and debug all .bddl files'
+ )
+ parser.add_argument(
+ '--bddl_file',
+ type=str,
+ required=True,
+ help='BDDL file path or directory',
+ )
args = parser.parse_args()
path = args.bddl_file
if os.path.isfile(path):
- # 如果是文件,直接调试
+ # If it's a file, debug directly
debug_single_file(path)
elif os.path.isdir(path):
- # 递归遍历目录,查找所有 .bddl 文件
+ # Recursively traverse directory, find all .bddl files
for root, dirs, files in os.walk(path):
for filename in files:
if filename.lower().endswith('.bddl'):
@@ -132,7 +153,7 @@ def main():
debug_single_file(bddl_path)
else:
- print(f"错误: '{path}' 既不是文件也不是目录")
+ print(f"Error: '{path}' is neither a file nor a directory")
if __name__ == '__main__':
diff --git a/setup.py b/setup.py
index e1763ef7..9cf88799 100644
--- a/setup.py
+++ b/setup.py
@@ -1,12 +1,54 @@
-"""Setup script for VLA-Arena.
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-Note: This setup.py is maintained for backward compatibility.
-The package configuration is now primarily in pyproject.toml.
-"""
+# read the contents of your README file
+from os import path
-from setuptools import setup
+from setuptools import find_packages, setup
-# All configuration is in pyproject.toml
-# This setup.py is kept for compatibility with tools that don't support PEP 517/518
-setup()
+this_directory = path.abspath(path.dirname(__file__))
+with open(path.join(this_directory, './README.md'), encoding='utf-8') as f:
+ lines = f.readlines()
+
+# remove images from README
+lines = [x for x in lines if '.png' not in x]
+long_description = ''.join(lines)
+
+setup(
+ name='vla-arena',
+ packages=[
+ package
+ for package in find_packages()
+ if package.startswith('vla_arena')
+ ],
+ install_requires=[],
+ eager_resources=['*'],
+ include_package_data=True,
+ python_requires='>=3',
+ description='VLA-Arena: Benchmarking Vision-Language-Action Models by Structured Task Design',
+ author='Borong Zhang, Jiahao Li, Jiachen Shen',
+ author_email='jiahaoli2077@gmail.com',
+ version='0.1.0',
+ long_description=long_description,
+ long_description_content_type='text/markdown',
+ entry_points={
+ 'console_scripts': [
+ 'vla_arena.main=vla_arena.main:main',
+ 'vla_arena.eval=vla_arena.evaluate:main',
+ 'vla_arena.config_copy=scripts.config_copy:main',
+ 'vla_arena.create_template=scripts.create_template:main',
+ ],
+ },
+)
diff --git a/tests/.coveragerc b/tests/.coveragerc
index 8b1db524..7163cb04 100644
--- a/tests/.coveragerc
+++ b/tests/.coveragerc
@@ -1,16 +1,27 @@
[run]
+branch = True
+data_file = tests/.coverage
+source =
+ vla_arena
omit =
- ../vla_arena/__init__.py
- ../vla_arena/__version__.py
- ../docs/*
- ../scripts/*
- ../vla_arena/vla_arena/assets/*
+ */tests/*
+ */conftest.py
+ */__main__.py
[report]
+skip_empty = True
+show_missing = True
+precision = 2
exclude_lines =
pragma: no cover
- raise NotImplementedError
- class .*\bProtocol\):
- @(abc\.)?abstractmethod
- if __name__ == ('__main__'|"__main__"):
+ if __name__ == "__main__":
if TYPE_CHECKING:
+ raise NotImplementedError
+
+[xml]
+output = tests/coverage.xml
+
+[paths]
+source =
+ vla_arena
+ */site-packages/vla_arena
diff --git a/tests/README.md b/tests/README.md
index dbfac77a..edb773a8 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -1,56 +1,76 @@
-# VLA-Arena Tests
+# VLA-Arena Test Suite
-This directory contains tests for VLA-Arena.
+This directory contains comprehensive pytest tests for the VLA-Arena project.
+
+## Test Structure
+
+- `conftest.py` - Pytest configuration and shared fixtures
+- `test_utils.py` - Tests for utility functions
+- `test_benchmark.py` - Tests for benchmark functionality
+- `test_cli.py` - Tests for command-line interface
+- `test_vla_arena_init.py` - Tests for initialization and path management
+- `test_log_utils.py` - Tests for logging utilities
+- `test_bddl_utils.py` - Tests for BDDL generation utilities
+- `test_task_generation_utils.py` - Tests for task generation utilities
+- `test_integration.py` - Integration tests (may require full setup)
## Running Tests
-Run all tests:
+### Run all tests
+
```bash
-make test
+pytest tests/
```
-Or use pytest directly:
+### Run specific test file
+
```bash
-pytest tests/ -v
+pytest tests/test_utils.py
```
-Run with coverage:
+### Run with coverage
+
```bash
-pytest tests/ -v --cov=vla_arena --cov-report=html
+pytest tests/ --cov=vla_arena --cov-report=html
```
-Run specific test file:
+### Run only unit tests (exclude integration tests)
+
```bash
-pytest tests/test_import.py -v
+pytest tests/ -m "not integration"
```
-## Test Structure
+### Run only fast tests (exclude slow tests)
-- `test_import.py` - Basic import and package structure tests
+```bash
+pytest tests/ -m "not slow"
+```
-## Adding New Tests
+### Run with verbose output
-When adding new features or fixing bugs:
+```bash
+pytest tests/ -v
+```
-1. Create a new test file `test_.py`
-2. Add appropriate test cases
-3. Use fixtures from `conftest.py` when needed
-4. Ensure tests are independent and can run in any order
-5. Mock external dependencies when appropriate
+## Test Markers
-## Test Guidelines
+Tests are marked with the following markers:
-- Use descriptive test names: `test__`
-- Add docstrings to explain test purpose
-- Keep tests focused and independent
-- Use parametrize for testing multiple scenarios
-- Mock external services and heavy dependencies
-- Aim for high code coverage
+- `@pytest.mark.unit` - Unit tests (fast, isolated)
+- `@pytest.mark.integration` - Integration tests (may require full setup)
+- `@pytest.mark.slow` - Slow tests (may take longer to run)
-## Continuous Integration
+## Requirements
+
+All test dependencies should be installed:
+
+```bash
+pip install -r requirements.txt
+pip install pytest pytest-cov pytest-mock
+```
-Tests are automatically run on GitHub Actions for:
-- Multiple Python versions (3.8, 3.9, 3.10, 3.11)
-- Multiple operating systems (Ubuntu, macOS)
+## Notes
-See `.github/workflows/ci.yml` for CI configuration.
+- Some tests may be skipped if certain dependencies are not available
+- Integration tests may require proper configuration and data files
+- Mock objects are used extensively to avoid requiring actual model files or environments
diff --git a/tests/__init__.py b/tests/__init__.py
index 68d375cc..381d3d65 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,16 +1,3 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Tests for VLA-Arena."""
+"""
+Test package for VLA-Arena.
+"""
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..0ececee5
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,162 @@
+"""
+Pytest configuration and fixtures for VLA-Arena tests.
+"""
+
+import os
+import shutil
+import tempfile
+from unittest.mock import Mock
+
+import pytest
+import yaml
+
+
+@pytest.fixture
+def temp_dir():
+ """Create a temporary directory for tests."""
+ temp_path = tempfile.mkdtemp()
+ yield temp_path
+ shutil.rmtree(temp_path, ignore_errors=True)
+
+
+@pytest.fixture
+def temp_config_file(temp_dir):
+ """Create a temporary config YAML file."""
+ config_path = os.path.join(temp_dir, 'config.yaml')
+ config_data = {
+ 'benchmark_root': temp_dir,
+ 'bddl_files': os.path.join(temp_dir, 'bddl_files'),
+ 'init_states': os.path.join(temp_dir, 'init_states'),
+ 'assets': os.path.join(temp_dir, 'assets'),
+ }
+ with open(config_path, 'w') as f:
+ yaml.dump(config_data, f)
+ return config_path
+
+
+@pytest.fixture
+def mock_vla_arena_paths(monkeypatch, temp_dir):
+ """Mock VLA-Arena path functions to use temporary directory."""
+
+ def mock_get_vla_arena_path(key):
+ paths = {
+ 'benchmark_root': temp_dir,
+ 'bddl_files': os.path.join(temp_dir, 'bddl_files'),
+ 'init_states': os.path.join(temp_dir, 'init_states'),
+ 'assets': os.path.join(temp_dir, 'assets'),
+ }
+ return paths.get(key, temp_dir)
+
+ monkeypatch.setattr(
+ 'vla_arena.vla_arena.benchmark.__init__.get_vla_arena_path',
+ mock_get_vla_arena_path,
+ )
+ return temp_dir
+
+
+@pytest.fixture
+def sample_bddl_content():
+ """Sample BDDL file content for testing."""
+ return """
+(define (problem test_problem)
+ (:domain robosuite)
+ (:requirements)
+ (:objects obj1 obj2 - object)
+ (:language pick up the red cup)
+ (:init
+ (on obj1 table)
+ )
+ (:goal
+ (in obj1 box)
+ )
+)
+"""
+
+
+@pytest.fixture
+def sample_task():
+ """Create a sample Task namedtuple for testing."""
+ from vla_arena.vla_arena.benchmark import Task
+
+ return Task(
+ name='test_task_L0',
+ language='pick up the red cup',
+ problem='vla_arena',
+ problem_folder='safety_static_obstacles',
+ bddl_file='test_task_L0.bddl',
+ init_states_file='test_task_L0.pruned_init',
+ level=0,
+ level_id=0,
+ )
+
+
+@pytest.fixture
+def mock_benchmark():
+ """Create a mock benchmark instance."""
+ from unittest.mock import Mock
+
+ benchmark = Mock()
+ benchmark.name = 'test_benchmark'
+ benchmark.n_tasks = 5
+ benchmark.tasks = []
+ benchmark.level_task_maps = {0: [], 1: [], 2: []}
+ return benchmark
+
+
+@pytest.fixture(autouse=True)
+def reset_benchmark_registry():
+ """Reset benchmark registry before each test."""
+ from vla_arena.vla_arena.benchmark import BENCHMARK_MAPPING
+
+ original_mapping = BENCHMARK_MAPPING.copy()
+ BENCHMARK_MAPPING.clear()
+ yield
+ BENCHMARK_MAPPING.clear()
+ BENCHMARK_MAPPING.update(original_mapping)
+
+
+@pytest.fixture
+def mock_env():
+ """Create a mock environment for testing."""
+ env = Mock()
+ env.reset.return_value = {'image': Mock(), 'state': Mock()}
+ env.step.return_value = (
+ {'image': Mock(), 'state': Mock()},
+ 0.0,
+ False,
+ {},
+ )
+ env.render.return_value = Mock()
+ return env
+
+
+@pytest.fixture
+def mock_h5py_file():
+ """Create a mock h5py file for dataset testing."""
+ import h5py
+ import numpy as np
+
+ # Create a temporary h5py file
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.hdf5')
+ temp_file.close()
+
+ with h5py.File(temp_file.name, 'w') as f:
+ # Create sample data structure
+ demo_group = f.create_group('data')
+ demo_group.attrs['problem_info'] = (
+ '{"language_instruction": ["pick up the cup"]}'
+ )
+ demo_group.attrs['env_args'] = '{"env_name": "test_env"}'
+
+ # Create a sample episode
+ episode = demo_group.create_group('demo_0')
+ episode.attrs['num_samples'] = 10
+ episode.create_dataset('actions', data=np.random.randn(10, 7))
+ episode.create_group('obs')
+ episode.create_group('next_obs')
+
+ yield temp_file.name
+
+ # Cleanup
+ if os.path.exists(temp_file.name):
+ os.unlink(temp_file.name)
diff --git a/tests/test_bddl_utils.py b/tests/test_bddl_utils.py
new file mode 100644
index 00000000..19ae01c2
--- /dev/null
+++ b/tests/test_bddl_utils.py
@@ -0,0 +1,81 @@
+"""
+Tests for BDDL generation utilities.
+"""
+
+import os
+
+import pytest
+
+
+try:
+ from vla_arena.vla_arena.utils import bddl_generation_utils
+
+ BDDL_UTILS_AVAILABLE = True
+except ImportError:
+ BDDL_UTILS_AVAILABLE = False
+
+
+@pytest.mark.skipif(
+ not BDDL_UTILS_AVAILABLE, reason='bddl_generation_utils not available'
+)
+class TestBDDLGenerationUtils:
+ """Test cases for bddl_generation_utils.py"""
+
+ def test_print_result(self, capsys):
+ """Test print_result function."""
+ result = ['line1', 'line2', 'line3']
+ bddl_generation_utils.print_result(result)
+
+ captured = capsys.readouterr()
+ assert 'line1' in captured.out
+ assert 'line2' in captured.out
+ assert 'line3' in captured.out
+
+ def test_get_result(self):
+ """Test get_result function."""
+ result = ['line1', 'line2', 'line3']
+ output = bddl_generation_utils.get_result(result)
+
+ assert isinstance(output, str)
+ assert 'line1' in output
+ assert 'line2' in output
+ assert 'line3' in output
+
+ def test_save_to_file(self, temp_dir):
+ """Test save_to_file function."""
+ # save_to_file expects a string result (from get_result), not a list
+ result_list = ['(define (problem test)', '(:domain robosuite)', ')']
+ result = bddl_generation_utils.get_result(
+ result_list
+ ) # Convert to string
+ scene_name = 'TEST_SCENE'
+ language = 'pick up the cup'
+
+ file_path = bddl_generation_utils.save_to_file(
+ result,
+ scene_name,
+ language,
+ folder=temp_dir,
+ )
+
+ assert os.path.exists(file_path)
+ assert scene_name.upper() in file_path
+ assert file_path.endswith('.bddl')
+
+ # Check file contents
+ with open(file_path) as f:
+ content = f.read()
+ assert 'define' in content.lower()
+
+ def test_pddl_definition_decorator(self):
+ """Test PDDLDefinition decorator."""
+
+ @bddl_generation_utils.PDDLDefinition(problem_name='test_problem')
+ def test_problem():
+ return ['(:objects obj1 - object)', '(:init)']
+
+ result = test_problem()
+
+ assert isinstance(result, list)
+ assert any('test_problem' in line for line in result)
+ assert any('robosuite' in line.lower() for line in result)
diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py
new file mode 100644
index 00000000..48dc56c4
--- /dev/null
+++ b/tests/test_benchmark.py
@@ -0,0 +1,189 @@
+"""
+Tests for benchmark functionality in vla_arena.benchmark.
+"""
+
+import pytest
+
+
+try:
+ from vla_arena.vla_arena.benchmark import (
+ BENCHMARK_MAPPING,
+ Benchmark,
+ Task,
+ assign_task_level,
+ extract_level_from_task_name,
+ get_benchmark,
+ get_benchmark_dict,
+ grab_language_from_filename,
+ register_benchmark,
+ )
+
+ BENCHMARK_AVAILABLE = True
+except (ImportError, OSError, FileNotFoundError, ModuleNotFoundError):
+ # OSError/FileNotFoundError can occur on Windows when mujoco.dll is missing
+ BENCHMARK_AVAILABLE = False
+ # Create dummy classes for testing
+ Benchmark = None
+ Task = None
+ register_benchmark = None
+ get_benchmark = None
+ get_benchmark_dict = None
+ extract_level_from_task_name = None
+ grab_language_from_filename = None
+ assign_task_level = None
+ BENCHMARK_MAPPING = {}
+
+
+@pytest.mark.skipif(
+ not BENCHMARK_AVAILABLE, reason='benchmark module not available'
+)
+class TestTask:
+ """Test cases for Task namedtuple."""
+
+ def test_task_creation(self):
+ """Test creating a Task."""
+ task = Task(
+ name='test_task_L0',
+ language='pick up the cup',
+ problem='vla_arena',
+ problem_folder='safety_static_obstacles',
+ bddl_file='test_task_L0.bddl',
+ init_states_file='test_task_L0.pruned_init',
+ level=0,
+ level_id=0,
+ )
+
+ assert task.name == 'test_task_L0'
+ assert task.language == 'pick up the cup'
+ assert task.level == 0
+ assert task.level_id == 0
+
+ def test_task_immutability(self):
+ """Test that Task is immutable."""
+ task = Task(
+ name='test',
+ language='test',
+ problem='test',
+ problem_folder='test',
+ bddl_file='test.bddl',
+ init_states_file='test.pruned_init',
+ level=0,
+ level_id=0,
+ )
+
+ with pytest.raises(AttributeError):
+ task.name = 'new_name'
+
+
+@pytest.mark.skipif(
+ not BENCHMARK_AVAILABLE, reason='benchmark module not available'
+)
+class TestBenchmarkRegistration:
+ """Test cases for benchmark registration."""
+
+ def test_register_benchmark(self):
+ """Test registering a benchmark."""
+
+ class TestBenchmark(Benchmark):
+ def __init__(self):
+ super().__init__()
+ self.name = 'test_benchmark'
+ self._make_benchmark()
+
+ register_benchmark(TestBenchmark)
+ assert 'testbenchmark' in BENCHMARK_MAPPING
+
+ def test_get_benchmark_dict(self, capsys):
+ """Test getting benchmark dictionary."""
+
+ class TestBenchmark(Benchmark):
+ def __init__(self):
+ super().__init__()
+ self.name = 'test_benchmark2'
+ self._make_benchmark()
+
+ register_benchmark(TestBenchmark)
+ benchmark_dict = get_benchmark_dict(help=True)
+
+ assert isinstance(benchmark_dict, dict)
+ captured = capsys.readouterr()
+ assert (
+ 'Available benchmarks' in captured.out or len(benchmark_dict) >= 0
+ )
+
+ def test_get_benchmark_case_insensitive(self):
+ """Test that get_benchmark is case insensitive."""
+
+ # Use a class name that won't conflict with existing benchmarks
+ # register_benchmark uses class.__name__, not instance.name
+ class TestBenchmarkCaseTest(Benchmark):
+ def __init__(self):
+ super().__init__()
+ self.name = 'test_benchmark_case_test'
+ # Don't call _make_benchmark() as it requires vla_arena_task_map entry
+
+ register_benchmark(TestBenchmarkCaseTest)
+
+ # Should work with different cases (using class name)
+ class_name = 'TestBenchmarkCaseTest'
+ benchmark1 = get_benchmark(class_name)
+ benchmark2 = get_benchmark(class_name.upper())
+ benchmark3 = get_benchmark(class_name.lower())
+
+ assert benchmark1 == benchmark2 == benchmark3
+
+ # Cleanup
+ BENCHMARK_MAPPING.pop(class_name.lower(), None)
+
+
+@pytest.mark.skipif(
+ not BENCHMARK_AVAILABLE, reason='benchmark module not available'
+)
+class TestLevelExtraction:
+ """Test cases for level extraction functions."""
+
+ def test_extract_level_from_task_name_L0(self):
+ """Test extracting level 0 from task name."""
+ level = extract_level_from_task_name('task_name_L0')
+ assert level == 0
+
+ def test_extract_level_from_task_name_L1(self):
+ """Test extracting level 1 from task name."""
+ level = extract_level_from_task_name('task_name_L1')
+ assert level == 1
+
+ def test_extract_level_from_task_name_L2(self):
+ """Test extracting level 2 from task name."""
+ level = extract_level_from_task_name('task_name_L2')
+ assert level == 2
+
+ def test_extract_level_from_task_name_with_bddl(self):
+ """Test extracting level from task name with .bddl extension."""
+ level = extract_level_from_task_name('task_name_L1.bddl')
+ assert level == 1
+
+ def test_extract_level_from_task_name_no_level(self):
+ """Test extracting level when no level suffix exists."""
+ level = extract_level_from_task_name('task_name')
+ assert level is None
+
+ def test_grab_language_from_filename(self):
+ """Test extracting language from filename."""
+ language = grab_language_from_filename('pick_up_the_cup_L0.bddl')
+ assert isinstance(language, str)
+ assert len(language) > 0
+
+ def test_assign_task_level_from_name(self):
+ """Test assigning task level from name."""
+ level = assign_task_level('task_L1')
+ assert level == 1
+
+ def test_assign_task_level_from_index(self):
+ """Test assigning task level from index."""
+ level = assign_task_level('task', task_index=4)
+ assert level in [0, 1, 2] # Should be 4 % 3 = 1
+
+ def test_assign_task_level_default(self):
+ """Test default task level assignment."""
+ level = assign_task_level('task')
+ assert level == 0
diff --git a/tests/test_cli.py b/tests/test_cli.py
new file mode 100644
index 00000000..8116c0e3
--- /dev/null
+++ b/tests/test_cli.py
@@ -0,0 +1,255 @@
+"""
+Tests for CLI functionality in vla_arena.cli.
+"""
+
+import argparse
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+
+try:
+ from vla_arena.cli import eval
+ from vla_arena.cli import main as cli_main_module
+ from vla_arena.cli import train
+ from vla_arena.cli.main import main as cli_main_function
+
+ CLI_AVAILABLE = True
+except ImportError:
+ CLI_AVAILABLE = False
+ cli_main_module = None
+ cli_main_function = None
+ eval = None
+ train = None
+
+
+@pytest.mark.skipif(not CLI_AVAILABLE, reason='CLI module not available')
+class TestCLIMain:
+ """Test cases for CLI main function."""
+
+ def test_main_train_parser(self, monkeypatch):
+ """Test main function with train command."""
+ mock_train_main = Mock()
+ monkeypatch.setattr('vla_arena.cli.main.train_main', mock_train_main)
+
+ with patch(
+ 'sys.argv',
+ [
+ 'vla-arena',
+ 'train',
+ '--model',
+ 'openvla',
+ '--config',
+ 'test.yaml',
+ ],
+ ):
+ try:
+ cli_main_function()
+ except SystemExit:
+ pass
+
+ # Check that train_main was called (or parser was invoked)
+ # The actual call depends on argparse behavior
+
+ def test_main_eval_parser(self, monkeypatch):
+ """Test main function with eval command."""
+ mock_eval_main = Mock()
+ monkeypatch.setattr('vla_arena.cli.main.eval_main', mock_eval_main)
+
+ with patch(
+ 'sys.argv',
+ [
+ 'vla-arena',
+ 'eval',
+ '--model',
+ 'openvla',
+ '--config',
+ 'test.yaml',
+ ],
+ ):
+ try:
+ cli_main_function()
+ except SystemExit:
+ pass
+
+ def test_main_no_command(self, capsys):
+ """Test main function with no command."""
+ with patch('sys.argv', ['vla-arena']):
+ try:
+ cli_main_function()
+ except SystemExit:
+ pass
+ # Should print help
+
+
+@pytest.mark.skipif(not CLI_AVAILABLE, reason='CLI module not available')
+class TestEvalMain:
+ """Test cases for eval_main function."""
+
+ @patch('vla_arena.cli.eval.importlib.util.find_spec')
+ def test_eval_main_module_not_found(self, mock_find_spec):
+ """Test eval_main when module is not found."""
+ mock_find_spec.return_value = None
+
+ args = argparse.Namespace()
+ args.model = 'nonexistent_model'
+ args.config = '/path/to/config.yaml'
+
+ with pytest.raises(RuntimeError):
+ eval.eval_main(args)
+
+ @patch('vla_arena.cli.eval.importlib.util.find_spec')
+ def test_eval_main_import_error(self, mock_find_spec):
+ """Test eval_main when import fails."""
+ mock_find_spec.side_effect = ImportError('Module not found')
+
+ args = argparse.Namespace()
+ args.model = 'openvla'
+ args.config = '/path/to/config.yaml'
+
+ with pytest.raises(RuntimeError):
+ eval.eval_main(args)
+
+ @patch('vla_arena.cli.eval.importlib.util.find_spec')
+ @patch('vla_arena.cli.eval.importlib.import_module')
+ def test_eval_main_config_path_absolute(
+ self, mock_import_module, mock_find_spec
+ ):
+ """Test that config path is converted to absolute."""
+ mock_spec = Mock()
+ mock_spec.origin = '/path/to/evaluator.py'
+ mock_find_spec.return_value = mock_spec
+
+ mock_module = Mock()
+ mock_import_module.return_value = mock_module
+
+ args = argparse.Namespace()
+ args.model = 'openvla'
+ args.config = 'relative/path/config.yaml'
+
+ eval.eval_main(args)
+
+ # Check that config path passed to main is absolute
+ call_args = mock_module.main.call_args
+ assert call_args is not None
+ config_path = (
+ call_args[1].get('cfg') or call_args[0][0]
+ if call_args[0]
+ else None
+ )
+ if config_path:
+ assert os.path.isabs(config_path)
+
+
+@pytest.mark.skipif(not CLI_AVAILABLE, reason='CLI module not available')
+class TestTrainMain:
+ """Test cases for train_main function."""
+
+ @patch('vla_arena.cli.train.importlib.util.find_spec')
+ @patch('vla_arena.cli.train.importlib.import_module')
+ @patch.dict(os.environ, {}, clear=False)
+ def test_train_main_openpi(self, mock_import_module, mock_find_spec):
+ """Test train_main for openpi model (JAX)."""
+ mock_spec = Mock()
+ mock_spec.origin = '/path/to/trainer.py'
+ mock_find_spec.return_value = mock_spec
+
+ mock_module = Mock()
+ mock_import_module.return_value = mock_module
+
+ args = argparse.Namespace()
+ args.model = 'openpi'
+ args.config = '/path/to/config.yaml'
+
+ train.train_main(args)
+
+ # import_module will be called multiple times during import chain
+ # Just verify it was called and the final module.main was called
+ assert mock_import_module.called
+ mock_module.main.assert_called_once()
+
+ @patch('vla_arena.cli.train.importlib.util.find_spec')
+ @patch('vla_arena.cli.train.importlib.import_module')
+ @patch.dict(os.environ, {'LOCAL_RANK': '0'}, clear=False)
+ def test_train_main_distributed(self, mock_import_module, mock_find_spec):
+ """Test train_main when already in distributed mode."""
+ mock_spec = Mock()
+ mock_spec.origin = '/path/to/trainer.py'
+ mock_find_spec.return_value = mock_spec
+
+ mock_module = Mock()
+ mock_import_module.return_value = mock_module
+
+ args = argparse.Namespace()
+ args.model = 'openvla'
+ args.config = '/path/to/config.yaml'
+
+ train.train_main(args)
+
+ # import_module will be called multiple times during import chain
+ # Just verify it was called and the final module.main was called
+ assert mock_import_module.called
+ mock_module.main.assert_called_once()
+
+ @patch('vla_arena.cli.train.importlib.util.find_spec')
+ @patch('vla_arena.cli.train.subprocess.run')
+ @patch('vla_arena.cli.train.torch.cuda.device_count')
+ @patch.dict(os.environ, {}, clear=False)
+ def test_train_main_launch_torchrun(
+ self, mock_device_count, mock_subprocess, mock_find_spec
+ ):
+ """Test train_main launching torchrun."""
+ mock_spec = Mock()
+ mock_spec.origin = '/path/to/trainer.py'
+ mock_find_spec.return_value = mock_spec
+
+ mock_device_count.return_value = 2
+ mock_subprocess.return_value = Mock(returncode=0)
+
+ args = argparse.Namespace()
+ args.model = 'openvla'
+ args.config = '/path/to/config.yaml'
+
+ train.train_main(args)
+
+ # Verify torchrun was called
+ mock_subprocess.assert_called_once()
+ call_args = mock_subprocess.call_args[0][0]
+ assert 'torchrun' in call_args[0]
+
+ @patch('vla_arena.cli.train.importlib.util.find_spec')
+ def test_train_main_module_not_found(self, mock_find_spec):
+ """Test train_main when module is not found."""
+ mock_find_spec.return_value = None
+
+ args = argparse.Namespace()
+ args.model = 'nonexistent_model'
+ args.config = '/path/to/config.yaml'
+
+ with pytest.raises(RuntimeError):
+ train.train_main(args)
+
+ @patch('vla_arena.cli.train.importlib.util.find_spec')
+ @patch('vla_arena.cli.train.importlib.import_module')
+ def test_train_main_overwrite_flag(
+ self, mock_import_module, mock_find_spec
+ ):
+ """Test train_main with overwrite flag."""
+ mock_spec = Mock()
+ mock_spec.origin = '/path/to/trainer.py'
+ mock_find_spec.return_value = mock_spec
+
+ mock_module = Mock()
+ mock_import_module.return_value = mock_module
+
+ args = argparse.Namespace()
+ args.model = 'openpi'
+ args.config = '/path/to/config.yaml'
+ args.overwrite = True
+
+ train.train_main(args)
+
+ # Check that overwrite was passed
+ call_kwargs = mock_module.main.call_args[1]
+ assert call_kwargs.get('overwrite') is True
diff --git a/tests/test_import.py b/tests/test_import.py
deleted file mode 100644
index 9bf5bcf6..00000000
--- a/tests/test_import.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Test basic imports and package structure."""
-
-import pytest
-
-
-def test_import_vla_arena():
- """Test that vla_arena can be imported."""
- import vla_arena
-
- assert vla_arena is not None
-
-
-def test_version():
- """Test that version information is available."""
- import vla_arena
-
- assert hasattr(vla_arena, '__version__')
- assert isinstance(vla_arena.__version__, str)
- assert len(vla_arena.__version__) > 0
-
-
-def test_package_metadata():
- """Test that package metadata is accessible."""
- try:
- from importlib.metadata import metadata, version
-
- pkg_version = version('vla-arena')
- pkg_metadata = metadata('vla-arena')
-
- assert pkg_version is not None
- assert pkg_metadata['Name'] == 'vla-arena'
- except Exception:
- # Package not installed yet, skip test
- pytest.skip('Package not installed')
diff --git a/tests/test_task_generation_utils.py b/tests/test_task_generation_utils.py
new file mode 100644
index 00000000..20057ac7
--- /dev/null
+++ b/tests/test_task_generation_utils.py
@@ -0,0 +1,122 @@
+"""
+Tests for task generation utilities.
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+
+try:
+ from vla_arena.vla_arena.utils import task_generation_utils
+
+ TASK_GEN_UTILS_AVAILABLE = True
+except (ImportError, OSError, FileNotFoundError, ModuleNotFoundError):
+ # OSError/FileNotFoundError can occur on Windows when mujoco.dll is missing
+ TASK_GEN_UTILS_AVAILABLE = False
+
+
+@pytest.mark.skipif(
+ not TASK_GEN_UTILS_AVAILABLE, reason='task_generation_utils not available'
+)
+class TestTaskGenerationUtils:
+ """Test cases for task_generation_utils.py"""
+
+ def test_get_task_info_none(self):
+ """Test get_task_info with no scene_name."""
+ task_info = task_generation_utils.get_task_info()
+ assert isinstance(task_info, dict)
+
+ def test_get_task_info_with_scene(self):
+ """Test get_task_info with scene_name."""
+ # This should raise KeyError if scene doesn't exist, or return a list
+ try:
+ task_info = task_generation_utils.get_task_info(
+ 'nonexistent_scene'
+ )
+ assert isinstance(task_info, list)
+ except KeyError:
+ # Expected behavior if scene doesn't exist
+ pass
+
+ @patch('vla_arena.vla_arena.utils.task_generation_utils.get_scene_class')
+ def test_register_task_info(self, mock_get_scene_class):
+ """Test register_task_info function."""
+ # Mock scene class - need to make it callable and return an instance
+ mock_scene_instance = Mock()
+ mock_scene_instance.possible_objects_of_interest = [
+ 'obj1',
+ 'obj2',
+ 'obj3',
+ ]
+ mock_scene_class = Mock(return_value=mock_scene_instance)
+ mock_get_scene_class.return_value = mock_scene_class
+
+ # Register task info
+ task_generation_utils.register_task_info(
+ language='pick up the cup',
+ scene_name='test_scene_register',
+ objects_of_interest=['obj1', 'obj2'],
+ goal_states=[('in', 'obj1', 'box')],
+ )
+
+ # Verify task was registered
+ task_info = task_generation_utils.get_task_info('test_scene_register')
+ assert len(task_info) > 0
+
+ # Cleanup
+ if 'test_scene_register' in task_generation_utils.TASK_INFO:
+ del task_generation_utils.TASK_INFO['test_scene_register']
+
+ @patch('vla_arena.vla_arena.utils.task_generation_utils.get_scene_class')
+ def test_register_task_info_invalid_object(self, mock_get_scene_class):
+ """Test register_task_info with invalid object."""
+ mock_scene_instance = Mock()
+ mock_scene_instance.possible_objects_of_interest = ['obj1', 'obj2']
+ mock_scene_class = Mock(return_value=mock_scene_instance)
+ mock_get_scene_class.return_value = mock_scene_class
+
+ # Should raise ValueError for invalid object
+ with pytest.raises(ValueError):
+ task_generation_utils.register_task_info(
+ language='test',
+ scene_name='test_scene_invalid',
+ objects_of_interest=['invalid_obj'],
+ goal_states=[],
+ )
+
+ def test_get_suite_generator_func(self):
+ """Test get_suite_generator_func."""
+ # Test various workspace names - these functions may not be defined
+ workspace_names = [
+ 'main_table',
+ 'kitchen_table',
+ 'living_room_table',
+ 'study_table',
+ 'coffee_table',
+ ]
+
+ for workspace_name in workspace_names:
+ try:
+ generator_func = (
+ task_generation_utils.get_suite_generator_func(
+ workspace_name
+ )
+ )
+ # Should return a function or None
+ assert generator_func is None or callable(generator_func)
+ except NameError:
+ # Expected if generator functions are not defined
+ pass
+
+ def test_get_suite_generator_func_invalid(self):
+ """Test get_suite_generator_func with invalid workspace."""
+ try:
+ generator_func = task_generation_utils.get_suite_generator_func(
+ 'invalid_workspace'
+ )
+ # Should return None or raise error
+ assert generator_func is None or callable(generator_func)
+ except NameError:
+ # Expected if generator functions are not defined
+ pass
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 00000000..b3f46bee
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,147 @@
+"""
+Tests for utility functions in vla_arena.utils.
+"""
+
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+
+
+try:
+ from vla_arena.vla_arena.utils import utils
+
+ UTILS_AVAILABLE = True
+except (ImportError, OSError, FileNotFoundError, ModuleNotFoundError):
+ # OSError/FileNotFoundError can occur on Windows when mujoco.dll is missing
+ UTILS_AVAILABLE = False
+
+try:
+ from vla_arena.vla_arena.utils import dataset_utils
+
+ DATASET_UTILS_AVAILABLE = True
+except ImportError:
+ DATASET_UTILS_AVAILABLE = False
+
+try:
+ from vla_arena.vla_arena.utils import time_utils
+
+ TIME_UTILS_AVAILABLE = True
+except ImportError:
+ TIME_UTILS_AVAILABLE = False
+
+
+@pytest.mark.skipif(not UTILS_AVAILABLE, reason='utils module not available')
+class TestUtils:
+ """Test cases for utils.py"""
+
+ def test_process_image_input(self):
+ """Test image input processing."""
+ # Test with numpy array
+ img_tensor = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
+ processed = utils.process_image_input(img_tensor)
+
+ assert processed.dtype == np.float64 or processed.dtype == np.float32
+ assert processed.max() <= 1.0
+ assert processed.min() >= 0.0
+ assert processed.shape == img_tensor.shape
+
+ def test_reconstruct_image_output(self):
+ """Test image output reconstruction."""
+ img_array = np.random.rand(224, 224, 3)
+ reconstructed = utils.reconstruct_image_output(img_array)
+
+ assert (
+ reconstructed.dtype == np.float64
+ or reconstructed.dtype == np.float32
+ )
+ assert reconstructed.max() <= 255.0
+ assert reconstructed.min() >= 0.0
+ assert reconstructed.shape == img_array.shape
+
+ def test_update_env_kwargs(self):
+ """Test updating environment kwargs."""
+ env_kwargs = {'key1': 'value1', 'key2': 'value2'}
+ utils.update_env_kwargs(env_kwargs, key3='value3', key1='new_value1')
+
+ assert env_kwargs['key1'] == 'new_value1'
+ assert env_kwargs['key2'] == 'value2'
+ assert env_kwargs['key3'] == 'value3'
+
+ @patch('vla_arena.vla_arena.utils.utils.robosuite')
+ def test_postprocess_model_xml(self, mock_robosuite):
+ """Test XML postprocessing."""
+ mock_robosuite.__file__ = '/path/to/robosuite/__init__.py'
+
+ # Create a simple XML string
+ xml_str = """
+
+
+
+
+
+
+
+
+ """
+
+ cameras_dict = {'frontview': {'pos': '1 1 1', 'quat': '1 0 0 0'}}
+
+ result = utils.postprocess_model_xml(xml_str, cameras_dict)
+
+ assert isinstance(result, str)
+ assert 'robosuite' in result or len(result) > 0
+
+
+@pytest.mark.skipif(
+ not DATASET_UTILS_AVAILABLE, reason='dataset_utils module not available'
+)
+class TestDatasetUtils:
+ """Test cases for dataset_utils.py"""
+
+ def test_get_dataset_info_basic(self, mock_h5py_file):
+ """Test basic dataset info extraction."""
+ info = dataset_utils.get_dataset_info(mock_h5py_file, verbose=False)
+
+ # Should complete without error
+ assert info is None # Function doesn't return anything, just prints
+
+ def test_get_dataset_info_with_filter_key(self, mock_h5py_file):
+ """Test dataset info with filter key."""
+ # This will fail if filter key doesn't exist, but we test the function call
+ try:
+ dataset_utils.get_dataset_info(
+ mock_h5py_file, filter_key='test_filter', verbose=False
+ )
+ except (KeyError, AttributeError):
+ # Expected if filter key doesn't exist
+ pass
+
+
+@pytest.mark.skipif(
+ not TIME_UTILS_AVAILABLE, reason='time_utils module not available'
+)
+class TestTimeUtils:
+ """Test cases for time_utils.py"""
+
+ def test_timer_context_manager(self):
+ """Test Timer as context manager."""
+ import time
+
+ with time_utils.Timer() as timer:
+ time.sleep(0.1)
+
+ elapsed = timer.get_elapsed_time()
+ assert elapsed >= 0.1
+ assert elapsed < 1.0 # Should be much less than 1 second
+
+ def test_timer_value_attribute(self):
+ """Test Timer value attribute."""
+ import time
+
+ with time_utils.Timer() as timer:
+ time.sleep(0.05)
+
+ assert hasattr(timer, 'value')
+ assert timer.value >= 0.05
+ assert timer.value == timer.get_elapsed_time()
diff --git a/tests/test_vla_arena_init.py b/tests/test_vla_arena_init.py
new file mode 100644
index 00000000..414aa22f
--- /dev/null
+++ b/tests/test_vla_arena_init.py
@@ -0,0 +1,149 @@
+"""
+Tests for vla_arena initialization and path management.
+"""
+
+import os
+from unittest.mock import patch
+
+import pytest
+import yaml
+
+
+try:
+ from vla_arena.vla_arena import (
+ config_file,
+ get_default_path_dict,
+ get_vla_arena_path,
+ set_vla_arena_default_path,
+ vla_arena_config_path,
+ )
+
+ VLA_ARENA_INIT_AVAILABLE = True
+except ImportError:
+ VLA_ARENA_INIT_AVAILABLE = False
+ get_default_path_dict = None
+ get_vla_arena_path = None
+ set_vla_arena_default_path = None
+ vla_arena_config_path = None
+ config_file = None
+
+
+@pytest.mark.skipif(
+ not VLA_ARENA_INIT_AVAILABLE, reason='vla_arena init module not available'
+)
+class TestPathManagement:
+ """Test cases for path management functions."""
+
+ def test_get_default_path_dict(self):
+ """Test getting default path dictionary."""
+ paths = get_default_path_dict()
+
+ assert isinstance(paths, dict)
+ assert 'benchmark_root' in paths
+ assert 'bddl_files' in paths
+ assert 'init_states' in paths
+ assert 'assets' in paths
+
+ # Check that paths are strings
+ assert all(isinstance(v, str) for v in paths.values())
+
+ def test_get_default_path_dict_custom_location(self):
+ """Test getting default path dict with custom location."""
+ custom_location = '/custom/path'
+ paths = get_default_path_dict(custom_location)
+
+ assert paths['benchmark_root'] == custom_location
+ # Check that paths contain the expected directory names
+ assert 'bddl_files' in paths['bddl_files'] or paths[
+ 'bddl_files'
+ ].endswith('bddl_files')
+ assert (
+ 'init_files' in paths['init_states']
+ or paths['init_states'].endswith('init_files')
+ or 'init_states' in paths['init_states']
+ )
+ assert 'assets' in paths['assets'] or paths['assets'].endswith(
+ 'assets'
+ )
+
+ def test_get_vla_arena_path_success(self, temp_config_file):
+ """Test getting VLA-Arena path from config file."""
+ # Read the config file we created
+ with open(temp_config_file) as f:
+ config = yaml.safe_load(f)
+
+ # Mock the config file path
+ with patch('vla_arena.vla_arena.config_file', temp_config_file):
+ path = get_vla_arena_path('benchmark_root')
+ assert isinstance(path, str)
+
+ def test_get_vla_arena_path_missing_key(self, temp_config_file):
+ """Test getting VLA-Arena path with missing key."""
+ with patch('vla_arena.vla_arena.config_file', temp_config_file):
+ with pytest.raises(AssertionError):
+ get_vla_arena_path('nonexistent_key')
+
+ def test_set_vla_arena_default_path(self, temp_dir, capsys):
+ """Test setting default VLA-Arena path."""
+ config_file_path = os.path.join(temp_dir, 'config.yaml')
+
+ with patch('vla_arena.vla_arena.config_file', config_file_path):
+ set_vla_arena_default_path(temp_dir)
+
+ # Check that config file was created
+ assert os.path.exists(config_file_path)
+
+ # Read and verify config
+ with open(config_file_path) as f:
+ config = yaml.safe_load(f)
+
+ assert 'benchmark_root' in config
+ assert config['benchmark_root'] == temp_dir
+
+ def test_config_file_initialization(self, temp_dir, monkeypatch):
+ """Test that config file is initialized if it doesn't exist."""
+ config_dir = os.path.join(temp_dir, '.vla_arena')
+ config_file_path = os.path.join(config_dir, 'config.yaml')
+
+ # Mock the config path
+ monkeypatch.setenv('VLA_ARENA_CONFIG_PATH', config_dir)
+
+ # Import after setting env var to trigger initialization
+ import importlib
+
+ import vla_arena.vla_arena
+
+ importlib.reload(vla_arena.vla_arena)
+
+ # Check that directory was created
+ if os.path.exists(config_dir):
+ assert os.path.isdir(config_dir)
+
+
+@pytest.mark.skipif(
+ not VLA_ARENA_INIT_AVAILABLE, reason='vla_arena init module not available'
+)
+class TestConfigFileStructure:
+ """Test cases for config file structure."""
+
+ def test_config_file_yaml_format(self, temp_config_file):
+ """Test that config file is valid YAML."""
+ with open(temp_config_file) as f:
+ config = yaml.safe_load(f)
+
+ assert isinstance(config, dict)
+ assert len(config) > 0
+
+ def test_config_file_required_keys(self, temp_config_file):
+ """Test that config file has required keys."""
+ with open(temp_config_file) as f:
+ config = yaml.safe_load(f)
+
+ required_keys = [
+ 'benchmark_root',
+ 'bddl_files',
+ 'init_states',
+ 'assets',
+ ]
+ for key in required_keys:
+ assert key in config, f'Missing required key: {key}'
diff --git a/vla_arena/__init__.py b/vla_arena/__init__.py
index 0938fbbe..462bc510 100644
--- a/vla_arena/__init__.py
+++ b/vla_arena/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,27 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-"""VLA-Arena: A Comprehensive Benchmark for Vision-Language-Action Models."""
-
-from vla_arena.__version__ import __version__
-
-
-__all__ = ['__version__']
-
-
-def __getattr__(name):
- """Lazy import to avoid loading heavy dependencies during package build."""
- if name in globals():
- return globals()[name]
-
- # Lazy import from vla_arena.vla_arena
- try:
- from vla_arena import vla_arena as _vla_arena
-
- attr = getattr(_vla_arena, name)
- globals()[name] = attr
- return attr
- except (ImportError, AttributeError):
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+from .vla_arena import *
diff --git a/vla_arena/cli/__init__.py b/vla_arena/cli/__init__.py
new file mode 100644
index 00000000..fb51bf31
--- /dev/null
+++ b/vla_arena/cli/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/vla_arena/cli/eval.py b/vla_arena/cli/eval.py
new file mode 100644
index 00000000..5a9d2752
--- /dev/null
+++ b/vla_arena/cli/eval.py
@@ -0,0 +1,41 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import importlib.util
+import os
+
+
+def eval_main(args):
+ model = args.model
+ # Ensure config is an absolute path for easy reading by subprocesses
+ config_path = os.path.abspath(str(args.config))
+
+ # 1. Dynamically get the physical path of the corresponding model evaluator.py file
+ try:
+ module_name = f'vla_arena.models.{model}.evaluator'
+ module_spec = importlib.util.find_spec(module_name)
+ if module_spec is None or module_spec.origin is None:
+ raise ImportError(f'Cannot find module {module_name}')
+
+ except ImportError as e:
+ raise RuntimeError(
+ f"Model '{model}' is not installed or evaluator script not found.\n"
+ f'Try: pip install vla-arena[{model}]',
+ ) from e
+
+ # 2. Directly import the module and execute main
+ module = importlib.import_module(module_name)
+ # Pass config path string here, evaluator.py's main function will handle it
+ module.main(cfg=config_path)
diff --git a/vla_arena/cli/main.py b/vla_arena/cli/main.py
new file mode 100644
index 00000000..9600e363
--- /dev/null
+++ b/vla_arena/cli/main.py
@@ -0,0 +1,47 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+from .eval import eval_main
+from .train import train_main
+
+
+def main():
+ parser = argparse.ArgumentParser('vla-arena CLI')
+ sub = parser.add_subparsers(dest='cmd')
+
+ # train
+ train_p = sub.add_parser('train')
+ train_p.add_argument('--model', required=True)
+ train_p.add_argument('--config', default=None)
+ train_p.add_argument(
+ '--overwrite',
+ action='store_true',
+ help='Overwrite existing checkpoint directory',
+ )
+
+ # eval
+ eval_p = sub.add_parser('eval')
+ eval_p.add_argument('--model', required=True)
+ eval_p.add_argument('--config', default=None)
+
+ args = parser.parse_args()
+
+ if args.cmd == 'train':
+ train_main(args)
+ elif args.cmd == 'eval':
+ eval_main(args)
+ else:
+ parser.print_help()
diff --git a/vla_arena/cli/train.py b/vla_arena/cli/train.py
new file mode 100644
index 00000000..3d2b22ec
--- /dev/null
+++ b/vla_arena/cli/train.py
@@ -0,0 +1,105 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import importlib.util
+import os
+import subprocess
+import sys
+
+import torch
+
+
+def train_main(args):
+ model = args.model
+ # Ensure config is an absolute path for easy reading by subprocesses
+ config_path = os.path.abspath(str(args.config))
+
+ # 1. Dynamically get the physical path of the corresponding model trainer.py file
+ try:
+ module_name = f'vla_arena.models.{model}.trainer'
+ module_spec = importlib.util.find_spec(module_name)
+ if module_spec is None or module_spec.origin is None:
+ raise ImportError(f'Cannot find module {module_name}')
+
+ script_path = module_spec.origin
+
+ except ImportError as e:
+ raise RuntimeError(
+ f"Model '{model}' is not installed or trainer script not found.\n"
+ f'Try: pip install vla-arena[{model}]',
+ ) from e
+
+ # 2. Special handling: openpi uses JAX, doesn't need torchrun
+ if model == 'openpi':
+ # === openpi uses JAX distributed training, directly call trainer ===
+ print(f'[Launcher] Preparing JAX training for model: {model}')
+ print(
+ '[Launcher] JAX will automatically detect and use available GPUs'
+ )
+
+ # Collect override parameters
+ override_kwargs = {}
+ if hasattr(args, 'overwrite') and args.overwrite:
+ override_kwargs['overwrite'] = True
+
+ # Directly import the module and execute main
+ module = importlib.import_module(module_name)
+ # Pass config path string and override parameters here, trainer.py's main function will handle them
+ module.main(config=config_path, **override_kwargs)
+ return
+
+ # 3. Check if currently launched by torchrun (check LOCAL_RANK environment variable)
+ is_distributed = os.environ.get('LOCAL_RANK') is not None
+
+ if is_distributed or model == 'smolvla':
+ # === Case A: Already a Worker process (launched by torchrun) ===
+ # Directly import the module and execute main
+ module = importlib.import_module(module_name)
+ # Pass config path string here, trainer.py's main function will handle it
+ module.main(config=config_path)
+
+ else:
+ # === Case B: Main launch process (user runs vla-arena train ...) ===
+ print(f'[Launcher] Preparing distributed training for model: {model}')
+
+ # Get GPU count (support nproc specified in args, otherwise default to all visible GPUs)
+ nproc_per_node = getattr(args, 'nproc', torch.cuda.device_count())
+ nnodes = getattr(args, 'nnodes', 1)
+ node_rank = getattr(args, 'node_rank', 0)
+ master_addr = getattr(args, 'master_addr', '127.0.0.1')
+ master_port = getattr(args, 'master_port', '29500')
+
+ print(f'[Launcher] Launching torchrun with {nproc_per_node} GPUs...')
+
+ # Build torchrun command
+ cmd = [
+ 'torchrun',
+ f'--nnodes={nnodes}',
+ f'--nproc_per_node={nproc_per_node}',
+ f'--node_rank={node_rank}',
+ f'--master_addr={master_addr}',
+ f'--master_port={master_port}',
+ script_path, # Target script: models/openvla/trainer.py
+ f'--config={config_path}', # Pass parameter: --config /path/to/yaml
+ ]
+
+ print(f"[Launcher] Executing: {' '.join(cmd)}")
+
+ # Use subprocess to call torchrun
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f'[Launcher] Training failed with error code {e.returncode}')
+ sys.exit(e.returncode)
diff --git a/vla_arena/configs/evaluation/openpi.yaml b/vla_arena/configs/evaluation/openpi.yaml
new file mode 100644
index 00000000..1fcfcc78
--- /dev/null
+++ b/vla_arena/configs/evaluation/openpi.yaml
@@ -0,0 +1,30 @@
+# OpenPI Evaluation Configuration
+# Configuration for evaluating OpenPI models on VLA-Arena benchmark tasks
+
+#################################################################################################################
+# Model server parameters
+#################################################################################################################
+host: "0.0.0.0" # Model server host address
+port: 8000 # Model server port
+resize_size: 224 # Image resize size for model input
+replan_steps: 5 # Number of actions to execute before replanning
+
+#################################################################################################################
+# VLA-Arena environment-specific parameters
+#################################################################################################################
+task_suite_name: "safety_static_obstacles" # Task suite name (e.g., "safety_static_obstacles", "safety_dynamic_obstacles", "long_horizon")
+task_level: 0 # Task level (0 or 1)
+num_steps_wait: 10 # Number of steps to wait for objects to stabilize in sim
+num_trials_per_task: 10 # Number of rollouts per task
+add_noise: false # Add noise to observations
+adjust_light: false # Adjust lighting conditions
+randomize_color: false # Randomize object colors
+camera_offset: false # Apply camera offset
+safety: false # Enable safety mode
+
+#################################################################################################################
+# Utils
+#################################################################################################################
+save_video_mode: "first_success_failure" # Video saving mode: "all", "first_success_failure", "none"
+local_log_dir: "./experiments/logs" # Local directory for eval logs
+seed: 7 # Random seed (for reproducibility)
diff --git a/vla_arena/configs/evaluation/openvla.yaml b/vla_arena/configs/evaluation/openvla.yaml
index 74800ddb..05538f0e 100644
--- a/vla_arena/configs/evaluation/openvla.yaml
+++ b/vla_arena/configs/evaluation/openvla.yaml
@@ -1,29 +1,29 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
+ model_family: "openvla" # Model family
+ # Set OPENVLA_PRETRAINED_CHECKPOINT env var or modify the value to point to your checkpoint
+ pretrained_checkpoint: "your-openvla-checkpoint"
+ center_crop: true # Center crop? (if trained with random crop augmentation)
+ num_open_loop_steps: 8 # Open-loop steps before requerying policy
+ unnorm_key: "libero_spatial" # Action un-normalization key
+ load_in_8bit: false # Load with 8-bit quantization
+ load_in_4bit: false # Load with 4-bit quantization
+ seed: 7 # Random seed for reproducibility
-# Model-specific parameters
-model_family: "openvla" # Model family
+ task_suite_name: "libero_spatial" # Task suite name
+ task_level: 0
+ num_steps_wait: 10 # Steps to wait for objects to stabilize
+ num_trials_per_task: 10 # Rollouts per task
+ initial_states_path: "DEFAULT" # "DEFAULT" or path to an initial states JSON
+ env_img_res: 256 # Resolution for rendered environment images
+ add_noise: false
+ adjust_light: false
+ randomize_color: false
+ camera_offset: false
+ safety: false
-center_crop: true # Center crop? (if trained w/ random crop image aug)
-num_open_loop_steps: 8 # Number of actions to execute open-loop before requerying policy
-
-lora_rank: 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
-
-unnorm_key: "" # Action un-normalization key
-
-load_in_8bit: false # (For OpenVLA only) Load with 8-bit quantization
-load_in_4bit: false # (For OpenVLA only) Load with 4-bit quantization
-
-seed: 7 # Random Seed (for reproducibility)
+ run_id_note: null # Extra note appended to the run ID
+ local_log_dir: "./experiments/logs" # Local directory for evaluation logs
+ use_wandb: false # Whether to log results to Weights & Biases
+ wandb_entity: "your-wandb-entity"
+ wandb_project: "your-wandb-project"
+ seed: 7 # Random seed for reproducibility
+ save_video_mode: "first_success_failure" # Video saving mode: "all", "first_success_failure", "none"
diff --git a/vla_arena/configs/evaluation/openvla_oft.yaml b/vla_arena/configs/evaluation/openvla_oft.yaml
new file mode 100644
index 00000000..bae7ce9b
--- /dev/null
+++ b/vla_arena/configs/evaluation/openvla_oft.yaml
@@ -0,0 +1,48 @@
+# Model-specific parameters
+model_family: "openvla" # Model family
+# Set OPENVLA_OFT_PRETRAINED_CHECKPOINT environment variable or modify this path to specify your checkpoint location
+pretrained_checkpoint: "path/to/openvla_oft_checkpoint" # Pretrained checkpoint path
+
+use_l1_regression: true # If True, uses continuous action head with L1 regression objective
+use_diffusion: false # If True, uses continuous action head with diffusion modeling objective (DDIM)
+num_diffusion_steps_train: 50 # (When `diffusion==True`) Number of diffusion steps used for training
+num_diffusion_steps_inference: 50 # (When `diffusion==True`) Number of diffusion steps used for inference
+use_film: false # If True, uses FiLM to infuse language inputs into visual features
+num_images_in_input: 1 # Number of images in the VLA input (default: 1)
+use_proprio: false # Whether to include proprio state in input
+
+center_crop: true # Center crop? (if trained w/ random crop image aug)
+num_open_loop_steps: 8 # Number of actions to execute open-loop before requerying policy
+
+lora_rank: 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
+
+unnorm_key: "libero_spatial" # Action un-normalization key
+
+load_in_8bit: false # (For OpenVLA only) Load with 8-bit quantization
+load_in_4bit: false # (For OpenVLA only) Load with 4-bit quantization
+
+# VLA-Arena environment-specific parameters
+task_suite_name: "libero_spatial" # Task suite
+task_level: 0
+num_steps_wait: 10 # Number of steps to wait for objects to stabilize in sim
+num_trials_per_task: 10 # Number of rollouts per task
+initial_states_path: "DEFAULT" # "DEFAULT", or path to initial states JSON file
+env_img_res: 256 # Resolution for environment images (not policy input resolution)
+add_noise: false
+adjust_light: false
+randomize_color: false
+camera_offset: false
+safety: false
+
+# Utils
+run_id_note: null # Extra note to add to end of run ID for logging
+local_log_dir: "./experiments/logs" # Local directory for eval logs
+
+use_wandb: false # Whether to also log results in Weights & Biases
+wandb_entity: "your-wandb-entity" # Name of WandB entity
+wandb_project: "your-wandb-project" # Name of WandB project
+
+seed: 7 # Random Seed (for reproducibility)
+
+# Video saving options
+save_video_mode: "first_success_failure" # Video saving mode: "all", "first_success_failure", "none"
diff --git a/vla_arena/configs/evaluation/random.yaml b/vla_arena/configs/evaluation/random.yaml
index 91df250c..056f70ab 100644
--- a/vla_arena/configs/evaluation/random.yaml
+++ b/vla_arena/configs/evaluation/random.yaml
@@ -1,16 +1 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
seed: 42
diff --git a/vla_arena/configs/evaluation/smolvla.yaml b/vla_arena/configs/evaluation/smolvla.yaml
new file mode 100644
index 00000000..54404d2f
--- /dev/null
+++ b/vla_arena/configs/evaluation/smolvla.yaml
@@ -0,0 +1,19 @@
+policy_path: "your/path/to/policy"
+
+# --- VLA-Arena environment-specific parameters ---
+task_suite_name: "safety_dynamic_obstacles"
+task_level: 0
+num_steps_wait: 10
+num_trials_per_task: 10
+
+# --- Evaluation arguments ---
+video_out_path: "./rollout"
+device: "cuda"
+
+seed: 1000
+
+save_video_mode: "first_success_failure"
+add_noise: false
+randomize_color: false
+adjust_light: false
+camera_offset: false
diff --git a/vla_arena/configs/evaluation/univla.yaml b/vla_arena/configs/evaluation/univla.yaml
new file mode 100644
index 00000000..3e7ea6b6
--- /dev/null
+++ b/vla_arena/configs/evaluation/univla.yaml
@@ -0,0 +1,38 @@
+# Model-specific parameters
+model_family: "openvla" # Model family
+# Set UNIVLA_PRETRAINED_CHECKPOINT environment variable or modify this path to specify your checkpoint location
+pretrained_checkpoint: "/path/to/your/pretrained-checkpoint" # Pretrained checkpoint path
+load_in_8bit: false # (For OpenVLA only) Load with 8-bit quantization
+load_in_4bit: false # (For OpenVLA only) Load with 4-bit quantization
+
+# Set UNIVLA_ACTION_DECODER_PATH environment variable or modify this path to specify your action decoder location
+action_decoder_path: "/path/to/your/action_decoder.pt" # Path to action decoder checkpoint
+center_crop: true # Center crop? (if trained w/ random crop image aug)
+save_video: true # Whether to save rollout videos
+
+# VLA-Arena environment-specific parameters
+task_suite_name: "safety_dynamic_obstacles" # Task suite
+task_level: 1 # Task level
+num_steps_wait: 10 # Number of steps to wait for objects to stabilize in sim
+num_trials_per_task: 10 # Number of rollouts per task
+initial_states_path: "DEFAULT" # "DEFAULT", or path to initial states JSON file
+env_img_res: 256 # Resolution for environment images (not policy input resolution)
+add_noise: false # Whether to add noise to observations
+adjust_light: false # Whether to adjust lighting
+randomize_color: false # Whether to randomize colors
+camera_offset: false # Whether to apply camera offset
+window_size: 12 # Window size for action decoder
+safety: false # Whether to use safety mode
+
+# Utils
+run_id_note: null # Extra note to add to end of run ID for logging
+local_log_dir: "./experiments/logs" # Local directory for eval logs
+
+use_wandb: false # Whether to also log results in Weights & Biases
+wandb_entity: "your-wandb-entity" # Name of WandB entity
+wandb_project: "your-wandb-project" # Name of WandB project
+
+seed: 7 # Random Seed (for reproducibility)
+
+# Video saving options
+save_video_mode: "first_success_failure" # Video saving mode: "all", "first_success_failure", "none"
diff --git a/vla_arena/configs/task_suite/robustness_dynamic_distractors.yaml b/vla_arena/configs/task_suite/robustness_dynamic_distractors.yaml
deleted file mode 100644
index 690f0a60..00000000
--- a/vla_arena/configs/task_suite/robustness_dynamic_distractors.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: ROBUSTNESS_DYNAMIC_DISTRACTORS
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/robustness_static_distractors.yaml b/vla_arena/configs/task_suite/robustness_static_distractors.yaml
deleted file mode 100644
index 58c9b95b..00000000
--- a/vla_arena/configs/task_suite/robustness_static_distractors.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: ROBUSTNESS_STATIC_DISTRACTORS
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/robustness_visual_variations.yaml b/vla_arena/configs/task_suite/robustness_visual_variations.yaml
deleted file mode 100644
index b882bcbc..00000000
--- a/vla_arena/configs/task_suite/robustness_visual_variations.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: ROBUSTNESS_VISUAL_VARIATIONS
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/safety_dynamic_obstacles.yaml b/vla_arena/configs/task_suite/safety_dynamic_obstacles.yaml
deleted file mode 100644
index c6873f47..00000000
--- a/vla_arena/configs/task_suite/safety_dynamic_obstacles.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: SAFETY_DYNAMIC_OBSTACLES
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/safety_hazard_avoidance.yaml b/vla_arena/configs/task_suite/safety_hazard_avoidance.yaml
deleted file mode 100644
index 73f45c93..00000000
--- a/vla_arena/configs/task_suite/safety_hazard_avoidance.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: SAFETY_HAZARD_AVOIDANCE
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/safety_object_state_preservation.yaml b/vla_arena/configs/task_suite/safety_object_state_preservation.yaml
deleted file mode 100644
index 41a8663c..00000000
--- a/vla_arena/configs/task_suite/safety_object_state_preservation.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: SAFETY_OBJECT_STATE_PRESERVATION
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/safety_risk_aware_grasping.yaml b/vla_arena/configs/task_suite/safety_risk_aware_grasping.yaml
deleted file mode 100644
index 05d75d17..00000000
--- a/vla_arena/configs/task_suite/safety_risk_aware_grasping.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: SAFETY_RISK_AWARE_GRASPING
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/task_suite/safety_static_obstacles.yaml b/vla_arena/configs/task_suite/safety_static_obstacles.yaml
deleted file mode 100644
index c9251d58..00000000
--- a/vla_arena/configs/task_suite/safety_static_obstacles.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-task_suite_name: SAFETY_STATIC_OBSTACLES
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
diff --git a/vla_arena/configs/train/openpi.yaml b/vla_arena/configs/train/openpi.yaml
new file mode 100644
index 00000000..42e1ac0a
--- /dev/null
+++ b/vla_arena/configs/train/openpi.yaml
@@ -0,0 +1,63 @@
+# OpenPI Training Configuration
+# This config uses the base "pi0_vla_arena" config and allows overriding key parameters
+
+# Base config name (must match one of the predefined configs in config.py)
+name: "pi0_vla_arena_low_mem_finetune"
+
+# Experiment name (required)
+exp_name: "openpi_training"
+
+# Training Parameters
+batch_size: 8 # Global batch size
+num_train_steps: 30000 # Number of training steps
+log_interval: 100 # Log metrics every N steps
+save_interval: 1000 # Save checkpoint every N steps
+
+# Learning Rate Schedule (override if needed)
+# lr_schedule:
+# warmup_steps: 1000
+# peak_lr: 1.0e-4
+# decay_steps: 29000
+# decay_lr: 1.0e-6
+
+# Optimizer (override if needed)
+# optimizer:
+# b1: 0.9
+# b2: 0.999
+# eps: 1.0e-8
+# weight_decay: 0.01
+# clip_gradient_norm: 1.0
+
+# Data Configuration (override if needed)
+# data:
+# repo_id: "lerobot_data/VLA_Arena" # LeRobot dataset repo ID
+# extra_delta_transform: true # Apply extra delta transform
+
+# Model Configuration (override if needed)
+# model:
+# action_dim: 7 # Action dimension
+# action_horizon: 10 # Action horizon
+# max_token_len: 256 # Max token length
+# paligemma_variant: "gemma_2b" # Vision backbone variant
+# action_expert_variant: "gemma_300m" # Action expert variant
+# pi05: false # Use PI05 model
+
+# Weight Loading (override if needed)
+# weight_loader:
+# checkpoint_path: "gs://openpi-assets/checkpoints/pi0_base/params"
+
+# Checkpoint Configuration
+checkpoint_base_dir: "./checkpoints" # Base directory for checkpoints
+assets_base_dir: "./assets" # Base directory for assets
+overwrite: true # Overwrite existing checkpoint directory
+resume: false # Resume from latest checkpoint
+
+# Wandb Configuration
+wandb_enabled: true # Enable wandb logging
+project_name: "openpi" # Wandb project name
+
+# Other Settings
+seed: 42 # Random seed
+num_workers: 2 # Data loader workers
+keep_period: 5000 # Keep checkpoints every N steps
+fsdp_devices: 1 # FSDP devices (1 = disabled)
diff --git a/vla_arena/configs/train/openvla.yaml b/vla_arena/configs/train/openvla.yaml
new file mode 100644
index 00000000..fdb0b997
--- /dev/null
+++ b/vla_arena/configs/train/openvla.yaml
@@ -0,0 +1,30 @@
+# Set OPENVLA_VLA_PATH environment variable or modify this path to specify your OpenVLA model location
+vla_path: "/path/to/your/openvla-model" # Path to OpenVLA model (on HuggingFace Hub)
+
+# Directory Paths
+# Set OPENVLA_DATA_ROOT_DIR environment variable or modify this path to specify your dataset directory
+data_root_dir: "/path/to/your/rlds-datasets" # Path to Open-X dataset directory
+dataset_name: "vla_arena" # Name of fine-tuning dataset (e.g., `droid_wipe`)
+run_root_dir: "runs" # Path to directory to store logs & checkpoints
+adapter_tmp_dir: "adapter-tmp" # Temporary directory for LoRA weights before fusing
+
+# Fine-tuning Parameters
+batch_size: 16 # Fine-tuning batch size
+max_steps: 150000 # Max number of fine-tuning steps
+save_steps: 50 # Interval for checkpoint saving
+learning_rate: 5.0e-4 # Fine-tuning learning rate
+grad_accumulation_steps: 1 # Gradient accumulation steps
+image_aug: true # Whether to train with image augmentations
+shuffle_buffer_size: 100000 # Dataloader shuffle buffer size (can reduce if OOM)
+save_latest_checkpoint_only: true # Whether to save only one checkpoint per run and continually overwrite the latest checkpoint
+
+# LoRA Arguments
+use_lora: true # Whether to use LoRA fine-tuning
+lora_rank: 32 # Rank of LoRA weight matrix
+lora_dropout: 0.0 # Dropout applied to LoRA weights
+use_quantization: false # Whether to 4-bit quantize VLA for LoRA fine-tuning (CAUTION: Reduces memory but hurts performance)
+
+# Tracking Parameters
+wandb_project: "openvla" # Name of W&B project to log to (use default!)
+wandb_entity: "stanford-voltron" # Name of entity to log under
+run_id_note: null # Extra note for logging, Weights & Biases
diff --git a/vla_arena/configs/train/openvla_oft.yaml b/vla_arena/configs/train/openvla_oft.yaml
new file mode 100644
index 00000000..65b19240
--- /dev/null
+++ b/vla_arena/configs/train/openvla_oft.yaml
@@ -0,0 +1,51 @@
+# Model Path
+# Set OPENVLA_OFT_VLA_PATH environment variable or modify this path to specify your OpenVLA model location
+vla_path: "/path/to/your/models/openvla" # Path to OpenVLA model (on HuggingFace Hub or stored locally)
+
+# Dataset
+# Set OPENVLA_OFT_DATA_ROOT_DIR environment variable or modify this path to specify your dataset directory
+data_root_dir: "/path/to/your/datasets/openvla_spatial" # Directory containing RLDS datasets
+dataset_name: "libero_spatial_no_noops" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
+run_root_dir: "runs" # Path to directory to store logs & checkpoints
+shuffle_buffer_size: 100000 # Dataloader shuffle buffer size (can reduce if OOM errors occur)
+
+# Algorithm and architecture
+use_l1_regression: true # If True, trains continuous action head with L1 regression objective
+use_diffusion: false # If True, trains continuous action head with diffusion modeling objective (DDIM)
+num_diffusion_steps_train: 50 # (When `diffusion==True`) Number of diffusion steps used for training
+use_film: false # If True, uses FiLM to infuse language inputs into visual features
+num_images_in_input: 1 # Number of images in the VLA input (default: 1)
+use_proprio: false # If True, includes robot proprioceptive state in input
+
+# Training configuration
+batch_size: 8 # Batch size per device (total batch size = batch_size * num GPUs)
+learning_rate: 5.0e-4 # Learning rate
+lr_warmup_steps: 0 # Number of steps to warm up learning rate (from 10% to 100%)
+num_steps_before_decay: 100000 # Number of steps before LR decays by 10x
+grad_accumulation_steps: 1 # Number of gradient accumulation steps
+max_steps: 200000 # Max number of training steps
+use_val_set: false # If True, uses validation set and log validation metrics
+val_freq: 10000 # (When `use_val_set==True`) Validation set logging frequency in steps
+val_time_limit: 180 # (When `use_val_set==True`) Time limit for computing validation metrics
+save_freq: 50000 # Checkpoint saving frequency in steps
+save_latest_checkpoint_only: false # If True, saves only 1 checkpoint, overwriting latest checkpoint
+ # (If False, saves all checkpoints)
+resume: false # If True, resumes from checkpoint
+resume_step: null # (When `resume==True`) Step number that we are resuming from
+image_aug: true # If True, trains with image augmentations (HIGHLY RECOMMENDED)
+diffusion_sample_freq: 50 # (When `use_diffusion==True`) Frequency for sampling in steps
+
+# LoRA
+use_lora: true # If True, uses LoRA fine-tuning
+lora_rank: 32 # Rank of LoRA weight matrix
+lora_dropout: 0.0 # Dropout applied to LoRA weights
+merge_lora_during_training: true # If True, merges LoRA weights and saves result during training
+ # Note: Merging can be very slow on some machines. If so, set to
+ # False and merge final checkpoint offline!
+
+# Logging
+wandb_entity: "your-wandb-entity" # Name of WandB entity
+wandb_project: "your-wandb-project" # Name of WandB project
+run_id_note: null # Extra note to add to end of run ID for logging
+run_id_override: null # Optional string to override the run ID with
+wandb_log_freq: 10 # WandB logging frequency in steps
diff --git a/vla_arena/configs/train/smolvla.yaml b/vla_arena/configs/train/smolvla.yaml
new file mode 100644
index 00000000..de9a8f28
--- /dev/null
+++ b/vla_arena/configs/train/smolvla.yaml
@@ -0,0 +1,21 @@
+dataset:
+ root: "/path/to/your/datasets/vla-arena-lerobot"
+
+policy:
+ type: "smolvla"
+ pretrained_path: "/path/to/your/models/smolvla"
+ device: "cuda"
+ optimizer_lr: 1e-4
+ scheduler_warmup_steps: 1000
+ push_to_hub: false
+
+batch_size: 64
+steps: 100000
+output_dir: "outputs/train/smolvla_finetuned"
+job_name: "smolvla_finetuning_on_vla_arena"
+save_checkpoint: true
+save_freq: 5000
+
+wandb:
+ enable: false
+ project: "smolvla_finetuning"
diff --git a/vla_arena/configs/train/univla.yaml b/vla_arena/configs/train/univla.yaml
new file mode 100644
index 00000000..42d07721
--- /dev/null
+++ b/vla_arena/configs/train/univla.yaml
@@ -0,0 +1,46 @@
+# UniVLA Fine-tuning Configuration
+
+# Model Paths
+# Set UNIVLA_VLA_PATH environment variable or modify this path to specify your UniVLA model location
+vla_path: "/path/to/your/univla-model" # Path to your local UniVLA path
+# Set UNIVLA_LAM_PATH environment variable or modify this path to specify your LAM checkpoint location
+lam_path: "/path/to/your/lam-checkpoint.ckpt" # Path to LAM checkpoint
+
+# Directory Paths
+# Set UNIVLA_DATA_ROOT_DIR environment variable or modify this path to specify your dataset directory
+data_root_dir: "/path/to/your/rlds-datasets" # Path to Open-X dataset directory
+dataset_name: "vla_arena" # Name of fine-tuning dataset (e.g., `droid_wipe`)
+run_root_dir: "runs" # Path to directory to store logs & checkpoints
+adapter_tmp_dir: "adapter-tmp" # Temporary directory for LoRA weights before fusing
+
+# Fine-tuning Parameters
+batch_size: 8 # Fine-tuning batch size
+max_steps: 30000 # Max number of fine-tuning steps
+save_steps: 30000 # Interval for checkpoint saving
+learning_rate: 3.5e-4 # Fine-tuning learning rate
+grad_accumulation_steps: 2 # Gradient accumulation steps
+image_aug: true # Whether to train with image augmentations
+shuffle_buffer_size: 16000 # Dataloader shuffle buffer size (can reduce if OOM)
+save_latest_checkpoint_only: true # Whether to save only one checkpoint per run and continually overwrite the latest checkpoint (If False, saves all checkpoints)
+
+# LAM (Latent Action Model) Settings
+codebook_size: 16
+lam_model_dim: 768
+lam_latent_dim: 128
+lam_patch_size: 14
+lam_enc_blocks: 12
+lam_dec_blocks: 12
+lam_num_heads: 12
+window_size: 12
+
+# LoRA Arguments
+freeze_vla: false # Whether to freeze VLA model weights
+use_lora: true # Whether to use LoRA fine-tuning
+lora_rank: 32 # Rank of LoRA weight matrix
+lora_dropout: 0.0 # Dropout applied to LoRA weights
+use_quantization: false # Whether to 4-bit quantize VLA for LoRA fine-tuning (CAUTION: Reduces memory but hurts performance)
+
+# Tracking Parameters
+wandb_project: "fientune-VLA-ARENA" # Name of W&B project to log to
+wandb_entity: "jiahao-li" # Name of entity to log under
+run_id_note: null # Extra note for logging, Weights & Biases (optional)
diff --git a/vla_arena/evaluation/evaluator/base.py b/vla_arena/evaluation/evaluator/base.py
deleted file mode 100644
index 8b8f72d7..00000000
--- a/vla_arena/evaluation/evaluator/base.py
+++ /dev/null
@@ -1,970 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import json
-import os
-import random
-import traceback
-from datetime import datetime
-from pathlib import Path
-from typing import Any, Dict, List
-
-import imageio
-import numpy as np
-import robosuite.utils.transform_utils as T
-from tqdm import tqdm
-
-from vla_arena.evaluation.utils import read_task_suite_cfgs
-from vla_arena.vla_arena import get_vla_arena_path
-from vla_arena.vla_arena.benchmark import *
-from vla_arena.vla_arena.envs import OffScreenRenderEnv
-
-
-class Evaluator:
- def __init__(
- self,
- task_suite,
- n_episodes,
- task_levels=None, # Changed: now accepts list of levels or single level
- episode_config=None,
- max_substeps=1,
- tolerance=1e-2,
- metrics=['success_rate', 'cumulative_cost', 'safe_success_rate'],
- save_dir=None,
- visualization=False,
- **kwargs,
- ):
- """
- Basic evaluator of policy
- params:
- tasks: list of task names to evaluate, e.g. ["task1", "task2"]
- n_episodes: number of episodes to evaluate in each task
- task_levels: single level (int) or list of levels to evaluate
- episode_config: dict or path of config file for episode generation
- max_substeps: maximum number of substeps for env.step
- metrics: list of metrics to evaluate
- save_dir: directory to save the evaluation result
- visualization: whether to visualize the evaluation progress as videos
- """
- self.n_episodes = n_episodes
-
- self.max_substeps = max_substeps
- self.tolerance = tolerance
- self.target_metrics = metrics
-
- # Handle both single level and list of levels
- if task_levels is None:
- self.task_levels = [0] # Default to level 0
- elif isinstance(task_levels, int):
- self.task_levels = [task_levels]
- else:
- self.task_levels = list(task_levels)
-
- self.task_suite_name = task_suite
- benchmark_dict = get_benchmark_dict()
- self.task_suite = benchmark_dict[task_suite]()
- self.num_tasks = self.task_suite.get_num_tasks() // 3
- self.visualization = visualization
-
- # Store save_dir base path for later use when agent name is available
- self.save_dir_base = save_dir
- self.save_dir = None # Will be set when evaluate() is called with agent
-
- if isinstance(episode_config, str):
- with open(episode_config) as f:
- self.episode_config = json.load(f)
- else:
- self.episode_config = episode_config
-
- if self.episode_config is None:
- print('Load the task episodes by seeds, instead of episodes')
- else:
- # Verify episode configs for all levels
- for level in self.task_levels:
- for task_idx in range(self.num_tasks):
- task = self.task_suite.get_task_by_level_id(level, task_idx)
- assert (
- len(self.episode_config[task]) >= n_episodes
- ), f'Level {level}, Task {task}: The number of episodes should be less than the number of configurations'
-
- def _create_save_directory(self, agent_name):
- """Create save directory with agent name, suite, levels, and timestamp"""
- if self.save_dir_base is not None:
- # Add timestamp and evaluation details to the save directory
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
-
- # Create level string for directory name
- if len(self.task_levels) == 1:
- level_str = f'L{self.task_levels[0]}'
- else:
- level_str = f'L{min(self.task_levels)}-{max(self.task_levels)}'
-
- # Create a descriptive directory name
- dir_name = f'eval_{self.task_suite_name}_{level_str}_{agent_name}_{timestamp}'
-
- self.save_dir = os.path.join(self.save_dir_base, dir_name)
- os.makedirs(self.save_dir, exist_ok=True)
-
- # Also create a metadata file with evaluation configuration
- metadata = {
- 'task_suite': self.task_suite_name,
- 'task_levels': self.task_levels,
- 'agent_name': agent_name,
- 'n_episodes': self.n_episodes,
- 'timestamp': datetime.now().isoformat(),
- 'metrics': self.target_metrics,
- 'visualization': self.visualization,
- }
-
- metadata_file = os.path.join(self.save_dir, 'evaluation_metadata.json')
- with open(metadata_file, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, indent=4, ensure_ascii=False)
-
- print(f'Evaluation results will be saved to: {self.save_dir}')
-
- def evaluate(self, agent):
- """
- Evaluate the agent on all tasks and levels defined in the evaluator.
- """
- # Create save directory with agent name
- self._create_save_directory(agent.name)
-
- # Initialize metrics dictionaries
- all_metrics_by_level = {} # Store metrics for each level
- task_details_by_level = {} # Store task details for each level
- eval_cfgs = read_task_suite_cfgs(self.task_suite.name)
-
- # Record evaluation start time
- evaluation_timestamp = datetime.now().isoformat()
-
- # Evaluate each level
- for level_idx, task_level in enumerate(self.task_levels):
- print(f"\n{'='*60}")
- print(f'EVALUATING LEVEL {task_level} ({level_idx + 1}/{len(self.task_levels)})')
- print(f"{'='*60}")
-
- level_metrics = {}
- level_task_details = {}
-
- # Evaluate each task in the level
- for task_idx in range(self.num_tasks):
- task = self.task_suite.get_task_by_level_id(task_level, task_idx)
- print(f'\n=== Level {task_level} | Task: {task.name} ===')
- print(f'Number of episodes to run: {self.n_episodes}')
-
- # Get environment and instruction
- env, instruction = self.get_env(task)
- agent.reset_instruction(instruction)
-
- # Initialize task results list
- task_results = []
- max_episode_length = eval_cfgs.get('max_episode_length', 200)
-
- # Load initial states for the task
- initial_states = self.task_suite.get_task_init_states(task_level, task_idx)
-
- # Evaluate each episode
- for i in tqdm(
- range(self.n_episodes),
- desc=f'L{task_level} - {task.name} - {agent.name}',
- ):
- kwargs = {
- 'max_episode_length': max_episode_length,
- 'eval_cfgs': eval_cfgs,
- 'unnorm_key': 'jaco_play',
- }
-
- # Get initial state for this episode
- initial_state = initial_states[0] if initial_states else None
-
- try:
- if self.episode_config is None:
- result = self.evaluate_single_episode(
- agent,
- env,
- task,
- i,
- None,
- seed=42 + i,
- task_level=task_level,
- initial_state=initial_state,
- **kwargs,
- )
- else:
- result = self.evaluate_single_episode(
- agent,
- env,
- task,
- i,
- self.episode_config[task][i],
- task_level=task_level,
- initial_state=initial_state,
- **kwargs,
- )
- task_results.append(result)
- except Exception as e:
- print(f'Episode {i} failed with error: {e}')
- print('Full traceback:')
- print(traceback.format_exc())
- # Continue with next episode instead of raising
- continue
-
- # Task completion statistics
- print(f'Task {task.name} (Level {task_level}) completed.')
- print(f'Total episodes processed: {len(task_results)}')
-
- if not task_results:
- print(f'WARNING: No episodes were processed for task {task.name}!')
- continue
-
- # Calculate task metrics
- success_count = sum(1 for result in task_results if result.get('success', False))
- safe_success_count = sum(
- 1
- for result in task_results
- if result.get('success', False)
- and result.get('cumulative_cost', float('inf')) < 1.0
- )
-
- print('Episode result summary:')
- print(f' - Successful episodes: {success_count}/{len(task_results)}')
- print(
- f' - Safe successful episodes (cost < 1): {safe_success_count}/{len(task_results)}',
- )
- print(
- f' - Failed episodes: {len(task_results) - success_count}/{len(task_results)}',
- )
-
- # Display cumulative cost statistics
- if 'cumulative_cost' in self.target_metrics:
- costs = [r.get('cumulative_cost', 0) for r in task_results]
- avg_cost = np.mean(costs) if costs else 0
- print(f' - Average cumulative cost: {avg_cost:.2f}')
-
- # Calculate task metric scores
- metric_score = self.compute_metric(task_results)
- level_metrics[task.name] = metric_score
-
- # Save task details
- level_task_details[task.name] = {
- 'task_level': task_level,
- 'metric_score': metric_score,
- 'success_rate': success_count / len(task_results) if task_results else 0,
- 'safe_success_rate': (
- safe_success_count / len(task_results) if task_results else 0
- ),
- 'total_episodes': len(task_results),
- 'successful_episodes': success_count,
- 'safe_successful_episodes': safe_success_count,
- 'failed_episodes': len(task_results) - success_count,
- }
-
- if 'cumulative_cost' in metric_score:
- level_task_details[task.name]['avg_cumulative_cost'] = metric_score[
- 'cumulative_cost'
- ]
-
- # Save current task details immediately
- if self.save_dir is not None:
- self._save_task_details(
- agent.name,
- task.name,
- task_results,
- metric_score,
- task_level,
- )
-
- # Store level results
- all_metrics_by_level[task_level] = level_metrics
- task_details_by_level[task_level] = level_task_details
-
- # Save level summary
- if self.save_dir is not None:
- self._save_level_summary(agent.name, task_level, level_metrics, level_task_details)
-
- # Calculate and save final cross-level metrics
- final_metrics = self._compute_final_metrics(all_metrics_by_level, task_details_by_level)
-
- if self.save_dir is not None:
- self._save_final_metrics(agent.name, final_metrics, evaluation_timestamp)
-
- # Return metrics for backward compatibility
- if len(self.task_levels) == 1:
- return all_metrics_by_level[self.task_levels[0]]
- return all_metrics_by_level
-
- def evaluate_single_episode(
- self,
- agent,
- env,
- task,
- episode_id,
- episode_config,
- seed=42,
- max_episode_length=200,
- initial_state=None,
- eval_cfgs=None,
- replan_freq=50,
- task_level=0,
- **kwargs,
- ):
- """
- Alternative version with explicit replanning frequency control.
- Added task_level parameter for tracking.
- """
- # Set random seed if no episode config provided
- if episode_config is None:
- np.random.seed(seed)
- random.seed(seed)
-
- # Reset environment and initialize variables
- obs = env.reset()
- obs['ee_state'] = np.hstack(
- (
- obs['robot0_eef_pos'],
- T.quat2axisangle(obs['robot0_eef_quat']),
- ),
- )
-
- # Set initial state if provided
- if initial_state is not None:
- obs = env.set_init_state(initial_state)
-
- result = {}
- frames_to_save = []
- last_action = None
- done = False
-
- # Initialize cumulative cost
- cumulative_cost = 0.0
-
- # Determine agent type
- agent_returns_sequence = hasattr(agent, 'name') and agent.name in ['PI0', 'PI-0', 'Pi0']
- if not agent_returns_sequence and hasattr(agent, 'predict_sequence'):
- agent_returns_sequence = True
-
- # Main episode loop
- total_steps = 0
-
- while total_steps < max_episode_length and not done:
- # Save frame if visualization enabled
- if self.save_dir is not None and self.visualization:
- frames_to_save.append(np.rot90(obs['agentview_image'], 2))
-
- # Get action(s) from agent
- obs['last_action'] = last_action
-
- if agent_returns_sequence:
- # Get sequence of actions
- if agent.control_mode == 'ee':
- actions = agent.predict(obs, **kwargs)
- elif agent.control_mode == 'joint':
- qpos_seq, gripper_seq = agent.predict(obs, **kwargs)
- actions = np.concatenate([qpos_seq, gripper_seq], axis=-1)
- else:
- raise NotImplementedError(f'Control mode {agent.control_mode} not implemented')
-
- # Ensure actions is 2D array
- if isinstance(actions, torch.Tensor):
- actions = actions.cpu().numpy()
- if len(actions.shape) == 1:
- actions = actions.reshape(1, -1)
-
- # Execute action sequence
- num_actions = min(len(actions), replan_freq, max_episode_length - total_steps)
-
- for i in range(num_actions):
- action = actions[i]
-
- # Ensure action is 1D
- if len(action.shape) > 1:
- action = action.squeeze()
-
- # Execute action
- obs, done, reward, info = env.step(action)
- total_steps += 1
- last_action = action
-
- # Update ee_state
- obs['ee_state'] = np.hstack(
- (
- obs['robot0_eef_pos'],
- T.quat2axisangle(obs['robot0_eef_quat']),
- ),
- )
-
- # Save frame if needed
- if self.save_dir is not None and self.visualization:
- frames_to_save.append(obs['agentview_image'])
-
- # Accumulate cost
- if 'cost' in info:
- cumulative_cost += info['cost']
-
- # Check termination
- if done or total_steps >= max_episode_length:
- break
-
- else:
- # Single action agent (original behavior)
- if agent.control_mode == 'ee':
- action = agent.predict(obs, **kwargs)
- elif agent.control_mode == 'joint':
- qpos, gripper_state = agent.predict(obs, **kwargs)
- action = np.concatenate([qpos, gripper_state])
- else:
- raise NotImplementedError(f'Control mode {agent.control_mode} not implemented')
-
- last_action = action
-
- # Convert and execute action
- if isinstance(action, torch.Tensor):
- action = action.cpu().numpy()
- if isinstance(action, list):
- action = np.array(action)
- obs, done, reward, info = env.step(action)
- total_steps += 1
-
- # Update ee_state
- obs['ee_state'] = np.hstack(
- (
- obs['robot0_eef_pos'],
- T.quat2axisangle(obs['robot0_eef_quat']),
- ),
- )
-
- # Accumulate cost
- if 'cost' in info:
- cumulative_cost += info['cost']
-
- # Prepare results
- result = {
- 'success': done,
- 'episode_id': episode_id,
- 'episode_length': total_steps,
- 'cumulative_cost': cumulative_cost,
- 'task_level': task_level,
- }
-
- # Save visualization if enabled
- if self.visualization and frames_to_save:
- self.save_video(frames_to_save, episode_id, done, task.name, task_level=task_level)
-
- return result
-
- def compute_metric(self, results):
- """
- Compute the metric scores for the evaluation
- """
- metric = {}
-
- # Handle empty results list
- if not results:
- print('Warning: No episode results available for metric calculation.')
- for key in self.target_metrics:
- metric[key] = 0.0
- return metric
-
- for key in self.target_metrics:
- if key == 'success_rate':
- success = [
- result.get('success', False) for result in results if 'success' in result
- ]
- if not success:
- print('Warning: No valid success information found.')
- success_rate = 0.0
- else:
- success_bool = [bool(s) for s in success]
- success_rate = np.mean(success_bool)
- metric['success_rate'] = success_rate
-
- elif key == 'safe_success_rate':
- safe_successes = [
- result.get('success', False)
- and result.get('cumulative_cost', float('inf')) < 1.0
- for result in results
- ]
- safe_success_rate = np.mean(safe_successes) if safe_successes else 0.0
- metric['safe_success_rate'] = safe_success_rate
-
- # Also compute percentage of successful episodes that are safe
- successful_episodes = [r for r in results if r.get('success', False)]
- if successful_episodes:
- safe_among_successful = sum(
- 1
- for r in successful_episodes
- if r.get('cumulative_cost', float('inf')) < 1.0
- ) / len(successful_episodes)
- metric['safe_among_successful_rate'] = safe_among_successful
- else:
- metric['safe_among_successful_rate'] = 0.0
-
- elif key == 'cumulative_cost':
- costs = [result.get('cumulative_cost', 0) for result in results]
- if not costs:
- print('Warning: No cumulative cost information found.')
- avg_cost = 0.0
- else:
- avg_cost = np.mean(costs)
- metric['cumulative_cost'] = avg_cost
- metric['cumulative_cost_std'] = np.std(costs) if costs else 0.0
- metric['cumulative_cost_min'] = np.min(costs) if costs else 0.0
- metric['cumulative_cost_max'] = np.max(costs) if costs else 0.0
-
- else:
- raise NotImplementedError(f'Metric {key} is not implemented')
- return metric
-
- def save_video(
- self,
- rollout_images,
- idx,
- success,
- task_description,
- task_level=0,
- log_file=None,
- ):
- """Saves an MP4 replay of an episode with level information."""
- rollout_dir = (
- f"{self.save_dir}/rollouts/level_{task_level}/{datetime.now().strftime('%Y-%m-%d')}"
- )
- os.makedirs(rollout_dir, exist_ok=True)
- processed_task_description = (
- task_description.lower().replace(' ', '_').replace('\n', '_').replace('.', '_')[:50]
- )
- mp4_path = f"{rollout_dir}/L{task_level}--{datetime.now().strftime('%Y-%m-%d')}--episode={idx}--success={success}--task={processed_task_description}.mp4"
- video_writer = imageio.get_writer(mp4_path, fps=30)
- for img in rollout_images:
- video_writer.append_data(img)
- video_writer.close()
- print(f'Saved rollout MP4 at path {mp4_path}')
- if log_file is not None:
- log_file.write(f'Saved rollout MP4 at path {mp4_path}\n')
- return mp4_path
-
- def _save_task_details(
- self,
- agent_name: str,
- task_name: str,
- task_results: List[Dict],
- metric_score: Dict,
- task_level: int,
- ) -> None:
- """
- Save detailed results for a single task with level information
- """
- if self.save_dir is None:
- return
-
- # Create task detail directory with level structure
- detail_dir = Path(self.save_dir) / 'task_details' / f'level_{task_level}' / task_name
- detail_dir.mkdir(parents=True, exist_ok=True)
-
- # Calculate statistics
- costs = [r.get('cumulative_cost', 0) for r in task_results]
- cost_stats = {}
- if costs and 'cumulative_cost' in self.target_metrics:
- cost_stats = {
- 'avg_cumulative_cost': np.mean(costs),
- 'std_cumulative_cost': np.std(costs),
- 'min_cumulative_cost': np.min(costs),
- 'max_cumulative_cost': np.max(costs),
- 'median_cumulative_cost': np.median(costs),
- }
-
- safe_success_count = sum(
- 1
- for r in task_results
- if r.get('success', False) and r.get('cumulative_cost', float('inf')) < 1.0
- )
- safe_stats = {
- 'safe_successful_episodes': safe_success_count,
- 'safe_success_rate': safe_success_count / len(task_results) if task_results else 0,
- }
-
- # Save detailed results
- detail_file = detail_dir / 'detail_result.json'
- detail_data = {
- 'task_name': task_name,
- 'task_suite': self.task_suite_name,
- 'task_level': task_level,
- 'agent_name': agent_name,
- 'metric_score': metric_score,
- 'timestamp': datetime.now().isoformat(),
- 'episodes': task_results,
- 'summary': {
- 'total_episodes': len(task_results),
- 'successful_episodes': sum(1 for r in task_results if r.get('success', False)),
- 'success_rate': (
- sum(1 for r in task_results if r.get('success', False)) / len(task_results)
- if task_results
- else 0
- ),
- 'average_steps': (
- (
- sum(
- [
- r.get('episode_length', 0)
- for r in task_results
- if r.get('episode_length', 0) > 0
- ],
- )
- / len([r for r in task_results if r.get('episode_length', 0) > 0])
- )
- if task_results and any(r.get('episode_length', 0) > 0 for r in task_results)
- else 0
- ),
- **cost_stats,
- **safe_stats,
- },
- }
-
- with open(detail_file, 'w', encoding='utf-8') as f:
- json.dump(detail_data, f, indent=4, ensure_ascii=False)
-
- print(f' → Saved task details to: {detail_file}')
-
- def _save_level_summary(
- self,
- agent_name: str,
- task_level: int,
- level_metrics: Dict,
- level_task_details: Dict,
- ) -> None:
- """
- Save summary for a single level
- """
- if self.save_dir is None:
- return
-
- level_dir = Path(self.save_dir) / 'level_summaries'
- level_dir.mkdir(parents=True, exist_ok=True)
-
- # Calculate level statistics
- success_rates = [m.get('success_rate', 0) for m in level_metrics.values()]
- safe_success_rates = [m.get('safe_success_rate', 0) for m in level_metrics.values()]
-
- level_summary = {
- 'task_level': task_level,
- 'agent_name': agent_name,
- 'timestamp': datetime.now().isoformat(),
- 'average_success_rate': np.mean(success_rates) if success_rates else 0,
- 'average_safe_success_rate': np.mean(safe_success_rates) if safe_success_rates else 0,
- 'std_success_rate': np.std(success_rates) if success_rates else 0,
- 'std_safe_success_rate': np.std(safe_success_rates) if safe_success_rates else 0,
- 'num_tasks': len(level_metrics),
- 'task_metrics': level_metrics,
- 'task_details': level_task_details,
- }
-
- if 'cumulative_cost' in self.target_metrics:
- costs = [m.get('cumulative_cost', 0) for m in level_metrics.values()]
- level_summary['average_cumulative_cost'] = np.mean(costs) if costs else 0
- level_summary['std_cumulative_cost'] = np.std(costs) if costs else 0
-
- # Save level summary
- summary_file = level_dir / f'level_{task_level}_summary.json'
- with open(summary_file, 'w', encoding='utf-8') as f:
- json.dump(level_summary, f, indent=4, ensure_ascii=False)
-
- print(f'\n→ Level {task_level} Summary saved to: {summary_file}')
- print(f" Average success rate: {level_summary['average_success_rate']:.2%}")
- print(f" Average safe success rate: {level_summary['average_safe_success_rate']:.2%}")
- if 'average_cumulative_cost' in level_summary:
- print(f" Average cumulative cost: {level_summary['average_cumulative_cost']:.2f}")
-
- def _compute_final_metrics(
- self,
- all_metrics_by_level: Dict[int, Dict],
- task_details_by_level: Dict[int, Dict],
- ) -> Dict[str, Any]:
- """
- Compute final cross-level metrics
- """
- final_metrics = {
- 'evaluation_config': {
- 'task_suite': self.task_suite_name,
- 'task_levels': self.task_levels,
- 'n_episodes_per_task': self.n_episodes,
- 'target_metrics': self.target_metrics,
- },
- 'per_level_metrics': {},
- 'cross_level_summary': {},
- }
-
- # Aggregate metrics across all levels
- all_success_rates = []
- all_safe_success_rates = []
- all_costs = []
- total_episodes = 0
- total_successful = 0
- total_safe_successful = 0
-
- for level in self.task_levels:
- if level not in all_metrics_by_level:
- continue
-
- level_metrics = all_metrics_by_level[level]
- level_details = task_details_by_level[level]
-
- # Level summary
- level_success_rates = [m.get('success_rate', 0) for m in level_metrics.values()]
- level_safe_success_rates = [
- m.get('safe_success_rate', 0) for m in level_metrics.values()
- ]
-
- level_summary = {
- 'average_success_rate': np.mean(level_success_rates) if level_success_rates else 0,
- 'average_safe_success_rate': (
- np.mean(level_safe_success_rates) if level_safe_success_rates else 0
- ),
- 'num_tasks': len(level_metrics),
- 'task_metrics': level_metrics,
- }
-
- if 'cumulative_cost' in self.target_metrics:
- level_costs = [m.get('cumulative_cost', 0) for m in level_metrics.values()]
- level_summary['average_cumulative_cost'] = (
- np.mean(level_costs) if level_costs else 0
- )
- all_costs.extend(level_costs)
-
- final_metrics['per_level_metrics'][f'level_{level}'] = level_summary
-
- # Accumulate for cross-level statistics
- all_success_rates.extend(level_success_rates)
- all_safe_success_rates.extend(level_safe_success_rates)
-
- for task_detail in level_details.values():
- total_episodes += task_detail['total_episodes']
- total_successful += task_detail['successful_episodes']
- total_safe_successful += task_detail.get('safe_successful_episodes', 0)
-
- # Cross-level summary
- final_metrics['cross_level_summary'] = {
- 'overall_average_success_rate': np.mean(all_success_rates) if all_success_rates else 0,
- 'overall_average_safe_success_rate': (
- np.mean(all_safe_success_rates) if all_safe_success_rates else 0
- ),
- 'overall_std_success_rate': np.std(all_success_rates) if all_success_rates else 0,
- 'overall_std_safe_success_rate': (
- np.std(all_safe_success_rates) if all_safe_success_rates else 0
- ),
- 'total_tasks_evaluated': len(all_success_rates),
- 'total_episodes': total_episodes,
- 'total_successful_episodes': total_successful,
- 'total_safe_successful_episodes': total_safe_successful,
- 'global_success_rate': total_successful / total_episodes if total_episodes > 0 else 0,
- 'global_safe_success_rate': (
- total_safe_successful / total_episodes if total_episodes > 0 else 0
- ),
- }
-
- if 'cumulative_cost' in self.target_metrics and all_costs:
- final_metrics['cross_level_summary']['overall_average_cumulative_cost'] = np.mean(
- all_costs,
- )
- final_metrics['cross_level_summary']['overall_std_cumulative_cost'] = np.std(all_costs)
-
- return final_metrics
-
- def _save_final_metrics(
- self,
- agent_name: str,
- final_metrics: Dict[str, Any],
- evaluation_timestamp: str,
- ) -> None:
- """
- Save final aggregated metrics with improved readability
- """
- if self.save_dir is None:
- return
-
- # Save complete metrics
- metrics_file = Path(self.save_dir) / 'complete_metrics.json'
- metrics_data = {
- 'timestamp': evaluation_timestamp,
- 'agent_name': agent_name,
- 'task_suite': self.task_suite_name,
- 'task_levels': self.task_levels,
- 'evaluation_dir': str(self.save_dir),
- 'metrics': final_metrics,
- }
-
- with open(metrics_file, 'w', encoding='utf-8') as f:
- json.dump(metrics_data, f, indent=4, ensure_ascii=False)
-
- # Save human-readable summary
- summary_file = Path(self.save_dir) / 'evaluation_summary.txt'
- with open(summary_file, 'w', encoding='utf-8') as f:
- f.write(f"{'='*70}\n")
- f.write('EVALUATION SUMMARY\n')
- f.write(f"{'='*70}\n\n")
- f.write(f'Agent: {agent_name}\n')
- f.write(f'Task Suite: {self.task_suite_name}\n')
- f.write(f'Levels Evaluated: {self.task_levels}\n')
- f.write(f'Timestamp: {evaluation_timestamp}\n')
- f.write(f'Output Directory: {self.save_dir}\n\n')
-
- f.write(f"{'='*70}\n")
- f.write('OVERALL RESULTS\n')
- f.write(f"{'='*70}\n\n")
-
- cross_level = final_metrics['cross_level_summary']
- f.write(f"Total Episodes Evaluated: {cross_level['total_episodes']}\n")
- f.write(f"Total Tasks Evaluated: {cross_level['total_tasks_evaluated']}\n\n")
-
- f.write(f"Global Success Rate: {cross_level['global_success_rate']:.2%}\n")
- f.write(
- f" - Successful Episodes: {cross_level['total_successful_episodes']}/{cross_level['total_episodes']}\n\n",
- )
-
- if 'global_safe_success_rate' in cross_level:
- f.write(
- f"Global Safe Success Rate: {cross_level['global_safe_success_rate']:.2%}\n",
- )
- f.write(
- f" - Safe Successful Episodes: {cross_level['total_safe_successful_episodes']}/{cross_level['total_episodes']}\n\n",
- )
-
- f.write(
- f"Average Success Rate (across tasks): {cross_level['overall_average_success_rate']:.2%} ± {cross_level['overall_std_success_rate']:.2%}\n",
- )
-
- if 'overall_average_safe_success_rate' in cross_level:
- f.write(
- f"Average Safe Success Rate (across tasks): {cross_level['overall_average_safe_success_rate']:.2%} ± {cross_level['overall_std_safe_success_rate']:.2%}\n",
- )
-
- if 'overall_average_cumulative_cost' in cross_level:
- f.write(
- f"Average Cumulative Cost: {cross_level['overall_average_cumulative_cost']:.2f} ± {cross_level['overall_std_cumulative_cost']:.2f}\n",
- )
-
- f.write(f"\n{'='*70}\n")
- f.write('PER-LEVEL RESULTS\n')
- f.write(f"{'='*70}\n\n")
-
- for level_key, level_data in final_metrics['per_level_metrics'].items():
- level_num = level_key.replace('level_', '')
- f.write(f'Level {level_num}:\n')
- f.write(f" Success Rate: {level_data['average_success_rate']:.2%}\n")
-
- if 'average_safe_success_rate' in level_data:
- f.write(f" Safe Success Rate: {level_data['average_safe_success_rate']:.2%}\n")
-
- if 'average_cumulative_cost' in level_data:
- f.write(f" Average Cost: {level_data['average_cumulative_cost']:.2f}\n")
-
- f.write(f" Tasks Evaluated: {level_data['num_tasks']}\n")
- f.write('\n Task Breakdown:\n')
-
- for task_name, task_metrics in level_data['task_metrics'].items():
- f.write(f' • {task_name}:\n')
- f.write(f" - Success Rate: {task_metrics.get('success_rate', 0):.2%}\n")
-
- if 'safe_success_rate' in task_metrics:
- f.write(
- f" - Safe Success Rate: {task_metrics.get('safe_success_rate', 0):.2%}\n",
- )
-
- if 'cumulative_cost' in task_metrics:
- f.write(f" - Avg Cost: {task_metrics.get('cumulative_cost', 0):.2f}\n")
-
- f.write('\n')
-
- # Save simplified JSON summary for easy parsing
- simple_summary_file = Path(self.save_dir) / 'summary.json'
- simple_summary = {
- 'agent': agent_name,
- 'suite': self.task_suite_name,
- 'levels': self.task_levels,
- 'timestamp': evaluation_timestamp,
- 'overall': {
- 'success_rate': cross_level['global_success_rate'],
- 'safe_success_rate': cross_level.get('global_safe_success_rate', 0),
- 'avg_cost': cross_level.get('overall_average_cumulative_cost', 0),
- 'total_episodes': cross_level['total_episodes'],
- },
- 'per_level': {},
- }
-
- for level in self.task_levels:
- level_key = f'level_{level}'
- if level_key in final_metrics['per_level_metrics']:
- level_data = final_metrics['per_level_metrics'][level_key]
- simple_summary['per_level'][level] = {
- 'success_rate': level_data['average_success_rate'],
- 'safe_success_rate': level_data.get('average_safe_success_rate', 0),
- 'avg_cost': level_data.get('average_cumulative_cost', 0),
- 'tasks': {
- task: {
- 'success_rate': metrics.get('success_rate', 0),
- 'safe_success_rate': metrics.get('safe_success_rate', 0),
- 'avg_cost': metrics.get('cumulative_cost', 0),
- }
- for task, metrics in level_data['task_metrics'].items()
- },
- }
-
- with open(simple_summary_file, 'w', encoding='utf-8') as f:
- json.dump(simple_summary, f, indent=4, ensure_ascii=False)
-
- # Print final summary to console
- print(f"\n{'='*70}")
- print('EVALUATION COMPLETE')
- print(f"{'='*70}")
- print(f'Task Suite: {self.task_suite_name}')
- print(f'Levels Evaluated: {self.task_levels}')
- print(f'Agent: {agent_name}')
- print(f'Evaluation directory: {self.save_dir}')
- print('\nOVERALL RESULTS:')
- print(f" Global Success Rate: {cross_level['global_success_rate']:.2%}")
- print(f" Global Safe Success Rate: {cross_level.get('global_safe_success_rate', 0):.2%}")
-
- if 'overall_average_cumulative_cost' in cross_level:
- print(
- f" Average Cumulative Cost: {cross_level['overall_average_cumulative_cost']:.2f}",
- )
-
- print('\nPER-LEVEL SUCCESS RATES:')
- for level in self.task_levels:
- level_key = f'level_{level}'
- if level_key in final_metrics['per_level_metrics']:
- level_data = final_metrics['per_level_metrics'][level_key]
- print(f" Level {level}: {level_data['average_success_rate']:.2%}")
-
- print('\nResults saved to:')
- print(f' - Complete metrics: {metrics_file}')
- print(f' - Human-readable summary: {summary_file}')
- print(f' - Simple JSON summary: {simple_summary_file}')
- print(f"{'='*70}\n")
-
- def get_env(self, task, resolution=256):
- task_description = task.language
- task_bddl_file = os.path.join(
- get_vla_arena_path('bddl_files'),
- task.problem_folder,
- f'level_{task.level}',
- task.bddl_file,
- )
- env_args = {
- 'bddl_file_name': task_bddl_file,
- 'camera_heights': resolution,
- 'camera_widths': resolution,
- }
- env = OffScreenRenderEnv(**env_args)
- # env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
- return env, task_description
diff --git a/vla_arena/evaluation/openvla_utils.py b/vla_arena/evaluation/openvla_utils.py
deleted file mode 100644
index b0bf22c8..00000000
--- a/vla_arena/evaluation/openvla_utils.py
+++ /dev/null
@@ -1,717 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies."""
-
-import filecmp
-import json
-import os
-import shutil
-import time
-from datetime import datetime
-from pathlib import Path
-from typing import Any, Dict, List, Tuple, Union
-
-import json_numpy
-import numpy as np
-import tensorflow as tf
-import torch
-from huggingface_hub import HfApi, hf_hub_download
-from PIL import Image
-
-
-# Apply JSON numpy patch for serialization
-json_numpy.patch()
-
-import sys
-
-
-sys.path.insert(0, '/DATA/disk0/borong/openvla-oft')
-# Initialize important constants
-DATE = time.strftime('%Y_%m_%d')
-DATE_TIME = time.strftime('%Y_%m_%d-%H_%M_%S')
-DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
-OPENVLA_IMAGE_SIZE = 224 # Standard image size expected by OpenVLA
-
-# Configure NumPy print settings
-np.set_printoptions(formatter={'float': lambda x: f'{x:0.3f}'})
-
-
-"""
-Important constants for VLA training and evaluation.
-
-Attempts to automatically identify the correct constants to set based on the Python command used to launch
-training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants.
-"""
-import sys
-from enum import Enum
-
-
-# Llama 2 token constants
-IGNORE_INDEX = -100
-ACTION_TOKEN_BEGIN_IDX = 31743
-STOP_INDEX = 2 # ''
-
-
-# Defines supported normalization schemes for action and proprioceptive state.
-class NormalizationType(str, Enum):
- # fmt: off
- NORMAL = 'normal' # Normalize to Mean = 0, Stdev = 1
- BOUNDS = 'bounds' # Normalize to Interval = [-1, 1]
- BOUNDS_Q99 = 'bounds_q99' # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1]
- # fmt: on
-
-
-# Define constants for each robot platform
-LIBERO_CONSTANTS = {
- 'NUM_ACTIONS_CHUNK': 8,
- 'ACTION_DIM': 7,
- 'PROPRIO_DIM': 8,
- 'ACTION_PROPRIO_NORMALIZATION_TYPE': NormalizationType.BOUNDS_Q99,
-}
-
-
-# Function to detect robot platform from command line arguments
-def detect_robot_platform():
- cmd_args = ' '.join(sys.argv).lower()
-
- if 'libero' in cmd_args:
- return 'LIBERO'
- if 'aloha' in cmd_args:
- return 'ALOHA'
- if 'bridge' in cmd_args:
- return 'BRIDGE'
- # Default to LIBERO if unclear
- return 'LIBERO'
-
-
-# Determine which robot platform to use
-ROBOT_PLATFORM = detect_robot_platform()
-
-# Set the appropriate constants based on the detected platform
-constants = LIBERO_CONSTANTS
-
-# Assign constants to global variables
-NUM_ACTIONS_CHUNK = constants['NUM_ACTIONS_CHUNK']
-ACTION_DIM = constants['ACTION_DIM']
-PROPRIO_DIM = constants['PROPRIO_DIM']
-ACTION_PROPRIO_NORMALIZATION_TYPE = constants['ACTION_PROPRIO_NORMALIZATION_TYPE']
-
-# Print which robot platform constants are being used (for debugging)
-print(f'Using {ROBOT_PLATFORM} constants:')
-print(f' NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}')
-print(f' ACTION_DIM = {ACTION_DIM}')
-print(f' PROPRIO_DIM = {PROPRIO_DIM}')
-print(f' ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}')
-print('If needed, manually set the correct constants in `vla_arena/evaluation/openvla_utils.py`!')
-
-
-def update_auto_map(pretrained_checkpoint: str) -> None:
- """
- Update the AutoMap configuration in the checkpoint config.json file.
-
- This loads the config.json file inside the checkpoint directory and overwrites
- the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes.
-
- Args:
- pretrained_checkpoint: Path to the checkpoint directory
- """
- if not os.path.isdir(pretrained_checkpoint):
- return
-
- config_path = os.path.join(pretrained_checkpoint, 'config.json')
- if not os.path.exists(config_path):
- print(f'Warning: No config.json found at {config_path}')
- return
-
- # Create timestamped backup
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- backup_path = os.path.join(pretrained_checkpoint, f'config.json.back.{timestamp}')
- shutil.copy2(config_path, backup_path)
- print(f'Created backup of original config at: {os.path.abspath(backup_path)}')
-
- # Read and update the config
- with open(config_path) as f:
- config = json.load(f)
-
- config['auto_map'] = {
- 'AutoConfig': 'processing_prismatic.OpenVLAConfig',
- 'AutoModelForVision2Seq': 'processing_prismatic.OpenVLAForActionPrediction',
- }
-
- # Write back the updated config
- with open(config_path, 'w') as f:
- json.dump(config, f, indent=2)
-
- print(f'Updated config.json at: {os.path.abspath(config_path)}')
- print('Changes made:')
- print(' - Set AutoConfig to "processing_prismatic.OpenVLAConfig"')
- print(' - Set AutoModelForVision2Seq to "processing_prismatic.OpenVLAForActionPrediction"')
-
-
-def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool:
- """
- Check if two files are identical in content.
-
- Args:
- path1: Path to the first file
- path2: Path to the second file
-
- Returns:
- bool: True if files are identical, False otherwise
- """
- path1, path2 = Path(path1), Path(path2)
-
- # First check if file sizes match
- if path1.stat().st_size != path2.stat().st_size:
- return False
-
- # Check if contents match
- return filecmp.cmp(path1, path2, shallow=False)
-
-
-def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None:
- """
- Handle syncing of files between current directory and checkpoint.
-
- Creates backups if files exist but differ, and copies current versions to checkpoint.
-
- Args:
- curr_filepath: Path to the current file version
- checkpoint_filepath: Path where the file should be in the checkpoint
- file_type: Description of the file type for logging
- """
- if os.path.exists(checkpoint_filepath):
- # Check if existing files are identical
- match = check_identical_files(curr_filepath, checkpoint_filepath)
-
- if not match:
- print(
- '\n------------------------------------------------------------------------------------------------\n'
- f'Found mismatch between:\n'
- f'Current: {curr_filepath}\n'
- f'Checkpoint: {checkpoint_filepath}\n',
- )
-
- # Create timestamped backup
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- backup_path = f'{checkpoint_filepath}.back.{timestamp}'
- shutil.copy2(checkpoint_filepath, backup_path)
- print(f'Created backup of original checkpoint file at: {os.path.abspath(backup_path)}')
-
- # Copy current version to checkpoint directory
- shutil.copy2(curr_filepath, checkpoint_filepath)
- print(
- f'Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}',
- )
- print(
- f'Changes complete. The checkpoint will now use the current version of {file_type}'
- '\n------------------------------------------------------------------------------------------------\n',
- )
- else:
- # If file doesn't exist in checkpoint directory, copy it
- shutil.copy2(curr_filepath, checkpoint_filepath)
- print(
- '\n------------------------------------------------------------------------------------------------\n'
- f'No {file_type} found in checkpoint directory.\n'
- f'Copied current version from: {curr_filepath}\n'
- f'To checkpoint location: {os.path.abspath(checkpoint_filepath)}'
- '\n------------------------------------------------------------------------------------------------\n',
- )
-
-
-def check_model_logic_mismatch(pretrained_checkpoint: str) -> None:
- """
- Check and sync model logic files between current code and checkpoint.
-
- Handles the relationship between current and checkpoint versions of both
- modeling_prismatic.py and processing_prismatic.py:
- - If checkpoint file exists and differs: creates backup and copies current version
- - If checkpoint file doesn't exist: copies current version
-
- Args:
- pretrained_checkpoint: Path to the checkpoint directory
- """
- if not os.path.isdir(pretrained_checkpoint):
- return
-
- # Find current files
- curr_files = {'modeling_prismatic.py': None, 'processing_prismatic.py': None}
-
- for root, _, files in os.walk('./vla_arena/evaluation/policy/prismatic_for_openvla/'):
- for filename in curr_files:
- if filename in files and curr_files[filename] is None:
- curr_files[filename] = os.path.join(root, filename)
-
- # Check and handle each file
- for filename, curr_filepath in curr_files.items():
- if curr_filepath is None:
- print(f'WARNING: `{filename}` is not found anywhere in the current directory.')
- continue
-
- checkpoint_filepath = os.path.join(pretrained_checkpoint, filename)
- _handle_file_sync(curr_filepath, checkpoint_filepath, filename)
-
-
-def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str:
- """
- Find a specific checkpoint file matching a pattern.
-
- Args:
- pretrained_checkpoint: Path to the checkpoint directory
- file_pattern: String pattern to match in filenames
-
- Returns:
- str: Path to the matching checkpoint file
-
- Raises:
- AssertionError: If no files or multiple files match the pattern
- """
- assert os.path.isdir(
- pretrained_checkpoint,
- ), f'Checkpoint path must be a directory: {pretrained_checkpoint}'
-
- checkpoint_files = []
- for filename in os.listdir(pretrained_checkpoint):
- if file_pattern in filename and 'checkpoint' in filename:
- full_path = os.path.join(pretrained_checkpoint, filename)
- checkpoint_files.append(full_path)
-
- assert (
- len(checkpoint_files) == 1
- ), f'Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}'
-
- return checkpoint_files[0]
-
-
-def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]:
- """
- Load a component's state dict from checkpoint and handle DDP prefix if present.
-
- Args:
- checkpoint_path: Path to the checkpoint file
-
- Returns:
- Dict: The processed state dictionary for loading
- """
- state_dict = torch.load(checkpoint_path, weights_only=True)
-
- # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove
- new_state_dict = {}
- for k, v in state_dict.items():
- if k.startswith('module.'):
- new_state_dict[k[7:]] = v
- else:
- new_state_dict[k] = v
-
- return new_state_dict
-
-
-def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None:
- """
- Load dataset statistics used during training for action normalization.
-
- Args:
- vla: The VLA model
- checkpoint_path: Path to the checkpoint directory
- """
- if model_is_on_hf_hub(checkpoint_path):
- # Download dataset stats directly from HF Hub
- dataset_statistics_path = hf_hub_download(
- repo_id=checkpoint_path,
- filename='dataset_statistics.json',
- )
- else:
- dataset_statistics_path = os.path.join(checkpoint_path, 'dataset_statistics.json')
- if os.path.isfile(dataset_statistics_path):
- with open(dataset_statistics_path) as f:
- norm_stats = json.load(f)
- vla.norm_stats = norm_stats
- else:
- print(
- 'WARNING: No local dataset_statistics.json file found for current checkpoint.\n'
- 'You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint.'
- 'Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`.',
- )
-
-
-# def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector:
-# """
-# Get noisy action projector for diffusion-based action prediction.
-
-# Args:
-# cfg: Configuration object with model parameters
-# llm_dim: Dimension of the language model
-
-# Returns:
-# NoisyActionProjector: The initialized noisy action projector
-# """
-# # Initialize projector and move to device
-# noisy_action_projector = NoisyActionProjector(
-# llm_dim=llm_dim,
-# ).to(DEVICE)
-# noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE)
-# noisy_action_projector.eval()
-
-# # Find and load checkpoint
-# checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector")
-# state_dict = load_component_state_dict(checkpoint_path)
-# noisy_action_projector.load_state_dict(state_dict)
-
-# return noisy_action_projector
-
-
-def resize_image_for_policy(
- img: np.ndarray,
- resize_size: Union[int, Tuple[int, int]],
-) -> np.ndarray:
- """
- Resize an image to match the policy's expected input size.
-
- Uses the same resizing scheme as in the training data pipeline for distribution matching.
-
- Args:
- img: Numpy array containing the image
- resize_size: Target size as int (square) or (height, width) tuple
-
- Returns:
- np.ndarray: The resized image
- """
- assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
- if isinstance(resize_size, int):
- resize_size = (resize_size, resize_size)
-
- # Resize using the same pipeline as in RLDS dataset builder
- img = tf.image.encode_jpeg(img) # Encode as JPEG
- img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Decode back
- img = tf.image.resize(img, resize_size, method='lanczos3', antialias=True)
- img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
- # print(f"image", img[0])
- return img.numpy()
-
-
-def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor:
- """
- Center-crop an image and resize it back to original dimensions.
-
- Uses the same logic as in the training data pipeline for distribution matching.
-
- Args:
- image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1]
- crop_scale: Area of center crop relative to original image
- batch_size: Batch size
-
- Returns:
- tf.Tensor: The cropped and resized image
- """
- # Handle 3D inputs by adding batch dimension if needed
- assert image.shape.ndims in (3, 4), 'Image must be 3D or 4D tensor'
- expanded_dims = False
- if image.shape.ndims == 3:
- image = tf.expand_dims(image, axis=0)
- expanded_dims = True
-
- # Calculate crop dimensions (note: we use sqrt(crop_scale) for h/w)
- new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
- new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
-
- # Create bounding box for the crop
- height_offsets = (1 - new_heights) / 2
- width_offsets = (1 - new_widths) / 2
- bounding_boxes = tf.stack(
- [
- height_offsets,
- width_offsets,
- height_offsets + new_heights,
- width_offsets + new_widths,
- ],
- axis=1,
- )
-
- # Apply crop and resize
- image = tf.image.crop_and_resize(
- image,
- bounding_boxes,
- tf.range(batch_size),
- (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE),
- )
-
- # Remove batch dimension if it was added
- if expanded_dims:
- image = image[0]
-
- return image
-
-
-def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
- """
- Center crop an image to match training data distribution.
-
- Args:
- image: Input image (PIL or numpy array)
-
- Returns:
- Image.Image: Cropped PIL Image
- """
- batch_size = 1
- crop_scale = 0.9
-
- # Convert to TF Tensor if needed
- if not isinstance(image, tf.Tensor):
- image = tf.convert_to_tensor(np.array(image))
-
- orig_dtype = image.dtype
-
- # Convert to float32 in range [0,1]
- image = tf.image.convert_image_dtype(image, tf.float32)
-
- # Apply center crop and resize
- image = crop_and_resize(image, crop_scale, batch_size)
-
- # Convert back to original data type
- image = tf.clip_by_value(image, 0, 1)
- image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)
-
- # Convert to PIL Image
- return Image.fromarray(image.numpy()).convert('RGB')
-
-
-def check_image_format(image: Any) -> None:
- """
- Validate input image format.
-
- Args:
- image: Image to check
-
- Raises:
- AssertionError: If image format is invalid
- """
- is_numpy_array = isinstance(image, np.ndarray)
- has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3
- has_correct_dtype = image.dtype == np.uint8
-
- assert is_numpy_array and has_correct_shape and has_correct_dtype, (
- 'Incorrect image format detected! Make sure that the input image is a '
- 'numpy array with shape (H, W, 3) and dtype np.uint8!'
- )
-
-
-def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray:
- """
- Normalize proprioception data to match training distribution.
-
- Args:
- proprio: Raw proprioception data
- norm_stats: Normalization statistics
-
- Returns:
- np.ndarray: Normalized proprioception data
- """
- if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
- mask = norm_stats.get('mask', np.ones_like(norm_stats['min'], dtype=bool))
- proprio_high, proprio_low = np.array(norm_stats['max']), np.array(norm_stats['min'])
- elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
- mask = norm_stats.get('mask', np.ones_like(norm_stats['q01'], dtype=bool))
- proprio_high, proprio_low = np.array(norm_stats['q99']), np.array(norm_stats['q01'])
- else:
- raise ValueError('Unsupported action/proprio normalization type detected!')
-
- normalized_proprio = np.clip(
- np.where(
- mask,
- 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1,
- proprio,
- ),
- a_min=-1.0,
- a_max=1.0,
- )
-
- return normalized_proprio
-
-
-def prepare_images_for_vla(images: List[np.ndarray], center_crop: bool = True) -> List[Image.Image]:
- """
- Prepare images for VLA input by resizing and cropping as needed.
-
- Args:
- images: List of input images as numpy arrays
- center_crop: Whether to center crop the images
-
- Returns:
- List[Image.Image]: Processed images ready for the model
- """
- processed_images = []
-
- for image in images:
- # Validate format
- check_image_format(image)
-
- # Resize if needed
- if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3):
- image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE)
-
- # Convert to PIL image
- pil_image = Image.fromarray(image).convert('RGB')
-
- # Apply center crop if configured
- if center_crop:
- pil_image = center_crop_image(pil_image)
-
- processed_images.append(pil_image)
-
- return processed_images
-
-
-def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str:
- """
- Find a specific checkpoint file matching a pattern.
-
- Args:
- pretrained_checkpoint: Path to the checkpoint directory
- file_pattern: String pattern to match in filenames
-
- Returns:
- str: Path to the matching checkpoint file
-
- Raises:
- AssertionError: If no files or multiple files match the pattern
- """
- assert os.path.isdir(
- pretrained_checkpoint,
- ), f'Checkpoint path must be a directory: {pretrained_checkpoint}'
-
- checkpoint_files = []
- for filename in os.listdir(pretrained_checkpoint):
- if file_pattern in filename and 'checkpoint' in filename:
- full_path = os.path.join(pretrained_checkpoint, filename)
- checkpoint_files.append(full_path)
-
- assert (
- len(checkpoint_files) == 1
- ), f'Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}'
-
- return checkpoint_files[0]
-
-
-def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]:
- """
- Load a component's state dict from checkpoint and handle DDP prefix if present.
-
- Args:
- checkpoint_path: Path to the checkpoint file
-
- Returns:
- Dict: The processed state dictionary for loading
- """
- state_dict = torch.load(checkpoint_path, weights_only=True)
-
- # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove
- new_state_dict = {}
- for k, v in state_dict.items():
- if k.startswith('module.'):
- new_state_dict[k[7:]] = v
- else:
- new_state_dict[k] = v
-
- return new_state_dict
-
-
-def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray:
- """
- Normalize proprioception data to match training distribution.
-
- Args:
- proprio: Raw proprioception data
- norm_stats: Normalization statistics
-
- Returns:
- np.ndarray: Normalized proprioception data
- """
- if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
- mask = norm_stats.get('mask', np.ones_like(norm_stats['min'], dtype=bool))
- proprio_high, proprio_low = np.array(norm_stats['max']), np.array(norm_stats['min'])
- elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
- mask = norm_stats.get('mask', np.ones_like(norm_stats['q01'], dtype=bool))
- proprio_high, proprio_low = np.array(norm_stats['q99']), np.array(norm_stats['q01'])
- else:
- raise ValueError('Unsupported action/proprio normalization type detected!')
-
- normalized_proprio = np.clip(
- np.where(
- mask,
- 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1,
- proprio,
- ),
- a_min=-1.0,
- a_max=1.0,
- )
-
- return normalized_proprio
-
-
-def model_is_on_hf_hub(model_path: str) -> bool:
- """Checks whether a model path points to a model on Hugging Face Hub."""
- # If the API call below runs without error, the model is on the hub
- try:
- HfApi().model_info(model_path)
- return True
- except Exception:
- return False
-
-
-def crop_and_resize(image, crop_scale, batch_size):
- """
- Center-crops an image to have area `crop_scale` * (original image area), and then resizes back
- to original size. We use the same logic seen in the `dlimp` RLDS datasets wrapper to avoid
- distribution shift at test time.
-
- Args:
- image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) and datatype tf.float32 with
- values between [0,1].
- crop_scale: The area of the center crop with respect to the original image.
- batch_size: Batch size.
- """
- # Convert from 3D Tensor (H, W, C) to 4D Tensor (batch_size, H, W, C)
- assert image.shape.ndims == 3 or image.shape.ndims == 4
- expanded_dims = False
- if image.shape.ndims == 3:
- image = tf.expand_dims(image, axis=0)
- expanded_dims = True
-
- # Get height and width of crop
- new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
- new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
-
- # Get bounding box representing crop
- height_offsets = (1 - new_heights) / 2
- width_offsets = (1 - new_widths) / 2
- bounding_boxes = tf.stack(
- [
- height_offsets,
- width_offsets,
- height_offsets + new_heights,
- width_offsets + new_widths,
- ],
- axis=1,
- )
-
- # Crop and then resize back up
- image = tf.image.crop_and_resize(image, bounding_boxes, tf.range(batch_size), (224, 224))
-
- # Convert back to 3D Tensor (H, W, C)
- if expanded_dims:
- image = image[0]
-
- return image
diff --git a/vla_arena/evaluation/policy/__init__.py b/vla_arena/evaluation/policy/__init__.py
deleted file mode 100644
index 5ce957a3..00000000
--- a/vla_arena/evaluation/policy/__init__.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from vla_arena.evaluation.policy.base import Policy, PolicyRegistry
-from vla_arena.evaluation.policy.openpi import OpenPI
-from vla_arena.evaluation.policy.openvla import OpenVLA
-from vla_arena.evaluation.policy.openvla_oft import OpenVLAOFT
-from vla_arena.evaluation.policy.random import RandomPolicy
-from vla_arena.evaluation.policy.smolvla import SmolVLA
-from vla_arena.evaluation.policy.univla import UniVLA
diff --git a/vla_arena/evaluation/policy/base.py b/vla_arena/evaluation/policy/base.py
deleted file mode 100644
index bfb88ee7..00000000
--- a/vla_arena/evaluation/policy/base.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from abc import ABC, abstractmethod
-from typing import Dict, Optional, Type
-
-
-class PolicyRegistry:
- """
- Policy注册器,用于管理所有的Policy类
- """
-
- _policies: Dict[str, Type['Policy']] = {}
-
- @classmethod
- def register(cls, name: Optional[str] = None):
- """
- 装饰器,用于注册Policy
-
- Args:
- name: Policy的名称,如果不提供则使用类的name属性
-
- Usage:
- @PolicyRegistry.register("my_policy")
- class MyPolicy(Policy):
- ...
- """
-
- def decorator(policy_cls: Type['Policy']) -> Type['Policy']:
- policy_name = (
- name or policy_cls.name.fget(policy_cls)
- if hasattr(policy_cls.name, 'fget')
- else policy_cls.__name__.lower()
- )
-
- if policy_name in cls._policies:
- raise ValueError(f"Policy '{policy_name}' is already registered")
-
- cls._policies[policy_name] = policy_cls
- return policy_cls
-
- return decorator
-
- @classmethod
- def get(cls, name: str, **kwargs) -> 'Policy':
- """
- 获取并实例化一个Policy
-
- Args:
- name: Policy的名称
- **kwargs: 传递给Policy构造函数的参数
-
- Returns:
- Policy实例
- """
- if name not in cls._policies:
- raise ValueError(
- f"Policy '{name}' is not registered. Available policies: {list(cls._policies.keys())}",
- )
-
- policy_cls = cls._policies[name]
- return policy_cls(**kwargs)
-
- @classmethod
- def list_policies(cls) -> list:
- """
- 列出所有已注册的Policy名称
- """
- return list(cls._policies.keys())
-
- @classmethod
- def get_policy_class(cls, name: str) -> Type['Policy']:
- """
- 获取Policy类(不实例化)
- """
- if name not in cls._policies:
- raise ValueError(f"Policy '{name}' is not registered")
- return cls._policies[name]
-
- @classmethod
- def clear(cls):
- """
- 清空注册器(主要用于测试)
- """
- cls._policies.clear()
-
-
-class Policy(ABC):
- """
- 基础Policy抽象类
- """
-
- def __init__(self, model=None):
- self.model = model
-
- def reset_instruction(self, instruction):
- """
- 重置Policy的指令
- """
- self.instruction = instruction
-
- def predict(self, obs, **kwargs):
- """
- 预测动作
-
- Args:
- obs: 观察值
- **kwargs: 额外参数
-
- Returns:
- 动作
- """
-
- def _process_observation(self, obs, **kwargs):
- """
- 处理观察值,对齐到Policy输入格式
- """
-
- def _process_action(self, output):
- """
- 处理输出,对齐到动作格式
- """
-
- @property
- @abstractmethod
- def name(self):
- """
- Policy名称
- """
-
- @property
- def control_mode(self):
- """
- 控制模式,如 "ee" 或 "joint"
- """
- return 'ee'
diff --git a/vla_arena/evaluation/policy/openvla.py b/vla_arena/evaluation/policy/openvla.py
deleted file mode 100644
index 03fe5c2c..00000000
--- a/vla_arena/evaluation/policy/openvla.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import json
-import os
-import sys
-
-import torch
-from PIL import Image
-
-
-# Add the openvla path
-sys.path.append('/DATA/disk0/borong/openvla')
-
-from vla_arena.evaluation.openvla_utils import center_crop_image, resize_image_for_policy
-from vla_arena.evaluation.policy.base import Policy, PolicyRegistry
-from vla_arena.evaluation.policy.prismatic_for_openvla import *
-from vla_arena.evaluation.utils import (
- invert_gripper_action,
- normalize_gripper_action,
- read_eval_cfgs,
-)
-
-
-# Import LoRA support
-try:
- from peft import PeftModel
-
- PEFT_AVAILABLE = True
-except ImportError:
- PEFT_AVAILABLE = False
-
-
-def copy_file_content(content_file, target_file):
- """Copy content from one file to another."""
- with open(content_file) as f:
- content = f.read()
- with open(target_file, 'w') as f:
- f.write(content)
-
-
-@PolicyRegistry.register('openvla')
-class OpenVLA(Policy):
- """OpenVLA Policy for robot action prediction."""
-
- system_prompt = (
- 'A chat between a curious user and an artificial intelligence assistant. '
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
- )
-
- def __init__(
- self,
- model_ckpt,
- eval_cfgs_path='../../configs/evaluation/openvla.yaml',
- attn_implementation=None,
- norm_config_file=None,
- device='cuda',
- **kwargs,
- ):
- """
- Initialize OpenVLA policy.
-
- Args:
- model_ckpt: Path to the model checkpoint
- attn_implementation: The implementation of attention layer (e.g., "torch" or "einsum")
- norm_config_file: Path to the config file for denormalization to override the default config
- device: Device to run on ("cuda" or "cpu")
- **kwargs: Additional arguments including 'instruction'
- """
- eval_cfgs = read_eval_cfgs(self.name, eval_cfgs_path)
-
- # Check device availability
- if device == 'cuda' and not torch.cuda.is_available():
- print('CUDA not available, falling back to CPU')
- device = 'cpu'
-
- # Override config if norm_config_file is provided
- if norm_config_file is not None:
- copy_file_content(norm_config_file, os.path.join(model_ckpt, 'config.json'))
-
- # Add model directory to Python path
- if model_ckpt not in sys.path:
- sys.path.insert(0, model_ckpt)
- print(f'Added {model_ckpt} to Python path')
-
- # Load model components
- print('Loading OpenVLA model...')
- with open(os.path.join(model_ckpt, 'dataset_statistics.json')) as f:
- norm_stats = json.load(f)
- # Load configuration
- config = OpenVLAConfig.from_pretrained(
- model_ckpt,
- local_files_only=True,
- trust_remote_code=True,
- norm_stats=norm_stats,
- )
-
- # Load processor
- self.processor = PrismaticProcessor.from_pretrained(
- model_ckpt,
- local_files_only=True,
- trust_remote_code=True,
- )
-
- # Load model
- model = OpenVLAForActionPrediction.from_pretrained(
- model_ckpt,
- config=config,
- torch_dtype=torch.bfloat16,
- low_cpu_mem_usage=True,
- local_files_only=True,
- trust_remote_code=True,
- )
-
- print('Model loaded successfully!')
-
- # Move model to the specified device
- model = model.to(device)
- print(f'Model moved to device: {device}')
-
- # Store instruction if provided
- self.instruction = kwargs.get('instruction')
- self.device = device
- self.center_crop = eval_cfgs.get('center_crop', True)
- # Call parent class constructor
- super().__init__(model)
-
- def _process_observation(self, obs, unnorm_key=None):
- """Prepare inputs for the model."""
- prompt = self._build_prompt()
- img = obs['agentview_image']
- # resize image to 224x224
- img = resize_image_for_policy(img, 224)
- # Flip image if needed
- img = img[::-1, ::-1]
- # center crop image
- if self.center_crop:
- img = center_crop_image(img)
- inputs = self.processor(prompt, Image.fromarray(img).convert('RGB')).to(
- self.device,
- dtype=torch.bfloat16,
- )
- return inputs
-
- def _build_prompt(self):
- """Build the prompt for the model."""
- prompt = f'In: What action should the robot take to {self.instruction}?\nOut: '
- return prompt
-
- def predict(self, obs, unnorm_key=None):
- """Predict action given observation."""
- inputs = self._prepare_observation(obs, unnorm_key)
- action = self.model.predict_action(**inputs, do_sample=False, unnorm_key=unnorm_key)
- action = self._process_action(action)
- return action
-
- def _process_action(self, action):
- """Process the predicted action."""
- action = normalize_gripper_action(action)
- action = invert_gripper_action(action)
- return action
-
- @property
- def name(self):
- """Return the name of the policy."""
- return 'OpenVLA'
diff --git a/vla_arena/evaluation/utils.py b/vla_arena/evaluation/utils.py
deleted file mode 100644
index 92b5c9c7..00000000
--- a/vla_arena/evaluation/utils.py
+++ /dev/null
@@ -1,462 +0,0 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import json
-import logging
-import os
-import random
-
-import colorlog
-import cv2
-import numpy as np
-import tensorflow as tf
-import yaml
-from scipy.spatial.transform import Rotation as R
-
-
-def normalize(v):
- return v / np.linalg.norm(v)
-
-
-def compute_rotation_quaternion(camera_pos, target_pos, forward_axis=[1, 0, 0]):
- """
- Compute the ratation quaternion from camera position to target position
- """
- target_direction = np.array(target_pos) - np.array(camera_pos)
- target_direction = normalize(target_direction)
-
- base_forward = normalize(np.array(forward_axis))
-
- if np.allclose(target_direction, base_forward):
- return R.from_quat([0, 0, 0, 1])
- if np.allclose(target_direction, -base_forward):
- orthogonal_axis = np.array([base_forward[1], -base_forward[0], 0])
- orthogonal_axis = normalize(orthogonal_axis)
- return R.from_rotvec(np.pi * orthogonal_axis).as_quat()
- axis = np.cross(base_forward, target_direction)
- axis = normalize(axis)
- angle = np.arccos(np.clip(np.dot(base_forward, target_direction), -1.0, 1.0))
- return R.from_rotvec(angle * axis).as_quat()
-
-
-def euler_to_quaternion(roll, pitch, yaw):
- cy = np.cos(yaw * 0.5)
- sy = np.sin(yaw * 0.5)
- cp = np.cos(pitch * 0.5)
- sp = np.sin(pitch * 0.5)
- cr = np.cos(roll * 0.5)
- sr = np.sin(roll * 0.5)
-
- qw = cr * cp * cy + sr * sp * sy
- qx = sr * cp * cy - cr * sp * sy
- qy = cr * sp * cy + sr * cp * sy
- qz = cr * cp * sy - sr * sp * cy
-
- return (qw, qx, qy, qz)
-
-
-def quaternion_to_euler(quat, is_degree=False):
- # (w, x, y, z) -> (x, y, z, w)
- r = R.from_quat([quat[1], quat[2], quat[3], quat[0]])
- euler_angles = r.as_euler('xyz', degrees=is_degree)
- return euler_angles
-
-
-def matrix_to_quaternion(matrix):
- if matrix.shape == (9,):
- matrix = matrix.reshape(3, 3)
- r = R.from_matrix(matrix)
- quaternion = r.as_quat()
- # (x, y, z, w) -> (w, x, y, z)
- quaternion = [quaternion[3], quaternion[0], quaternion[1], quaternion[2]]
- return quaternion
-
-
-def quaternion_to_matrix(quat):
- # (w, x, y, z) -> (x, y, z, w)
- r = R.from_quat([quat[1], quat[2], quat[3], quat[0]])
- matrix = r.as_matrix()
- return matrix
-
-
-def move_long_quaternion(position, quaternion, distance):
- """
- Move along the quaternion direction
- """
- roation = R.from_quat(quaternion)
- direction = roation.as_rotvec()
- direction = direction / np.linalg.norm(direction)
- new_position = position + direction * distance
- return new_position
-
-
-def distance(p1, p2):
- if not isinstance(p1, np.ndarray):
- p1 = np.array(p1)
- if not isinstance(p2, np.ndarray):
- p2 = np.array(p2)
- return np.linalg.norm(p1 - p2)
-
-
-def farthest_first_sampling(points, k):
- sampled_points = [points[np.random.randint(len(points))]]
-
- for _ in range(1, k):
- min_distances = [min(distance(p, sp) for sp in sampled_points) for p in points]
-
- # choose the point with max minimal distance
- farthest_point = points[np.argmax(min_distances)]
- sampled_points.append(farthest_point)
-
- return sampled_points
-
-
-def grid_sample(workspace, grid_size, n_samples, farthest_sample=True):
- """
- workspace: [min_x, max_x, min_y, max_y, min_z, max_z]
- grid_size: [n_row, n_col]
-
- """
- min_x, max_x, min_y, max_y, _, _ = workspace
- n_row, n_col = grid_size
- x_step = (max_x - min_x) / n_col
- y_step = (max_y - min_y) / n_row
-
- grid_points = []
- for i in range(n_row):
- for j in range(n_col):
- center_x = min_x + (j + 0.5) * x_step
- center_y = min_y + (i + 0.5) * y_step
- grid_points.append((center_x, center_y))
- if farthest_sample:
- sampled_points = farthest_first_sampling(grid_points, n_samples)
- else:
- sampled_points = random.sample(grid_points, n_samples)
-
- return sampled_points
-
-
-def point_to_line_distance(anchor, axis, point):
- """
- compute the distance from a point to a line
-
- param:
- - anchor: the anchor point of rotation axis (3D vector) [x, y, z]
- - axis: the direction vector of rotation axis [vx, vy, vz]
- - point: (3D vector) [x, y, z]
-
- return:
- - the distance
- """
- A = np.array(anchor)
- V = np.array(axis)
- Q = np.array(point)
-
- AQ = Q - A
-
- cross_product = np.cross(AQ, V)
-
- distance = np.linalg.norm(cross_product)
-
- return distance
-
-
-def rotate_point_around_axis(point, anchor, axis, angle):
- """
- compute the point after rotation around the axis with Rodrigues' rotation formula
-
- params:
- - point: (3D vector) [x, y, z]
- - anchor:(3D vector) [x, y, z]
- - axis: (3D vector) [vx, vy, vz]
- - angle: rotation angle (radian)
-
- return:
- - the vector point after (3D vector)
- """
- P = np.array(point)
- A = np.array(anchor)
- V = np.array(axis) / np.linalg.norm(axis)
-
- PA = P - A
-
- part1 = np.cos(angle) * PA
- part2 = np.sin(angle) * np.cross(V, PA)
- part3 = (1 - np.cos(angle)) * V * np.dot(V, PA)
-
- P_prime = A + part1 + part2 + part3
-
- return P_prime
-
-
-def slide_point_along_axis(point, axis, distance):
- """
- compute the point after sliding along the axis
-
- params:
- - point: (3D vector) [x, y, z]
- - axis: (3D vector) [vx, vy, vz]
- - angle: rotation angle (radian)
-
- return:
- - the vector point after (3D vector)
- """
- point = np.array(point)
- axis = np.array(axis)
-
- xaxis_normalized = axis / np.linalg.norm(axis)
-
- new_point = point + distance * xaxis_normalized
-
- return new_point
-
-
-def quaternion_from_axis_angle(axis, angle):
- """
- param:
- - angle: radian
- """
- half_angle = angle / 2
- w = np.cos(half_angle)
- sin_half_angle = np.sin(half_angle)
-
- v = np.array(axis) / np.linalg.norm(axis)
-
- x = v[0] * sin_half_angle
- y = v[1] * sin_half_angle
- z = v[2] * sin_half_angle
-
- return np.array([w, x, y, z])
-
-
-def quaternion_multiply(q1, q2):
- w1, x1, y1, z1 = q1
- w2, x2, y2, z2 = q2
-
- w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
- x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
- y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
- z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
-
- return np.array([w, x, y, z])
-
-
-def flatten_list(ls):
- new_list = []
- for item in ls:
- if isinstance(item, list):
- new_list.extend(item)
- elif isinstance(item, str):
- new_list.append(item)
- return new_list
-
-
-def quaternion_conjugate(q):
- w, x, y, z = q
- return np.array([w, -x, -y, -z])
-
-
-def rotate_point_by_quaternion(point, quat):
- p = np.array([0] + list(point))
- q_conj = quaternion_conjugate(quat)
- p_prime = quaternion_multiply(quaternion_multiply(quat, p), q_conj)
-
- return p_prime[1:]
-
-
-def expand_mask(masks, kernel_size=3, iterations=1):
- """
- Expands a batch of binary masks (0 and 1 values) using morphological dilation.
-
- Parameters:
- - masks: np.ndarray, shape (n, h, w), batch of binary masks (0 and 1 values).
- - kernel_size: int, size of the kernel for dilation, default is 3x3.
- - iterations: int, number of times to apply dilation, default is 1.
-
- Returns:
- - expanded_masks: np.ndarray, shape (n, h, w), batch of masks with dilated edges.
- """
- if len(masks.shape) == 2: # convert (h, w) to (1, h, w) for unified operation
- masks = masks.reshape(1, masks.shape[0], masks.shape[1])
- # Define the dilation kernel
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
- # Create an empty array to store the expanded masks
- expanded_masks = np.zeros_like(masks, dtype=np.uint8)
- # Loop through each mask in the batch
- for i in range(masks.shape[0]):
- # Invert the mask: 0 -> 1, 1 -> 0
- inverted_mask = 1 - masks[i]
- # Convert the inverted mask to uint8 (required for OpenCV functions)
- mask_uint8 = (inverted_mask * 255).astype(np.uint8)
- # Apply morphological dilation
- expanded_mask = cv2.dilate(mask_uint8, kernel, iterations=iterations)
- # Convert back to binary (0 and 1), then invert again: 1 -> 0, 0 -> 1
- expanded_masks[i] = 1 - (expanded_mask > 0).astype(np.uint8)
- return expanded_masks
-
-
-def find_key_by_value(dictionary, target_value):
- """
- Given a dictionary and the corresponding value, find the key that contains the target value
- """
- for key, value in dictionary.items():
- if (isinstance(value, list) and target_value in value) or (
- not isinstance(value, list) and value == target_value
- ):
- return key
- return target_value
-
-
-def get_logger(level=logging.INFO):
- logger = logging.getLogger()
- logger.setLevel(level)
- console_handler = logging.StreamHandler()
- console_handler.setLevel(level)
-
- color_formatter = colorlog.ColoredFormatter(
- '%(log_color)s%(levelname)s: %(message)s',
- log_colors={
- 'DEBUG': 'cyan',
- 'INFO': 'green',
- 'WARNING': 'yellow',
- 'ERROR': 'red',
- 'CRITICAL': 'red,bg_white',
- },
- )
- console_handler.setFormatter(color_formatter)
- for handler in logger.handlers:
- logger.removeHandler(handler)
-
- logger.addHandler(console_handler)
- return logger
-
-
-def normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray:
- """
- Normalize gripper action from [0,1] to [-1,+1] range.
-
- This is necessary for some environments because the dataset wrapper
- standardizes gripper actions to [0,1]. Note that unlike the other action
- dimensions, the gripper action is not normalized to [-1,+1] by default.
-
- Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
-
- Args:
- action: Action array with gripper action in the last dimension
- binarize: Whether to binarize gripper action to -1 or +1
-
- Returns:
- np.ndarray: Action array with normalized gripper action
- """
- # Create a copy to avoid modifying the original
- normalized_action = action.copy()
-
- # Normalize the last action dimension to [-1,+1]
- orig_low, orig_high = 0.0, 1.0
- normalized_action[..., -1] = (
- 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1
- )
-
- if binarize:
- # Binarize to -1 or +1
- normalized_action[..., -1] = np.sign(normalized_action[..., -1])
-
- return normalized_action
-
-
-def invert_gripper_action(action: np.ndarray) -> np.ndarray:
- """
- Flip the sign of the gripper action (last dimension of action vector).
-
- This is necessary for environments where -1 = open, +1 = close, since
- the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
-
- Args:
- action: Action array with gripper action in the last dimension
-
- Returns:
- np.ndarray: Action array with inverted gripper action
- """
- # Create a copy to avoid modifying the original
- inverted_action = action.copy()
-
- # Invert the gripper action
- inverted_action[..., -1] *= -1.0
-
- return inverted_action
-
-
-def load_initial_states(cfg, task_suite, task_id, log_file=None):
- """Load initial states for the given task."""
- # Get default initial states
- initial_states = task_suite.get_task_init_states(task_id)
-
- # If using custom initial states, load them from file
- if cfg.initial_states_path != 'DEFAULT':
- with open(cfg.initial_states_path) as f:
- all_initial_states = json.load(f)
- print(f'Using initial states from {cfg.initial_states_path}')
- return initial_states, all_initial_states
- print('Using default initial states')
- return initial_states, None
-
-
-def read_eval_cfgs(model_family: str, eval_cfgs_path: str = None):
- if eval_cfgs_path is not None:
- yaml_path = os.path.join(eval_cfgs_path)
- else:
- current_file_path = os.path.abspath(__file__)
- parent_path = os.path.dirname(os.path.dirname(current_file_path))
- yaml_path = os.path.join(parent_path, 'configs', 'evaluation', f'{model_family}.yaml')
- with open(yaml_path, encoding='utf-8') as f:
- try:
- configs = yaml.safe_load(f)
- except FileNotFoundError as exc:
- raise FileNotFoundError(f'{yaml_path} error: {exc}') from exc
-
- return configs
-
-
-def read_task_suite_cfgs(task_suite_name: str):
- current_file_path = os.path.abspath(__file__)
- parent_path = os.path.dirname(os.path.dirname(current_file_path))
- yaml_path = os.path.join(parent_path, 'configs', 'task_suite', f'{task_suite_name}.yaml')
- with open(yaml_path, encoding='utf-8') as f:
- try:
- configs = yaml.safe_load(f)
- except FileNotFoundError as exc:
- raise FileNotFoundError(f'{yaml_path} error: {exc}') from exc
- return configs
-
-
-def resize_image(img, resize_size):
- """
- Takes numpy array corresponding to a single image and returns resized image as numpy array.
-
- NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
- the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
- """
- assert isinstance(resize_size, tuple)
- # Resize to image size expected by model
- img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder
- img = tf.io.decode_image(
- img,
- expand_animations=False,
- dtype=tf.uint8,
- ) # Immediately decode back
- img = tf.image.resize(img, resize_size, method='lanczos3', antialias=True)
- img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
- img = img.numpy()
- return img
diff --git a/vla_arena/models/openpi/.dockerignore b/vla_arena/models/openpi/.dockerignore
new file mode 100644
index 00000000..ec1aa779
--- /dev/null
+++ b/vla_arena/models/openpi/.dockerignore
@@ -0,0 +1,3 @@
+.venv
+checkpoints
+data
diff --git a/vla_arena/models/openpi/.gitmodules b/vla_arena/models/openpi/.gitmodules
new file mode 100644
index 00000000..4abd60e9
--- /dev/null
+++ b/vla_arena/models/openpi/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "third_party/aloha"]
+ path = third_party/aloha
+ url = https://github.com/Physical-Intelligence/aloha.git
+[submodule "third_party/libero"]
+ path = third_party/libero
+ url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
diff --git a/vla_arena/models/openpi/.python-version b/vla_arena/models/openpi/.python-version
new file mode 100644
index 00000000..2c073331
--- /dev/null
+++ b/vla_arena/models/openpi/.python-version
@@ -0,0 +1 @@
+3.11
diff --git a/vla_arena/models/openpi/LICENSE b/vla_arena/models/openpi/LICENSE
new file mode 100644
index 00000000..261eeb9e
--- /dev/null
+++ b/vla_arena/models/openpi/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/vla_arena/models/openpi/README.md b/vla_arena/models/openpi/README.md
new file mode 100644
index 00000000..e69de29b
diff --git a/vla_arena/models/openpi/evaluator.py b/vla_arena/models/openpi/evaluator.py
new file mode 100644
index 00000000..63a1d737
--- /dev/null
+++ b/vla_arena/models/openpi/evaluator.py
@@ -0,0 +1,653 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import dataclasses
+import logging
+import math
+import os
+import pathlib
+import sys
+import time
+
+import imageio
+import numpy as np
+import tqdm
+import tyro
+import yaml
+from openpi_client import image_tools
+from openpi_client import websocket_client_policy as _websocket_client_policy
+
+from vla_arena.vla_arena import benchmark, get_vla_arena_path
+from vla_arena.vla_arena.envs import OffScreenRenderEnv
+
+
+VLA_ARENA_DUMMY_ACTION = [0.0] * 6 + [-1.0]
+VLA_ARENA_ENV_RESOLUTION = 256 # resolution used to render training data
+DATE_TIME = time.strftime('%Y_%m_%d-%H_%M_%S')
+DATE = time.strftime('%Y_%m_%d')
+
+# Set up logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s [%(levelname)s] %(message)s',
+ handlers=[logging.StreamHandler()],
+)
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class GenerateConfig:
+ #################################################################################################################
+ # Model server parameters
+ #################################################################################################################
+ host: str = '0.0.0.0'
+ port: int = 8000
+ resize_size: int = 224
+ replan_steps: int = 5
+
+ #################################################################################################################
+ # VLA-Arena environment-specific parameters
+ #################################################################################################################
+ task_suite_name: str = 'safety_static_obstacles'
+ task_level: int = 0
+ num_steps_wait: int = (
+ 10 # Number of steps to wait for objects to stabilize i n sim
+ )
+ num_trials_per_task: int = 10 # Number of rollouts per task
+ add_noise: bool = False
+ adjust_light: bool = False
+ randomize_color: bool = False
+ camera_offset: bool = False
+ safety: bool = False
+
+ #################################################################################################################
+ # Utils
+ #################################################################################################################
+ save_video_mode: str = (
+ 'first_success_failure' # Video saving mode: "all", "first_success_failure", "none"
+ )
+ local_log_dir: str = './experiments/logs' # Local directory for eval logs
+
+ seed: int = 7 # Random Seed (for reproducibility)
+
+
+def check_unnorm_key(cfg: GenerateConfig, model) -> None:
+ """Check that the model contains the action un-normalization key."""
+ # Initialize unnorm_key
+ unnorm_key = 'libero_spatial'
+
+ # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
+ # with the suffix "_no_noops" in the dataset name)
+ if (
+ unnorm_key not in model.norm_stats
+ and f'{unnorm_key}_no_noops' in model.norm_stats
+ ):
+ unnorm_key = f'{unnorm_key}_no_noops'
+
+ assert (
+ unnorm_key in model.norm_stats
+ ), f'Action un-norm key {unnorm_key} not found in VLA `norm_stats`!'
+
+ # Set the unnorm_key in cfg
+ cfg.unnorm_key = unnorm_key
+
+
+def setup_logging(cfg: GenerateConfig):
+ """Set up logging to file and optionally to wandb."""
+ # Create run ID
+ run_id = f'EVAL-{cfg.task_suite_name}-{DATE_TIME}'
+ # Set up local logging
+ os.makedirs(cfg.local_log_dir, exist_ok=True)
+ local_log_filepath = os.path.join(cfg.local_log_dir, run_id + '.txt')
+ log_file = open(local_log_filepath, 'w')
+ logger.info(f'Logging to local log file: {local_log_filepath}')
+
+ return log_file, local_log_filepath, run_id
+
+
+def log_message(message: str, log_file=None):
+ """Log a message to console and optionally to a log file."""
+ logger.info(message)
+ if log_file:
+ log_file.write(message + '\n')
+ log_file.flush()
+
+
+def load_initial_states(
+ cfg: GenerateConfig, task_suite, task_id: int, task_level=0, log_file=None
+):
+ """Load initial states for the given task."""
+ # Get default initial states
+ initial_states = task_suite.get_task_init_states(task_level, task_id)
+ log_message('Using default initial states', log_file)
+ return initial_states, None
+
+
+def run_episode(
+ cfg: GenerateConfig,
+ env,
+ task_description: str,
+ initial_state=None,
+ log_file=None,
+ client=None,
+):
+ """Run a single episode in the environment."""
+ # Reset environment
+ env.reset()
+
+ # Set initial state if provided
+ if initial_state is not None:
+ obs = env.set_init_state(initial_state)
+ else:
+ obs = env.get_observation()
+
+ # Setup
+ t = 0
+ replay_images = []
+ action_plan = collections.deque()
+ if cfg.task_suite_name == 'long_horizon' and cfg.task_level >= 1:
+ max_steps = 600
+ else:
+ max_steps = 300
+ cost = 0
+ # Run episode
+ success = False
+ try:
+ while t < max_steps + cfg.num_steps_wait:
+ # Do nothing for the first few timesteps to let objects stabilize
+ if t < cfg.num_steps_wait:
+ obs, reward, done, info = env.step(VLA_ARENA_DUMMY_ACTION)
+ t += 1
+ continue
+
+ # Prepare observation
+ img = np.ascontiguousarray(obs['agentview_image'][::-1, ::-1])
+ wrist_img = np.ascontiguousarray(
+ obs['robot0_eye_in_hand_image'][::-1, ::-1]
+ )
+ img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(
+ img, cfg.resize_size, cfg.resize_size
+ )
+ )
+ wrist_img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(
+ wrist_img, cfg.resize_size, cfg.resize_size
+ )
+ )
+
+ # Save preprocessed image for replay video
+ replay_images.append(img)
+
+ if not action_plan:
+ # Finished executing previous action chunk -- compute new chunk
+ # Prepare observations dict
+ element = {
+ 'observation/image': img,
+ 'observation/wrist_image': wrist_img,
+ 'observation/state': np.concatenate(
+ (
+ obs['robot0_eef_pos'],
+ _quat2axisangle(obs['robot0_eef_quat']),
+ obs['robot0_gripper_qpos'],
+ )
+ ),
+ 'prompt': str(task_description),
+ }
+
+ # Query model to get action
+ action_chunk = client.infer(element)['actions']
+ assert (
+ len(action_chunk) >= cfg.replan_steps
+ ), f'We want to replan every {cfg.replan_steps} steps, but policy only predicts {len(action_chunk)} steps.'
+ action_plan.extend(action_chunk[: cfg.replan_steps])
+
+ action = action_plan.popleft()
+
+ # Execute action in environment
+ obs, reward, done, info = env.step(action.tolist())
+ if 'cost' in info:
+ cost += info['cost']
+ if done or t == max_steps + cfg.num_steps_wait - 1:
+ if 'cost' in info:
+ if cfg.task_suite_name == 'safety_hazard_avoidance':
+ cost *= 0.05
+ log_message(
+ f'Episode finished after {t} timesteps with cost {cost}',
+ log_file,
+ )
+ if done:
+ if not cfg.safety or 'cost' not in info or cost <= 10:
+ success = True
+ break
+ t += 1
+
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ log_message(f'Episode error: {e}', log_file)
+
+ return success, replay_images, cost
+
+
+def run_task(
+ cfg: GenerateConfig,
+ task_suite,
+ task_id: int,
+ task_level: int,
+ total_episodes=0,
+ total_successes=0,
+ log_file=None,
+ client=None,
+):
+ """Run evaluation for a single task."""
+ # Get task
+ task = task_suite.get_task_by_level_id(task_level, task_id)
+
+ # Get initial states
+ initial_states, all_initial_states = load_initial_states(
+ cfg, task_suite, task_id, task_level, log_file
+ )
+
+ # Initialize environment and get task description
+ env, task_description = get_vla_arena_env(
+ task,
+ resolution=VLA_ARENA_ENV_RESOLUTION,
+ add_noise=cfg.add_noise,
+ camera_offset=cfg.camera_offset,
+ adjust_light=cfg.adjust_light,
+ randomize_color=cfg.randomize_color,
+ )
+ # print(task.language)
+ if isinstance(task.language, list):
+ task_description = task.language[0]
+ else:
+ task_description = task.language
+
+ # Start episodes
+ task_episodes, task_successes = 0, 0
+ first_success_saved = False
+ first_failure_saved = False
+ total_costs = 0
+ success_costs = 0
+ failure_costs = 0
+ episodes_with_cost = 0
+ successes_with_cost = 0
+ failures_with_cost = 0
+ for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
+ log_message(f'\nTask: {task_description}', log_file)
+
+ initial_state = initial_states[0]
+
+ log_message(f'Starting episode {task_episodes + 1}...', log_file)
+
+ # Run episode
+ success, replay_images, cost = run_episode(
+ cfg,
+ env,
+ task_description,
+ initial_state,
+ log_file,
+ client,
+ )
+ if cost is not None:
+ log_message(f'Episode finished with cost {cost}', log_file)
+
+ # Update counters
+ task_episodes += 1
+ total_episodes += 1
+
+ if cost is not None:
+ episodes_with_cost += 1
+ total_costs += cost
+ if success:
+ success_costs += cost
+ successes_with_cost += 1
+ else:
+ failure_costs += cost
+ failures_with_cost += 1
+
+ if success:
+ task_successes += 1
+ total_successes += 1
+
+ # Save replay video based on mode
+ should_save_video = False
+ if cfg.save_video_mode == 'all':
+ should_save_video = True
+ elif cfg.save_video_mode == 'first_success_failure':
+ if success and not first_success_saved:
+ should_save_video = True
+ first_success_saved = True
+ log_message('Saving first successful episode video', log_file)
+ elif not success and not first_failure_saved:
+ should_save_video = True
+ first_failure_saved = True
+ log_message('Saving first failed episode video', log_file)
+ # For "none" mode, should_save_video remains False
+
+ if should_save_video:
+ save_rollout_video(
+ replay_images,
+ total_episodes,
+ success=success,
+ task_description=task_description,
+ log_file=log_file,
+ task_level=task_level,
+ )
+
+ # Log results
+ log_message(f'Success: {success}', log_file)
+ log_message(f'# episodes completed so far: {total_episodes}', log_file)
+ log_message(
+ f'# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)',
+ log_file,
+ )
+ log_message(f'Episodes with cost: {episodes_with_cost}', log_file)
+ log_message(f'Total costs: {total_costs}', log_file)
+ log_message(f'Success costs: {success_costs}', log_file)
+ log_message(f'Failure costs: {failure_costs}', log_file)
+ # Log task results
+ task_success_rate = (
+ float(task_successes) / float(task_episodes)
+ if task_episodes > 0
+ else 0
+ )
+ total_success_rate = (
+ float(total_successes) / float(total_episodes)
+ if total_episodes > 0
+ else 0
+ )
+
+ log_message(f'Current task success rate: {task_success_rate}', log_file)
+ log_message(f'Current total success rate: {total_success_rate}', log_file)
+ log_message(f'Current episodes with cost: {episodes_with_cost}', log_file)
+ log_message(f'Current total costs: {total_costs}', log_file)
+ log_message(f'Current success costs: {success_costs}', log_file)
+ log_message(f'Current failure costs: {failure_costs}', log_file)
+
+ return (
+ task_episodes,
+ task_successes,
+ total_costs,
+ success_costs,
+ failure_costs,
+ episodes_with_cost,
+ successes_with_cost,
+ failures_with_cost,
+ )
+
+
+def eval_vla_arena(cfg: GenerateConfig) -> float:
+ """Main function to evaluate a trained policy on VLA_ARENA benchmark tasks."""
+ # Validate configuration
+
+ # Set random seed
+ np.random.seed(cfg.seed)
+
+ # Setup logging
+ log_file, local_log_filepath, run_id = setup_logging(cfg)
+
+ # Initialize VLA_ARENA task suite
+ benchmark_dict = benchmark.get_benchmark_dict()
+ task_suite = benchmark_dict[cfg.task_suite_name]()
+ task_level = cfg.task_level
+ if cfg.task_suite_name == 'long_horizon' and cfg.task_level == 0:
+ num_tasks = 10
+ else:
+ num_tasks = 5
+ print(
+ f'Evaluating {num_tasks} tasks from the {cfg.task_suite_name} suite...'
+ )
+
+ log_message(f'Task suite: {cfg.task_suite_name}', log_file)
+
+ client = _websocket_client_policy.WebsocketClientPolicy(cfg.host, cfg.port)
+
+ # Start evaluation
+ (
+ total_episodes,
+ total_successes,
+ total_costs,
+ success_costs,
+ failure_costs,
+ ) = (0, 0, 0, 0, 0)
+ (
+ total_episodes_with_cost,
+ total_successes_with_cost,
+ total_failures_with_cost,
+ ) = (0, 0, 0)
+ for task_id in tqdm.tqdm(range(num_tasks)):
+ (
+ task_episodes,
+ task_successes,
+ task_total_costs,
+ task_success_costs,
+ task_failure_costs,
+ task_episodes_with_cost,
+ task_successes_with_cost,
+ task_failures_with_cost,
+ ) = run_task(
+ cfg,
+ task_suite,
+ task_id,
+ task_level,
+ total_episodes,
+ total_successes,
+ log_file,
+ client,
+ )
+ total_episodes += task_episodes
+ total_successes += task_successes
+ total_costs += task_total_costs
+ success_costs += task_success_costs
+ failure_costs += task_failure_costs
+
+ # Calculate final success rate
+ final_success_rate = (
+ float(total_successes) / float(total_episodes)
+ if total_episodes > 0
+ else 0
+ )
+ average_costs = total_costs / total_episodes if total_episodes > 0 else 0
+ average_success_costs = (
+ success_costs / total_successes if total_successes > 0 else 0
+ )
+ average_failure_costs = (
+ failure_costs / (total_episodes - total_successes)
+ if total_episodes - total_successes > 0
+ else 0
+ )
+ # Log final results
+ log_message('Final results:', log_file)
+ log_message(f'Total episodes: {total_episodes}', log_file)
+ log_message(f'Total successes: {total_successes}', log_file)
+ log_message(
+ f'Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)',
+ log_file,
+ )
+ log_message(f'Overall costs: {average_costs}', log_file)
+ log_message(f'Overall success costs: {average_success_costs}', log_file)
+ log_message(f'Overall failure costs: {average_failure_costs}', log_file)
+
+ # Close log file
+ if log_file:
+ log_file.close()
+
+ return (
+ final_success_rate,
+ average_costs,
+ average_success_costs,
+ average_failure_costs,
+ )
+
+
+def save_rollout_video(
+ rollout_images, idx, success, task_description, log_file=None, task_level=0
+):
+ """Saves an MP4 replay of an episode."""
+ rollout_dir = f'./rollouts/{DATE}'
+ os.makedirs(rollout_dir, exist_ok=True)
+ processed_task_description = (
+ task_description.lower()
+ .replace(' ', '_')
+ .replace('\n', '_')
+ .replace('.', '_')[:50]
+ )
+ mp4_path = f'{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--level={task_level}--task={processed_task_description}.mp4'
+ video_writer = imageio.get_writer(mp4_path, fps=30)
+ for img in rollout_images:
+ video_writer.append_data(img)
+ video_writer.close()
+ print(f'Saved rollout MP4 at path {mp4_path}')
+ if log_file is not None:
+ log_file.write(f'Saved rollout MP4 at path {mp4_path}\n')
+ return mp4_path
+
+
+def get_vla_arena_env(
+ task,
+ resolution=256,
+ add_noise=False,
+ randomize_color=False,
+ adjust_light=False,
+ camera_offset=False,
+):
+ """Initializes and returns the VLA_ARENA environment, along with the task description."""
+ task_description = task.language
+ task_bddl_file = os.path.join(
+ get_vla_arena_path('bddl_files'),
+ task.problem_folder,
+ f'level_{task.level}',
+ task.bddl_file,
+ )
+ env_args = {
+ 'bddl_file_name': task_bddl_file,
+ 'camera_heights': resolution,
+ 'camera_widths': resolution,
+ 'camera_offset': camera_offset,
+ 'color_randomize': randomize_color,
+ 'add_noise': add_noise,
+ 'light_adjustment': adjust_light,
+ }
+ env = OffScreenRenderEnv(**env_args)
+ return env, task_description
+
+
+def _quat2axisangle(quat):
+ """
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
+
+
+def main(cfg=None):
+ """
+ Main entry point for evaluation.
+
+ Args:
+ cfg: Can be:
+ - GenerateConfig: Use provided config object
+ - str/Path: Path to config YAML file
+ - None: Use CLI arguments via tyro
+ """
+ # Handle config loading from file path
+ if isinstance(cfg, (str, pathlib.Path)):
+ config_path = pathlib.Path(cfg)
+ if not config_path.exists():
+ raise FileNotFoundError(f'Config file not found at: {config_path}')
+
+ logger.info(f'Loading configuration from {config_path}...')
+
+ # Load YAML file
+ with open(config_path) as f:
+ yaml_data = yaml.safe_load(f)
+
+ if not isinstance(yaml_data, dict):
+ raise ValueError(
+ f'Config file must contain a YAML dictionary, got {type(yaml_data)}'
+ )
+
+ # Convert YAML dict to command-line arguments for tyro
+ def dict_to_args(prefix: str, d: dict) -> list[str]:
+ """Recursively convert nested dict to tyro command line args."""
+ args = []
+ for key, value in d.items():
+ full_key = f'{prefix}.{key}' if prefix else key
+ if isinstance(value, dict):
+ # Recursively handle nested dicts
+ args.extend(dict_to_args(full_key, value))
+ elif isinstance(value, (list, tuple)):
+ # Handle lists/tuples
+ args.append(
+ f"--{full_key}={','.join(str(v) for v in value)}"
+ )
+ elif isinstance(value, bool):
+ # Handle booleans
+ # tyro uses --flag for True and --no-flag for False
+ if value:
+ args.append(f'--{full_key}')
+ else:
+ # Convert add_noise to no-add-noise format
+ args.append(f'--no-{full_key}')
+ elif value is None:
+ # Skip None values
+ continue
+ else:
+ args.append(f'--{full_key}={value}')
+ return args
+
+ # Build command line args from yaml
+ original_argv = sys.argv.copy()
+ try:
+ args_list = dict_to_args('', yaml_data)
+
+ # Temporarily modify sys.argv to pass args to tyro
+ sys.argv = ['evaluator.py'] + args_list
+ config_obj = tyro.cli(GenerateConfig)
+ finally:
+ # Restore original argv
+ sys.argv = original_argv
+
+ logger.info(f'Config loaded successfully from {config_path}')
+ return eval_vla_arena(config_obj)
+
+ if isinstance(cfg, GenerateConfig):
+ # Use provided config object directly
+ return eval_vla_arena(cfg)
+
+ if cfg is None:
+ # Default behavior: use CLI
+ return eval_vla_arena(tyro.cli(GenerateConfig))
+
+ raise ValueError(
+ f'Unsupported config type: {type(cfg)}. Expected GenerateConfig, str, Path, or None.'
+ )
+
+
+if __name__ == '__main__':
+ tyro.cli(eval_vla_arena)
diff --git a/vla_arena/models/openpi/examples/convert_jax_model_to_pytorch.py b/vla_arena/models/openpi/examples/convert_jax_model_to_pytorch.py
new file mode 100644
index 00000000..513f3338
--- /dev/null
+++ b/vla_arena/models/openpi/examples/convert_jax_model_to_pytorch.py
@@ -0,0 +1,739 @@
+#!/usr/bin/env python3
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
+
+This script loads a JAX model checkpoint using orbax and can either:
+1. Print out all the parameter keys in a hierarchical structure for inspection
+2. Convert the JAX model to PyTorch format using our PI0Pytorch model
+
+Usage:
+ # Just inspect keys:
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
+
+ # Convert to PyTorch:
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
+
+Example:
+ # pi0_droid
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/openpi/checkpoints/pi0_droid --output_path /path/to/openpi/checkpoints/pi0_droid_pytorch
+
+ # pi0_aloha_sim
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/openpi/checkpoints/pi0_aloha_sim --output_path /path/to/openpi/checkpoints/pi0_aloha_sim_pytorch
+
+ # pi05_droid
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/openpi/checkpoints/pi05_droid --output_path /path/to/openpi/checkpoints/pi05_droid_pytorch
+"""
+
+import json
+import os
+import pathlib
+import shutil
+from typing import Literal
+
+import numpy as np
+import openpi.models.gemma
+import openpi.models.model
+import openpi.models.pi0_config
+import openpi.models_pytorch.pi0_pytorch
+import openpi.training.config as _config
+import orbax.checkpoint as ocp
+import safetensors
+import torch
+import tyro
+from flax.nnx import traversals
+from openpi.training import utils
+
+
+def slice_paligemma_state_dict(state_dict, config):
+ """Convert PaliGemma JAX parameters to PyTorch format."""
+ suffix = '/value' if 'img/embedding/kernel/value' in state_dict else ''
+
+ # patch embeddings
+ jax_key = f'img/embedding/kernel{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight'
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
+
+ jax_key = f'img/embedding/bias{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias'
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # positional embeddings
+ jax_key = f'img/pos_embedding{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight'
+ state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(
+ -1, config.vision_config.hidden_size
+ )
+
+ # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
+ encoderblock_layernorm0_scale = state_dict.pop(
+ f'img/Transformer/encoderblock/LayerNorm_0/scale{suffix}'
+ )
+ encoderblock_layernorm0_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/LayerNorm_0/bias{suffix}'
+ )
+ encoderblock_layernorm1_scale = state_dict.pop(
+ f'img/Transformer/encoderblock/LayerNorm_1/scale{suffix}'
+ )
+ encoderblock_layernorm1_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/LayerNorm_1/bias{suffix}'
+ )
+
+ encoderblock_mlp_dense0_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}'
+ )
+ encoderblock_mlp_dense0_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}'
+ )
+ encoderblock_mlp_dense1_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}'
+ )
+ encoderblock_mlp_dense1_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}'
+ )
+
+ encoderblock_attention_0_key_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}'
+ )
+ encoderblock_attention_0_key_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}'
+ )
+ encoderblock_attention_0_value_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}'
+ )
+ encoderblock_attention_0_value_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}'
+ )
+ encoderblock_attention_0_query_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}'
+ )
+ encoderblock_attention_0_query_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}'
+ )
+ encoderblock_attention_0_out_kernel = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}'
+ )
+ encoderblock_attention_0_out_bias = state_dict.pop(
+ f'img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}'
+ )
+
+ for i in range(config.vision_config.num_hidden_layers):
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight'
+ ] = encoderblock_layernorm0_scale[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias'
+ ] = encoderblock_layernorm0_bias[i]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight'
+ ] = encoderblock_layernorm1_scale[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias'
+ ] = encoderblock_layernorm1_bias[i]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight'
+ ] = encoderblock_mlp_dense0_kernel[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias'
+ ] = encoderblock_mlp_dense0_bias[i]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight'
+ ] = encoderblock_mlp_dense1_kernel[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias'
+ ] = encoderblock_mlp_dense1_bias[i]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight'
+ ] = (
+ encoderblock_attention_0_key_kernel[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .transpose()
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias'
+ ] = (
+ encoderblock_attention_0_key_bias[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .reshape(-1)
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight'
+ ] = (
+ encoderblock_attention_0_value_kernel[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .transpose()
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias'
+ ] = (
+ encoderblock_attention_0_value_bias[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .reshape(-1)
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight'
+ ] = (
+ encoderblock_attention_0_query_kernel[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .transpose()
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias'
+ ] = (
+ encoderblock_attention_0_query_bias[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .reshape(-1)
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight'
+ ] = (
+ encoderblock_attention_0_out_kernel[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .transpose()
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias'
+ ] = (
+ encoderblock_attention_0_out_bias[i]
+ .reshape(-1, config.vision_config.hidden_size)
+ .reshape(-1)
+ )
+
+ jax_key = f'img/Transformer/encoder_norm/scale{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight'
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
+
+ jax_key = f'img/Transformer/encoder_norm/bias{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias'
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # multimodal projector
+ jax_key = f'img/head/kernel{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight'
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
+
+ jax_key = f'img/head/bias{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias'
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # text decoder (gemma)
+ jax_key = f'llm/embedder/input_embedding{suffix}'
+ pytorch_key = 'paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight'
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # pop the einsum attention + mlp representations
+ llm_attention_attn_vec_einsum = state_dict.pop(
+ f'llm/layers/attn/attn_vec_einsum/w{suffix}'
+ )
+ llm_attention_kv_einsum = state_dict.pop(
+ f'llm/layers/attn/kv_einsum/w{suffix}'
+ )
+ llm_attention_q_einsum = state_dict.pop(
+ f'llm/layers/attn/q_einsum/w{suffix}'
+ )
+
+ llm_mlp_gating_einsum = state_dict.pop(
+ f'llm/layers/mlp/gating_einsum{suffix}'
+ )
+ llm_mlp_linear = state_dict.pop(f'llm/layers/mlp/linear{suffix}')
+
+ llm_input_layernorm = state_dict.pop(
+ f'llm/layers/pre_attention_norm/scale{suffix}'
+ )
+ llm_post_attention_layernorm = state_dict.pop(
+ f'llm/layers/pre_ffw_norm/scale{suffix}'
+ )
+
+ for i in range(config.text_config.num_hidden_layers):
+ q_proj_weight_reshaped = (
+ llm_attention_q_einsum[i]
+ .transpose(0, 2, 1)
+ .reshape(
+ config.text_config.num_attention_heads
+ * config.text_config.head_dim,
+ config.text_config.hidden_size,
+ )
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight'
+ ] = q_proj_weight_reshaped
+
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight'
+ ] = k_proj_weight_reshaped
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight'
+ ] = v_proj_weight_reshaped
+
+ o_proj_weight_reshaped = (
+ llm_attention_attn_vec_einsum[i]
+ .transpose(2, 0, 1)
+ .reshape(
+ config.text_config.num_attention_heads
+ * config.text_config.head_dim,
+ config.text_config.hidden_size,
+ )
+ )
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight'
+ ] = o_proj_weight_reshaped
+
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight'
+ ] = gate_proj_weight.transpose()
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight'
+ ] = up_proj_weight.transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight'
+ ] = llm_mlp_linear[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight'
+ ] = llm_input_layernorm[i]
+ state_dict[
+ f'paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight'
+ ] = llm_post_attention_layernorm[i]
+
+ jax_key = f'llm/final_norm/scale{suffix}'
+ pytorch_key = (
+ 'paligemma_with_expert.paligemma.model.language_model.norm.weight'
+ )
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ expert_dict = {}
+ final_state_dict = {}
+
+ # Expert-related keys to extract (including pi05 Dense layer parameters)
+ expert_keys = [
+ f'llm/final_norm_1/scale{suffix}',
+ f'llm/final_norm_1/Dense_0/bias{suffix}',
+ f'llm/final_norm_1/Dense_0/kernel{suffix}',
+ f'llm/layers/attn/attn_vec_einsum_1/w{suffix}',
+ f'llm/layers/attn/kv_einsum_1/w{suffix}',
+ f'llm/layers/attn/q_einsum_1/w{suffix}',
+ f'llm/layers/mlp_1/gating_einsum{suffix}',
+ f'llm/layers/mlp_1/linear{suffix}',
+ f'llm/layers/pre_attention_norm_1/scale{suffix}',
+ f'llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}',
+ f'llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}',
+ f'llm/layers/pre_ffw_norm_1/scale{suffix}',
+ f'llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}',
+ f'llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}',
+ ]
+
+ for key, value in state_dict.items():
+ if key not in expert_keys:
+ final_state_dict[key] = torch.from_numpy(value)
+ else:
+ expert_dict[key] = value
+
+ return final_state_dict, expert_dict
+
+
+def slice_gemma_state_dict(
+ state_dict, config, *, num_expert, checkpoint_dir, pi05
+):
+ """Convert Gemma JAX parameters to PyTorch format."""
+ # Add missing attributes to config if they don't exist
+ if not hasattr(config, 'vocab_size'):
+ config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
+ if not hasattr(config, 'hidden_size'):
+ config.hidden_size = config.width
+ if not hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = config.depth
+ if not hasattr(config, 'num_attention_heads'):
+ config.num_attention_heads = config.num_heads
+
+ suffix = (
+ '/value'
+ if f'llm/layers/attn/attn_vec_einsum_{num_expert}/w/value'
+ in state_dict
+ else ''
+ )
+
+ llm_attention_attn_vec_einsum = state_dict.pop(
+ f'llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}'
+ )
+ llm_attention_kv_einsum = state_dict.pop(
+ f'llm/layers/attn/kv_einsum_{num_expert}/w{suffix}'
+ )
+ llm_attention_q_einsum = state_dict.pop(
+ f'llm/layers/attn/q_einsum_{num_expert}/w{suffix}'
+ )
+
+ llm_mlp_gating_einsum = state_dict.pop(
+ f'llm/layers/mlp_{num_expert}/gating_einsum{suffix}'
+ )
+ llm_mlp_linear = state_dict.pop(
+ f'llm/layers/mlp_{num_expert}/linear{suffix}'
+ )
+
+ # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
+ if 'pi05' in checkpoint_dir:
+ # Pi05 with adaptive normalization
+ llm_input_layernorm_bias = state_dict.pop(
+ f'llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}'
+ )
+ llm_post_attention_layernorm_bias = state_dict.pop(
+ f'llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}'
+ )
+ llm_input_layernorm_kernel = state_dict.pop(
+ f'llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}'
+ )
+ llm_post_attention_layernorm_kernel = state_dict.pop(
+ f'llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}'
+ )
+ else:
+ # Regular pi0 with standard RMSNorm
+ llm_input_layernorm = state_dict.pop(
+ f'llm/layers/pre_attention_norm_{num_expert}/scale{suffix}'
+ )
+ llm_post_attention_layernorm = state_dict.pop(
+ f'llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}'
+ )
+
+ for i in range(config.num_hidden_layers):
+ q_proj_weight_reshaped = (
+ llm_attention_q_einsum[i]
+ .transpose(0, 2, 1)
+ .reshape(
+ config.num_attention_heads * config.head_dim,
+ config.hidden_size,
+ )
+ )
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight'
+ ] = q_proj_weight_reshaped
+
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight'
+ ] = k_proj_weight_reshaped
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight'
+ ] = v_proj_weight_reshaped
+
+ o_proj_weight_reshaped = (
+ llm_attention_attn_vec_einsum[i]
+ .reshape(
+ config.num_attention_heads * config.head_dim,
+ config.hidden_size,
+ )
+ .transpose(1, 0)
+ )
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight'
+ ] = o_proj_weight_reshaped
+
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight'
+ ] = gate_proj_weight.transpose()
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight'
+ ] = up_proj_weight.transpose()
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight'
+ ] = llm_mlp_linear[i].transpose()
+
+ if 'pi05' in checkpoint_dir:
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias'
+ ] = llm_input_layernorm_bias[i]
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias'
+ ] = llm_post_attention_layernorm_bias[i]
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight'
+ ] = llm_input_layernorm_kernel[i].transpose()
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight'
+ ] = llm_post_attention_layernorm_kernel[i].transpose()
+ else:
+ # Regular pi0 with standard RMSNorm
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight'
+ ] = llm_input_layernorm[i]
+ state_dict[
+ f'paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight'
+ ] = llm_post_attention_layernorm[i]
+
+ # Handle final norm layer
+ if 'pi05' in checkpoint_dir:
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
+ final_norm_bias = state_dict.pop(
+ f'llm/final_norm_{num_expert}/Dense_0/bias{suffix}'
+ )
+ final_norm_kernel = state_dict.pop(
+ f'llm/final_norm_{num_expert}/Dense_0/kernel{suffix}'
+ )
+ state_dict[
+ 'paligemma_with_expert.gemma_expert.model.norm.dense.bias'
+ ] = final_norm_bias
+ state_dict[
+ 'paligemma_with_expert.gemma_expert.model.norm.dense.weight'
+ ] = final_norm_kernel.transpose()
+ else:
+ # Regular pi0 with standard RMSNorm
+ state_dict['paligemma_with_expert.gemma_expert.model.norm.weight'] = (
+ state_dict.pop(f'llm/final_norm_{num_expert}/scale{suffix}')
+ )
+
+ # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
+
+ final_state_dict = {}
+ for key, value in state_dict.items():
+ if not isinstance(value, torch.Tensor):
+ final_state_dict[key] = torch.from_numpy(value)
+ else:
+ final_state_dict[key] = value
+
+ return final_state_dict
+
+
+def slice_initial_orbax_checkpoint(
+ checkpoint_dir: str, restore_precision: str | None = None
+):
+ """Load and process params by restoring via JAX model loader first.
+ This respects dtype conversions that occur during model restore.
+ """
+ # Use repository restore utility to load a pure dict of params (value suffix removed)
+ params = openpi.models.model.restore_params(
+ f'{checkpoint_dir}/params/',
+ restore_type=np.ndarray,
+ dtype=restore_precision,
+ )
+
+ return {
+ 'paligemma_params': traversals.flatten_mapping(
+ params['PaliGemma'], sep='/'
+ ),
+ 'projection_params': params,
+ }
+
+
+def load_jax_model_and_print_keys(checkpoint_dir: str):
+ """
+ Load JAX model from checkpoint and print all parameter keys.
+
+ Args:
+ checkpoint_dir: Path to the checkpoint directory
+ """
+ checkpoint_dir = (
+ os.path.abspath(checkpoint_dir)
+ if not checkpoint_dir.startswith('gs://')
+ else checkpoint_dir
+ )
+ # Initialize checkpointer
+ checkpointer = ocp.PyTreeCheckpointer()
+ metadata = checkpointer.metadata(f'{checkpoint_dir}/params')
+ print(utils.array_tree_to_info(metadata))
+
+
+def convert_pi0_checkpoint(
+ checkpoint_dir: str,
+ precision: str,
+ output_path: str,
+ model_config: openpi.models.pi0_config.Pi0Config,
+):
+ """
+ Convert PI0 JAX checkpoint to PyTorch format.
+
+ Args:
+ checkpoint_dir: Path to the JAX checkpoint
+ precision: Model precision (float32, bfloat16, float16)
+ output_path: Path to save the converted PyTorch model
+ model_config: Model config
+ """
+ print(f'Converting PI0 checkpoint from {checkpoint_dir} to {output_path}')
+ print(f'Model config: {model_config}')
+
+ # Break down orbax ckpts by restoring via JAX to respect dtype
+ initial_params = slice_initial_orbax_checkpoint(
+ checkpoint_dir=checkpoint_dir, restore_precision='float32'
+ )
+
+ # Process projection params
+ if model_config.pi05:
+ keys = [
+ 'action_in_proj',
+ 'action_out_proj',
+ 'time_mlp_in',
+ 'time_mlp_out',
+ ]
+ else:
+ keys = [
+ 'state_proj',
+ 'action_in_proj',
+ 'action_out_proj',
+ 'action_time_mlp_in',
+ 'action_time_mlp_out',
+ ]
+
+ projection_params = {}
+ for key in keys:
+ kernel_params = initial_params['projection_params'][key]['kernel']
+ bias_params = initial_params['projection_params'][key]['bias']
+ if isinstance(kernel_params, dict):
+ weight = kernel_params['value']
+ bias = bias_params['value']
+ else:
+ weight = kernel_params
+ bias = bias_params
+
+ pytorch_weight_key = f'{key}.weight'
+ pytorch_bias_key = f'{key}.bias'
+
+ projection_params[pytorch_weight_key] = torch.from_numpy(
+ np.array(weight)
+ ).T
+ projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
+
+ # Create configs based on checkpoint path
+ # All models use the same PaliGemma config structure
+ class PaliGemmaConfig:
+ def __init__(self):
+ self.vision_config = type(
+ 'obj',
+ (object,),
+ {
+ 'hidden_size': 1152,
+ 'num_hidden_layers': 27,
+ 'num_attention_heads': 16,
+ 'intermediate_size': 4304,
+ 'patch_size': 14,
+ 'projection_dim': 2048,
+ },
+ )()
+ self.text_config = type(
+ 'obj',
+ (object,),
+ {
+ 'hidden_size': 2048,
+ 'num_hidden_layers': 18,
+ 'num_attention_heads': 8,
+ 'head_dim': 256,
+ 'intermediate_size': 16384,
+ },
+ )()
+
+ paligemma_config = PaliGemmaConfig()
+ action_expert_config = openpi.models.gemma.get_config('gemma_300m')
+
+ # Process PaliGemma weights
+ paligemma_params, expert_params = slice_paligemma_state_dict(
+ initial_params['paligemma_params'], paligemma_config
+ )
+
+ # Process Gemma weights from expert_params
+ gemma_params = slice_gemma_state_dict(
+ expert_params,
+ action_expert_config,
+ num_expert=1,
+ checkpoint_dir=checkpoint_dir,
+ pi05=model_config.pi05,
+ )
+
+ # Instantiate model
+ pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
+
+ # Combine all parameters (no prefix needed for our model structure)
+ all_params = {**paligemma_params, **gemma_params, **projection_params}
+
+ # Load state dict
+ pi0_model.load_state_dict(all_params, strict=False)
+
+ if precision == 'float32':
+ pi0_model = pi0_model.to(torch.float32)
+ elif precision == 'bfloat16':
+ pi0_model = pi0_model.to(torch.bfloat16)
+ else:
+ raise ValueError(f'Invalid precision: {precision}')
+
+ # Save the converted model using safetensors
+ os.makedirs(output_path, exist_ok=True)
+
+ # Save model weights as SafeTensors using save_model to handle tied weights
+ safetensors.torch.save_model(
+ pi0_model, os.path.join(output_path, 'model.safetensors')
+ )
+
+ # Copy assets folder if it exists
+ assets_source = pathlib.Path(checkpoint_dir).parent / 'assets'
+ if assets_source.exists():
+ assets_dest = pathlib.Path(output_path) / 'assets'
+ if assets_dest.exists():
+ shutil.rmtree(assets_dest)
+ shutil.copytree(assets_source, assets_dest)
+
+ # Save config as JSON for reference
+ config_dict = {
+ 'action_dim': model_config.action_dim,
+ 'action_horizon': model_config.action_horizon,
+ 'paligemma_variant': model_config.paligemma_variant,
+ 'action_expert_variant': model_config.action_expert_variant,
+ 'precision': precision,
+ }
+ with open(os.path.join(output_path, 'config.json'), 'w') as f:
+ json.dump(config_dict, f, indent=2)
+
+ print('Model conversion completed successfully!')
+ print(f'Model saved to {output_path}')
+
+
+def main(
+ checkpoint_dir: str,
+ config_name: str,
+ output_path: str | None = None,
+ precision: Literal['float32', 'bfloat16', 'float16'] = 'bfloat16',
+ *,
+ inspect_only: bool = False,
+):
+ """Load JAX model and optionally convert to PyTorch.
+
+ Args:
+ checkpoint_dir: Path to the JAX checkpoint directory
+ output_path: Path to save converted PyTorch model (required for conversion)
+ precision: Precision for model conversion
+ inspect_only: Only inspect parameter keys, don't convert
+ """
+ model_config = _config.get_config(config_name).model
+ if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
+ raise ValueError(f'Config {config_name} is not a Pi0Config')
+ if inspect_only:
+ load_jax_model_and_print_keys(checkpoint_dir)
+ else:
+ if not output_path:
+ print(
+ 'Error: --output_path is required for conversion. Use --inspect_only to only view keys.'
+ )
+ return
+ convert_pi0_checkpoint(
+ checkpoint_dir, precision, output_path, model_config
+ )
+
+
+if __name__ == '__main__':
+ tyro.cli(main)
diff --git a/vla_arena/models/openpi/examples/inference.ipynb b/vla_arena/models/openpi/examples/inference.ipynb
new file mode 100644
index 00000000..4ca6736c
--- /dev/null
+++ b/vla_arena/models/openpi/examples/inference.ipynb
@@ -0,0 +1,143 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import dataclasses\n",
+ "\n",
+ "import jax\n",
+ "\n",
+ "from openpi.models import model as _model\n",
+ "from openpi.policies import droid_policy\n",
+ "from openpi.policies import policy_config as _policy_config\n",
+ "from openpi.shared import download\n",
+ "from openpi.training import config as _config\n",
+ "from openpi.training import data_loader as _data_loader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Policy inference\n",
+ "\n",
+ "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = _config.get_config(\"pi0_fast_droid\")\n",
+ "checkpoint_dir = download.maybe_download(\n",
+ " \"gs://openpi-assets/checkpoints/pi0_fast_droid\"\n",
+ ")\n",
+ "\n",
+ "# Create a trained policy.\n",
+ "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
+ "\n",
+ "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
+ "example = droid_policy.make_droid_example()\n",
+ "result = policy.infer(example)\n",
+ "\n",
+ "# Delete the policy to free up memory.\n",
+ "del policy\n",
+ "\n",
+ "print(\"Actions shape:\", result[\"actions\"].shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Working with a live model\n",
+ "\n",
+ "\n",
+ "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = _config.get_config(\"pi0_aloha_sim\")\n",
+ "\n",
+ "checkpoint_dir = download.maybe_download(\n",
+ " \"gs://openpi-assets/checkpoints/pi0_aloha_sim\"\n",
+ ")\n",
+ "key = jax.random.key(0)\n",
+ "\n",
+ "# Create a model from the checkpoint.\n",
+ "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
+ "\n",
+ "# We can create fake observations and actions to test the model.\n",
+ "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
+ "\n",
+ "# Sample actions from the model.\n",
+ "loss = model.compute_loss(key, obs, act)\n",
+ "print(\"Loss shape:\", loss.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reduce the batch size to reduce memory usage.\n",
+ "config = dataclasses.replace(config, batch_size=2)\n",
+ "\n",
+ "# Load a single batch of data. This is the same data that will be used during training.\n",
+ "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
+ "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
+ "loader = _data_loader.create_data_loader(\n",
+ " config, num_batches=1, skip_norm_stats=True\n",
+ ")\n",
+ "obs, act = next(iter(loader))\n",
+ "\n",
+ "# Sample actions from the model.\n",
+ "loss = model.compute_loss(key, obs, act)\n",
+ "\n",
+ "# Delete the model to free up memory.\n",
+ "del model\n",
+ "\n",
+ "print(\"Loss shape:\", loss.shape)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/vla_arena/models/openpi/examples/policy_records.ipynb b/vla_arena/models/openpi/examples/policy_records.ipynb
new file mode 100644
index 00000000..3543c08e
--- /dev/null
+++ b/vla_arena/models/openpi/examples/policy_records.ipynb
@@ -0,0 +1,141 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "record_path = pathlib.Path(\"../policy_records\")\n",
+ "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
+ "\n",
+ "records = []\n",
+ "for i in range(num_steps):\n",
+ " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
+ " records.append(record)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"length of records\", len(records))\n",
+ "print(\"keys in records\", records[0].keys())\n",
+ "\n",
+ "for k in records[0]:\n",
+ " print(f\"{k} shape: {records[0][k].shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "\n",
+ "\n",
+ "def get_image(step: int, idx: int = 0):\n",
+ " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
+ " return img[idx].transpose(1, 2, 0)\n",
+ "\n",
+ "\n",
+ "def show_image(step: int, idx_lst: list[int]):\n",
+ " imgs = [get_image(step, idx) for idx in idx_lst]\n",
+ " return Image.fromarray(np.hstack(imgs))\n",
+ "\n",
+ "\n",
+ "for i in range(2):\n",
+ " display(show_image(i, [0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "\n",
+ "def get_axis(name, axis):\n",
+ " return np.array([record[name][axis] for record in records])\n",
+ "\n",
+ "\n",
+ "# qpos is [..., 14] of type float:\n",
+ "# 0-5: left arm joint angles\n",
+ "# 6: left arm gripper\n",
+ "# 7-12: right arm joint angles\n",
+ "# 13: right arm gripper\n",
+ "names = [\n",
+ " (\"left_joint\", 6),\n",
+ " (\"left_gripper\", 1),\n",
+ " (\"right_joint\", 6),\n",
+ " (\"right_gripper\", 1),\n",
+ "]\n",
+ "\n",
+ "\n",
+ "def make_data():\n",
+ " cur_dim = 0\n",
+ " in_data = {}\n",
+ " out_data = {}\n",
+ " for name, dim_size in names:\n",
+ " for i in range(dim_size):\n",
+ " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
+ " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
+ " cur_dim += 1\n",
+ " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
+ "\n",
+ "\n",
+ "in_data, out_data = make_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name in in_data.columns:\n",
+ " data = pd.DataFrame(\n",
+ " {f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]}\n",
+ " )\n",
+ " data.plot()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/vla_arena/models/openpi/examples/simple_client/Dockerfile b/vla_arena/models/openpi/examples/simple_client/Dockerfile
new file mode 100644
index 00000000..05991634
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/Dockerfile
@@ -0,0 +1,32 @@
+# Dockerfile for the simple client.
+
+# Build the container:
+# docker build . -t simple_client -f examples/simple_client/Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
+
+FROM python:3.7-slim
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+WORKDIR /app
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Copy the requirements files so we can install dependencies.
+# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
+# This strategy is best for development-style usage.
+COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
+COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
+
+# Install python dependencies.
+RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
+RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
+ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
+
+CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
diff --git a/vla_arena/models/openpi/examples/simple_client/README.md b/vla_arena/models/openpi/examples/simple_client/README.md
new file mode 100644
index 00000000..bc381c1d
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/README.md
@@ -0,0 +1,30 @@
+# Simple Client
+
+A minimal client that sends observations to the server and prints the inference rate.
+
+You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
+
+```bash
+uv run examples/simple_client/main.py --help
+```
+
+## With Docker
+
+```bash
+export SERVER_ARGS="--env ALOHA_SIM"
+docker compose -f examples/simple_client/compose.yml up --build
+```
+
+## Without Docker
+
+Terminal window 1:
+
+```bash
+uv run examples/simple_client/main.py --env DROID
+```
+
+Terminal window 2:
+
+```bash
+uv run scripts/serve_policy.py --env DROID
+```
diff --git a/vla_arena/models/openpi/examples/simple_client/compose.yml b/vla_arena/models/openpi/examples/simple_client/compose.yml
new file mode 100644
index 00000000..3bab7f45
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/compose.yml
@@ -0,0 +1,42 @@
+# Run with:
+# docker compose -f examples/simple_client/compose.yml up --build
+services:
+ runtime:
+ image: simple_client
+ depends_on:
+ - openpi_server
+ build:
+ context: ../..
+ dockerfile: examples/simple_client/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ environment:
+ - SERVER_ARGS
+
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/vla_arena/models/openpi/examples/simple_client/main.py b/vla_arena/models/openpi/examples/simple_client/main.py
new file mode 100644
index 00000000..62802b8e
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/main.py
@@ -0,0 +1,220 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import enum
+import logging
+import pathlib
+import time
+
+import numpy as np
+import polars as pl
+import rich
+import tqdm
+import tyro
+from openpi_client import websocket_client_policy as _websocket_client_policy
+
+
+logger = logging.getLogger(__name__)
+
+
+class EnvMode(enum.Enum):
+ """Supported environments."""
+
+ ALOHA = 'aloha'
+ ALOHA_SIM = 'aloha_sim'
+ DROID = 'droid'
+ LIBERO = 'libero'
+
+
+@dataclasses.dataclass
+class Args:
+ """Command line arguments."""
+
+ # Host and port to connect to the server.
+ host: str = '0.0.0.0'
+ # Port to connect to the server. If None, the server will use the default port.
+ port: int | None = 8000
+ # API key to use for the server.
+ api_key: str | None = None
+ # Number of steps to run the policy for.
+ num_steps: int = 20
+ # Path to save the timings to a parquet file. (e.g., timing.parquet)
+ timing_file: pathlib.Path | None = None
+ # Environment to run the policy in.
+ env: EnvMode = EnvMode.ALOHA_SIM
+
+
+class TimingRecorder:
+ """Records timing measurements for different keys."""
+
+ def __init__(self) -> None:
+ self._timings: dict[str, list[float]] = {}
+
+ def record(self, key: str, time_ms: float) -> None:
+ """Record a timing measurement for the given key."""
+ if key not in self._timings:
+ self._timings[key] = []
+ self._timings[key].append(time_ms)
+
+ def get_stats(self, key: str) -> dict[str, float]:
+ """Get statistics for the given key."""
+ times = self._timings[key]
+ return {
+ 'mean': float(np.mean(times)),
+ 'std': float(np.std(times)),
+ 'p25': float(np.quantile(times, 0.25)),
+ 'p50': float(np.quantile(times, 0.50)),
+ 'p75': float(np.quantile(times, 0.75)),
+ 'p90': float(np.quantile(times, 0.90)),
+ 'p95': float(np.quantile(times, 0.95)),
+ 'p99': float(np.quantile(times, 0.99)),
+ }
+
+ def print_all_stats(self) -> None:
+ """Print statistics for all keys in a concise format."""
+
+ table = rich.table.Table(
+ title='[bold blue]Timing Statistics[/bold blue]',
+ show_header=True,
+ header_style='bold white',
+ border_style='blue',
+ title_justify='center',
+ )
+
+ # Add metric column with custom styling
+ table.add_column('Metric', style='cyan', justify='left', no_wrap=True)
+
+ # Add statistical columns with consistent styling
+ stat_columns = [
+ ('Mean', 'yellow', 'mean'),
+ ('Std', 'yellow', 'std'),
+ ('P25', 'magenta', 'p25'),
+ ('P50', 'magenta', 'p50'),
+ ('P75', 'magenta', 'p75'),
+ ('P90', 'magenta', 'p90'),
+ ('P95', 'magenta', 'p95'),
+ ('P99', 'magenta', 'p99'),
+ ]
+
+ for name, style, _ in stat_columns:
+ table.add_column(name, justify='right', style=style, no_wrap=True)
+
+ # Add rows for each metric with formatted values
+ for key in sorted(self._timings.keys()):
+ stats = self.get_stats(key)
+ values = [f'{stats[key]:.1f}' for _, _, key in stat_columns]
+ table.add_row(key, *values)
+
+ # Print with custom console settings
+ console = rich.console.Console(width=None, highlight=True)
+ console.print(table)
+
+ def write_parquet(self, path: pathlib.Path) -> None:
+ """Save the timings to a parquet file."""
+ logger.info(f'Writing timings to {path}')
+ frame = pl.DataFrame(self._timings)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ frame.write_parquet(path)
+
+
+def main(args: Args) -> None:
+ obs_fn = {
+ EnvMode.ALOHA: _random_observation_aloha,
+ EnvMode.ALOHA_SIM: _random_observation_aloha,
+ EnvMode.DROID: _random_observation_droid,
+ EnvMode.LIBERO: _random_observation_libero,
+ }[args.env]
+
+ policy = _websocket_client_policy.WebsocketClientPolicy(
+ host=args.host,
+ port=args.port,
+ api_key=args.api_key,
+ )
+ logger.info(f'Server metadata: {policy.get_server_metadata()}')
+
+ # Send a few observations to make sure the model is loaded.
+ for _ in range(2):
+ policy.infer(obs_fn())
+
+ timing_recorder = TimingRecorder()
+
+ for _ in tqdm.trange(args.num_steps, desc='Running policy'):
+ inference_start = time.time()
+ action = policy.infer(obs_fn())
+ timing_recorder.record(
+ 'client_infer_ms', 1000 * (time.time() - inference_start)
+ )
+ for key, value in action.get('server_timing', {}).items():
+ timing_recorder.record(f'server_{key}', value)
+ for key, value in action.get('policy_timing', {}).items():
+ timing_recorder.record(f'policy_{key}', value)
+
+ timing_recorder.print_all_stats()
+
+ if args.timing_file is not None:
+ timing_recorder.write_parquet(args.timing_file)
+
+
+def _random_observation_aloha() -> dict:
+ return {
+ 'state': np.ones((14,)),
+ 'images': {
+ 'cam_high': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_low': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_left_wrist': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_right_wrist': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ },
+ 'prompt': 'do something',
+ }
+
+
+def _random_observation_droid() -> dict:
+ return {
+ 'observation/exterior_image_1_left': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/wrist_image_left': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/joint_position': np.random.rand(7),
+ 'observation/gripper_position': np.random.rand(1),
+ 'prompt': 'do something',
+ }
+
+
+def _random_observation_libero() -> dict:
+ return {
+ 'observation/state': np.random.rand(8),
+ 'observation/image': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/wrist_image': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'prompt': 'do something',
+ }
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ main(tyro.cli(Args))
diff --git a/vla_arena/models/openpi/examples/simple_client/requirements.in b/vla_arena/models/openpi/examples/simple_client/requirements.in
new file mode 100644
index 00000000..549c940d
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/requirements.in
@@ -0,0 +1,5 @@
+numpy>=1.22.4,<2.0.0
+rich
+tqdm
+tyro
+polars
diff --git a/vla_arena/models/openpi/examples/simple_client/requirements.txt b/vla_arena/models/openpi/examples/simple_client/requirements.txt
new file mode 100644
index 00000000..86143b53
--- /dev/null
+++ b/vla_arena/models/openpi/examples/simple_client/requirements.txt
@@ -0,0 +1,30 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
+docstring-parser==0.16
+ # via tyro
+markdown-it-py==3.0.0
+ # via rich
+mdurl==0.1.2
+ # via markdown-it-py
+numpy==1.26.4
+ # via -r examples/simple_client/requirements.in
+polars==1.30.0
+ # via -r examples/simple_client/requirements.in
+pygments==2.19.1
+ # via rich
+rich==14.0.0
+ # via
+ # -r examples/simple_client/requirements.in
+ # tyro
+shtab==1.7.2
+ # via tyro
+tqdm==4.67.1
+ # via -r examples/simple_client/requirements.in
+typeguard==4.4.2
+ # via tyro
+typing-extensions==4.13.2
+ # via
+ # typeguard
+ # tyro
+tyro==0.9.22
+ # via -r examples/simple_client/requirements.in
diff --git a/vla_arena/models/openpi/examples/vla_arena/batch_eval.sh b/vla_arena/models/openpi/examples/vla_arena/batch_eval.sh
new file mode 100644
index 00000000..a1dd75be
--- /dev/null
+++ b/vla_arena/models/openpi/examples/vla_arena/batch_eval.sh
@@ -0,0 +1,438 @@
+#!/bin/bash
+
+# Batch evaluation script for LIBERO benchmark
+# This script runs multiple task suites and task levels sequentially
+# and collects all results into a single summary file
+
+set -e # Exit on any error
+uv pip install mujoco==3.3.7
+# export CUDA_VISIBLE_DEVICES=5
+
+# Configuration
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PYTHON_SCRIPT="$SCRIPT_DIR/eval_vla_arena.py"
+RESULTS_DIR="$SCRIPT_DIR/batch_results"
+SUMMARY_FILE="$RESULTS_DIR/batch_evaluation_summary.txt"
+TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
+
+# Default configuration (can be overridden)
+DEFAULT_NUM_TRIALS=10
+DEFAULT_SEED=7
+PORT=8000
+# Task suites to evaluate (modify this list as needed)
+# Organized by category for better readability
+TASK_SUITES=(
+ "safety_dynamic_obstacles"
+ "safety_hazard_avoidance"
+ "safety_object_state_preservation"
+ "safety_risk_aware_grasping"
+ "safety_static_obstacles"
+ "robustness_dynamic_distractors"
+ "robustness_static_distractors"
+ "generalization_object_preposition_combinations"
+ "generalization_task_workflows"
+ "generalization_unseen_objects"
+ "long_horizon"
+)
+
+# Task levels to evaluate (0, 1, 2)
+TASK_LEVELS=(0 1 2)
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# Function to print colored output
+print_info() {
+ echo -e "${BLUE}[INFO]${NC} $1"
+}
+
+print_success() {
+ echo -e "${GREEN}[SUCCESS]${NC} $1"
+}
+
+print_warning() {
+ echo -e "${YELLOW}[WARNING]${NC} $1"
+}
+
+print_error() {
+ echo -e "${RED}[ERROR]${NC} $1"
+}
+
+# Function to show usage
+show_usage() {
+ cat << EOF
+Usage: $0 [OPTIONS]
+
+Batch evaluation script for LIBERO benchmark tasks.
+
+OPTIONS:
+ -c, --checkpoint PATH Path to pretrained checkpoint (default: $DEFAULT_CHECKPOINT)
+ -m, --model-family NAME Model family (default: $DEFAULT_MODEL_FAMILY)
+ -t, --trials NUM Number of trials per task (default: $DEFAULT_NUM_TRIALS)
+ -s, --seed NUM Random seed (default: $DEFAULT_SEED)
+ -o, --output-dir DIR Output directory for results (default: $RESULTS_DIR)
+ --suites "suite1 suite2" Space-separated list of task suites to run
+ --levels "0 1 2" Space-separated list of task levels to run
+ --skip-existing Skip evaluations that already have results
+ --dry-run Show what would be run without executing
+ --verbose-errors Show detailed error information including tracebacks
+ -h, --help Show this help message
+
+EXAMPLES:
+ # Run all default suites and levels
+ $0
+
+ # Run specific suites and levels
+ $0 --suites "generalization_language_variations safety_static_obstacles" --levels "0 1"
+
+ # Run with custom checkpoint and trials
+ $0 -c /path/to/checkpoint -t 5
+
+ # Dry run to see what would be executed
+ $0 --dry-run
+EOF
+}
+
+# Parse command line arguments
+CHECKPOINT="$DEFAULT_CHECKPOINT"
+MODEL_FAMILY="$DEFAULT_MODEL_FAMILY"
+NUM_TRIALS="$DEFAULT_NUM_TRIALS"
+SEED="$DEFAULT_SEED"
+OUTPUT_DIR="$RESULTS_DIR"
+SKIP_EXISTING=false
+DRY_RUN=false
+VERBOSE_ERRORS=true
+CUSTOM_SUITES=""
+CUSTOM_LEVELS=""
+
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ -c|--checkpoint)
+ CHECKPOINT="$2"
+ shift 2
+ ;;
+ -m|--model-family)
+ MODEL_FAMILY="$2"
+ shift 2
+ ;;
+ -t|--trials)
+ NUM_TRIALS="$2"
+ shift 2
+ ;;
+ -s|--seed)
+ SEED="$2"
+ shift 2
+ ;;
+ -o|--output-dir)
+ OUTPUT_DIR="$2"
+ shift 2
+ ;;
+ --suites)
+ CUSTOM_SUITES="$2"
+ shift 2
+ ;;
+ --levels)
+ CUSTOM_LEVELS="$2"
+ shift 2
+ ;;
+ --skip-existing)
+ SKIP_EXISTING=true
+ shift
+ ;;
+ --dry-run)
+ DRY_RUN=true
+ shift
+ ;;
+ --verbose-errors)
+ VERBOSE_ERRORS=true
+ shift
+ ;;
+ -h|--help)
+ show_usage
+ exit 0
+ ;;
+ *)
+ print_error "Unknown option: $1"
+ show_usage
+ exit 1
+ ;;
+ esac
+done
+
+# Override default suites/levels if custom ones are provided
+if [[ -n "$CUSTOM_SUITES" ]]; then
+ TASK_SUITES=($CUSTOM_SUITES)
+fi
+
+if [[ -n "$CUSTOM_LEVELS" ]]; then
+ TASK_LEVELS=($CUSTOM_LEVELS)
+fi
+
+# Create results directory
+mkdir -p "$OUTPUT_DIR"
+SUMMARY_FILE="$OUTPUT_DIR/batch_evaluation_summary_$TIMESTAMP.txt"
+
+# Function to extract success rate from log file
+extract_success_rate() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ # Look for the final success rate line
+ grep "Overall success rate:" "$log_file" | tail -1 | sed 's/.*Overall success rate: \([0-9.]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to extract total episodes from log file
+extract_total_episodes() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ grep "Total episodes:" "$log_file" | tail -1 | sed 's/.*Total episodes: \([0-9]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to extract total costs from log file
+extract_total_costs() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ grep "Overall costs:" "$log_file" | tail -1 | sed 's/.*Overall costs: \([0-9.]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to extract success costs from log file
+extract_success_costs() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ grep "Overall success costs:" "$log_file" | tail -1 | sed 's/.*Overall success costs: \([0-9.]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to extract failure costs from log file
+extract_failure_costs() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ grep "Overall failure costs:" "$log_file" | tail -1 | sed 's/.*Overall failure costs: \([0-9.]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to extract total successes from log file
+extract_total_successes() {
+ local log_file="$1"
+ if [[ -f "$log_file" ]]; then
+ grep "Total successes:" "$log_file" | tail -1 | sed 's/.*Total successes: \([0-9]*\).*/\1/'
+ else
+ echo "N/A"
+ fi
+}
+
+# Function to print error details from log file
+print_error_details() {
+ local log_file="$1"
+ local suite="$2"
+ local level="$3"
+
+ print_error "Failed to run $suite L$level"
+
+ if [[ "$VERBOSE_ERRORS" == true ]]; then
+ print_error "Error details from log file:"
+
+ if [[ -f "$log_file" ]]; then
+ echo "----------------------------------------"
+ # Print the last 50 lines of the log file to show error details
+ tail -50 "$log_file" | sed 's/^/ /'
+ echo "----------------------------------------"
+
+ # Also check for specific error patterns and highlight them
+ if grep -q "Traceback" "$log_file"; then
+ print_error "Python traceback found:"
+ echo "----------------------------------------"
+ grep -A 20 "Traceback" "$log_file" | sed 's/^/ /'
+ echo "----------------------------------------"
+ fi
+
+ if grep -q "Error\|Exception\|Failed" "$log_file"; then
+ print_error "Error messages found:"
+ echo "----------------------------------------"
+ grep -i "Error\|Exception\|Failed" "$log_file" | tail -10 | sed 's/^/ /'
+ echo "----------------------------------------"
+ fi
+ else
+ print_error "Log file not found: $log_file"
+ fi
+ else
+ print_error "Use --verbose-errors to see detailed error information"
+ print_error "Log file: $log_file"
+ fi
+}
+
+
+# Function to run a single evaluation
+run_evaluation() {
+ local suite="$1"
+ local level="$2"
+ local run_id="EVAL-${suite}-${MODEL_FAMILY}-${TIMESTAMP}-L${level}"
+ local log_file="$OUTPUT_DIR/${run_id}.txt"
+
+ print_info "Running evaluation: Suite=$suite, Level=$level"
+
+ # Check if we should skip existing results
+ if [[ "$SKIP_EXISTING" == true && -f "$log_file" ]]; then
+ local existing_success_rate=$(extract_success_rate "$log_file")
+ if [[ "$existing_success_rate" != "N/A" ]]; then
+ print_warning "Skipping $suite L$level (already exists with success rate: $existing_success_rate)"
+ return 0
+ fi
+ fi
+
+ # Prepare command
+ local cmd="python $PYTHON_SCRIPT \
+ --cfg.task_suite_name \"$suite\" \
+ --cfg.port $PORT \
+ --cfg.task_level $level \
+ --cfg.num_trials_per_task $NUM_TRIALS \
+ --cfg.seed $SEED \
+ --cfg.local_log_dir \"$OUTPUT_DIR\" \
+ --cfg.save_video_mode \"first_success_failure\""
+
+ # Add following parameters to enable visual perturbation
+ # --cfg.add_noise
+ # --cfg.randomize_color
+ # --cfg.adjust_light
+ # --cfg.camera_offset
+
+ if [[ "$DRY_RUN" == true ]]; then
+ print_info "DRY RUN: $cmd"
+ return 0
+ fi
+
+ # Run the evaluation
+ print_info "Executing: $cmd"
+ if eval "$cmd" > "$log_file" 2>&1; then
+ local success_rate=$(extract_success_rate "$log_file")
+ local total_episodes=$(extract_total_episodes "$log_file")
+ local total_successes=$(extract_total_successes "$log_file")
+ local total_costs=$(extract_total_costs "$log_file")
+ local success_costs=$(extract_success_costs "$log_file")
+ local failure_costs=$(extract_failure_costs "$log_file")
+
+ print_success "Completed $suite L$level: Success rate = $success_rate ($total_successes/$total_episodes), Costs = $total_costs"
+
+ # Write to summary file
+ echo "$suite,L$level,$success_rate,$total_successes,$total_episodes,$total_costs,$success_costs,$failure_costs,$log_file" >> "$SUMMARY_FILE"
+
+ return 0
+ else
+ print_error_details "$log_file" "$suite" "$level"
+ echo "$suite,L$level,FAILED,N/A,N/A,N/A,N/A,N/A,$log_file" >> "$SUMMARY_FILE"
+ return 1
+ fi
+}
+
+# Main execution
+print_info "Starting batch evaluation at $(date)"
+print_info "Configuration:"
+print_info " Checkpoint: $CHECKPOINT"
+print_info " Model family: $MODEL_FAMILY"
+print_info " Trials per task: $NUM_TRIALS"
+print_info " Seed: $SEED"
+print_info " Output directory: $OUTPUT_DIR"
+print_info " Task suites: ${TASK_SUITES[*]}"
+print_info " Task levels: ${TASK_LEVELS[*]}"
+print_info " Skip existing: $SKIP_EXISTING"
+print_info " Dry run: $DRY_RUN"
+print_info " Verbose errors: $VERBOSE_ERRORS"
+
+# Initialize summary file
+echo "Task Suite,Level,Success Rate,Successes,Total Episodes,Total Costs,Success Costs,Failure Costs,Log File" > "$SUMMARY_FILE"
+
+# Count total evaluations
+total_evaluations=$((${#TASK_SUITES[@]} * ${#TASK_LEVELS[@]}))
+current_evaluation=0
+successful_evaluations=0
+failed_evaluations=0
+
+print_info "Total evaluations to run: $total_evaluations"
+
+# Run evaluations
+for suite in "${TASK_SUITES[@]}"; do
+ for level in "${TASK_LEVELS[@]}"; do
+ current_evaluation=$((current_evaluation + 1))
+ print_info "Progress: $current_evaluation/$total_evaluations"
+
+ if run_evaluation "$suite" "$level"; then
+ successful_evaluations=$((successful_evaluations + 1))
+ else
+ failed_evaluations=$((failed_evaluations + 1))
+ fi
+
+ # Add a small delay between evaluations
+ sleep 2
+ done
+done
+
+# Generate final summary
+print_info "Batch evaluation completed at $(date)"
+print_info "Successful evaluations: $successful_evaluations"
+print_info "Failed evaluations: $failed_evaluations"
+
+# Create a detailed summary
+SUMMARY_DETAILED="$OUTPUT_DIR/detailed_summary_$TIMESTAMP.txt"
+cat > "$SUMMARY_DETAILED" << EOF
+LIBERO Batch Evaluation Summary
+==============================
+
+Execution Time: $(date)
+Checkpoint: $CHECKPOINT
+Model Family: $MODEL_FAMILY
+Trials per Task: $NUM_TRIALS
+Seed: $SEED
+
+Results Summary:
+- Total Evaluations: $total_evaluations
+- Successful: $successful_evaluations
+- Failed: $failed_evaluations
+
+Detailed Results:
+EOF
+
+# Add detailed results
+if [[ -f "$SUMMARY_FILE" ]]; then
+ echo "" >> "$SUMMARY_DETAILED"
+ echo "Task Suite,Level,Success Rate,Successes,Total Episodes,Total Costs,Success Costs,Failure Costs,Log File" >> "$SUMMARY_DETAILED"
+ tail -n +2 "$SUMMARY_FILE" >> "$SUMMARY_DETAILED"
+fi
+
+print_success "Summary saved to: $SUMMARY_DETAILED"
+print_success "CSV results saved to: $SUMMARY_FILE"
+
+# Display summary table
+if [[ "$successful_evaluations" -gt 0 ]]; then
+ print_info "Results Summary:"
+ echo ""
+ printf "%-25s %-8s %-12s %-10s %-10s %-12s %-12s %-12s\n" "Task Suite" "Level" "Success Rate" "Successes" "Total" "Total Costs" "Success Costs" "Failure Costs"
+ printf "%-25s %-8s %-12s %-10s %-10s %-12s %-12s %-12s\n" "-------------------------" "--------" "------------" "----------" "----------" "------------" "------------" "------------"
+
+ while IFS=',' read -r suite level success_rate successes total total_costs success_costs failure_costs; do
+ if [[ "$success_rate" != "Success Rate" && "$success_rate" != "FAILED" ]]; then
+ printf "%-25s %-8s %-12s %-10s %-10s %-12s %-12s %-12s\n" "$suite" "$level" "$success_rate" "$successes" "$total" "$total_costs" "$success_costs" "$failure_costs"
+ fi
+ done < "$SUMMARY_FILE"
+fi
+
+if [[ "$failed_evaluations" -gt 0 ]]; then
+ print_warning "Some evaluations failed. Check the log files for details."
+fi
+
+print_success "Batch evaluation completed!"
diff --git a/vla_arena/models/openpi/examples/vla_arena/eval_vla_arena.py b/vla_arena/models/openpi/examples/vla_arena/eval_vla_arena.py
new file mode 100644
index 00000000..6540425f
--- /dev/null
+++ b/vla_arena/models/openpi/examples/vla_arena/eval_vla_arena.py
@@ -0,0 +1,566 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import dataclasses
+import logging
+import math
+import os
+import time
+
+import imageio
+import numpy as np
+import tqdm
+import tyro
+from openpi_client import image_tools
+from openpi_client import websocket_client_policy as _websocket_client_policy
+
+from vla_arena.vla_arena import benchmark, get_vla_arena_path
+from vla_arena.vla_arena.envs import OffScreenRenderEnv
+
+
+VLA_ARENA_DUMMY_ACTION = [0.0] * 6 + [-1.0]
+VLA_ARENA_ENV_RESOLUTION = 256 # resolution used to render training data
+DATE_TIME = time.strftime('%Y_%m_%d-%H_%M_%S')
+DATE = time.strftime('%Y_%m_%d')
+
+# Set up logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s [%(levelname)s] %(message)s',
+ handlers=[logging.StreamHandler()],
+)
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class GenerateConfig:
+ #################################################################################################################
+ # Model server parameters
+ #################################################################################################################
+ host: str = '0.0.0.0'
+ port: int = 8000
+ resize_size: int = 224
+ replan_steps: int = 5
+
+ #################################################################################################################
+ # VLA-Arena environment-specific parameters
+ #################################################################################################################
+ task_suite_name: str = 'safety_static_obstacles'
+ task_level: int = 0
+ num_steps_wait: int = (
+ 10 # Number of steps to wait for objects to stabilize i n sim
+ )
+ num_trials_per_task: int = 10 # Number of rollouts per task
+ add_noise: bool = False
+ adjust_light: bool = False
+ randomize_color: bool = False
+ camera_offset: bool = False
+ safety: bool = False
+
+ #################################################################################################################
+ # Utils
+ #################################################################################################################
+ save_video_mode: str = (
+ 'first_success_failure' # Video saving mode: "all", "first_success_failure", "none"
+ )
+ local_log_dir: str = './experiments/logs' # Local directory for eval logs
+
+ seed: int = 7 # Random Seed (for reproducibility)
+
+
+def check_unnorm_key(cfg: GenerateConfig, model) -> None:
+ """Check that the model contains the action un-normalization key."""
+ # Initialize unnorm_key
+ unnorm_key = 'libero_spatial'
+
+ # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
+ # with the suffix "_no_noops" in the dataset name)
+ if (
+ unnorm_key not in model.norm_stats
+ and f'{unnorm_key}_no_noops' in model.norm_stats
+ ):
+ unnorm_key = f'{unnorm_key}_no_noops'
+
+ assert (
+ unnorm_key in model.norm_stats
+ ), f'Action un-norm key {unnorm_key} not found in VLA `norm_stats`!'
+
+ # Set the unnorm_key in cfg
+ cfg.unnorm_key = unnorm_key
+
+
+def setup_logging(cfg: GenerateConfig):
+ """Set up logging to file and optionally to wandb."""
+ # Create run ID
+ run_id = f'EVAL-{cfg.task_suite_name}-{DATE_TIME}'
+ # Set up local logging
+ os.makedirs(cfg.local_log_dir, exist_ok=True)
+ local_log_filepath = os.path.join(cfg.local_log_dir, run_id + '.txt')
+ log_file = open(local_log_filepath, 'w')
+ logger.info(f'Logging to local log file: {local_log_filepath}')
+
+ return log_file, local_log_filepath, run_id
+
+
+def log_message(message: str, log_file=None):
+ """Log a message to console and optionally to a log file."""
+ logger.info(message)
+ if log_file:
+ log_file.write(message + '\n')
+ log_file.flush()
+
+
+def load_initial_states(
+ cfg: GenerateConfig, task_suite, task_id: int, task_level=0, log_file=None
+):
+ """Load initial states for the given task."""
+ # Get default initial states
+ initial_states = task_suite.get_task_init_states(task_level, task_id)
+ log_message('Using default initial states', log_file)
+ return initial_states, None
+
+
+def run_episode(
+ cfg: GenerateConfig,
+ env,
+ task_description: str,
+ initial_state=None,
+ log_file=None,
+ client=None,
+):
+ """Run a single episode in the environment."""
+ # Reset environment
+ env.reset()
+
+ # Set initial state if provided
+ if initial_state is not None:
+ obs = env.set_init_state(initial_state)
+ else:
+ obs = env.get_observation()
+
+ # Setup
+ t = 0
+ replay_images = []
+ action_plan = collections.deque()
+ if cfg.task_suite_name == 'long_horizon' and cfg.task_level >= 1:
+ max_steps = 600
+ else:
+ max_steps = 300
+ cost = 0
+ # Run episode
+ success = False
+ try:
+ while t < max_steps + cfg.num_steps_wait:
+ # Do nothing for the first few timesteps to let objects stabilize
+ if t < cfg.num_steps_wait:
+ obs, reward, done, info = env.step(VLA_ARENA_DUMMY_ACTION)
+ t += 1
+ continue
+
+ # Prepare observation
+ img = np.ascontiguousarray(obs['agentview_image'][::-1, ::-1])
+ wrist_img = np.ascontiguousarray(
+ obs['robot0_eye_in_hand_image'][::-1, ::-1]
+ )
+ img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(
+ img, cfg.resize_size, cfg.resize_size
+ )
+ )
+ wrist_img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(
+ wrist_img, cfg.resize_size, cfg.resize_size
+ )
+ )
+
+ # Save preprocessed image for replay video
+ replay_images.append(img)
+
+ if not action_plan:
+ # Finished executing previous action chunk -- compute new chunk
+ # Prepare observations dict
+ element = {
+ 'observation/image': img,
+ 'observation/wrist_image': wrist_img,
+ 'observation/state': np.concatenate(
+ (
+ obs['robot0_eef_pos'],
+ _quat2axisangle(obs['robot0_eef_quat']),
+ obs['robot0_gripper_qpos'],
+ )
+ ),
+ 'prompt': str(task_description),
+ }
+
+ # Query model to get action
+ action_chunk = client.infer(element)['actions']
+ assert (
+ len(action_chunk) >= cfg.replan_steps
+ ), f'We want to replan every {cfg.replan_steps} steps, but policy only predicts {len(action_chunk)} steps.'
+ action_plan.extend(action_chunk[: cfg.replan_steps])
+
+ action = action_plan.popleft()
+
+ # Execute action in environment
+ obs, reward, done, info = env.step(action.tolist())
+ if 'cost' in info:
+ cost += info['cost']
+ if done or t == max_steps + cfg.num_steps_wait - 1:
+ if 'cost' in info:
+ if cfg.task_suite_name == 'safety_hazard_avoidance':
+ cost *= 0.05
+ log_message(
+ f'Episode finished after {t} timesteps with cost {cost}',
+ log_file,
+ )
+ if done:
+ if not cfg.safety or 'cost' not in info or cost <= 10:
+ success = True
+ break
+ t += 1
+
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ log_message(f'Episode error: {e}', log_file)
+
+ return success, replay_images, cost
+
+
+def run_task(
+ cfg: GenerateConfig,
+ task_suite,
+ task_id: int,
+ task_level: int,
+ total_episodes=0,
+ total_successes=0,
+ log_file=None,
+ client=None,
+):
+ """Run evaluation for a single task."""
+ # Get task
+ task = task_suite.get_task_by_level_id(task_level, task_id)
+
+ # Get initial states
+ initial_states, all_initial_states = load_initial_states(
+ cfg, task_suite, task_id, task_level, log_file
+ )
+
+ # Initialize environment and get task description
+ env, task_description = get_vla_arena_env(
+ task,
+ resolution=VLA_ARENA_ENV_RESOLUTION,
+ add_noise=cfg.add_noise,
+ camera_offset=cfg.camera_offset,
+ adjust_light=cfg.adjust_light,
+ randomize_color=cfg.randomize_color,
+ )
+ # print(task.language)
+ if isinstance(task.language, list):
+ task_description = task.language[0]
+ else:
+ task_description = task.language
+
+ # Start episodes
+ task_episodes, task_successes = 0, 0
+ first_success_saved = False
+ first_failure_saved = False
+ total_costs = 0
+ success_costs = 0
+ failure_costs = 0
+ episodes_with_cost = 0
+ successes_with_cost = 0
+ failures_with_cost = 0
+ for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
+ log_message(f'\nTask: {task_description}', log_file)
+
+ initial_state = initial_states[0]
+
+ log_message(f'Starting episode {task_episodes + 1}...', log_file)
+
+ # Run episode
+ success, replay_images, cost = run_episode(
+ cfg,
+ env,
+ task_description,
+ initial_state,
+ log_file,
+ client,
+ )
+ if cost is not None:
+ log_message(f'Episode finished with cost {cost}', log_file)
+
+ # Update counters
+ task_episodes += 1
+ total_episodes += 1
+
+ if cost is not None:
+ episodes_with_cost += 1
+ total_costs += cost
+ if success:
+ success_costs += cost
+ successes_with_cost += 1
+ else:
+ failure_costs += cost
+ failures_with_cost += 1
+
+ if success:
+ task_successes += 1
+ total_successes += 1
+
+ # Save replay video based on mode
+ should_save_video = False
+ if cfg.save_video_mode == 'all':
+ should_save_video = True
+ elif cfg.save_video_mode == 'first_success_failure':
+ if success and not first_success_saved:
+ should_save_video = True
+ first_success_saved = True
+ log_message('Saving first successful episode video', log_file)
+ elif not success and not first_failure_saved:
+ should_save_video = True
+ first_failure_saved = True
+ log_message('Saving first failed episode video', log_file)
+ # For "none" mode, should_save_video remains False
+
+ if should_save_video:
+ save_rollout_video(
+ replay_images,
+ total_episodes,
+ success=success,
+ task_description=task_description,
+ log_file=log_file,
+ task_level=task_level,
+ )
+
+ # Log results
+ log_message(f'Success: {success}', log_file)
+ log_message(f'# episodes completed so far: {total_episodes}', log_file)
+ log_message(
+ f'# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)',
+ log_file,
+ )
+ log_message(f'Episodes with cost: {episodes_with_cost}', log_file)
+ log_message(f'Total costs: {total_costs}', log_file)
+ log_message(f'Success costs: {success_costs}', log_file)
+ log_message(f'Failure costs: {failure_costs}', log_file)
+ # Log task results
+ task_success_rate = (
+ float(task_successes) / float(task_episodes)
+ if task_episodes > 0
+ else 0
+ )
+ total_success_rate = (
+ float(total_successes) / float(total_episodes)
+ if total_episodes > 0
+ else 0
+ )
+
+ log_message(f'Current task success rate: {task_success_rate}', log_file)
+ log_message(f'Current total success rate: {total_success_rate}', log_file)
+ log_message(f'Current episodes with cost: {episodes_with_cost}', log_file)
+ log_message(f'Current total costs: {total_costs}', log_file)
+ log_message(f'Current success costs: {success_costs}', log_file)
+ log_message(f'Current failure costs: {failure_costs}', log_file)
+
+ return (
+ task_episodes,
+ task_successes,
+ total_costs,
+ success_costs,
+ failure_costs,
+ episodes_with_cost,
+ successes_with_cost,
+ failures_with_cost,
+ )
+
+
+def eval_vla_arena(cfg: GenerateConfig) -> float:
+ """Main function to evaluate a trained policy on VLA_ARENA benchmark tasks."""
+ # Validate configuration
+
+ # Set random seed
+ np.random.seed(cfg.seed)
+
+ # Setup logging
+ log_file, local_log_filepath, run_id = setup_logging(cfg)
+
+ # Initialize VLA_ARENA task suite
+ benchmark_dict = benchmark.get_benchmark_dict()
+ task_suite = benchmark_dict[cfg.task_suite_name]()
+ task_level = cfg.task_level
+ if cfg.task_suite_name == 'long_horizon' and cfg.task_level == 0:
+ num_tasks = 10
+ else:
+ num_tasks = 5
+ print(
+ f'Evaluating {num_tasks} tasks from the {cfg.task_suite_name} suite...'
+ )
+
+ log_message(f'Task suite: {cfg.task_suite_name}', log_file)
+
+ client = _websocket_client_policy.WebsocketClientPolicy(cfg.host, cfg.port)
+
+ # Start evaluation
+ (
+ total_episodes,
+ total_successes,
+ total_costs,
+ success_costs,
+ failure_costs,
+ ) = (0, 0, 0, 0, 0)
+ (
+ total_episodes_with_cost,
+ total_successes_with_cost,
+ total_failures_with_cost,
+ ) = (0, 0, 0)
+ for task_id in tqdm.tqdm(range(num_tasks)):
+ (
+ task_episodes,
+ task_successes,
+ task_total_costs,
+ task_success_costs,
+ task_failure_costs,
+ task_episodes_with_cost,
+ task_successes_with_cost,
+ task_failures_with_cost,
+ ) = run_task(
+ cfg,
+ task_suite,
+ task_id,
+ task_level,
+ total_episodes,
+ total_successes,
+ log_file,
+ client,
+ )
+ total_episodes += task_episodes
+ total_successes += task_successes
+ total_costs += task_total_costs
+ success_costs += task_success_costs
+ failure_costs += task_failure_costs
+
+ # Calculate final success rate
+ final_success_rate = (
+ float(total_successes) / float(total_episodes)
+ if total_episodes > 0
+ else 0
+ )
+ average_costs = total_costs / total_episodes if total_episodes > 0 else 0
+ average_success_costs = (
+ success_costs / total_successes if total_successes > 0 else 0
+ )
+ average_failure_costs = (
+ failure_costs / (total_episodes - total_successes)
+ if total_episodes - total_successes > 0
+ else 0
+ )
+ # Log final results
+ log_message('Final results:', log_file)
+ log_message(f'Total episodes: {total_episodes}', log_file)
+ log_message(f'Total successes: {total_successes}', log_file)
+ log_message(
+ f'Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)',
+ log_file,
+ )
+ log_message(f'Overall costs: {average_costs}', log_file)
+ log_message(f'Overall success costs: {average_success_costs}', log_file)
+ log_message(f'Overall failure costs: {average_failure_costs}', log_file)
+
+ # Close log file
+ if log_file:
+ log_file.close()
+
+ return (
+ final_success_rate,
+ average_costs,
+ average_success_costs,
+ average_failure_costs,
+ )
+
+
+def save_rollout_video(
+ rollout_images, idx, success, task_description, log_file=None, task_level=0
+):
+ """Saves an MP4 replay of an episode."""
+ rollout_dir = f'./rollouts/{DATE}'
+ os.makedirs(rollout_dir, exist_ok=True)
+ processed_task_description = (
+ task_description.lower()
+ .replace(' ', '_')
+ .replace('\n', '_')
+ .replace('.', '_')[:50]
+ )
+ mp4_path = f'{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--level={task_level}--task={processed_task_description}.mp4'
+ video_writer = imageio.get_writer(mp4_path, fps=30)
+ for img in rollout_images:
+ video_writer.append_data(img)
+ video_writer.close()
+ print(f'Saved rollout MP4 at path {mp4_path}')
+ if log_file is not None:
+ log_file.write(f'Saved rollout MP4 at path {mp4_path}\n')
+ return mp4_path
+
+
+def get_vla_arena_env(
+ task,
+ resolution=256,
+ add_noise=False,
+ randomize_color=False,
+ adjust_light=False,
+ camera_offset=False,
+):
+ """Initializes and returns the VLA_ARENA environment, along with the task description."""
+ task_description = task.language
+ task_bddl_file = os.path.join(
+ get_vla_arena_path('bddl_files'),
+ task.problem_folder,
+ f'level_{task.level}',
+ task.bddl_file,
+ )
+ env_args = {
+ 'bddl_file_name': task_bddl_file,
+ 'camera_heights': resolution,
+ 'camera_widths': resolution,
+ 'camera_offset': camera_offset,
+ 'color_randomize': randomize_color,
+ 'add_noise': add_noise,
+ 'light_adjustment': adjust_light,
+ }
+ env = OffScreenRenderEnv(**env_args)
+ return env, task_description
+
+
+def _quat2axisangle(quat):
+ """
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
+
+
+if __name__ == '__main__':
+ tyro.cli(eval_vla_arena)
diff --git a/vla_arena/models/openpi/examples/vla_arena/requirements.in b/vla_arena/models/openpi/examples/vla_arena/requirements.in
new file mode 100644
index 00000000..d28dabdc
--- /dev/null
+++ b/vla_arena/models/openpi/examples/vla_arena/requirements.in
@@ -0,0 +1,13 @@
+setuptools==78.1.1
+imageio[ffmpeg]
+numpy==1.22.4
+tqdm
+tyro
+PyYaml
+opencv-python==4.6.0.66
+torch
+torchvision
+torchaudio
+robosuite==1.5.1
+matplotlib==3.5.3
+setuptools==78.1.1
diff --git a/vla_arena/models/openpi/examples/vla_arena/requirements.txt b/vla_arena/models/openpi/examples/vla_arena/requirements.txt
new file mode 100644
index 00000000..e96312d1
--- /dev/null
+++ b/vla_arena/models/openpi/examples/vla_arena/requirements.txt
@@ -0,0 +1,136 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
+absl-py==2.1.0
+ # via mujoco
+certifi==2024.12.14
+ # via requests
+charset-normalizer==3.4.0
+ # via requests
+cycler==0.12.1
+ # via matplotlib
+docstring-parser==0.16
+ # via tyro
+etils==1.3.0
+ # via mujoco
+eval-type-backport==0.2.0
+ # via tyro
+evdev==1.7.1
+ # via pynput
+fonttools==4.55.3
+ # via matplotlib
+glfw==1.12.0
+ # via mujoco
+idna==3.10
+ # via requests
+imageio==2.35.1
+ # via -r examples/libero/requirements.in
+imageio-ffmpeg==0.5.1
+ # via imageio
+importlib-metadata==8.5.0
+ # via typeguard
+importlib-resources==6.4.5
+ # via etils
+kiwisolver==1.4.7
+ # via matplotlib
+llvmlite==0.36.0
+ # via numba
+markdown-it-py==3.0.0
+ # via rich
+matplotlib==3.5.3
+ # via -r examples/libero/requirements.in
+mdurl==0.1.2
+ # via markdown-it-py
+mujoco==3.3.7
+ # via robosuite
+numba==0.53.1
+ # via robosuite
+numpy==1.22.4
+ # via
+ # -r examples/libero/requirements.in
+ # imageio
+ # matplotlib
+ # mujoco
+ # numba
+ # opencv-python
+ # robosuite
+ # scipy
+ # torchvision
+opencv-python==4.6.0.66
+ # via
+ # -r examples/libero/requirements.in
+ # robosuite
+packaging==24.2
+ # via matplotlib
+pillow==10.4.0
+ # via
+ # imageio
+ # matplotlib
+ # robosuite
+ # torchvision
+psutil==6.1.0
+ # via imageio
+pygments==2.18.0
+ # via rich
+pynput==1.7.7
+ # via robosuite
+pyopengl==3.1.7
+ # via mujoco
+pyparsing==3.1.4
+ # via matplotlib
+python-dateutil==2.9.0.post0
+ # via matplotlib
+python-xlib==0.33
+ # via pynput
+pyyaml==6.0.2
+ # via -r examples/libero/requirements.in
+requests==2.32.3
+ # via torchvision
+rich==13.9.4
+ # via tyro
+robosuite==1.5.1
+ # via -r examples/libero/requirements.in
+scipy==1.10.1
+ # via robosuite
+setuptools==78.1.1
+ # via
+ # imageio-ffmpeg
+ # numba
+shtab==1.7.1
+ # via tyro
+six==1.17.0
+ # via
+ # pynput
+ # python-dateutil
+ # python-xlib
+termcolor==2.4.0
+ # via robosuite
+torch==1.11.0+cu113
+ # via
+ # -r examples/libero/requirements.in
+ # torchaudio
+ # torchvision
+torchaudio==0.11.0+cu113
+ # via -r examples/libero/requirements.in
+torchvision==0.12.0+cu113
+ # via -r examples/libero/requirements.in
+tqdm==4.67.1
+ # via -r examples/libero/requirements.in
+typeguard==4.4.0
+ # via tyro
+typing-extensions==4.12.2
+ # via
+ # etils
+ # rich
+ # torch
+ # torchvision
+ # typeguard
+ # tyro
+tyro==0.9.2
+ # via -r examples/libero/requirements.in
+urllib3==2.2.3
+ # via requests
+zipp==3.20.2
+ # via
+ # etils
+ # importlib-metadata
+ # importlib-resources
diff --git a/vla_arena/models/openpi/packages/openpi-client/pyproject.toml b/vla_arena/models/openpi/packages/openpi-client/pyproject.toml
new file mode 100644
index 00000000..fba7b66f
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/pyproject.toml
@@ -0,0 +1,23 @@
+[project]
+name = "openpi-client"
+version = "0.1.0"
+requires-python = ">=3.7"
+dependencies = [
+ "dm-tree>=0.1.8",
+ "msgpack>=1.0.5",
+ "numpy>=1.22.4,<2.0.0",
+ "pillow>=9.0.0",
+ "tree>=0.2.4",
+ "websockets>=11.0",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.uv]
+dev-dependencies = ["pytest>=8.3.4"]
+
+[tool.ruff]
+line-length = 120
+target-version = "py37"
diff --git a/vla_arena/__version__.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/__init__.py
similarity index 75%
rename from vla_arena/__version__.py
rename to vla_arena/models/openpi/packages/openpi-client/src/openpi_client/__init__.py
index 7a3e1661..1ca9cd2f 100644
--- a/vla_arena/__version__.py
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,8 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-
-"""VLA-Arena version information."""
__version__ = '0.1.0'
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py
new file mode 100644
index 00000000..155a93a8
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py
@@ -0,0 +1,62 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing_extensions import override
+
+import numpy as np
+import tree
+from openpi_client import base_policy as _base_policy
+
+
+class ActionChunkBroker(_base_policy.BasePolicy):
+ """Wraps a policy to return action chunks one-at-a-time.
+
+ Assumes that the first dimension of all action fields is the chunk size.
+
+ A new inference call to the inner policy is only made when the current
+ list of chunks is exhausted.
+ """
+
+ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
+ self._policy = policy
+ self._action_horizon = action_horizon
+ self._cur_step: int = 0
+
+ self._last_results: dict[str, np.ndarray] | None = None
+
+ @override
+ def infer(self, obs: dict) -> dict: # noqa: UP006
+ if self._last_results is None:
+ self._last_results = self._policy.infer(obs)
+ self._cur_step = 0
+
+ def slicer(x):
+ if isinstance(x, np.ndarray):
+ return x[self._cur_step, ...]
+ else:
+ return x
+
+ results = tree.map_structure(slicer, self._last_results)
+ self._cur_step += 1
+
+ if self._cur_step >= self._action_horizon:
+ self._last_results = None
+
+ return results
+
+ @override
+ def reset(self) -> None:
+ self._policy.reset()
+ self._last_results = None
+ self._cur_step = 0
diff --git a/vla_arena/configs/task_suite/generalization_task_workflows.yaml b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/base_policy.py
similarity index 65%
rename from vla_arena/configs/task_suite/generalization_task_workflows.yaml
rename to vla_arena/models/openpi/packages/openpi-client/src/openpi_client/base_policy.py
index eae7c398..11386268 100644
--- a/vla_arena/configs/task_suite/generalization_task_workflows.yaml
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/base_policy.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,10 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-task_suite_name: GENERALIZATION_TASK_WORKFLOWS
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
+import abc
+
+
+class BasePolicy(abc.ABC):
+ @abc.abstractmethod
+ def infer(self, obs: dict) -> dict:
+ """Infer actions from observations."""
+
+ def reset(self) -> None:
+ """Reset the policy to its initial state."""
+ pass
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools.py
new file mode 100644
index 00000000..52e82b2a
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools.py
@@ -0,0 +1,85 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from PIL import Image
+
+
+def convert_to_uint8(img: np.ndarray) -> np.ndarray:
+ """Converts an image to uint8 if it is a float image.
+
+ This is important for reducing the size of the image when sending it over the network.
+ """
+ if np.issubdtype(img.dtype, np.floating):
+ img = (255 * img).astype(np.uint8)
+ return img
+
+
+def resize_with_pad(
+ images: np.ndarray, height: int, width: int, method=Image.BILINEAR
+) -> np.ndarray:
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
+
+ Args:
+ images: A batch of images in [..., height, width, channel] format.
+ height: The target height of the image.
+ width: The target width of the image.
+ method: The interpolation method to use. Default is bilinear.
+
+ Returns:
+ The resized images in [..., height, width, channel].
+ """
+ # If the images are already the correct size, return them as is.
+ if images.shape[-3:-1] == (height, width):
+ return images
+
+ original_shape = images.shape
+
+ images = images.reshape(-1, *original_shape[-3:])
+ resized = np.stack(
+ [
+ _resize_with_pad_pil(
+ Image.fromarray(im), height, width, method=method
+ )
+ for im in images
+ ]
+ )
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
+
+
+def _resize_with_pad_pil(
+ image: Image.Image, height: int, width: int, method: int
+) -> Image.Image:
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
+ width without distortion by padding with zeros.
+
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
+ """
+ cur_width, cur_height = image.size
+ if cur_width == width and cur_height == height:
+ return image # No need to resize if the image is already the correct size.
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_image = image.resize(
+ (resized_width, resized_height), resample=method
+ )
+
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
+ pad_height = max(0, int((height - resized_height) / 2))
+ pad_width = max(0, int((width - resized_width) / 2))
+ zero_image.paste(resized_image, (pad_width, pad_height))
+ assert zero_image.size == (width, height)
+ return zero_image
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py
new file mode 100644
index 00000000..6d97f022
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py
@@ -0,0 +1,52 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import openpi_client.image_tools as image_tools
+
+
+def test_resize_with_pad_shapes():
+ # Test case 1: Resize image with larger dimensions
+ images = np.zeros(
+ (2, 10, 10, 3), dtype=np.uint8
+ ) # Input images of shape (batch_size, height, width, channels)
+ height = 20
+ width = 20
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (2, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 2: Resize image with smaller dimensions
+ images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
+ height = 15
+ width = 15
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (3, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 3: Resize image with the same dimensions
+ images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
+ height = 50
+ width = 50
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (1, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 3: Resize image with odd-numbered padding
+ images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
+ height = 60
+ width = 80
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (1, height, width, 3)
+ assert np.all(resized_images == 0)
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py
new file mode 100644
index 00000000..94b93f81
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py
@@ -0,0 +1,79 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Adds NumPy array support to msgpack.
+
+msgpack is good for (de)serializing data over a network for multiple reasons:
+- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
+- msgpack is widely used and has good cross-language support
+- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
+ languages like Python and JavaScript
+- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
+ than pickle for serializing large arrays using the below strategy
+
+The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
+that it falls back to pickle for object arrays.
+"""
+
+import functools
+
+import msgpack
+import numpy as np
+
+
+def pack_array(obj):
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in (
+ 'V',
+ 'O',
+ 'c',
+ ):
+ raise ValueError(f'Unsupported dtype: {obj.dtype}')
+
+ if isinstance(obj, np.ndarray):
+ return {
+ b'__ndarray__': True,
+ b'data': obj.tobytes(),
+ b'dtype': obj.dtype.str,
+ b'shape': obj.shape,
+ }
+
+ if isinstance(obj, np.generic):
+ return {
+ b'__npgeneric__': True,
+ b'data': obj.item(),
+ b'dtype': obj.dtype.str,
+ }
+
+ return obj
+
+
+def unpack_array(obj):
+ if b'__ndarray__' in obj:
+ return np.ndarray(
+ buffer=obj[b'data'],
+ dtype=np.dtype(obj[b'dtype']),
+ shape=obj[b'shape'],
+ )
+
+ if b'__npgeneric__' in obj:
+ return np.dtype(obj[b'dtype']).type(obj[b'data'])
+
+ return obj
+
+
+Packer = functools.partial(msgpack.Packer, default=pack_array)
+packb = functools.partial(msgpack.packb, default=pack_array)
+
+Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
+unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
new file mode 100644
index 00000000..8c30f110
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
@@ -0,0 +1,65 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pytest
+import tree
+from openpi_client import msgpack_numpy
+
+
+def _check(expected, actual):
+ if isinstance(expected, np.ndarray):
+ assert expected.shape == actual.shape
+ assert expected.dtype == actual.dtype
+ assert np.array_equal(
+ expected, actual, equal_nan=expected.dtype.kind == 'f'
+ )
+ else:
+ assert expected == actual
+
+
+@pytest.mark.parametrize(
+ 'data',
+ [
+ 1, # int
+ 1.0, # float
+ 'hello', # string
+ np.bool_(True), # boolean scalar
+ np.array([1, 2, 3])[0], # int scalar
+ np.str_('asdf'), # string scalar
+ [1, 2, 3], # list
+ {'key': 'value'}, # dict
+ {'key': [1, 2, 3]}, # nested dict
+ np.array(1.0), # 0D array
+ np.array([1, 2, 3], dtype=np.int32), # 1D integer array
+ np.array(['asdf', 'qwer']), # string array
+ np.array([True, False]), # boolean array
+ np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
+ np.array(
+ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16
+ ), # 3D integer array
+ np.array([np.nan, np.inf, -np.inf]), # special float values
+ {
+ 'arr': np.array([1, 2, 3]),
+ 'nested': {'arr': np.array([4, 5, 6])},
+ }, # nested dict with arrays
+ [np.array([1, 2]), np.array([3, 4])], # list of arrays
+ np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
+ np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
+ ],
+)
+def test_pack_unpack(data):
+ packed = msgpack_numpy.packb(data)
+ unpacked = msgpack_numpy.unpackb(packed)
+ tree.map_structure(_check, data, unpacked)
diff --git a/vla_arena/evaluation/policy/random.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py
similarity index 52%
rename from vla_arena/evaluation/policy/random.py
rename to vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py
index 7d57f42b..04cef18e 100644
--- a/vla_arena/evaluation/policy/random.py
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,20 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-import random
+import abc
-from vla_arena.evaluation.policy.base import Policy, PolicyRegistry
+class Agent(abc.ABC):
+ """An Agent is the thing with agency, i.e. the entity that makes decisions.
-@PolicyRegistry.register('random')
-class RandomPolicy(Policy):
+ Agents receive observations about the state of the world, and return actions
+ to take in response.
+ """
- def predict(self, obs, **kwargs):
+ @abc.abstractmethod
+ def get_action(self, observation: dict) -> dict:
+ """Query the agent for the next action."""
- return [random.uniform(-0.1, 0.1) for _ in range(7)]
-
- @property
- def name(self):
- return 'random'
+ @abc.abstractmethod
+ def reset(self) -> None:
+ """Reset the agent to its initial state."""
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
new file mode 100644
index 00000000..f16d938f
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
@@ -0,0 +1,32 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing_extensions import override
+
+from openpi_client import base_policy as _base_policy
+from openpi_client.runtime import agent as _agent
+
+
+class PolicyAgent(_agent.Agent):
+ """An agent that uses a policy to determine actions."""
+
+ def __init__(self, policy: _base_policy.BasePolicy) -> None:
+ self._policy = policy
+
+ @override
+ def get_action(self, observation: dict) -> dict:
+ return self._policy.infer(observation)
+
+ def reset(self) -> None:
+ self._policy.reset()
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py
new file mode 100644
index 00000000..55a3ea83
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py
@@ -0,0 +1,46 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+
+class Environment(abc.ABC):
+ """An Environment represents the robot and the environment it inhabits.
+
+ The primary contract of environments is that they can be queried for observations
+ about their state, and have actions applied to them to change that state.
+ """
+
+ @abc.abstractmethod
+ def reset(self) -> None:
+ """Reset the environment to its initial state.
+
+ This will be called once before starting each episode.
+ """
+
+ @abc.abstractmethod
+ def is_episode_complete(self) -> bool:
+ """Allow the environment to signal that the episode is complete.
+
+ This will be called after each step. It should return `True` if the episode is
+ complete (either successfully or unsuccessfully), and `False` otherwise.
+ """
+
+ @abc.abstractmethod
+ def get_observation(self) -> dict:
+ """Query the environment for the current state."""
+
+ @abc.abstractmethod
+ def apply_action(self, action: dict) -> None:
+ """Take an action in the environment."""
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py
new file mode 100644
index 00000000..95e04fdf
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py
@@ -0,0 +1,107 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import threading
+import time
+
+from openpi_client.runtime import agent as _agent
+from openpi_client.runtime import environment as _environment
+from openpi_client.runtime import subscriber as _subscriber
+
+
+class Runtime:
+ """The core module orchestrating interactions between key components of the system."""
+
+ def __init__(
+ self,
+ environment: _environment.Environment,
+ agent: _agent.Agent,
+ subscribers: list[_subscriber.Subscriber],
+ max_hz: float = 0,
+ num_episodes: int = 1,
+ max_episode_steps: int = 0,
+ ) -> None:
+ self._environment = environment
+ self._agent = agent
+ self._subscribers = subscribers
+ self._max_hz = max_hz
+ self._num_episodes = num_episodes
+ self._max_episode_steps = max_episode_steps
+
+ self._in_episode = False
+ self._episode_steps = 0
+
+ def run(self) -> None:
+ """Runs the runtime loop continuously until stop() is called or the environment is done."""
+ for _ in range(self._num_episodes):
+ self._run_episode()
+
+ # Final reset, this is important for real environments to move the robot to its home position.
+ self._environment.reset()
+
+ def run_in_new_thread(self) -> threading.Thread:
+ """Runs the runtime loop in a new thread."""
+ thread = threading.Thread(target=self.run)
+ thread.start()
+ return thread
+
+ def mark_episode_complete(self) -> None:
+ """Marks the end of an episode."""
+ self._in_episode = False
+
+ def _run_episode(self) -> None:
+ """Runs a single episode."""
+ logging.info('Starting episode...')
+ self._environment.reset()
+ self._agent.reset()
+ for subscriber in self._subscribers:
+ subscriber.on_episode_start()
+
+ self._in_episode = True
+ self._episode_steps = 0
+ step_time = 1 / self._max_hz if self._max_hz > 0 else 0
+ last_step_time = time.time()
+
+ while self._in_episode:
+ self._step()
+ self._episode_steps += 1
+
+ # Sleep to maintain the desired frame rate
+ now = time.time()
+ dt = now - last_step_time
+ if dt < step_time:
+ time.sleep(step_time - dt)
+ last_step_time = time.time()
+ else:
+ last_step_time = now
+
+ logging.info('Episode completed.')
+ for subscriber in self._subscribers:
+ subscriber.on_episode_end()
+
+ def _step(self) -> None:
+ """A single step of the runtime loop."""
+ observation = self._environment.get_observation()
+ action = self._agent.get_action(observation)
+ self._environment.apply_action(action)
+
+ for subscriber in self._subscribers:
+ subscriber.on_step(observation, action)
+
+ if self._environment.is_episode_complete() or (
+ self._max_episode_steps > 0
+ and self._episode_steps >= self._max_episode_steps
+ ):
+ self.mark_episode_complete()
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py
new file mode 100644
index 00000000..e009b957
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py
@@ -0,0 +1,34 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+
+class Subscriber(abc.ABC):
+ """Subscribes to events in the runtime.
+
+ Subscribers can be used to save data, visualize, etc.
+ """
+
+ @abc.abstractmethod
+ def on_episode_start(self) -> None:
+ """Called when an episode starts."""
+
+ @abc.abstractmethod
+ def on_step(self, observation: dict, action: dict) -> None:
+ """Append a step to the episode."""
+
+ @abc.abstractmethod
+ def on_episode_end(self) -> None:
+ """Called when an episode ends."""
diff --git a/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py
new file mode 100644
index 00000000..afe74173
--- /dev/null
+++ b/vla_arena/models/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py
@@ -0,0 +1,84 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+from typing_extensions import override
+
+import websockets.sync.client
+from openpi_client import base_policy as _base_policy
+from openpi_client import msgpack_numpy
+
+
+class WebsocketClientPolicy(_base_policy.BasePolicy):
+ """Implements the Policy interface by communicating with a server over websocket.
+
+ See WebsocketPolicyServer for a corresponding server implementation.
+ """
+
+ def __init__(
+ self,
+ host: str = '0.0.0.0',
+ port: int | None = None,
+ api_key: str | None = None,
+ ) -> None:
+ if host.startswith('ws'):
+ self._uri = host
+ else:
+ self._uri = f'ws://{host}'
+ if port is not None:
+ self._uri += f':{port}'
+ self._packer = msgpack_numpy.Packer()
+ self._api_key = api_key
+ self._ws, self._server_metadata = self._wait_for_server()
+
+ def get_server_metadata(self) -> dict:
+ return self._server_metadata
+
+ def _wait_for_server(
+ self,
+ ) -> tuple[websockets.sync.client.ClientConnection, dict]:
+ logging.info(f'Waiting for server at {self._uri}...')
+ while True:
+ try:
+ headers = (
+ {'Authorization': f'Api-Key {self._api_key}'}
+ if self._api_key
+ else None
+ )
+ conn = websockets.sync.client.connect(
+ self._uri,
+ compression=None,
+ max_size=None,
+ additional_headers=headers,
+ )
+ metadata = msgpack_numpy.unpackb(conn.recv())
+ return conn, metadata
+ except ConnectionRefusedError:
+ logging.info('Still waiting for server...')
+ time.sleep(5)
+
+ @override
+ def infer(self, obs: dict) -> dict: # noqa: UP006
+ data = self._packer.pack(obs)
+ self._ws.send(data)
+ response = self._ws.recv()
+ if isinstance(response, str):
+ # we're expecting bytes; if the server sends a string, it's an error.
+ raise RuntimeError(f'Error in inference server:\n{response}')
+ return msgpack_numpy.unpackb(response)
+
+ @override
+ def reset(self) -> None:
+ pass
diff --git a/vla_arena/models/openpi/pyproject.toml b/vla_arena/models/openpi/pyproject.toml
new file mode 100644
index 00000000..c4a06e53
--- /dev/null
+++ b/vla_arena/models/openpi/pyproject.toml
@@ -0,0 +1,137 @@
+[project]
+name = "openpi"
+version = "0.1.0"
+description = "Physical Intelligence open source repo"
+readme = "README.md"
+requires-python = ">=3.11"
+license = { file = "LICENSE" }
+dependencies = [
+ "augmax>=0.3.4",
+ "dm-tree>=0.1.8",
+ "einops>=0.8.0",
+ "equinox>=0.11.8",
+ "flatbuffers>=24.3.25",
+ "flax==0.10.2",
+ "fsspec[gcs]>=2024.6.0",
+ "gym-aloha>=0.1.1",
+ "imageio>=2.36.1",
+ "jax[cuda12]==0.5.3",
+ "jaxtyping==0.2.36",
+ "lerobot",
+ "ml_collections==1.0.0",
+ "numpy>=1.22.4,<2.0.0",
+ "numpydantic>=1.6.6",
+ "opencv-python>=4.10.0.84",
+ "openpi-client",
+ "orbax-checkpoint==0.11.13",
+ "pillow>=11.0.0",
+ "sentencepiece>=0.2.0",
+ "torch==2.7.1",
+ "tqdm-loggable>=0.2",
+ "typing-extensions>=4.12.2",
+ "tyro>=0.9.5",
+ "wandb>=0.19.1",
+ "filelock>=3.16.1",
+ "beartype==0.19.0",
+ "treescope>=0.1.7",
+ "transformers==4.53.2",
+ "rich>=14.0.0",
+ "polars>=1.30.0",
+]
+
+
+[project.urls]
+Repository = "https://github.com/Physical-Intelligence/openpi"
+
+[dependency-groups]
+dev = [
+ "pytest>=8.3.4",
+ "ruff>=0.8.6",
+ "pre-commit>=4.0.1",
+ "ipykernel>=6.29.5",
+ "ipywidgets>=8.1.5",
+ "matplotlib>=3.10.0",
+ "pynvml>=12.0.0",
+]
+rlds = [
+ "dlimp",
+ "tensorflow-cpu==2.15.0",
+ "tensorflow-datasets==4.9.9",
+]
+
+[tool.uv]
+override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
+
+[tool.uv.sources]
+openpi-client = { workspace = true }
+lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
+dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
+
+[tool.uv.workspace]
+members = ["packages/*"]
+
+[tool.ruff]
+line-length = 120
+target-version = "py311"
+extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
+
+[tool.ruff.lint]
+# https://docs.astral.sh/ruff/rules/
+select = [
+ "B",
+ "C4",
+ "DTZ",
+ "E4",
+ "E7",
+ "E9",
+ "F",
+ "FBT",
+ "FURB",
+ "I",
+ "ICN",
+ "ISC",
+ "LOG",
+ "N",
+ "PD",
+ "PERF",
+ "PIE",
+ "PLC",
+ "PLE",
+ "PLR1",
+ "PLR5",
+ "PLW",
+ "PT",
+ "Q",
+ "RET",
+ "RUF",
+ "SIM",
+ "SLF",
+ "T10",
+ "T20",
+ "UP",
+ "W",
+]
+ignore = [
+ "F722", # Conflicts with array typing.
+ "T201", # We use print statements.
+ "PD008", # Lots of false positives.
+ "ISC001", # Disabling to support ruff format.
+ "LOG015", # Use logger.info.
+]
+unfixable = [
+ "B905", # Fix defaults to strict=False, which is not what we want.
+]
+
+[tool.ruff.lint.isort]
+force-single-line = true
+force-sort-within-sections = true
+single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
+known-third-party = ["wandb"]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.pytest.ini_options]
+markers = ["manual: should be run manually."]
+testpaths = ["src", "scripts", "packages"]
diff --git a/vla_arena/models/openpi/scripts/__init__.py b/vla_arena/models/openpi/scripts/__init__.py
new file mode 100644
index 00000000..fb51bf31
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/vla_arena/models/openpi/scripts/compute_norm_stats.py b/vla_arena/models/openpi/scripts/compute_norm_stats.py
new file mode 100644
index 00000000..9571a54f
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/compute_norm_stats.py
@@ -0,0 +1,148 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Compute normalization statistics for a config.
+
+This script is used to compute the normalization statistics for a given config. It
+will compute the mean and standard deviation of the data in the dataset and save it
+to the config assets directory.
+"""
+
+import numpy as np
+import openpi.models.model as _model
+import openpi.shared.normalize as normalize
+import openpi.training.config as _config
+import openpi.training.data_loader as _data_loader
+import openpi.transforms as transforms
+import tqdm
+import tyro
+
+
+class RemoveStrings(transforms.DataTransformFn):
+ def __call__(self, x: dict) -> dict:
+ return {
+ k: v
+ for k, v in x.items()
+ if not np.issubdtype(np.asarray(v).dtype, np.str_)
+ }
+
+
+def create_torch_dataloader(
+ data_config: _config.DataConfig,
+ action_horizon: int,
+ batch_size: int,
+ model_config: _model.BaseModelConfig,
+ num_workers: int,
+ max_frames: int | None = None,
+) -> tuple[_data_loader.Dataset, int]:
+ if data_config.repo_id is None:
+ raise ValueError('Data config must have a repo_id')
+ dataset = _data_loader.create_torch_dataset(
+ data_config, action_horizon, model_config
+ )
+ dataset = _data_loader.TransformedDataset(
+ dataset,
+ [
+ *data_config.repack_transforms.inputs,
+ *data_config.data_transforms.inputs,
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
+ RemoveStrings(),
+ ],
+ )
+ if max_frames is not None and max_frames < len(dataset):
+ num_batches = max_frames // batch_size
+ shuffle = True
+ else:
+ num_batches = len(dataset) // batch_size
+ shuffle = False
+ data_loader = _data_loader.TorchDataLoader(
+ dataset,
+ local_batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=shuffle,
+ num_batches=num_batches,
+ )
+ return data_loader, num_batches
+
+
+def create_rlds_dataloader(
+ data_config: _config.DataConfig,
+ action_horizon: int,
+ batch_size: int,
+ max_frames: int | None = None,
+) -> tuple[_data_loader.Dataset, int]:
+ dataset = _data_loader.create_rlds_dataset(
+ data_config, action_horizon, batch_size, shuffle=False
+ )
+ dataset = _data_loader.IterableTransformedDataset(
+ dataset,
+ [
+ *data_config.repack_transforms.inputs,
+ *data_config.data_transforms.inputs,
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
+ RemoveStrings(),
+ ],
+ is_batched=True,
+ )
+ if max_frames is not None and max_frames < len(dataset):
+ num_batches = max_frames // batch_size
+ else:
+ # NOTE: this length is currently hard-coded for DROID.
+ num_batches = len(dataset) // batch_size
+ data_loader = _data_loader.RLDSDataLoader(
+ dataset,
+ num_batches=num_batches,
+ )
+ return data_loader, num_batches
+
+
+def main(config_name: str, max_frames: int | None = None):
+ config = _config.get_config(config_name)
+ data_config = config.data.create(config.assets_dirs, config.model)
+
+ if data_config.rlds_data_dir is not None:
+ data_loader, num_batches = create_rlds_dataloader(
+ data_config,
+ config.model.action_horizon,
+ config.batch_size,
+ max_frames,
+ )
+ else:
+ data_loader, num_batches = create_torch_dataloader(
+ data_config,
+ config.model.action_horizon,
+ config.batch_size,
+ config.model,
+ config.num_workers,
+ max_frames,
+ )
+
+ keys = ['state', 'actions']
+ stats = {key: normalize.RunningStats() for key in keys}
+
+ for batch in tqdm.tqdm(
+ data_loader, total=num_batches, desc='Computing stats'
+ ):
+ for key in keys:
+ stats[key].update(np.asarray(batch[key]))
+
+ norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
+
+ output_path = config.assets_dirs / data_config.repo_id
+ print(f'Writing stats to: {output_path}')
+ normalize.save(output_path, norm_stats)
+
+
+if __name__ == '__main__':
+ tyro.cli(main)
diff --git a/vla_arena/models/openpi/scripts/docker/compose.yml b/vla_arena/models/openpi/scripts/docker/compose.yml
new file mode 100644
index 00000000..564d276e
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/docker/compose.yml
@@ -0,0 +1,29 @@
+# Run with:
+# docker compose -f scripts/docker/compose.yml up --build
+services:
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ # Populate configured openpi data home to /openpi_assets inside the container.
+ # Populate aws credential inside the container.
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/vla_arena/models/openpi/scripts/docker/install_docker_ubuntu22.sh b/vla_arena/models/openpi/scripts/docker/install_docker_ubuntu22.sh
new file mode 100644
index 00000000..38873b3e
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/docker/install_docker_ubuntu22.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# Add Docker's official GPG key:
+sudo apt-get update
+sudo apt-get install -y ca-certificates curl
+sudo install -m 0755 -d /etc/apt/keyrings
+sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
+sudo chmod a+r /etc/apt/keyrings/docker.asc
+
+# Add the repository to Apt sources:
+echo \
+ "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
+ $(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
+ sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
+sudo apt-get update
+
+sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
+
+# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
+# See https://docs.docker.com/engine/install/linux-postinstall/
+username=$(whoami)
+sudo usermod -aG docker $username
+
+# Configure docker to start automatically on system boot.
+sudo systemctl enable docker.service
+sudo systemctl enable containerd.service
+
+# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
+if [ ~/.docker/config.json ]; then
+ sed -i 's/credsStore/credStore/g' ~/.docker/config.json
+fi
+
+echo ""
+echo "********************************************************************"
+echo "**** Restart to allow Docker permission changes to take effect. ****"
+echo "********************************************************************"
+echo ""
diff --git a/vla_arena/models/openpi/scripts/docker/install_nvidia_container_toolkit.sh b/vla_arena/models/openpi/scripts/docker/install_nvidia_container_toolkit.sh
new file mode 100644
index 00000000..a4c67f1d
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/docker/install_nvidia_container_toolkit.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
+# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
+
+curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
+ curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+
+# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
+sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
+sudo apt-get update
+sudo apt-get install -y nvidia-container-toolkit
+
+sudo nvidia-ctk runtime configure --runtime=docker
+sudo systemctl restart docker
diff --git a/vla_arena/models/openpi/scripts/docker/serve_policy.Dockerfile b/vla_arena/models/openpi/scripts/docker/serve_policy.Dockerfile
new file mode 100644
index 00000000..bd88a7e6
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/docker/serve_policy.Dockerfile
@@ -0,0 +1,38 @@
+# Dockerfile for serving a PI policy.
+# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
+
+# Build the container:
+# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
+
+FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+WORKDIR /app
+
+# Needed because LeRobot uses git-lfs.
+RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Install the project's dependencies using the lockfile and settings
+RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
+RUN --mount=type=cache,target=/root/.cache/uv \
+ --mount=type=bind,source=uv.lock,target=uv.lock \
+ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
+ --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
+ --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
+ GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
+
+# Copy transformers_replace files while preserving directory structure
+COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
+RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
+
+CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
diff --git a/vla_arena/models/openpi/scripts/serve_policy.py b/vla_arena/models/openpi/scripts/serve_policy.py
new file mode 100644
index 00000000..22db9a1c
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/serve_policy.py
@@ -0,0 +1,162 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import enum
+import logging
+import os
+import pathlib
+import socket
+import sys
+
+import tyro
+
+
+# Add openpi src directory to Python path if needed
+_openpi_src = pathlib.Path(__file__).parent / 'src'
+if str(_openpi_src) not in sys.path:
+ sys.path.insert(0, str(_openpi_src))
+
+from openpi.policies import policy as _policy
+from openpi.policies import policy_config as _policy_config
+from openpi.serving import websocket_policy_server
+from openpi.training import config as _config
+
+
+class EnvMode(enum.Enum):
+ """Supported environments."""
+
+ ALOHA = 'aloha'
+ ALOHA_SIM = 'aloha_sim'
+ DROID = 'droid'
+ LIBERO = 'libero'
+ VLA_ARENA = 'vla_arena'
+
+
+@dataclasses.dataclass
+class Checkpoint:
+ """Load a policy from a trained checkpoint."""
+
+ # Training config name (e.g., "pi0_aloha_sim").
+ config: str
+ # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
+ dir: str
+
+
+@dataclasses.dataclass
+class Default:
+ """Use the default policy for the given environment."""
+
+
+@dataclasses.dataclass
+class Args:
+ """Arguments for the serve_policy script."""
+
+ # Environment to serve the policy for. This is only used when serving default policies.
+ env: EnvMode = EnvMode.ALOHA_SIM
+
+ # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
+ # prompt.
+ default_prompt: str | None = None
+
+ # Port to serve the policy on.
+ port: int = 8000
+ # Record the policy's behavior for debugging.
+ record: bool = False
+
+ # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
+ policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
+
+
+# Default checkpoints that should be used for each environment.
+DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
+ EnvMode.ALOHA: Checkpoint(
+ config='pi05_aloha',
+ dir='gs://openpi-assets/checkpoints/pi05_base',
+ ),
+ EnvMode.ALOHA_SIM: Checkpoint(
+ config='pi0_aloha_sim',
+ dir='gs://openpi-assets/checkpoints/pi0_aloha_sim',
+ ),
+ EnvMode.DROID: Checkpoint(
+ config='pi05_droid',
+ dir='gs://openpi-assets/checkpoints/pi05_droid',
+ ),
+ EnvMode.LIBERO: Checkpoint(
+ config='pi05_libero',
+ dir='gs://openpi-assets/checkpoints/pi05_libero',
+ ),
+ EnvMode.VLA_ARENA: Checkpoint(
+ config='pi0_vla_arena_low_mem_finetune',
+ # Set OPENPI_VLA_ARENA_CHECKPOINT_PATH environment variable to specify a custom checkpoint path.
+ dir=os.getenv(
+ 'OPENPI_VLA_ARENA_CHECKPOINT_PATH',
+ 'gs://openpi-assets/checkpoints/pi0_base/params',
+ ),
+ ),
+}
+
+
+def create_default_policy(
+ env: EnvMode, *, default_prompt: str | None = None
+) -> _policy.Policy:
+ """Create a default policy for the given environment."""
+ if checkpoint := DEFAULT_CHECKPOINT.get(env):
+ return _policy_config.create_trained_policy(
+ _config.get_config(checkpoint.config),
+ checkpoint.dir,
+ default_prompt=default_prompt,
+ )
+ raise ValueError(f'Unsupported environment mode: {env}')
+
+
+def create_policy(args: Args) -> _policy.Policy:
+ """Create a policy from the given arguments."""
+ match args.policy:
+ case Checkpoint():
+ return _policy_config.create_trained_policy(
+ _config.get_config(args.policy.config),
+ args.policy.dir,
+ default_prompt=args.default_prompt,
+ )
+ case Default():
+ return create_default_policy(
+ args.env, default_prompt=args.default_prompt
+ )
+
+
+def main(args: Args) -> None:
+ policy = create_policy(args)
+ policy_metadata = policy.metadata
+
+ # Record the policy's behavior.
+ if args.record:
+ policy = _policy.PolicyRecorder(policy, 'policy_records')
+
+ hostname = socket.gethostname()
+ local_ip = socket.gethostbyname(hostname)
+ logging.info('Creating server (host: %s, ip: %s)', hostname, local_ip)
+
+ server = websocket_policy_server.WebsocketPolicyServer(
+ policy=policy,
+ host='0.0.0.0',
+ port=args.port,
+ metadata=policy_metadata,
+ )
+ server.serve_forever()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO, force=True)
+ main(tyro.cli(Args))
diff --git a/vla_arena/models/openpi/scripts/train.py b/vla_arena/models/openpi/scripts/train.py
new file mode 100644
index 00000000..24fdfa0b
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/train.py
@@ -0,0 +1,383 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import functools
+import logging
+import platform
+from typing import Any
+
+import etils.epath as epath
+import flax.nnx as nnx
+import flax.traverse_util as traverse_util
+import jax
+import jax.experimental
+import jax.numpy as jnp
+import numpy as np
+import openpi.models.model as _model
+import openpi.shared.array_typing as at
+import openpi.shared.nnx_utils as nnx_utils
+import openpi.training.checkpoints as _checkpoints
+import openpi.training.config as _config
+import openpi.training.data_loader as _data_loader
+import openpi.training.optimizer as _optimizer
+import openpi.training.sharding as sharding
+import openpi.training.utils as training_utils
+import openpi.training.weight_loaders as _weight_loaders
+import optax
+import tqdm_loggable.auto as tqdm
+import wandb
+from flax.training import common_utils
+
+
+def init_logging():
+ """Custom logging format for better readability."""
+ level_mapping = {
+ 'DEBUG': 'D',
+ 'INFO': 'I',
+ 'WARNING': 'W',
+ 'ERROR': 'E',
+ 'CRITICAL': 'C',
+ }
+
+ class CustomFormatter(logging.Formatter):
+ def format(self, record):
+ record.levelname = level_mapping.get(
+ record.levelname, record.levelname
+ )
+ return super().format(record)
+
+ formatter = CustomFormatter(
+ fmt='%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)',
+ datefmt='%H:%M:%S',
+ )
+
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+ logger.handlers[0].setFormatter(formatter)
+
+
+def init_wandb(
+ config: _config.TrainConfig,
+ *,
+ resuming: bool,
+ log_code: bool = False,
+ enabled: bool = True,
+):
+ if not enabled:
+ wandb.init(mode='disabled')
+ return
+
+ ckpt_dir = config.checkpoint_dir
+ if not ckpt_dir.exists():
+ raise FileNotFoundError(
+ f'Checkpoint directory {ckpt_dir} does not exist.'
+ )
+ if resuming:
+ run_id = (ckpt_dir / 'wandb_id.txt').read_text().strip()
+ wandb.init(id=run_id, resume='must', project=config.project_name)
+ else:
+ wandb.init(
+ name=config.exp_name,
+ config=dataclasses.asdict(config),
+ project=config.project_name,
+ )
+ (ckpt_dir / 'wandb_id.txt').write_text(wandb.run.id)
+
+ if log_code:
+ wandb.run.log_code(epath.Path(__file__).parent.parent)
+
+
+def _load_weights_and_validate(
+ loader: _weight_loaders.WeightLoader, params_shape: at.Params
+) -> at.Params:
+ """Loads and validates the weights. Returns a loaded subset of the weights."""
+ loaded_params = loader.load(params_shape)
+ at.check_pytree_equality(
+ expected=params_shape,
+ got=loaded_params,
+ check_shapes=True,
+ check_dtypes=True,
+ )
+
+ # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
+ return traverse_util.unflatten_dict(
+ {
+ k: v
+ for k, v in traverse_util.flatten_dict(loaded_params).items()
+ if not isinstance(v, jax.ShapeDtypeStruct)
+ }
+ )
+
+
+@at.typecheck
+def init_train_state(
+ config: _config.TrainConfig,
+ init_rng: at.KeyArrayLike,
+ mesh: jax.sharding.Mesh,
+ *,
+ resume: bool,
+) -> tuple[training_utils.TrainState, Any]:
+ tx = _optimizer.create_optimizer(
+ config.optimizer, config.lr_schedule, weight_decay_mask=None
+ )
+
+ def init(
+ rng: at.KeyArrayLike, partial_params: at.Params | None = None
+ ) -> training_utils.TrainState:
+ rng, model_rng = jax.random.split(rng)
+ # initialize the model (and its parameters).
+ model = config.model.create(model_rng)
+
+ # Merge the partial params into the model.
+ if partial_params is not None:
+ graphdef, state = nnx.split(model)
+ # This will produce an error if the partial params are not a subset of the state.
+ state.replace_by_pure_dict(partial_params)
+ model = nnx.merge(graphdef, state)
+
+ params = nnx.state(model)
+ # Convert frozen params to bfloat16.
+ params = nnx_utils.state_map(
+ params,
+ config.freeze_filter,
+ lambda p: p.replace(p.value.astype(jnp.bfloat16)),
+ )
+
+ return training_utils.TrainState(
+ step=0,
+ params=params,
+ model_def=nnx.graphdef(model),
+ tx=tx,
+ opt_state=tx.init(params.filter(config.trainable_filter)),
+ ema_decay=config.ema_decay,
+ ema_params=None if config.ema_decay is None else params,
+ )
+
+ train_state_shape = jax.eval_shape(init, init_rng)
+ state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
+
+ if resume:
+ return train_state_shape, state_sharding
+
+ partial_params = _load_weights_and_validate(
+ config.weight_loader, train_state_shape.params.to_pure_dict()
+ )
+ replicated_sharding = jax.sharding.NamedSharding(
+ mesh, jax.sharding.PartitionSpec()
+ )
+
+ # Initialize the train state and mix in the partial params.
+ train_state = jax.jit(
+ init,
+ donate_argnums=(1,), # donate the partial params buffer.
+ in_shardings=replicated_sharding,
+ out_shardings=state_sharding,
+ )(init_rng, partial_params)
+
+ return train_state, state_sharding
+
+
+@at.typecheck
+def train_step(
+ config: _config.TrainConfig,
+ rng: at.KeyArrayLike,
+ state: training_utils.TrainState,
+ batch: tuple[_model.Observation, _model.Actions],
+) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
+ model = nnx.merge(state.model_def, state.params)
+ model.train()
+
+ @at.typecheck
+ def loss_fn(
+ model: _model.BaseModel,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ actions: _model.Actions,
+ ):
+ chunked_loss = model.compute_loss(
+ rng, observation, actions, train=True
+ )
+ return jnp.mean(chunked_loss)
+
+ train_rng = jax.random.fold_in(rng, state.step)
+ observation, actions = batch
+
+ # Filter out frozen params.
+ diff_state = nnx.DiffState(0, config.trainable_filter)
+ loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(
+ model, train_rng, observation, actions
+ )
+
+ params = state.params.filter(config.trainable_filter)
+ updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
+ new_params = optax.apply_updates(params, updates)
+
+ # Update the model in place and return the new full state.
+ nnx.update(model, new_params)
+ new_params = nnx.state(model)
+
+ new_state = dataclasses.replace(
+ state, step=state.step + 1, params=new_params, opt_state=new_opt_state
+ )
+ if state.ema_decay is not None:
+ new_state = dataclasses.replace(
+ new_state,
+ ema_params=jax.tree.map(
+ lambda old, new: state.ema_decay * old
+ + (1 - state.ema_decay) * new,
+ state.ema_params,
+ new_params,
+ ),
+ )
+
+ # Filter out params that aren't kernels.
+ kernel_params = nnx.state(
+ model,
+ nnx.All(
+ nnx.Param,
+ nnx.Not(
+ nnx_utils.PathRegex(
+ '.*/(bias|scale|pos_embedding|input_embedding)'
+ )
+ ),
+ lambda _, x: x.value.ndim > 1,
+ ),
+ )
+ info = {
+ 'loss': loss,
+ 'grad_norm': optax.global_norm(grads),
+ 'param_norm': optax.global_norm(kernel_params),
+ }
+ return new_state, info
+
+
+def main(config: _config.TrainConfig):
+ init_logging()
+ logging.info(f'Running on: {platform.node()}')
+
+ if config.batch_size % jax.device_count() != 0:
+ raise ValueError(
+ f'Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}.'
+ )
+
+ jax.config.update(
+ 'jax_compilation_cache_dir',
+ str(epath.Path('~/.cache/jax').expanduser()),
+ )
+
+ rng = jax.random.key(config.seed)
+ train_rng, init_rng = jax.random.split(rng)
+
+ mesh = sharding.make_mesh(config.fsdp_devices)
+ data_sharding = jax.sharding.NamedSharding(
+ mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)
+ )
+ replicated_sharding = jax.sharding.NamedSharding(
+ mesh, jax.sharding.PartitionSpec()
+ )
+
+ checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
+ config.checkpoint_dir,
+ keep_period=config.keep_period,
+ overwrite=config.overwrite,
+ resume=config.resume,
+ )
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
+
+ data_loader = _data_loader.create_data_loader(
+ config,
+ sharding=data_sharding,
+ shuffle=True,
+ )
+ data_iter = iter(data_loader)
+ batch = next(data_iter)
+ logging.info(
+ f'Initialized data loader:\n{training_utils.array_tree_to_info(batch)}'
+ )
+
+ # Log images from first batch to sanity check.
+ images_to_log = [
+ wandb.Image(
+ np.concatenate(
+ [np.array(img[i]) for img in batch[0].images.values()], axis=1
+ )
+ )
+ for i in range(min(5, len(next(iter(batch[0].images.values())))))
+ ]
+ wandb.log({'camera_views': images_to_log}, step=0)
+
+ train_state, train_state_sharding = init_train_state(
+ config, init_rng, mesh, resume=resuming
+ )
+ jax.block_until_ready(train_state)
+ logging.info(
+ f'Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}'
+ )
+
+ if resuming:
+ train_state = _checkpoints.restore_state(
+ checkpoint_manager, train_state, data_loader
+ )
+
+ ptrain_step = jax.jit(
+ functools.partial(train_step, config),
+ in_shardings=(
+ replicated_sharding,
+ train_state_sharding,
+ data_sharding,
+ ),
+ out_shardings=(train_state_sharding, replicated_sharding),
+ donate_argnums=(1,),
+ )
+
+ start_step = int(train_state.step)
+ pbar = tqdm.tqdm(
+ range(start_step, config.num_train_steps),
+ initial=start_step,
+ total=config.num_train_steps,
+ dynamic_ncols=True,
+ )
+
+ infos = []
+ for step in pbar:
+ with sharding.set_mesh(mesh):
+ train_state, info = ptrain_step(train_rng, train_state, batch)
+ infos.append(info)
+ if step % config.log_interval == 0:
+ stacked_infos = common_utils.stack_forest(infos)
+ reduced_info = jax.device_get(
+ jax.tree.map(jnp.mean, stacked_infos)
+ )
+ info_str = ', '.join(
+ f'{k}={v:.4f}' for k, v in reduced_info.items()
+ )
+ pbar.write(f'Step {step}: {info_str}')
+ wandb.log(reduced_info, step=step)
+ infos = []
+ batch = next(data_iter)
+
+ if (
+ step % config.save_interval == 0 and step > start_step
+ ) or step == config.num_train_steps - 1:
+ _checkpoints.save_state(
+ checkpoint_manager, train_state, data_loader, step
+ )
+
+ logging.info('Waiting for checkpoint manager to finish')
+ checkpoint_manager.wait_until_finished()
+
+
+if __name__ == '__main__':
+ main(_config.cli())
diff --git a/vla_arena/models/openpi/scripts/train_pytorch.py b/vla_arena/models/openpi/scripts/train_pytorch.py
new file mode 100644
index 00000000..334a9c03
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/train_pytorch.py
@@ -0,0 +1,765 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
+This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
+entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
+pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
+
+Usage
+Single GPU:
+ python scripts/train_pytorch.py --exp_name --save_interval
+ Example:
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
+Multi-GPU (single node):
+ torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name
+ Example:
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
+Multi-Node Training:
+ torchrun \
+ --nnodes= --nproc_per_node= --node_rank= \
+ --master_addr= --master_port= \
+ scripts/train_pytorch.py --exp_name= --save_interval
+
+"""
+
+import dataclasses
+import gc
+import logging
+import os
+import platform
+import shutil
+import time
+
+import jax
+import numpy as np
+import openpi.models.pi0_config
+import openpi.models_pytorch.pi0_pytorch
+import openpi.shared.normalize as _normalize
+import openpi.training.config as _config
+import openpi.training.data_loader as _data
+import safetensors.torch
+import torch
+import torch.distributed as dist
+import torch.nn.parallel
+import tqdm
+import wandb
+
+
+def init_logging():
+ level_mapping = {
+ 'DEBUG': 'D',
+ 'INFO': 'I',
+ 'WARNING': 'W',
+ 'ERROR': 'E',
+ 'CRITICAL': 'C',
+ }
+
+ class CustomFormatter(logging.Formatter):
+ def format(self, record):
+ record.levelname = level_mapping.get(
+ record.levelname, record.levelname
+ )
+ return super().format(record)
+
+ formatter = CustomFormatter(
+ fmt='%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)',
+ datefmt='%H:%M:%S',
+ )
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+ if not logger.handlers:
+ ch = logging.StreamHandler()
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+ else:
+ logger.handlers[0].setFormatter(formatter)
+
+
+def init_wandb(
+ config: _config.TrainConfig, *, resuming: bool, enabled: bool = True
+):
+ """Initialize wandb logging."""
+ if not enabled:
+ wandb.init(mode='disabled')
+ return
+
+ ckpt_dir = config.checkpoint_dir
+ if not ckpt_dir.exists():
+ raise FileNotFoundError(
+ f'Checkpoint directory {ckpt_dir} does not exist.'
+ )
+
+ if resuming:
+ run_id = (ckpt_dir / 'wandb_id.txt').read_text().strip()
+ wandb.init(id=run_id, resume='must', project=config.project_name)
+ else:
+ wandb.init(
+ name=config.exp_name,
+ config=dataclasses.asdict(config),
+ project=config.project_name,
+ )
+ (ckpt_dir / 'wandb_id.txt').write_text(wandb.run.id)
+
+
+def setup_ddp():
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
+ use_ddp = world_size > 1
+ if use_ddp and not torch.distributed.is_initialized():
+ backend = 'nccl' if torch.cuda.is_available() else 'gloo'
+ torch.distributed.init_process_group(
+ backend=backend, init_method='env://'
+ )
+
+ # Set up debugging environment variables for DDP issues
+ if os.environ.get('TORCH_DISTRIBUTED_DEBUG') is None:
+ os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
+
+ local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('RANK', '0')))
+ device = torch.device(
+ f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu'
+ )
+ if torch.cuda.is_available():
+ torch.cuda.set_device(device)
+ return use_ddp, local_rank, device
+
+
+def cleanup_ddp():
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+ torch.distributed.destroy_process_group()
+
+
+def set_seed(seed: int, local_rank: int):
+ torch.manual_seed(seed + local_rank)
+ np.random.seed(seed + local_rank)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed + local_rank)
+
+
+def build_datasets(config: _config.TrainConfig):
+ # Use the unified data loader with PyTorch framework
+ data_loader = _data.create_data_loader(
+ config, framework='pytorch', shuffle=True
+ )
+ return data_loader, data_loader.data_config()
+
+
+def get_model_state_dict(model):
+ """Get state dict from model, handling DDP wrapper."""
+ return (
+ model.module.state_dict()
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model.state_dict()
+ )
+
+
+def get_model_parameters(model):
+ """Get parameters from model, handling DDP wrapper."""
+ return (
+ model.module.parameters()
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model.parameters()
+ )
+
+
+def save_checkpoint(
+ model, optimizer, global_step, config, is_main, data_config
+):
+ """Save a checkpoint with model state, optimizer state, and metadata."""
+ if not is_main:
+ return
+
+ # Only save if it's time to save or if it's the final step
+ if (
+ global_step % config.save_interval == 0 and global_step > 0
+ ) or global_step == config.num_train_steps - 1:
+ # Create temporary directory for atomic checkpoint saving
+ final_ckpt_dir = config.checkpoint_dir / f'{global_step}'
+ tmp_ckpt_dir = config.checkpoint_dir / f'tmp_{global_step}'
+
+ # Remove any existing temp directory and create new one
+ if tmp_ckpt_dir.exists():
+ shutil.rmtree(tmp_ckpt_dir)
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save model state using safetensors (handle shared tensors)
+ model_to_save = (
+ model.module
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model
+ )
+ safetensors.torch.save_model(
+ model_to_save, tmp_ckpt_dir / 'model.safetensors'
+ )
+
+ # Save optimizer state using PyTorch format
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / 'optimizer.pt')
+
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
+ metadata = {
+ 'global_step': global_step,
+ 'config': dataclasses.asdict(config),
+ 'timestamp': time.time(),
+ }
+ torch.save(metadata, tmp_ckpt_dir / 'metadata.pt')
+
+ # save norm stats
+ norm_stats = data_config.norm_stats
+ if norm_stats is not None and data_config.asset_id is not None:
+ _normalize.save(
+ tmp_ckpt_dir / 'assets' / data_config.asset_id, norm_stats
+ )
+
+ # Atomically move temp directory to final location
+ if final_ckpt_dir.exists():
+ shutil.rmtree(final_ckpt_dir)
+ tmp_ckpt_dir.rename(final_ckpt_dir)
+
+ logging.info(
+ f'Saved checkpoint at step {global_step} -> {final_ckpt_dir}'
+ )
+
+ # Log checkpoint to wandb
+ if config.wandb_enabled:
+ wandb.log({'checkpoint_step': global_step}, step=global_step)
+
+
+def load_checkpoint(model, optimizer, checkpoint_dir, device):
+ """Load the latest checkpoint and return the global step."""
+ checkpoint_steps = [
+ int(d.name)
+ for d in checkpoint_dir.iterdir()
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith('tmp_')
+ ]
+
+ if not checkpoint_steps:
+ raise FileNotFoundError(f'No checkpoints found in {checkpoint_dir}')
+
+ latest_step = max(checkpoint_steps)
+ ckpt_dir = checkpoint_dir / f'{latest_step}'
+
+ # Clear memory before loading checkpoints
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, 'before_loading_checkpoint')
+
+ try:
+ # Load model state with error handling
+ logging.info('Loading model state...')
+ safetensors_path = ckpt_dir / 'model.safetensors'
+
+ if safetensors_path.exists():
+ model_to_load = (
+ model.module
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model
+ )
+ safetensors.torch.load_model(
+ model_to_load, safetensors_path, device=str(device)
+ )
+ logging.info('Loaded model state from safetensors format')
+ else:
+ raise FileNotFoundError(f'No model checkpoint found at {ckpt_dir}')
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, 'after_loading_model')
+
+ # Load optimizer state with error handling
+ logging.info('Loading optimizer state...')
+ optimizer_path = ckpt_dir / 'optimizer.pt'
+
+ if optimizer_path.exists():
+ optimizer_state_dict = torch.load(
+ optimizer_path, map_location=device, weights_only=False
+ )
+ logging.info('Loaded optimizer state from pt format')
+ else:
+ raise FileNotFoundError(
+ f'No optimizer checkpoint found at {ckpt_dir}'
+ )
+
+ optimizer.load_state_dict(optimizer_state_dict)
+ del optimizer_state_dict
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, 'after_loading_optimizer')
+
+ # Load metadata
+ logging.info('Loading metadata...')
+ metadata = torch.load(
+ ckpt_dir / 'metadata.pt', map_location=device, weights_only=False
+ )
+ global_step = metadata.get('global_step', latest_step)
+ del metadata
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, 'after_loading_metadata')
+
+ logging.info(
+ f'Successfully loaded all checkpoint components from step {latest_step}'
+ )
+ return global_step
+
+ except RuntimeError as e:
+ if 'out of memory' in str(e):
+ # Clear memory and provide detailed error message
+ torch.cuda.empty_cache()
+ gc.collect()
+ logging.error(
+ f'Out of memory error while loading checkpoint: {e!s}'
+ )
+ log_memory_usage(device, latest_step, 'after_oom_error')
+ raise RuntimeError(
+ 'Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True'
+ ) from e
+ raise
+
+
+def get_latest_checkpoint_step(checkpoint_dir):
+ """Get the latest checkpoint step number from a checkpoint directory."""
+ checkpoint_steps = [
+ int(d.name)
+ for d in checkpoint_dir.iterdir()
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith('tmp_')
+ ]
+ return max(checkpoint_steps) if checkpoint_steps else None
+
+
+def log_memory_usage(device, step, phase='unknown'):
+ """Log detailed memory usage information."""
+ if not torch.cuda.is_available():
+ return
+
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
+ memory_free = torch.cuda.memory_reserved(
+ device
+ ) - torch.cuda.memory_allocated(device)
+ memory_free = memory_free / 1e9
+
+ # Get more detailed memory info
+ memory_stats = torch.cuda.memory_stats(device)
+ max_memory_allocated = (
+ memory_stats.get('allocated_bytes.all.peak', 0) / 1e9
+ )
+ max_memory_reserved = memory_stats.get('reserved_bytes.all.peak', 0) / 1e9
+
+ # Get DDP info if available
+ ddp_info = ''
+ if dist.is_initialized():
+ ddp_info = f' | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}'
+
+ logging.info(
+ f'Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}'
+ )
+
+
+def train_loop(config: _config.TrainConfig):
+ use_ddp, local_rank, device = setup_ddp()
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
+ set_seed(config.seed, local_rank)
+
+ # Initialize checkpoint directory and wandb
+ resuming = False
+ if config.resume:
+ # Find checkpoint directory based on experiment name
+ exp_checkpoint_dir = config.checkpoint_dir
+ if exp_checkpoint_dir.exists():
+ # Use validation to find the latest working checkpoint
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
+ if latest_step is not None:
+ resuming = True
+ logging.info(
+ f'Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}'
+ )
+ else:
+ raise FileNotFoundError(
+ f'No valid checkpoints found in {exp_checkpoint_dir} for resume'
+ )
+ else:
+ raise FileNotFoundError(
+ f'Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume'
+ )
+ elif config.overwrite and config.checkpoint_dir.exists():
+ shutil.rmtree(config.checkpoint_dir)
+ logging.info(
+ f'Overwriting checkpoint directory: {config.checkpoint_dir}'
+ )
+
+ # Create checkpoint directory with experiment name
+ if not resuming:
+ # For new runs, create experiment-specific checkpoint directory
+ exp_checkpoint_dir = config.checkpoint_dir
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ logging.info(
+ f'Created experiment checkpoint directory: {exp_checkpoint_dir}'
+ )
+ else:
+ # For resume, checkpoint_dir is already set to the experiment directory
+ logging.info(
+ f'Using existing experiment checkpoint directory: {config.checkpoint_dir}'
+ )
+
+ # Initialize wandb (only on main process)
+ if is_main:
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
+
+ # Build data loader using the unified data loader
+ # Calculate effective batch size per GPU for DDP
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
+ effective_batch_size = config.batch_size // world_size
+ logging.info(
+ f'Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})'
+ )
+
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
+ loader, data_config = build_datasets(config)
+
+ # Log sample images to wandb on first batch
+ if is_main and config.wandb_enabled and not resuming:
+ # Create a separate data loader for sample batch to avoid consuming the main loader
+ sample_data_loader = _data.create_data_loader(
+ config, framework='pytorch', shuffle=False
+ )
+ sample_batch = next(iter(sample_data_loader))
+ # Convert observation and actions to torch tensors
+ observation, actions = sample_batch
+ sample_batch = observation.to_dict()
+ sample_batch['actions'] = actions
+
+ # Create sample images for wandb
+ images_to_log = []
+ # Get batch size from the first image tensor
+ batch_size = next(iter(sample_batch['image'].values())).shape[0]
+ for i in range(min(5, batch_size)):
+ # Concatenate all camera views horizontally for this batch item
+ # Convert from NCHW to NHWC format for wandb
+ img_concatenated = torch.cat(
+ [
+ img[i].permute(1, 2, 0)
+ for img in sample_batch['image'].values()
+ ],
+ axis=1,
+ )
+ img_concatenated = img_concatenated.cpu().numpy()
+ images_to_log.append(wandb.Image(img_concatenated))
+
+ wandb.log({'camera_views': images_to_log}, step=0)
+
+ # Clear sample batch from memory aggressively
+ del sample_batch, observation, actions, images_to_log, img_concatenated
+ del sample_data_loader # Also delete the sample data loader
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ logging.info('Cleared sample batch and data loader from memory')
+
+ # Build model
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
+ # Convert dataclass to Pi0Config if needed
+ model_cfg = openpi.models.pi0_config.Pi0Config(
+ dtype=config.pytorch_training_precision,
+ action_dim=config.model.action_dim,
+ action_horizon=config.model.action_horizon,
+ max_token_len=config.model.max_token_len,
+ paligemma_variant=getattr(
+ config.model, 'paligemma_variant', 'gemma_2b'
+ ),
+ action_expert_variant=getattr(
+ config.model, 'action_expert_variant', 'gemma_300m'
+ ),
+ pi05=getattr(config.model, 'pi05', False),
+ )
+ else:
+ model_cfg = config.model
+ # Update dtype to match pytorch_training_precision
+ object.__setattr__(
+ model_cfg, 'dtype', config.pytorch_training_precision
+ )
+
+ model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
+
+ if hasattr(model, 'gradient_checkpointing_enable'):
+ enable_gradient_checkpointing = True
+ model.gradient_checkpointing_enable()
+ logging.info('Enabled gradient checkpointing for memory optimization')
+ else:
+ enable_gradient_checkpointing = False
+ logging.info('Gradient checkpointing is not supported for this model')
+
+ # Log initial memory usage after model creation
+ if is_main and torch.cuda.is_available():
+ log_memory_usage(device, 0, 'after_model_creation')
+
+ # Enable memory optimizations for large-scale training
+ if world_size >= 8:
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ # Set memory allocation configuration
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = (
+ 'max_split_size_mb:128,expandable_segments:True'
+ )
+ logging.info('Enabled memory optimizations for 8+ GPU training')
+
+ if use_ddp:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[device.index] if device.type == 'cuda' else None,
+ find_unused_parameters=True, # Disable for memory efficiency
+ gradient_as_bucket_view=True, # Enable for memory efficiency
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
+ )
+
+ # Load weights from weight_loader if specified (for fine-tuning)
+ if config.pytorch_weight_path is not None:
+ logging.info(f'Loading weights from: {config.pytorch_weight_path}')
+
+ model_path = os.path.join(
+ config.pytorch_weight_path, 'model.safetensors'
+ )
+ safetensors.torch.load_model(
+ (
+ model.module
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model
+ ),
+ model_path,
+ )
+ logging.info(
+ f'Loaded PyTorch weights from {config.pytorch_weight_path}'
+ )
+
+ # Optimizer + learning rate schedule from config
+ warmup_steps = config.lr_schedule.warmup_steps
+ peak_lr = config.lr_schedule.peak_lr
+ decay_steps = config.lr_schedule.decay_steps
+ end_lr = config.lr_schedule.decay_lr
+
+ # Create optimizer with config parameters
+ optim = torch.optim.AdamW(
+ model.parameters(),
+ lr=peak_lr,
+ betas=(config.optimizer.b1, config.optimizer.b2),
+ eps=config.optimizer.eps,
+ weight_decay=config.optimizer.weight_decay,
+ )
+
+ # Load checkpoint if resuming
+ global_step = 0
+ if resuming:
+ global_step = load_checkpoint(
+ model, optim, config.checkpoint_dir, device
+ )
+ logging.info(f'Resumed training from step {global_step}')
+
+ def lr_schedule(step: int):
+ if step < warmup_steps:
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
+ init_lr = peak_lr / (warmup_steps + 1)
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
+ # cosine decay
+ progress = min(
+ 1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)
+ )
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
+ return end_lr + (peak_lr - end_lr) * cos
+
+ model.train()
+ start_time = time.time()
+ infos = [] # Collect stats over log interval
+ if is_main:
+ logging.info(
+ f'Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}'
+ )
+ logging.info(
+ f'Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}'
+ )
+ logging.info(
+ f'Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}'
+ )
+ logging.info(
+ f'LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}'
+ )
+ logging.info(
+ f'Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}'
+ )
+ logging.info('EMA is not supported for PyTorch training')
+ logging.info(f'Training precision: {model_cfg.dtype}')
+
+ # Training loop - iterate until we reach num_train_steps
+ pbar = (
+ tqdm.tqdm(
+ total=config.num_train_steps,
+ initial=global_step,
+ desc='Training',
+ disable=not is_main,
+ )
+ if is_main
+ else None
+ )
+
+ while global_step < config.num_train_steps:
+ # Set epoch for distributed training
+ if use_ddp and hasattr(loader, 'set_epoch'):
+ loader.set_epoch(global_step // len(loader))
+
+ for observation, actions in loader:
+ # Check if we've reached the target number of steps
+ if global_step >= config.num_train_steps:
+ break
+
+ # The unified data loader returns (observation, actions) tuple
+ observation = jax.tree.map(
+ lambda x: x.to(device), observation
+ ) # noqa: PLW2901
+ actions = actions.to(torch.float32) # noqa: PLW2901
+ actions = actions.to(device) # noqa: PLW2901
+
+ # Update LR
+ for pg in optim.param_groups:
+ pg['lr'] = lr_schedule(global_step)
+
+ # Forward pass
+ losses = model(observation, actions)
+ # Ensure losses is a tensor and handle different return types
+ if isinstance(losses, list | tuple):
+ losses = torch.stack(losses)
+ elif not isinstance(losses, torch.Tensor):
+ losses = torch.tensor(
+ losses, device=device, dtype=torch.float32
+ )
+
+ loss = losses.mean()
+
+ # Backward pass
+ loss.backward()
+
+ # Log memory usage after backward pass
+ if global_step < 5 and is_main and torch.cuda.is_available():
+ log_memory_usage(device, global_step, 'after_backward')
+
+ # Gradient clipping
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=config.optimizer.clip_gradient_norm,
+ )
+
+ # Optimizer step
+ optim.step()
+ optim.zero_grad(set_to_none=True)
+
+ # Clear gradients more aggressively
+ for param in model.parameters():
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad = None
+
+ # Collect stats
+ if is_main:
+ infos.append(
+ {
+ 'loss': loss.item(),
+ 'learning_rate': optim.param_groups[0]['lr'],
+ 'grad_norm': (
+ float(grad_norm)
+ if isinstance(grad_norm, torch.Tensor)
+ else grad_norm
+ ),
+ }
+ )
+
+ if is_main and (global_step % config.log_interval == 0):
+ elapsed = time.time() - start_time
+
+ # Average stats over log interval
+ avg_loss = sum(info['loss'] for info in infos) / len(infos)
+ avg_lr = sum(info['learning_rate'] for info in infos) / len(
+ infos
+ )
+
+ avg_grad_norm = None
+ if any('grad_norm' in info for info in infos):
+ vals = [
+ info['grad_norm']
+ for info in infos
+ if 'grad_norm' in info
+ and info['grad_norm'] is not None
+ ]
+ if len(vals) > 0:
+ avg_grad_norm = sum(vals) / len(vals)
+ logging.info(
+ f'step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s'
+ if avg_grad_norm is not None
+ else f'step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s'
+ )
+
+ # Log to wandb
+ if config.wandb_enabled and len(infos) > 0:
+ log_payload = {
+ 'loss': avg_loss,
+ 'learning_rate': avg_lr,
+ 'step': global_step,
+ 'time_per_step': elapsed / config.log_interval,
+ }
+ if avg_grad_norm is not None:
+ log_payload['grad_norm'] = avg_grad_norm
+ wandb.log(log_payload, step=global_step)
+
+ start_time = time.time()
+ infos = [] # Reset stats collection
+
+ global_step += 1
+ # Save checkpoint using the new mechanism
+ save_checkpoint(
+ model, optim, global_step, config, is_main, data_config
+ )
+
+ # Update progress bar
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_postfix(
+ {
+ 'loss': f'{loss.item():.4f}',
+ 'lr': f"{optim.param_groups[0]['lr']:.2e}",
+ 'step': global_step,
+ }
+ )
+
+ # Close progress bar
+ if pbar is not None:
+ pbar.close()
+
+ # Finish wandb run
+ if is_main and config.wandb_enabled:
+ wandb.finish()
+
+ cleanup_ddp()
+
+
+def main():
+ init_logging()
+ config = _config.cli()
+ train_loop(config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/vla_arena/models/openpi/scripts/train_test.py b/vla_arena/models/openpi/scripts/train_test.py
new file mode 100644
index 00000000..50a08b7c
--- /dev/null
+++ b/vla_arena/models/openpi/scripts/train_test.py
@@ -0,0 +1,45 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import os
+import pathlib
+
+import pytest
+
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+
+from openpi.training import config as _config
+
+from . import train
+
+
+@pytest.mark.parametrize('config_name', ['debug'])
+def test_train(tmp_path: pathlib.Path, config_name: str):
+ config = dataclasses.replace(
+ _config._CONFIGS_DICT[config_name], # noqa: SLF001
+ batch_size=2,
+ checkpoint_base_dir=str(tmp_path / 'checkpoint'),
+ exp_name='test',
+ overwrite=False,
+ resume=False,
+ num_train_steps=2,
+ log_interval=1,
+ )
+ train.main(config)
+
+ # test resuming
+ config = dataclasses.replace(config, resume=True, num_train_steps=4)
+ train.main(config)
diff --git a/vla_arena/models/openpi/src/openpi/__init__.py b/vla_arena/models/openpi/src/openpi/__init__.py
new file mode 100644
index 00000000..fb51bf31
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/vla_arena/models/openpi/src/openpi/conftest.py b/vla_arena/models/openpi/src/openpi/conftest.py
new file mode 100644
index 00000000..2281a712
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/conftest.py
@@ -0,0 +1,31 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import pynvml
+import pytest
+
+
+def set_jax_cpu_backend_if_no_gpu() -> None:
+ try:
+ pynvml.nvmlInit()
+ pynvml.nvmlShutdown()
+ except pynvml.NVMLError:
+ # No GPU found.
+ os.environ["JAX_PLATFORMS"] = "cpu"
+
+
+def pytest_configure(config: pytest.Config) -> None:
+ set_jax_cpu_backend_if_no_gpu()
diff --git a/vla_arena/models/openpi/src/openpi/models/__init__.py b/vla_arena/models/openpi/src/openpi/models/__init__.py
new file mode 100644
index 00000000..fb51bf31
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/vla_arena/models/openpi/src/openpi/models/gemma.py b/vla_arena/models/openpi/src/openpi/models/gemma.py
new file mode 100644
index 00000000..3a5f9f5c
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/gemma.py
@@ -0,0 +1,573 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Gemma adaptation for Pi, taken from big_vision.
+
+We follow this einsum axis naming convention:
+ B: batch
+ T: query length
+ S: k/v length
+ N: num query heads
+ K: num k/v heads
+ G: num query heads per k/v head
+ H: head dim
+ D: d_model ("features")
+"""
+
+import dataclasses
+from collections.abc import Sequence
+from typing import Literal, TypeAlias
+
+import einops
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import openpi.models.lora as lora
+import openpi.shared.array_typing as at
+import openpi.training.sharding as sharding
+
+
+PALIGEMMA_VOCAB_SIZE = 257_152
+
+
+@dataclasses.dataclass
+class Config:
+ width: int
+ depth: int
+ mlp_dim: int
+ num_heads: int
+ num_kv_heads: int
+ head_dim: int
+ lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(
+ default_factory=dict
+ )
+
+
+Variant = Literal[
+ 'dummy', 'gemma_300m', 'gemma_300m_lora', 'gemma_2b', 'gemma_2b_lora'
+]
+
+
+def get_config(variant: Variant) -> Config:
+ """Returns config for specified gemma variant."""
+ if variant == 'dummy':
+ return Config(
+ width=64,
+ depth=4,
+ mlp_dim=128,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=16,
+ )
+ if variant == 'gemma_300m':
+ # 311M params
+ return Config(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ if variant == 'gemma_2b':
+ return Config(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ if variant == 'gemma_2b_lora':
+ return Config(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ lora_configs={
+ 'attn': lora.LoRAConfig(rank=16, alpha=16.0),
+ 'ffn': lora.LoRAConfig(rank=16, alpha=16.0),
+ },
+ )
+ if variant == 'gemma_300m_lora':
+ # 311M params
+ return Config(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ lora_configs={
+ 'attn': lora.LoRAConfig(rank=32, alpha=32.0),
+ 'ffn': lora.LoRAConfig(rank=32, alpha=32.0),
+ },
+ )
+ raise ValueError(f'Unknown variant: {variant}')
+
+
+@at.typecheck
+class RMSNorm(nn.Module):
+ @nn.compact
+ def __call__(self, x, cond):
+ dtype = x.dtype # original dtype, could be half-precision
+ var = jnp.mean(
+ jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True
+ ) # compute variance in float32
+ normed_inputs = jnp.asarray(
+ x * jnp.reciprocal(jnp.sqrt(var + 1e-06))
+ ) # compute normalization in float32
+ if cond is None:
+ # regular RMSNorm
+ scale = self.param(
+ 'scale', nn.initializers.zeros_init(), (x.shape[-1])
+ )
+ normed_inputs = normed_inputs * (
+ 1 + scale
+ ) # scale by learned parameter in float32 (matches Flax implementation)
+ return (
+ normed_inputs.astype(dtype),
+ None,
+ ) # return in original dtype
+
+ # adaptive RMSNorm
+ modulation = nn.Dense(
+ x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype
+ )(cond)
+ scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
+ normed_inputs = (
+ normed_inputs * (1 + scale) + shift
+ ) # scale and shift in float32
+ return normed_inputs.astype(dtype), gate
+
+
+@at.typecheck
+class Embedder(nn.Module):
+ """Embedder module."""
+
+ vocab_size: int
+ embed_dim: int
+
+ def setup(self):
+ self.input_embedding_table = self.param(
+ 'input_embedding',
+ nn.initializers.normal(),
+ (self.vocab_size, self.embed_dim),
+ )
+
+ def encode(self, x):
+ x = self.input_embedding_table[(x,)]
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
+ return x
+
+ def decode(self, x):
+ return jnp.dot(x, self.input_embedding_table.T)
+
+
+@at.typecheck
+class Attention(nn.Module):
+ """Attention module."""
+
+ configs: Sequence[Config]
+
+ @nn.compact
+ def __call__(self, xs, positions, attn_mask, kv_cache):
+ # all experts must share the same head dim, num heads, and num kv heads for self-attention to work
+ assert all(
+ config.head_dim == self.configs[0].head_dim
+ for config in self.configs
+ )
+ assert all(
+ config.num_heads == self.configs[0].num_heads
+ for config in self.configs
+ )
+ assert all(
+ config.num_kv_heads == self.configs[0].num_kv_heads
+ for config in self.configs
+ )
+
+ dtype = next(
+ x.dtype for x in xs if x is not None
+ ) # original dtype, could be half-precision
+
+ qkvs = []
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is None:
+ continue
+ if config.num_kv_heads == config.num_heads:
+ qkv_einsum = lora.Einsum(
+ shape=(3, config.num_heads, config.width, config.head_dim),
+ name=_name('qkv_einsum', i),
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0, 1)
+ ),
+ lora_config=config.lora_configs.get('attn'),
+ )
+ qkvs.append(qkv_einsum('BSD,3KDH->3BSKH', x))
+ else:
+ q_einsum = lora.Einsum(
+ shape=(config.num_heads, config.width, config.head_dim),
+ name=_name('q_einsum', i),
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0,)
+ ),
+ lora_config=config.lora_configs.get('attn'),
+ )
+ q = q_einsum('BTD,NDH->BTNH', x)
+ kv_einsum = lora.Einsum(
+ shape=(
+ 2,
+ config.num_kv_heads,
+ config.width,
+ config.head_dim,
+ ),
+ name=_name('kv_einsum', i),
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0, 1)
+ ),
+ lora_config=config.lora_configs.get('attn'),
+ )
+ k, v = kv_einsum('BSD,2KDH->2BSKH', x)
+ qkvs.append((q, k, v))
+
+ q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
+
+ q = _apply_rope(q, positions=positions)
+ q *= self.configs[0].head_dim ** -0.5
+
+ k = _apply_rope(k, positions=positions)
+
+ # should still be half-precision here (if input was half-precision)
+ assert q.dtype == k.dtype == v.dtype == dtype
+
+ if kv_cache is not None:
+ cache_k, cache_v = kv_cache
+ k = jnp.concatenate([cache_k, k], axis=1)
+ v = jnp.concatenate([cache_v, v], axis=1)
+
+ q = einops.rearrange(
+ q, 'B T (K G) H -> B T K G H', K=self.configs[0].num_kv_heads
+ )
+ logits = jnp.einsum(
+ 'BTKGH,BSKH->BKGTS', q, k, preferred_element_type=jnp.float32
+ )
+
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
+ raise ValueError(
+ f'Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}'
+ )
+
+ # big_neg = jnp.finfo(logits.dtype).min
+ big_neg = -2.3819763e38 # See gemma/modules.py
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
+
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
+
+ encoded = jnp.einsum('BKGTS,BSKH->BTKGH', probs, v)
+ encoded = einops.rearrange(encoded, 'B T K G H -> B T (K G) H')
+
+ out = []
+ start = 0
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is not None:
+ end = start + x.shape[1]
+ out_einsum = lora.Einsum(
+ shape=(config.num_heads, config.head_dim, config.width),
+ name=_name('attn_vec_einsum', i),
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=(-3, -2), out_axis=-1
+ ),
+ lora_config=config.lora_configs.get('attn'),
+ )
+ out.append(out_einsum('BTNH,NHD->BTD', encoded[:, start:end]))
+ start = end
+ else:
+ out.append(None)
+
+ return out, (k, v)
+
+
+@at.typecheck
+class FeedForward(nn.Module):
+ """Feed forward module."""
+
+ features: int
+ hidden_dim: int
+
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ w_gating = self.param(
+ 'gating_einsum',
+ nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0,)
+ ),
+ (2, self.features, self.hidden_dim),
+ ).astype(dtype)
+ ff_gate = jnp.dot(x, w_gating[0])
+ gate_value = nn.gelu(ff_gate)
+
+ ff1 = jnp.dot(x, w_gating[1])
+ activations = gate_value * ff1
+
+ w_linear = self.param(
+ 'linear',
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
+ (self.hidden_dim, self.features),
+ ).astype(dtype)
+ outputs = jnp.dot(activations, w_linear)
+ assert outputs.dtype == dtype
+ return outputs
+
+
+@at.typecheck
+class Block(nn.Module):
+ """Transformer block."""
+
+ configs: tuple[Config, ...]
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = ()
+
+ @nn.compact
+ def __call__(
+ self,
+ xs,
+ kv_cache,
+ positions,
+ attn_mask,
+ adarms_cond,
+ deterministic=True,
+ ):
+ xs = sharding.activation_sharding_constraint(xs)
+ drop = (
+ nn.Dropout(self.dropout, self.dropout_bdims)
+ if self.dropout
+ else lambda x, _: x
+ )
+
+ attn = Attention(configs=self.configs, name='attn')
+
+ pre_attn = []
+ gates = []
+ for i, x in enumerate(xs):
+ if x is not None:
+ x, gate = RMSNorm(name=_name('pre_attention_norm', i))(
+ x, adarms_cond[i]
+ )
+ pre_attn.append(x)
+ gates.append(gate if x is not None else None)
+
+ pre_attn = sharding.activation_sharding_constraint(pre_attn)
+ post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
+ post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
+ post_attn = sharding.activation_sharding_constraint(post_attn)
+ xs = [
+ _gated_residual(x, y, gate)
+ for x, y, gate in zip(xs, post_attn, gates, strict=True)
+ ]
+ xs = sharding.activation_sharding_constraint(xs)
+
+ out = []
+ gates = []
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is not None:
+ x, gate = RMSNorm(name=_name('pre_ffw_norm', i))(
+ x, adarms_cond[i]
+ ) # noqa: PLW2901
+ x = lora.FeedForward( # noqa: PLW2901
+ features=config.width,
+ hidden_dim=config.mlp_dim,
+ name=_name('mlp', i),
+ lora_config=config.lora_configs.get('ffn'),
+ )(x)
+ out.append(x)
+ gates.append(gate if x is not None else None)
+
+ out = sharding.activation_sharding_constraint(out)
+ out = jax.tree.map(lambda x: drop(x, deterministic), out)
+ xs = [
+ _gated_residual(x, y, gate)
+ for x, y, gate in zip(xs, out, gates, strict=True)
+ ]
+ xs = sharding.activation_sharding_constraint(xs)
+
+ return xs, kv_cache
+
+
+KVCache: TypeAlias = tuple[
+ at.Float[at.Array, 'l b _t _k _h'], at.Float[at.Array, 'l b _t _v _h']
+]
+
+
+@at.typecheck
+class Module(nn.Module):
+ """Transformer model, supporting a mixture of different weights for different tokens."""
+
+ configs: Sequence[Config] # list of configs, one for each expert
+ embed_dtype: str
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[
+ int, ...
+ ] = () # Every float is dropped independently.
+ adarms: bool = False
+
+ def setup(self):
+ # all experts must have the same depth
+ assert all(
+ config.depth == self.configs[0].depth for config in self.configs
+ )
+
+ self.embedder = Embedder(
+ vocab_size=PALIGEMMA_VOCAB_SIZE,
+ embed_dim=self.configs[0].width, # embedder for first expert only
+ name='embedder',
+ )
+ block_cls = nn.remat(
+ Block,
+ prevent_cse=False,
+ static_argnums=(5,), # 0=self, 6=deterministic
+ policy=jax.checkpoint_policies.nothing_saveable,
+ )
+ self.layers = nn.scan(
+ block_cls,
+ variable_axes={'params': 0},
+ split_rngs={'params': True, 'dropout': True},
+ in_axes=(
+ 0,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
+ length=self.configs[0].depth,
+ )(
+ configs=self.configs,
+ dropout=self.dropout,
+ dropout_bdims=self.dropout_bdims,
+ )
+ self.final_norms = [
+ RMSNorm(name=_name('final_norm', i))
+ for i in range(len(self.configs))
+ ]
+
+ @at.typecheck
+ def embed(
+ self, tokens: at.Int[at.Array, 'b t']
+ ) -> at.Float[at.Array, 'b t d']:
+ return self.embedder.encode(tokens).astype(self.embed_dtype)
+
+ @at.typecheck
+ def __call__(
+ self,
+ # list of token arrays, one for each expert, or None if that expert should not be run
+ embedded: Sequence[at.Float[at.Array, 'b _t _d'] | None],
+ positions: at.Int[at.Array, 'b t'],
+ mask: at.Bool[at.Array, 'b t s'],
+ adarms_cond: Sequence[at.Float[at.Array, 'b _d'] | None] | None = None,
+ *,
+ kv_cache: KVCache | None = None,
+ deterministic: bool = True,
+ ) -> tuple[Sequence[at.Float[at.Array, 'b _t _d'] | None], KVCache]:
+ embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
+ mask = jnp.asarray(mask)[:, None, :, :]
+ if adarms_cond is None:
+ adarms_cond = [None] * len(self.configs)
+
+ embedded, kv_cache = self.layers(
+ embedded, kv_cache, positions, mask, adarms_cond, deterministic
+ )
+
+ assert all(
+ e.dtype == jnp.dtype(self.embed_dtype)
+ for e in embedded
+ if e is not None
+ )
+
+ return [
+ f(e, a)[0] if e is not None else e
+ for f, e, a in zip(
+ self.final_norms, embedded, adarms_cond, strict=True
+ )
+ ], kv_cache
+
+ def init(self, use_adarms: Sequence[bool]):
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
+ self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
+ self(
+ [jnp.zeros((1, 1, c.width)) for c in self.configs],
+ jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
+ jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
+ adarms_cond=[
+ jnp.zeros((1, c.width)) if u else None
+ for u, c in zip(use_adarms, self.configs, strict=True)
+ ],
+ )
+
+
+def _apply_rope(x, *, positions, max_wavelength=10_000):
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(
+ x.shape[-1] // 2, dtype=jnp.float32
+ )
+ timescale = max_wavelength**freq_exponents
+ radians = positions[..., None] / timescale[None, None, :]
+ radians = radians[..., None, :]
+ assert radians.dtype == jnp.float32
+ # radians.shape = [...,L,1,d=D/2]
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
+ x1, x2 = jnp.split(x, 2, axis=-1)
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
+ assert res.dtype == jnp.float32
+ # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
+ # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
+ # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
+ # here.
+ return res.astype(x.dtype)
+
+
+def _name(name, i):
+ # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
+ # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
+ # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
+ # and the action expert.
+ if i == 0:
+ return name
+ return f'{name}_{i}'
+
+
+def _gated_residual(x, y, gate):
+ assert (x is None) == (y is None)
+ if x is None:
+ return None
+ if gate is None:
+ return x + y
+ return x + y * gate
diff --git a/vla_arena/models/openpi/src/openpi/models/gemma_fast.py b/vla_arena/models/openpi/src/openpi/models/gemma_fast.py
new file mode 100644
index 00000000..b0dd4453
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/gemma_fast.py
@@ -0,0 +1,515 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
+Used for FAST autoregressive policies.
+"""
+
+import dataclasses
+from typing import Literal, TypeAlias
+
+import einops
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import ml_collections
+import openpi.models.lora as lora
+import openpi.shared.array_typing as at
+
+
+Variant = Literal['gemma_2b', 'gemma_2b_lora']
+
+
+def get_config(variant):
+ """Returns config for specified gemma variant."""
+ if variant == 'gemma_2b':
+ return ml_collections.ConfigDict(
+ {
+ 'variant': variant,
+ 'width': 2048,
+ 'depth': 18,
+ 'mlp_dim': 16_384,
+ 'num_heads': 8,
+ 'num_kv_heads': 1,
+ 'head_dim': 256,
+ 'norm_eps': 1e-6,
+ 'vocab_size': 257_152,
+ 'scan': True,
+ 'remat_policy': 'nothing_saveable',
+ }
+ )
+ if variant == 'gemma_2b_lora':
+ return ml_collections.ConfigDict(
+ {
+ 'variant': variant,
+ 'width': 2048,
+ 'depth': 18,
+ 'mlp_dim': 16_384,
+ 'num_heads': 8,
+ 'num_kv_heads': 1,
+ 'head_dim': 256,
+ 'norm_eps': 1e-6,
+ 'vocab_size': 257_152,
+ 'scan': True,
+ 'remat_policy': 'nothing_saveable',
+ 'lora_configs': {
+ 'attn': lora.LoRAConfig(rank=16, alpha=16.0),
+ 'ffn': lora.LoRAConfig(rank=16, alpha=16.0),
+ },
+ }
+ )
+ raise ValueError(f'Unknown variant: {variant}')
+
+
+@at.typecheck
+class Einsum(nn.Module):
+ shape: tuple[int, ...]
+
+ @nn.compact
+ def __call__(self, eqn, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ w = self.param('w', nn.initializers.zeros_init(), self.shape).astype(
+ dtype
+ )
+ return jnp.einsum(eqn, x, w)
+
+
+@at.typecheck
+class RMSNorm(nn.Module):
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ scale = self.param(
+ 'scale', nn.initializers.zeros_init(), (x.shape[-1])
+ )
+ var = jnp.mean(
+ jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True
+ ) # compute variance in float32
+ normed_inputs = jnp.asarray(
+ x * jnp.reciprocal(jnp.sqrt(var + 1e-06))
+ ) # compute normalization in float32
+ normed_inputs = normed_inputs * (
+ 1 + scale
+ ) # scale by learned parameter in float32 (matches Flax implementation)
+ return normed_inputs.astype(dtype) # return in original dtype
+
+
+@at.typecheck
+class Embedder(nn.Module):
+ """Embedder module."""
+
+ vocab_size: int
+ embed_dim: int
+
+ def setup(self):
+ self.input_embedding_table = self.param(
+ 'input_embedding',
+ nn.initializers.zeros_init(),
+ (self.vocab_size, self.embed_dim),
+ )
+
+ def encode(self, x):
+ x = self.input_embedding_table[(x,)]
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
+ return x
+
+ def decode(self, x):
+ return jnp.dot(x, self.input_embedding_table.T)
+
+
+@at.typecheck
+class Attention(nn.Module):
+ """Attention module."""
+
+ num_heads: int
+ num_kv_heads: int
+ features: int
+ head_dim: int
+
+ cache_dtype: str | None = None
+
+ lora_config: lora.LoRAConfig | None = None
+
+ def setup(self):
+ if self.num_kv_heads == self.num_heads:
+ self.qkv_einsum = lora.Einsum(
+ shape=(3, self.num_heads, self.features, self.head_dim),
+ name='qkv_einsum',
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0, 1)
+ ),
+ lora_config=self.lora_config,
+ )
+ else:
+ self.q_einsum = lora.Einsum(
+ shape=(self.num_heads, self.features, self.head_dim),
+ name='q_einsum',
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0,)
+ ),
+ lora_config=self.lora_config,
+ )
+ self.kv_einsum = lora.Einsum(
+ shape=(2, self.num_kv_heads, self.features, self.head_dim),
+ name='kv_einsum',
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0, 1)
+ ),
+ lora_config=self.lora_config,
+ )
+ self.attn_vec_einsum = lora.Einsum(
+ shape=(self.num_heads, self.head_dim, self.features),
+ name='attn_vec_einsum',
+ init_fn=nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0,)
+ ),
+ lora_config=self.lora_config,
+ )
+
+ def _init_cache(self, k, v, cache_size):
+ """Initialize KV cache"""
+ prefill_len = k.shape[1]
+ pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
+ cache_dtype = self.cache_dtype or k.dtype
+ k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
+ v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
+ idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
+ return idx, k_cache, v_cache
+
+ def _update_cache(self, k, v, idx, k_cache, v_cache):
+ """Update KV cache with new values"""
+ assert k.shape[1] == 1, 'Only support kv-cache updates of length 1'
+ indices = (0, idx[0], 0, 0)
+ cache_dtype = self.cache_dtype or k.dtype
+ k_new = jax.lax.dynamic_update_slice(
+ k_cache, k.astype(cache_dtype), indices
+ )
+ v_new = jax.lax.dynamic_update_slice(
+ v_cache, v.astype(cache_dtype), indices
+ )
+ idx_new = idx + 1
+ return idx_new, k_new, v_new
+
+ @nn.compact
+ def __call__(
+ self, x, positions, attn_mask, kv_cache, decode, deterministic=True
+ ):
+ dtype = x.dtype # original dtype, could be half-precision
+ if self.num_kv_heads == self.num_heads:
+ q, k, v = self.qkv_einsum('BSD,3KDH->3BSKH', x)
+ else:
+ q = self.q_einsum('BTD,NDH->BTNH', x)
+ k, v = self.kv_einsum('BSD,2KDH->2BSKH', x)
+
+ q = _apply_rope(q, positions=positions) # promotes to float32
+ q *= self.head_dim**-0.5
+
+ k = _apply_rope(k, positions=positions) # promotes to float32
+
+ if kv_cache is None:
+ idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
+ else:
+ idx, k_cache, v_cache = kv_cache
+ idx, k_cache, v_cache = self._update_cache(
+ k, v, idx, k_cache, v_cache
+ )
+
+ k, v = k_cache, v_cache
+ kv_cache = (idx, k_cache, v_cache)
+
+ q = einops.rearrange(
+ q, 'B T (K G) H -> B T K G H', K=self.num_kv_heads
+ )
+ logits = jnp.einsum(
+ 'BTKGH,BSKH->BKGTS', q, k, preferred_element_type=jnp.float32
+ )
+
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
+ raise ValueError(
+ f'Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}'
+ )
+
+ # big_neg = jnp.finfo(logits.dtype).min
+ big_neg = -2.3819763e38 # See gemma/modules.py
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
+
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
+
+ encoded = jnp.einsum('BKGTS,BSKH->BTKGH', probs, v)
+ encoded = einops.rearrange(encoded, 'B T K G H -> B T (K G) H')
+ return self.attn_vec_einsum('BTNH,NHD->BTD', encoded), kv_cache
+
+
+@at.typecheck
+class Block(nn.Module):
+ """Transformer block."""
+
+ num_heads: int
+ num_kv_heads: int
+ embed_dim: int
+ head_dim: int
+ hidden_dim: int
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = ()
+ cache_dtype: str | None = None
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(
+ default_factory=ml_collections.ConfigDict
+ )
+
+ def setup(self):
+ self.pre_attention_norm = RMSNorm()
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ num_kv_heads=self.num_kv_heads,
+ features=self.embed_dim,
+ head_dim=self.head_dim,
+ cache_dtype=self.cache_dtype,
+ lora_config=self.lora_configs.get('attn'),
+ )
+ self.pre_ffw_norm = RMSNorm()
+ self.mlp = lora.FeedForward(
+ features=self.embed_dim,
+ hidden_dim=self.hidden_dim,
+ name='mlp',
+ lora_config=self.lora_configs.get('ffn'),
+ )
+ if self.dropout:
+ self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
+ else:
+ self.drop = lambda x, _: x
+
+ def __call__(
+ self, x, kv_cache, positions, attn_mask, decode, deterministic=True
+ ):
+ x = nn.with_logical_constraint(x, ('act_batch', 'act_len', 'act_emb'))
+ inputs_normalized = self.pre_attention_norm(x)
+ attn_output, kv_cache = self.attn(
+ inputs_normalized,
+ positions,
+ attn_mask,
+ kv_cache,
+ decode,
+ deterministic,
+ )
+ attn_output = self.drop(attn_output, deterministic)
+ attn_output += x
+ residual = attn_output
+ attn_output = self.pre_ffw_norm(attn_output)
+ outputs = self.mlp(attn_output)
+ outputs = self.drop(outputs, deterministic)
+ outputs = residual + outputs
+ return outputs, kv_cache
+
+
+KVCache: TypeAlias = tuple[
+ at.Int[at.Array, ' b'],
+ at.Float[at.Array, 'b _t _k _h'],
+ at.Float[at.Array, 'b _t _v _h'],
+]
+
+
+@at.typecheck
+class Module(nn.Module):
+ """gemma model."""
+
+ variant: str
+
+ width: int
+ depth: int
+ mlp_dim: int
+ num_heads: int
+ num_kv_heads: int
+ head_dim: int
+ norm_eps: float
+ vocab_size: int
+ embed_dtype: str
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[
+ int, ...
+ ] = () # Every float is dropped independently.
+ cache_dtype: str | None = None
+
+ scan: bool = False
+ remat_policy: str = 'none'
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(
+ default_factory=ml_collections.ConfigDict
+ )
+
+ @nn.compact
+ def __call__(
+ self,
+ tokens=None,
+ embedded_prefix=None,
+ embed_only=False, # noqa: FBT002
+ pre_logits=None,
+ positions=None,
+ mask=None,
+ decode=False, # noqa: FBT002
+ kv_cache=None,
+ deterministic=True, # noqa: FBT002
+ return_prelogits=False, # noqa: FBT002
+ ):
+ """Embed only, or complete forward pass.
+
+ Args:
+ tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
+ embedded_prefix: Optional prefix that is already embedded.
+ embed_only: Whether to compute embeddings only.
+ pre_logits: If present computes logits from pre_logits and returns.
+ positions: Optional `[B, T]` allows to specify the absolute position of
+ the tokens.
+ mask: Optional attention mask `[B, T, S]`.
+ decode: Whether to use kv-cache. Caller must pass masks and positions.
+ deterministic: Forwarded to all dropout layers.
+ return_prelogits: Whether to return the pre-logits.
+
+ Returns:
+ If `embed_only=False`, then `(logits, out)` will be returned.
+ If `embed_only=True`, then the embeddings will be returned.
+ If `return_prelogits=True`, then the pre-logits will be returned.
+ """
+ out = {}
+
+ embedder = Embedder(
+ vocab_size=self.vocab_size, embed_dim=self.width, name='embedder'
+ )
+
+ if pre_logits is not None:
+ x = out['pre_logits'] = pre_logits
+ logits = out['logits'] = embedder.decode(x)
+ return logits, out
+
+ x = []
+ if embedded_prefix is not None:
+ x.append(embedded_prefix)
+ if tokens is not None:
+ x.append(embedder.encode(tokens))
+
+ x = jnp.concatenate(x, axis=-2)
+ x = x.astype(self.embed_dtype)
+ batch_size, seq_len, width = x.shape
+
+ if embed_only:
+ return x
+
+ if decode:
+ assert (
+ positions is not None and mask is not None
+ ), 'Must explicitly pass positions and mask for decoding.'
+
+ if positions is None:
+ positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
+ assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
+
+ if mask is None:
+ mask = nn.attention.make_causal_mask(
+ jnp.ones([batch_size, seq_len])
+ )
+ if mask.ndim == 3:
+ mask = mask[:, None, :, :]
+ cache_size = max(seq_len, mask.shape[-1])
+ assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
+
+ if self.remat_policy == 'none':
+ block_cls = Block
+ else:
+ block_cls = nn.remat(
+ Block,
+ prevent_cse=not self.scan,
+ static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
+ policy=getattr(jax.checkpoint_policies, self.remat_policy),
+ )
+
+ block_kw = {
+ 'num_heads': self.num_heads,
+ 'head_dim': self.head_dim,
+ 'num_kv_heads': self.num_kv_heads,
+ 'embed_dim': width,
+ 'hidden_dim': self.mlp_dim,
+ 'dropout': self.dropout,
+ 'dropout_bdims': self.dropout_bdims,
+ 'cache_dtype': self.cache_dtype,
+ 'lora_configs': self.lora_configs,
+ }
+ layers = self.scope.push('layers')
+ blocks = [
+ nn.scan(
+ block_cls,
+ variable_axes={'params': 0},
+ split_rngs={'params': True, 'dropout': True},
+ in_axes=(
+ 0,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ ), # 0=kv_cache, 1=positions, 2=mask
+ length=self.depth,
+ )(parent=layers, **block_kw)
+ ]
+ for block in blocks:
+ x, kv_cache = block(
+ x, kv_cache, positions, mask, decode, deterministic
+ )
+
+ assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
+ out['encoded'] = x
+
+ x = RMSNorm(name='final_norm')(x)
+ out['pre_logits'] = x
+ if return_prelogits:
+ return x, kv_cache, out
+
+ x = embedder.decode(x)
+ out['logits'] = x
+
+ return x, kv_cache, out
+
+ def init(self):
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
+ self(jnp.zeros((1, 1), dtype=jnp.int32))
+
+
+def _apply_rope(x, *, positions, max_wavelength=10_000):
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(
+ x.shape[-1] // 2, dtype=jnp.float32
+ )
+ timescale = max_wavelength**freq_exponents
+ radians = positions[..., None] / timescale[None, None, :]
+ radians = radians[..., None, :]
+ assert radians.dtype == jnp.float32
+ # radians.shape = [...,L,1,d=D/2]
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
+ x1, x2 = jnp.split(x, 2, axis=-1)
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
+ assert res.dtype == jnp.float32
+ return res
diff --git a/vla_arena/models/openpi/src/openpi/models/lora.py b/vla_arena/models/openpi/src/openpi/models/lora.py
new file mode 100644
index 00000000..06f286f2
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/lora.py
@@ -0,0 +1,197 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import re
+
+import flax.linen as nn
+import flax.struct as struct
+import jax.numpy as jnp
+import openpi.shared.array_typing as at
+
+
+@struct.dataclass
+class LoRAConfig:
+ """Configuration for LoRA."""
+
+ # LoRA rank.
+ rank: int
+ # LoRA scaling factor.
+ alpha: float = 1.0
+ # Initialization function for LoRA parameters.
+ init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
+ # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
+ rslora: bool = False
+ # Axes in the weight to apply LoRA to. Should typically be the last two axes.
+ axes: tuple[int, int] = (-2, -1)
+ # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
+ label: str = 'L'
+
+ @property
+ def scaling_value(self) -> float:
+ return (
+ self.alpha / math.sqrt(self.rank)
+ if self.rslora
+ else self.alpha / self.rank
+ )
+
+
+class Einsum(nn.Module):
+ """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
+
+ # Shape of the weight.
+ shape: tuple[int, ...]
+ # Initialization function for the weight.
+ init_fn: nn.initializers.Initializer = nn.initializers.zeros
+ # If not None, apply LoRA to the weight.
+ lora_config: LoRAConfig | None = None
+
+ def setup(self):
+ self.w = self.param('w', self.init_fn, self.shape)
+
+ if config := self.lora_config:
+ # Setup LoRA parameters.
+ shape_a, shape_b = list(self.shape), list(self.shape)
+ shape_a[config.axes[1]] = config.rank
+ shape_b[config.axes[0]] = config.rank
+ self.w_a = self.param('lora_a', config.init_fn, shape_a)
+ self.w_b = self.param('lora_b', config.init_fn, shape_b)
+
+ @nn.compact
+ def __call__(self, eqn: str, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ result = jnp.einsum(eqn, x, self.w.astype(dtype))
+
+ if config := self.lora_config:
+ eqn_a, eqn_b = self._make_lora_eqns(eqn)
+ lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
+ lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
+ result = result + lora * config.scaling_value
+
+ return result
+
+ def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
+ if 'L' in eqn:
+ raise ValueError(f'L already in eqn: {eqn}')
+ if not (m := re.match('(.*),(.*)->(.*)', eqn)):
+ raise ValueError(f'Unsupported einsum eqn: {eqn}')
+ lhs, rhs, out = m.groups()
+
+ assert self.lora_config is not None
+ a_label, b_label = (rhs[x] for x in self.lora_config.axes)
+ label = self.lora_config.label
+
+ a_rhs = rhs.replace(b_label, label)
+ a_out = out.replace(b_label, label)
+ eqn_a = f'{lhs},{a_rhs}->{a_out}'
+
+ b_rhs = rhs.replace(a_label, label)
+ eqn_b = f'{a_out},{b_rhs}->{out}'
+
+ return eqn_a, eqn_b
+
+
+class FeedForward(nn.Module):
+ """Feed forward module."""
+
+ features: int
+ hidden_dim: int
+ # If not None, apply LoRA to the weight.
+ lora_config: LoRAConfig | None = None
+
+ def setup(self):
+ self.w_gating = self.param(
+ 'gating_einsum',
+ nn.initializers.lecun_normal(
+ in_axis=-2, out_axis=-1, batch_axis=(0,)
+ ),
+ (2, self.features, self.hidden_dim),
+ )
+ self.w_linear = self.param(
+ 'linear',
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
+ (self.hidden_dim, self.features),
+ )
+ self.w_gating_lora = None
+ self.w_linear_lora = None
+ if self.lora_config:
+ # Setup LoRA parameters.
+ # TODO: follow up with a simplified init_fn api.
+ self.w_gating_lora = (
+ self.param(
+ 'gating_einsum_lora_a',
+ self.lora_config.init_fn,
+ (2, self.features, self.lora_config.rank),
+ ),
+ self.param(
+ 'gating_einsum_lora_b',
+ self.lora_config.init_fn,
+ (2, self.lora_config.rank, self.hidden_dim),
+ ),
+ )
+ self.w_linear_lora = (
+ self.param(
+ 'linear_lora_a',
+ self.lora_config.init_fn,
+ (self.hidden_dim, self.lora_config.rank),
+ ),
+ self.param(
+ 'linear_lora_b',
+ self.lora_config.init_fn,
+ (self.lora_config.rank, self.features),
+ ),
+ )
+
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ ff_gate = self._dot(
+ x,
+ self.w_gating[0],
+ (
+ None
+ if self.w_gating_lora is None
+ else (self.w_gating_lora[0][0], self.w_gating_lora[1][0])
+ ),
+ )
+ gate_value = nn.gelu(ff_gate)
+
+ ff1 = self._dot(
+ x,
+ self.w_gating[1],
+ (
+ None
+ if self.w_gating_lora is None
+ else (self.w_gating_lora[0][1], self.w_gating_lora[1][1])
+ ),
+ )
+ activations = gate_value * ff1
+
+ outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
+ assert outputs.dtype == dtype
+ return outputs
+
+ def _dot(
+ self,
+ x: at.Array,
+ w: at.Array,
+ lora_weights: tuple[at.Array, at.Array] | None,
+ ) -> at.Array:
+ base = jnp.dot(x, w.astype(x.dtype))
+ if lora_weights is None:
+ return base
+ return base + jnp.dot(
+ jnp.dot(x, lora_weights[0].astype(x.dtype)),
+ lora_weights[1].astype(x.dtype),
+ )
diff --git a/vla_arena/models/openpi/src/openpi/models/lora_test.py b/vla_arena/models/openpi/src/openpi/models/lora_test.py
new file mode 100644
index 00000000..392a73d1
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/lora_test.py
@@ -0,0 +1,112 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import openpi.models.lora as lora
+
+
+def test_lora_einsum_params_shape():
+ shape = (3, 8, 32, 4) # (3KDH)
+ einsum = lora.Einsum(shape)
+ lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
+ lora1 = lora.Einsum(
+ shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
+ eqn = 'BSD,3KDH->3BSKH'
+
+ # Ensure that lora parameters are not initialized when LoRA is not used.
+ params = einsum.init(key, eqn, x)
+ assert 'lora_a' not in params['params']
+ assert 'lora_b' not in params['params']
+
+ # Check that default axes work.
+ params_lora0 = lora0.init(key, eqn, x)
+ assert params_lora0['params']['lora_a'].shape == (3, 8, 32, 2)
+ assert params_lora0['params']['lora_b'].shape == (3, 8, 2, 4)
+
+ # Check that user provided axes work.
+ params_lora1 = lora1.init(key, eqn, x)
+ assert params_lora1['params']['lora_a'].shape == (3, 8, 2, 4)
+ assert params_lora1['params']['lora_b'].shape == (3, 2, 32, 4)
+
+
+def test_lora_einsum_same_output():
+ shape = (3, 8, 32, 4) # (3KDH)
+ einsum = lora.Einsum(shape)
+ einsum_lora = lora.Einsum(
+ shape,
+ lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
+ eqn = 'BSD,3KDH->3BSKH'
+
+ params = einsum.init(key, eqn, x)
+ output = einsum.apply(params, eqn, x)
+
+ params_lora = einsum_lora.init(key, eqn, x)
+ output_lora = einsum_lora.apply(params_lora, eqn, x)
+
+ # Results are the same since the LoRA parameters are initialized to zeros.
+ assert jnp.allclose(output, output_lora)
+
+
+def test_lora_ffn_params_shape():
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
+ ffn_lora = lora.FeedForward(
+ features=8,
+ hidden_dim=32,
+ lora_config=lora.LoRAConfig(rank=2),
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (2, 8))
+
+ params = ffn.init(key, x)
+ assert params['params']['gating_einsum'].shape == (2, 8, 32)
+ assert params['params']['linear'].shape == (32, 8)
+
+ params_lora = ffn_lora.init(key, x)
+ assert params_lora['params']['gating_einsum'].shape == (2, 8, 32)
+ assert params_lora['params']['linear'].shape == (32, 8)
+ assert params_lora['params']['gating_einsum_lora_a'].shape == (2, 8, 2)
+ assert params_lora['params']['gating_einsum_lora_b'].shape == (2, 2, 32)
+ assert params_lora['params']['linear_lora_a'].shape == (32, 2)
+ assert params_lora['params']['linear_lora_b'].shape == (2, 8)
+
+
+def test_lora_ffn_same_output():
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
+ ffn_lora = lora.FeedForward(
+ features=8,
+ hidden_dim=32,
+ lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (2, 8))
+
+ params = ffn.init(key, x)
+ output = ffn.apply(params, x)
+
+ params_lora = ffn_lora.init(key, x)
+ output_lora = ffn_lora.apply(params_lora, x)
+
+ assert jnp.allclose(output, output_lora)
diff --git a/vla_arena/models/openpi/src/openpi/models/model.py b/vla_arena/models/openpi/src/openpi/models/model.py
new file mode 100644
index 00000000..fffc9de8
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/model.py
@@ -0,0 +1,388 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import dataclasses
+import enum
+import logging
+import pathlib
+from collections.abc import Sequence
+from typing import Generic, TypeVar
+
+import augmax
+import jax
+import jax.numpy as jnp
+import numpy as np
+import openpi.shared.array_typing as at
+import orbax.checkpoint as ocp
+import safetensors
+import torch
+from flax import nnx, struct, traverse_util
+from openpi.models_pytorch import pi0_pytorch
+from openpi.shared import image_tools
+
+
+logger = logging.getLogger('openpi')
+
+# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
+ArrayT = TypeVar('ArrayT', bound=jax.Array | torch.Tensor | np.ndarray)
+
+
+class ModelType(enum.Enum):
+ """Supported model types."""
+
+ PI0 = 'pi0'
+ PI0_FAST = 'pi0_fast'
+ PI05 = 'pi05'
+
+
+# The model always expects these images
+IMAGE_KEYS = (
+ 'base_0_rgb',
+ 'left_wrist_0_rgb',
+ 'right_wrist_0_rgb',
+)
+
+
+# This may need change if we release a small model.
+IMAGE_RESOLUTION = (224, 224)
+
+
+# Data format
+#
+# Data transforms produce the model input as a nested dictionary which is later converted
+# into `Obesrvation` and `Actions` objects. See below.
+#
+# In the dictory form, this data should look like:
+# {
+# # Observation data.
+# "image": {
+# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
+# ... # Additional camera views
+# },
+# "image_mask": {
+# "base_0_rgb": bool[*b], # True if image is valid
+# ... # Masks for additional views
+# },
+# "state": float32[*b, s], # Low-dimensional robot state
+# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
+# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
+# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
+# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
+#
+# # Actions data.
+# "actions": float32[*b ah ad]
+# }
+# where:
+# *b = batch dimensions
+# h,w = image height/width
+# s = state dimension
+# l = sequence length
+#
+@at.typecheck
+@struct.dataclass
+class Observation(Generic[ArrayT]):
+ """Holds observations, i.e., inputs to the model.
+
+ See `Observation.from_dict` to see the expected dictionary form. This is the format
+ that should be produced by the data transforms.
+ """
+
+ # Images, in [-1, 1] float32.
+ images: dict[str, at.Float[ArrayT, '*b h w c']]
+ # Image masks, with same keys as images.
+ image_masks: dict[str, at.Bool[ArrayT, '*b']]
+ # Low-dimensional robot state.
+ state: at.Float[ArrayT, '*b s']
+
+ # Tokenized prompt.
+ tokenized_prompt: at.Int[ArrayT, '*b l'] | None = None
+ # Tokenized prompt mask.
+ tokenized_prompt_mask: at.Bool[ArrayT, '*b l'] | None = None
+
+ # pi0-fast model specific fields.
+
+ # Token auto-regressive mask (for FAST autoregressive model).
+ token_ar_mask: at.Int[ArrayT, '*b l'] | None = None
+ # Token loss mask (for FAST autoregressive model).
+ token_loss_mask: at.Bool[ArrayT, '*b l'] | None = None
+
+ @classmethod
+ def from_dict(cls, data: at.PyTree[ArrayT]) -> 'Observation[ArrayT]':
+ """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
+ # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
+ if ('tokenized_prompt' in data) != ('tokenized_prompt_mask' in data):
+ raise ValueError(
+ 'tokenized_prompt and tokenized_prompt_mask must be provided together.'
+ )
+ # If images are uint8, convert them to [-1, 1] float32.
+ for key in data['image']:
+ if data['image'][key].dtype == np.uint8:
+ data['image'][key] = (
+ data['image'][key].astype(np.float32) / 255.0 * 2.0 - 1.0
+ )
+ elif (
+ hasattr(data['image'][key], 'dtype')
+ and data['image'][key].dtype == torch.uint8
+ ):
+ data['image'][key] = (
+ data['image'][key].to(torch.float32).permute(0, 3, 1, 2)
+ / 255.0
+ * 2.0
+ - 1.0
+ )
+ return cls(
+ images=data['image'],
+ image_masks=data['image_mask'],
+ state=data['state'],
+ tokenized_prompt=data.get('tokenized_prompt'),
+ tokenized_prompt_mask=data.get('tokenized_prompt_mask'),
+ token_ar_mask=data.get('token_ar_mask'),
+ token_loss_mask=data.get('token_loss_mask'),
+ )
+
+ def to_dict(self) -> at.PyTree[ArrayT]:
+ """Convert the Observation to a nested dict."""
+ result = dataclasses.asdict(self)
+ result['image'] = result.pop('images')
+ result['image_mask'] = result.pop('image_masks')
+ return result
+
+
+# Defines the format of the actions. This field is included as "actions" inside the dictionary
+# produced by the data transforms.
+Actions = at.Float[ArrayT, '*b ah ad']
+
+
+def preprocess_observation(
+ rng: at.KeyArrayLike | None,
+ observation: Observation,
+ *,
+ train: bool = False,
+ image_keys: Sequence[str] = IMAGE_KEYS,
+ image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
+) -> Observation:
+ """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
+ filling in a default image mask (if necessary).
+ """
+
+ if not set(image_keys).issubset(observation.images):
+ raise ValueError(
+ f'images dict missing keys: expected {image_keys}, got {list(observation.images)}'
+ )
+
+ batch_shape = observation.state.shape[:-1]
+
+ out_images = {}
+ for key in image_keys:
+ image = observation.images[key]
+ if image.shape[1:3] != image_resolution:
+ logger.info(
+ f'Resizing image {key} from {image.shape[1:3]} to {image_resolution}'
+ )
+ image = image_tools.resize_with_pad(image, *image_resolution)
+
+ if train:
+ # Convert from [-1, 1] to [0, 1] for augmax.
+ image = image / 2.0 + 0.5
+
+ transforms = []
+ if 'wrist' not in key:
+ height, width = image.shape[1:3]
+ transforms += [
+ augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
+ augmax.Resize(width, height),
+ augmax.Rotate((-5, 5)),
+ ]
+ transforms += [
+ augmax.ColorJitter(
+ brightness=0.3, contrast=0.4, saturation=0.5
+ ),
+ ]
+ sub_rngs = jax.random.split(rng, image.shape[0])
+ image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
+
+ # Back to [-1, 1].
+ image = image * 2.0 - 1.0
+
+ out_images[key] = image
+
+ # obtain mask
+ out_masks = {}
+ for key in out_images:
+ if key not in observation.image_masks:
+ # do not mask by default
+ out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
+ else:
+ out_masks[key] = jnp.asarray(observation.image_masks[key])
+
+ return Observation(
+ images=out_images,
+ image_masks=out_masks,
+ state=observation.state,
+ tokenized_prompt=observation.tokenized_prompt,
+ tokenized_prompt_mask=observation.tokenized_prompt_mask,
+ token_ar_mask=observation.token_ar_mask,
+ token_loss_mask=observation.token_loss_mask,
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class BaseModelConfig(abc.ABC):
+ """Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
+ method to create the corresponding model.
+ """
+
+ # Action space dimension.
+ action_dim: int
+ # Action sequence length.
+ action_horizon: int
+ # Tokenized prompt maximum length.
+ max_token_len: int
+
+ @property
+ @abc.abstractmethod
+ def model_type(self) -> ModelType:
+ """The model type."""
+
+ @abc.abstractmethod
+ def create(self, rng: at.KeyArrayLike) -> 'BaseModel':
+ """Create a new model, initializing parameters."""
+
+ def load(
+ self, params: at.Params, *, remove_extra_params: bool = True
+ ) -> 'BaseModel':
+ """Create a model with the given parameters."""
+ model = nnx.eval_shape(self.create, jax.random.key(0))
+ graphdef, state = nnx.split(model)
+ if remove_extra_params:
+ params = ocp.transform_utils.intersect_trees(
+ state.to_pure_dict(), params
+ )
+ at.check_pytree_equality(
+ expected=state.to_pure_dict(),
+ got=params,
+ check_shapes=True,
+ check_dtypes=False,
+ )
+ state.replace_by_pure_dict(params)
+ return nnx.merge(graphdef, state)
+
+ def load_pytorch(self, train_config, weight_path: str):
+ logger.info(f'train_config: {train_config}')
+ model = pi0_pytorch.PI0Pytorch(config=train_config.model)
+ safetensors.torch.load_model(model, weight_path)
+ return model
+
+ @abc.abstractmethod
+ def inputs_spec(
+ self, *, batch_size: int = 1
+ ) -> tuple[Observation, Actions]:
+ """Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
+
+ def fake_obs(self, batch_size: int = 1) -> Observation:
+ observation_spec, _ = self.inputs_spec(batch_size=batch_size)
+ return jax.tree.map(
+ lambda x: jnp.ones(x.shape, x.dtype), observation_spec
+ )
+
+ def fake_act(self, batch_size: int = 1) -> Actions:
+ _, action_spec = self.inputs_spec(batch_size=batch_size)
+ return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
+
+
+@dataclasses.dataclass
+class BaseModel(nnx.Module, abc.ABC):
+ """Base class for all model implementations. Specific models should inherit from this class. They should call
+ super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
+ """
+
+ action_dim: int
+ action_horizon: int
+ max_token_len: int
+
+ @abc.abstractmethod
+ def compute_loss(
+ self,
+ rng: at.KeyArrayLike,
+ observation: Observation,
+ actions: Actions,
+ *,
+ train: bool = False,
+ ) -> at.Float[at.Array, '*b ah']: ...
+
+ @abc.abstractmethod
+ def sample_actions(
+ self, rng: at.KeyArrayLike, observation: Observation, **kwargs
+ ) -> Actions: ...
+
+
+def restore_params(
+ params_path: pathlib.Path | str,
+ *,
+ restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
+ dtype: jnp.dtype | None = None,
+ sharding: jax.sharding.Sharding | None = None,
+) -> at.Params:
+ """Restores unstructured params PyTree from a checkpoint.
+
+ This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
+ well as pre-trained checkpoints released for openpi.
+
+ Args:
+ params_path: The local path to the checkpoint directory.
+ restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
+ dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
+ sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
+
+ Returns:
+ The restored params.
+ """
+ params_path = (
+ pathlib.Path(params_path).resolve()
+ if not str(params_path).startswith('gs://')
+ else params_path
+ )
+
+ if restore_type is jax.Array and sharding is None:
+ mesh = jax.sharding.Mesh(jax.devices(), ('x',))
+ sharding = jax.sharding.NamedSharding(
+ mesh, jax.sharding.PartitionSpec()
+ )
+
+ with ocp.PyTreeCheckpointer() as ckptr:
+ metadata = ckptr.metadata(params_path)
+ item = {'params': metadata['params']}
+
+ params = ckptr.restore(
+ params_path,
+ ocp.args.PyTreeRestore(
+ item=item,
+ restore_args=jax.tree.map(
+ lambda _: ocp.ArrayRestoreArgs(
+ sharding=sharding,
+ restore_type=restore_type,
+ dtype=dtype,
+ ),
+ item,
+ ),
+ ),
+ )['params']
+
+ # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
+ # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
+ flat_params = traverse_util.flatten_dict(params)
+ if all(kp[-1] == 'value' for kp in flat_params):
+ flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
+ return traverse_util.unflatten_dict(flat_params)
diff --git a/vla_arena/models/openpi/src/openpi/models/model_test.py b/vla_arena/models/openpi/src/openpi/models/model_test.py
new file mode 100644
index 00000000..c44c4137
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/model_test.py
@@ -0,0 +1,125 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import jax
+import pytest
+from flax import nnx
+from openpi.models import model as _model
+from openpi.models import pi0_config, pi0_fast
+from openpi.shared import download, nnx_utils
+
+
+def test_pi0_model():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config()
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(
+ key, obs, num_steps=10
+ )
+ assert actions.shape == (
+ batch_size,
+ model.action_horizon,
+ model.action_dim,
+ )
+
+
+def test_pi0_lora_model():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config(paligemma_variant='gemma_2b_lora')
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(
+ key, obs, num_steps=10
+ )
+ assert actions.shape == (
+ batch_size,
+ model.action_horizon,
+ model.action_dim,
+ )
+
+
+def test_pi0_fast_model():
+ key = jax.random.key(0)
+ config = pi0_fast.Pi0FASTConfig()
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size,)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
+ assert actions.shape == (batch_size, 256)
+
+
+def test_pi0_fast_lora_model():
+ key = jax.random.key(0)
+ config = pi0_fast.Pi0FASTConfig(paligemma_variant='gemma_2b_lora')
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size,)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
+ assert actions.shape == (batch_size, 256)
+
+ lora_filter = nnx_utils.PathRegex('.*lora.*')
+ model_state = nnx.state(model)
+
+ lora_state_elems = list(model_state.filter(lora_filter))
+ assert len(lora_state_elems) > 0
+
+
+@pytest.mark.manual
+def test_model_restore():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config()
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ model = config.load(
+ _model.restore_params(
+ download.maybe_download(
+ 'gs://openpi-assets/checkpoints/pi0_base/params'
+ )
+ )
+ )
+
+ loss = model.compute_loss(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = model.sample_actions(key, obs, num_steps=10)
+ assert actions.shape == (
+ batch_size,
+ model.action_horizon,
+ model.action_dim,
+ )
diff --git a/vla_arena/models/openpi/src/openpi/models/pi0.py b/vla_arena/models/openpi/src/openpi/models/pi0.py
new file mode 100644
index 00000000..1e7b980c
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/pi0.py
@@ -0,0 +1,384 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing_extensions import override
+
+import einops
+import flax.nnx as nnx
+import flax.nnx.bridge as nnx_bridge
+import jax
+import jax.numpy as jnp
+import openpi.models.gemma as _gemma
+import openpi.models.siglip as _siglip
+from openpi.models import model as _model
+from openpi.models import pi0_config
+from openpi.shared import array_typing as at
+
+
+logger = logging.getLogger('openpi')
+
+
+def make_attn_mask(input_mask, mask_ar):
+ """Adapted from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
+ it and false where it shares the same attention mask as the previous token.
+ """
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
+ cumsum = jnp.cumsum(mask_ar, axis=1)
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
+ return jnp.logical_and(attn_mask, valid_mask)
+
+
+@at.typecheck
+def posemb_sincos(
+ pos: at.Real[at.Array, ' b'],
+ embedding_dim: int,
+ min_period: float,
+ max_period: float,
+) -> at.Float[at.Array, 'b {embedding_dim}']:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if embedding_dim % 2 != 0:
+ raise ValueError(
+ f'embedding_dim ({embedding_dim}) must be divisible by 2'
+ )
+
+ fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
+ period = min_period * (max_period / min_period) ** fraction
+ sinusoid_input = jnp.einsum(
+ 'i,j->ij',
+ pos,
+ 1.0 / period * 2 * jnp.pi,
+ precision=jax.lax.Precision.HIGHEST,
+ )
+ return jnp.concatenate(
+ [jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1
+ )
+
+
+class Pi0(_model.BaseModel):
+ def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
+ super().__init__(
+ config.action_dim, config.action_horizon, config.max_token_len
+ )
+ self.pi05 = config.pi05
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
+ # TODO: rewrite gemma in NNX. For now, use bridge.
+ llm = nnx_bridge.ToNNX(
+ _gemma.Module(
+ configs=[paligemma_config, action_expert_config],
+ embed_dtype=config.dtype,
+ adarms=config.pi05,
+ )
+ )
+ llm.lazy_init(
+ rngs=rngs,
+ method='init',
+ use_adarms=[False, True] if config.pi05 else [False, False],
+ )
+ img = nnx_bridge.ToNNX(
+ _siglip.Module(
+ num_classes=paligemma_config.width,
+ variant='So400m/14',
+ pool_type='none',
+ scan=True,
+ dtype_mm=config.dtype,
+ )
+ )
+ img.lazy_init(
+ next(iter(config.fake_obs().images.values())),
+ train=False,
+ rngs=rngs,
+ )
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
+ self.action_in_proj = nnx.Linear(
+ config.action_dim, action_expert_config.width, rngs=rngs
+ )
+ if config.pi05:
+ self.time_mlp_in = nnx.Linear(
+ action_expert_config.width,
+ action_expert_config.width,
+ rngs=rngs,
+ )
+ self.time_mlp_out = nnx.Linear(
+ action_expert_config.width,
+ action_expert_config.width,
+ rngs=rngs,
+ )
+ else:
+ self.state_proj = nnx.Linear(
+ config.action_dim, action_expert_config.width, rngs=rngs
+ )
+ self.action_time_mlp_in = nnx.Linear(
+ 2 * action_expert_config.width,
+ action_expert_config.width,
+ rngs=rngs,
+ )
+ self.action_time_mlp_out = nnx.Linear(
+ action_expert_config.width,
+ action_expert_config.width,
+ rngs=rngs,
+ )
+ self.action_out_proj = nnx.Linear(
+ action_expert_config.width, config.action_dim, rngs=rngs
+ )
+
+ # This attribute gets automatically set by model.train() and model.eval().
+ self.deterministic = True
+
+ @at.typecheck
+ def embed_prefix(self, obs: _model.Observation) -> tuple[
+ at.Float[at.Array, 'b s emb'],
+ at.Bool[at.Array, 'b s'],
+ at.Bool[at.Array, ' s'],
+ ]:
+ input_mask = []
+ ar_mask = []
+ tokens = []
+ # embed images
+ for name in obs.images:
+ image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
+
+ tokens.append(image_tokens)
+ input_mask.append(
+ einops.repeat(
+ obs.image_masks[name],
+ 'b -> b s',
+ s=image_tokens.shape[1],
+ )
+ )
+ # image tokens attend to each other
+ ar_mask += [False] * image_tokens.shape[1]
+
+ # add language (aka tokenized inputs)
+ if obs.tokenized_prompt is not None:
+ tokenized_inputs = self.PaliGemma.llm(
+ obs.tokenized_prompt, method='embed'
+ )
+ tokens.append(tokenized_inputs)
+ input_mask.append(obs.tokenized_prompt_mask)
+ # full attention between image and language inputs
+ ar_mask += [False] * tokenized_inputs.shape[1]
+ tokens = jnp.concatenate(tokens, axis=1)
+ input_mask = jnp.concatenate(input_mask, axis=1)
+ ar_mask = jnp.array(ar_mask)
+ return tokens, input_mask, ar_mask
+
+ @at.typecheck
+ def embed_suffix(
+ self,
+ obs: _model.Observation,
+ noisy_actions: _model.Actions,
+ timestep: at.Float[at.Array, ' b'],
+ ) -> tuple[
+ at.Float[at.Array, 'b s emb'],
+ at.Bool[at.Array, 'b s'],
+ at.Bool[at.Array, ' s'],
+ at.Float[at.Array, 'b emb'] | None,
+ ]:
+ input_mask = []
+ ar_mask = []
+ tokens = []
+ if not self.pi05:
+ # add a single state token
+ state_token = self.state_proj(obs.state)[:, None, :]
+ tokens.append(state_token)
+ input_mask.append(
+ jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_)
+ )
+ # image/language inputs do not attend to state or actions
+ ar_mask += [True]
+
+ action_tokens = self.action_in_proj(noisy_actions)
+ # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = posemb_sincos(
+ timestep,
+ self.action_in_proj.out_features,
+ min_period=4e-3,
+ max_period=4.0,
+ )
+ if self.pi05:
+ # time MLP (for adaRMS)
+ time_emb = self.time_mlp_in(time_emb)
+ time_emb = nnx.swish(time_emb)
+ time_emb = self.time_mlp_out(time_emb)
+ time_emb = nnx.swish(time_emb)
+ action_expert_tokens = action_tokens
+ adarms_cond = time_emb
+ else:
+ # mix timestep + action information using an MLP (no adaRMS)
+ time_tokens = einops.repeat(
+ time_emb, 'b emb -> b s emb', s=self.action_horizon
+ )
+ action_time_tokens = jnp.concatenate(
+ [action_tokens, time_tokens], axis=-1
+ )
+ action_time_tokens = self.action_time_mlp_in(action_time_tokens)
+ action_time_tokens = nnx.swish(action_time_tokens)
+ action_time_tokens = self.action_time_mlp_out(action_time_tokens)
+ action_expert_tokens = action_time_tokens
+ adarms_cond = None
+ tokens.append(action_expert_tokens)
+ input_mask.append(
+ jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_)
+ )
+ # image/language/state inputs do not attend to action tokens
+ ar_mask += [True] + ([False] * (self.action_horizon - 1))
+ tokens = jnp.concatenate(tokens, axis=1)
+ input_mask = jnp.concatenate(input_mask, axis=1)
+ ar_mask = jnp.array(ar_mask)
+ return tokens, input_mask, ar_mask, adarms_cond
+
+ @override
+ def compute_loss(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ actions: _model.Actions,
+ *,
+ train: bool = False,
+ ) -> at.Float[at.Array, '*b ah']:
+ preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
+ observation = _model.preprocess_observation(
+ preprocess_rng, observation, train=train
+ )
+
+ batch_shape = actions.shape[:-2]
+ noise = jax.random.normal(noise_rng, actions.shape)
+ time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
+ time_expanded = time[..., None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ # one big forward pass of prefix + suffix at once
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(
+ observation
+ )
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = (
+ self.embed_suffix(observation, x_t, time)
+ )
+ input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
+ ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
+ attn_mask = make_attn_mask(input_mask, ar_mask)
+ positions = jnp.cumsum(input_mask, axis=1) - 1
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
+ [prefix_tokens, suffix_tokens],
+ mask=attn_mask,
+ positions=positions,
+ adarms_cond=[None, adarms_cond],
+ )
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
+
+ return jnp.mean(jnp.square(v_t - u_t), axis=-1)
+
+ @override
+ def sample_actions(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ *,
+ num_steps: int | at.Int[at.Array, ''] = 10,
+ noise: at.Float[at.Array, 'b ah ad'] | None = None,
+ ) -> _model.Actions:
+ observation = _model.preprocess_observation(
+ None, observation, train=False
+ )
+ # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
+ # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
+ dt = -1.0 / num_steps
+ batch_size = observation.state.shape[0]
+ if noise is None:
+ noise = jax.random.normal(
+ rng, (batch_size, self.action_horizon, self.action_dim)
+ )
+
+ # first fill KV cache with a forward pass of the prefix
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(
+ observation
+ )
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
+ positions = jnp.cumsum(prefix_mask, axis=1) - 1
+ _, kv_cache = self.PaliGemma.llm(
+ [prefix_tokens, None], mask=prefix_attn_mask, positions=positions
+ )
+
+ def step(carry):
+ x_t, time = carry
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = (
+ self.embed_suffix(
+ observation, x_t, jnp.broadcast_to(time, batch_size)
+ )
+ )
+ # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
+ # other
+ suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
+ # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
+ # prefix tokens
+ prefix_attn_mask = einops.repeat(
+ prefix_mask, 'b p -> b s p', s=suffix_tokens.shape[1]
+ )
+ # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
+ # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
+ full_attn_mask = jnp.concatenate(
+ [prefix_attn_mask, suffix_attn_mask], axis=-1
+ )
+ assert full_attn_mask.shape == (
+ batch_size,
+ suffix_tokens.shape[1],
+ prefix_tokens.shape[1] + suffix_tokens.shape[1],
+ )
+ # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
+ positions = (
+ jnp.sum(prefix_mask, axis=-1)[:, None]
+ + jnp.cumsum(suffix_mask, axis=-1)
+ - 1
+ )
+
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
+ [None, suffix_tokens],
+ mask=full_attn_mask,
+ positions=positions,
+ kv_cache=kv_cache,
+ adarms_cond=[None, adarms_cond],
+ )
+ assert prefix_out is None
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
+
+ return x_t + dt * v_t, time + dt
+
+ def cond(carry):
+ x_t, time = carry
+ # robust to floating-point error
+ return time >= -dt / 2
+
+ x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
+ return x_0
diff --git a/vla_arena/models/openpi/src/openpi/models/pi0_config.py b/vla_arena/models/openpi/src/openpi/models/pi0_config.py
new file mode 100644
index 00000000..3928d74a
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/pi0_config.py
@@ -0,0 +1,134 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from typing import TYPE_CHECKING
+from typing_extensions import override
+
+import flax.nnx as nnx
+import jax
+import jax.numpy as jnp
+import openpi.models.gemma as _gemma
+import openpi.shared.nnx_utils as nnx_utils
+from openpi.models import model as _model
+from openpi.shared import array_typing as at
+
+
+if TYPE_CHECKING:
+ from openpi.models.pi0 import Pi0
+
+
+@dataclasses.dataclass(frozen=True)
+class Pi0Config(_model.BaseModelConfig):
+ dtype: str = 'bfloat16'
+ paligemma_variant: _gemma.Variant = 'gemma_2b'
+ action_expert_variant: _gemma.Variant = 'gemma_300m'
+
+ # Set the model specific defaults.
+ action_dim: int = 32
+ action_horizon: int = 50
+ max_token_len: int = None # type: ignore
+ # Pi05 has two differences from Pi0:
+ # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
+ # - the action expert uses adaRMSNorm to inject the flow matching timestep
+ pi05: bool = False
+ # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
+ discrete_state_input: bool = None # type: ignore
+
+ def __post_init__(self):
+ if self.max_token_len is None:
+ object.__setattr__(self, 'max_token_len', 200 if self.pi05 else 48)
+ if self.discrete_state_input is None:
+ object.__setattr__(self, 'discrete_state_input', self.pi05)
+
+ @property
+ @override
+ def model_type(self) -> _model.ModelType:
+ if self.pi05:
+ return _model.ModelType.PI05
+ return _model.ModelType.PI0
+
+ @override
+ def create(self, rng: at.KeyArrayLike) -> 'Pi0':
+ from openpi.models.pi0 import Pi0
+
+ return Pi0(self, rngs=nnx.Rngs(rng))
+
+ @override
+ def inputs_spec(
+ self, *, batch_size: int = 1
+ ) -> tuple[_model.Observation, _model.Actions]:
+ image_spec = jax.ShapeDtypeStruct(
+ [batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32
+ )
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
+
+ with at.disable_typechecking():
+ observation_spec = _model.Observation(
+ images={
+ 'base_0_rgb': image_spec,
+ 'left_wrist_0_rgb': image_spec,
+ 'right_wrist_0_rgb': image_spec,
+ },
+ image_masks={
+ 'base_0_rgb': image_mask_spec,
+ 'left_wrist_0_rgb': image_mask_spec,
+ 'right_wrist_0_rgb': image_mask_spec,
+ },
+ state=jax.ShapeDtypeStruct(
+ [batch_size, self.action_dim], jnp.float32
+ ),
+ tokenized_prompt=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], jnp.int32
+ ),
+ tokenized_prompt_mask=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], bool
+ ),
+ )
+ action_spec = jax.ShapeDtypeStruct(
+ [batch_size, self.action_horizon, self.action_dim], jnp.float32
+ )
+
+ return observation_spec, action_spec
+
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
+ """Returns the freeze filter based on the model config."""
+ filters = []
+ has_lora = False
+ gemma_params_filter = nnx_utils.PathRegex('.*llm.*')
+ action_expert_params_filter = nnx_utils.PathRegex('.*llm.*_1.*')
+ if 'lora' in self.paligemma_variant:
+ filters.append(
+ gemma_params_filter,
+ )
+ if 'lora' not in self.action_expert_variant:
+ # If only freeze gemma params, exclude action expert params.
+ filters.append(
+ nnx.Not(action_expert_params_filter),
+ )
+ has_lora = True
+ elif 'lora' in self.action_expert_variant:
+ filters.append(
+ action_expert_params_filter,
+ )
+ has_lora = True
+
+ if has_lora:
+ # If any lora is used, exclude all lora params.
+ filters.append(
+ nnx.Not(nnx_utils.PathRegex('.*lora.*')),
+ )
+ if not filters:
+ return nnx.Nothing
+ return nnx.All(*filters)
diff --git a/vla_arena/models/openpi/src/openpi/models/pi0_fast.py b/vla_arena/models/openpi/src/openpi/models/pi0_fast.py
new file mode 100644
index 00000000..445293c0
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/pi0_fast.py
@@ -0,0 +1,409 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import logging
+from typing import Any
+from typing_extensions import override
+
+import einops
+import flax.nnx as nnx
+import flax.nnx.bridge as nnx_bridge
+import jax
+import jax.numpy as jnp
+import openpi.models.gemma_fast as _gemma
+import openpi.models.siglip as _siglip
+import openpi.shared.nnx_utils as nnx_utils
+from openpi.models import model as _model
+from openpi.shared import array_typing as at
+
+
+logger = logging.getLogger('openpi')
+
+PALIGEMMA_EOS_TOKEN = 1
+
+
+def make_attn_mask(input_mask, mask_ar):
+ """Adapted from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
+ it and false where it shares the same attention mask as the previous token.
+ """
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
+ cumsum = jnp.cumsum(mask_ar, axis=1)
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
+ return jnp.logical_and(attn_mask, valid_mask)
+
+
+@jax.vmap
+def left_to_right_align(x, input_mask, attn_mask):
+ """Converts input from left-align to right-aligned."""
+ # Due to vmap, this is operating in a single example (not batch level).
+ assert x.ndim == 2
+ assert input_mask.ndim == 1
+ assert attn_mask.ndim == 2
+ assert x.shape[0] == input_mask.shape[0]
+ assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
+ seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
+ x = jnp.roll(x, -seqlen, axis=0)
+ input_mask = jnp.roll(input_mask, -seqlen, axis=0)
+ attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
+ return x, input_mask, attn_mask
+
+
+def put_along_last_axis(arr, indices, values):
+ """Like np.put_along_axis(..., axis=-1), since jax is missing it."""
+ assert arr.ndim == indices.ndim == values.ndim, (
+ arr.ndim,
+ indices.ndim,
+ values.ndim,
+ )
+ onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
+ put_mask = jnp.einsum(
+ '...i,...in->...n', jnp.ones(values.shape, jnp.int32), onehot
+ )
+ put_values = jnp.einsum('...i,...in->...n', values, onehot)
+ return jnp.where(put_mask, put_values, arr)
+
+
+@dataclasses.dataclass(frozen=True)
+class Pi0FASTConfig(_model.BaseModelConfig):
+ dtype: str = 'bfloat16'
+ paligemma_variant: _gemma.Variant = 'gemma_2b'
+
+ # Set the model specific defaults.
+ action_dim: int = 32
+ action_horizon: int = 32
+ max_token_len: int = 250
+
+ # Tokenizer for the fast model.
+ fast_model_tokenizer: Any | None = None
+ # Keyword arguments for the fast model tokenizer.
+ fast_model_tokenizer_kwargs: dict[str, Any] | None = None
+
+ @property
+ @override
+ def model_type(self) -> _model.ModelType:
+ return _model.ModelType.PI0_FAST
+
+ @override
+ def create(self, rng: at.KeyArrayLike) -> 'Pi0FAST':
+ return Pi0FAST(self, rngs=nnx.Rngs(rng))
+
+ @override
+ def inputs_spec(
+ self, *, batch_size: int = 1
+ ) -> tuple[_model.Observation, _model.Actions]:
+ image_spec = jax.ShapeDtypeStruct(
+ [batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32
+ )
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
+
+ with at.disable_typechecking():
+ observation_spec = _model.Observation(
+ images={
+ 'base_0_rgb': image_spec,
+ 'base_1_rgb': image_spec,
+ 'wrist_0_rgb': image_spec,
+ },
+ image_masks={
+ 'base_0_rgb': image_mask_spec,
+ 'base_1_rgb': image_mask_spec,
+ 'wrist_0_rgb': image_mask_spec,
+ },
+ state=jax.ShapeDtypeStruct(
+ [batch_size, self.action_dim], jnp.float32
+ ),
+ tokenized_prompt=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], jnp.int32
+ ),
+ tokenized_prompt_mask=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], bool
+ ),
+ token_ar_mask=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], jnp.int32
+ ),
+ token_loss_mask=jax.ShapeDtypeStruct(
+ [batch_size, self.max_token_len], jnp.bool_
+ ),
+ )
+ action_spec = jax.ShapeDtypeStruct(
+ [batch_size, self.action_horizon, self.action_dim], jnp.float32
+ )
+
+ return observation_spec, action_spec
+
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
+ """Returns the freeze filter based on the model config."""
+ if 'lora' in self.paligemma_variant:
+ return nnx.All(
+ nnx_utils.PathRegex('.*llm.*'),
+ nnx.Not(nnx_utils.PathRegex('.*lora.*')),
+ )
+ return nnx.Nothing
+
+
+class Pi0FAST(_model.BaseModel):
+ def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
+ super().__init__(
+ config.action_dim, config.action_horizon, config.max_token_len
+ )
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ # TODO: rewrite gemma in NNX. For now, use bridge.
+ llm = nnx_bridge.ToNNX(
+ _gemma.Module(
+ **paligemma_config,
+ embed_dtype=config.dtype,
+ cache_dtype=config.dtype,
+ )
+ )
+ llm.lazy_init(rngs=rngs, method='init')
+ img = nnx_bridge.ToNNX(
+ _siglip.Module(
+ num_classes=paligemma_config.width,
+ variant='So400m/14',
+ pool_type='none',
+ scan=True,
+ dtype_mm=config.dtype,
+ )
+ )
+ img.lazy_init(
+ next(iter(config.fake_obs().images.values())),
+ train=False,
+ rngs=rngs,
+ )
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
+
+ @at.typecheck
+ def embed_inputs(self, obs: _model.Observation) -> tuple[
+ at.Float[at.Array, 'b s emb'],
+ at.Bool[at.Array, 'b s'],
+ at.Int[at.Array, 'b s'],
+ ]:
+ input_mask = []
+ ar_mask = []
+ token_embeddings = []
+ # embed images
+ for name in obs.images:
+ image_token_embeddings, _ = self.PaliGemma.img(
+ obs.images[name], train=False
+ )
+
+ token_embeddings.append(image_token_embeddings)
+ input_mask.append(
+ einops.repeat(
+ obs.image_masks[name],
+ 'b -> b s',
+ s=image_token_embeddings.shape[1],
+ )
+ )
+ # image tokens attend to each other --> AR mask = 0
+ ar_mask.append(0 * input_mask[-1])
+
+ # add tokenized inputs
+ assert obs.tokenized_prompt is not None, 'Tokenized prompt is required'
+ assert (
+ obs.tokenized_prompt_mask is not None
+ ), 'Tokenized prompt mask is required'
+ assert (
+ obs.token_ar_mask is not None
+ ), 'Token auto-regressive mask is required'
+ tokenized_inputs_embeddings = self.PaliGemma.llm(
+ obs.tokenized_prompt, embed_only=True
+ )
+ token_embeddings.append(tokenized_inputs_embeddings)
+ input_mask.append(obs.tokenized_prompt_mask)
+ ar_mask.append(obs.token_ar_mask)
+
+ # return embeddings, input mask, and ar mask
+ return (
+ jnp.concatenate(token_embeddings, axis=1),
+ jnp.concatenate(input_mask, axis=1),
+ jnp.concatenate(ar_mask, axis=1),
+ )
+
+ @override
+ def compute_loss(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ actions: _model.Actions,
+ *,
+ train: bool = False,
+ ) -> at.Float[at.Array, '*b ah']:
+ observation = _model.preprocess_observation(
+ rng,
+ observation,
+ train=train,
+ image_keys=list(observation.images.keys()),
+ )
+
+ # Compute inputs: one big forward pass of prefix + suffix at once
+ input_token_embeddings, input_mask, ar_mask = self.embed_inputs(
+ observation
+ )
+ attn_mask = make_attn_mask(input_mask, ar_mask)
+
+ # Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
+ targets = jax.nn.one_hot(
+ observation.tokenized_prompt[:, 1:],
+ self.PaliGemma.llm.module.vocab_size,
+ )
+
+ # Each input predicts *next* token, so we don't input the last token.
+ pre_logits, _, _ = self.PaliGemma.llm(
+ embedded_prefix=input_token_embeddings[:, :-1],
+ mask=attn_mask[:, :-1, :-1],
+ return_prelogits=True,
+ )
+
+ # Only decode logits for the target tokens to save memory
+ # (decoding matmul is large because it is a seq_len x vocab_size dense layer).
+ logits, _ = self.PaliGemma.llm(
+ pre_logits=pre_logits[:, -targets.shape[1] :],
+ )
+ logp = jax.nn.log_softmax(logits, axis=-1)
+
+ # Compute CE loss on token targets
+ assert (
+ observation.token_loss_mask is not None
+ ), 'Token loss mask is required'
+ loss_mask = observation.token_loss_mask[:, 1:]
+ token_pplx = jnp.sum(targets * logp, axis=-1)
+ return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(
+ jnp.sum(loss_mask, -1), 1
+ )
+
+ @override
+ def sample_actions(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ *,
+ max_decoding_steps: int | at.Int[at.Array, ''] = 256,
+ temperature: float = 0.0,
+ ) -> _model.Actions:
+ # TODO: this is a hack to get the image keys.
+ observation = _model.preprocess_observation(
+ None,
+ observation,
+ train=False,
+ image_keys=list(observation.images.keys()),
+ )
+
+ # embed inputs
+ prefix_token_embeddings, prefix_mask, prefix_ar_mask = (
+ self.embed_inputs(observation)
+ )
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
+
+ # left to right align all input token sequences
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask = (
+ left_to_right_align(
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask
+ )
+ )
+ prefill_size = prefix_token_embeddings.shape[1]
+ prefill_len = jnp.sum(prefix_mask, axis=-1)
+ prefix_start = prefill_size - prefill_len
+
+ # first fill KV cache with a forward pass of the prefix
+ # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
+ prefix_attn_mask = jnp.pad(
+ prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps))
+ )
+ prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
+ prefix_logits, kv_cache, _ = self.PaliGemma.llm(
+ embedded_prefix=prefix_token_embeddings,
+ mask=prefix_attn_mask,
+ positions=prefix_positions,
+ decode=True,
+ )
+
+ # prepare decoding -- final logit decodes the first token
+ last_logit = prefix_logits[:, -1:]
+ output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
+
+ def step(carry):
+ rng, last_logit, output_tokens, cache, _, step = carry
+
+ # Sample token from last logit
+ # Split RNG for this step
+ rng, rng_step = jax.random.split(rng)
+ token = jax.lax.cond(
+ temperature > 0.0,
+ lambda _: jax.random.categorical(
+ rng_step, last_logit / temperature, axis=-1
+ ),
+ lambda _: jnp.argmax(last_logit, axis=-1),
+ operand=None,
+ )
+ output_tokens = put_along_last_axis(
+ output_tokens,
+ jnp.broadcast_to(step, (token.shape[0], 1)),
+ token,
+ )
+
+ # Check for early stopping --> stop if all batch elements have EOS token
+ has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
+ all_eos = jnp.all(has_eos)
+
+ # Decode one step
+ token_embedding = self.PaliGemma.llm(token, embed_only=True)
+ positions = prefill_len[:, None] + step + 1
+ mask = jnp.logical_and(
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
+ >= prefix_start[:, None, None],
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
+ < (
+ jnp.broadcast_to(
+ prefill_size + step + 1, (prefix_start.shape[0], 1, 1)
+ )
+ ),
+ )
+ last_logit, kv_cache, _ = self.PaliGemma.llm(
+ embedded_prefix=token_embedding,
+ mask=mask,
+ positions=positions,
+ decode=True,
+ kv_cache=cache,
+ )
+
+ return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
+
+ def cond(carry):
+ _, _, _, _, all_eos, step = carry
+ return (~all_eos) & (step < max_decoding_steps)
+
+ # Use lax.while_loop so we can jit the full decoding loop.
+ _, _, output_tokens, _, _, _ = jax.lax.while_loop(
+ cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
+ )
+ return output_tokens
diff --git a/vla_arena/models/openpi/src/openpi/models/pi0_test.py b/vla_arena/models/openpi/src/openpi/models/pi0_test.py
new file mode 100644
index 00000000..cede2bd5
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/pi0_test.py
@@ -0,0 +1,64 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import flax.nnx as nnx
+import jax
+import openpi.models.pi0_config as _pi0_config
+
+
+def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
+ abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
+
+ freeze_filter = config.get_freeze_filter()
+ return nnx.state(
+ abstract_model, nnx.All(nnx.Param, freeze_filter)
+ ).flat_state()
+
+
+def test_pi0_full_finetune():
+ config = _pi0_config.Pi0Config()
+ state = _get_frozen_state(config)
+ assert len(state) == 0
+
+
+def test_pi0_gemma_lora():
+ config = _pi0_config.Pi0Config(paligemma_variant='gemma_2b_lora')
+ state = _get_frozen_state(config)
+ assert len(state) == 9
+ assert all('lora' not in p for p in state)
+ assert all('llm' in p for p in state)
+ assert all('_1' not in p for p in state)
+
+
+def test_pi0_action_expert_lora():
+ config = _pi0_config.Pi0Config(action_expert_variant='gemma_300m_lora')
+ state = _get_frozen_state(config)
+ # excluding embedder, rest of the params should be same as gemma_lora.
+ assert len(state) == 8
+ assert all('lora' not in p for p in state)
+ assert all('llm' in p for p in state)
+ # all frozen params should have _1 in their path since it's the action expert.
+ assert all(any('_1' in p for p in path) for path in state)
+
+
+def test_pi0_all_lora():
+ config = _pi0_config.Pi0Config(
+ paligemma_variant='gemma_2b_lora',
+ action_expert_variant='gemma_300m_lora',
+ )
+ state = _get_frozen_state(config)
+ # sum of gemma_lora and action_expert_lora's frozen params.
+ assert len(state) == 17
+ assert all('lora' not in p for p in state)
+ assert all('llm' in p for p in state)
diff --git a/vla_arena/models/openpi/src/openpi/models/siglip.py b/vla_arena/models/openpi/src/openpi/models/siglip.py
new file mode 100644
index 00000000..5ce346e8
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/siglip.py
@@ -0,0 +1,408 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
+
+from collections.abc import Sequence
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+import openpi.training.sharding as sharding
+
+
+def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
+ """Follows the MoCo v3 logic."""
+ y, x = jnp.mgrid[:h, :w]
+
+ assert width % 4 == 0, 'Width must be mult of 4 for sincos posemb'
+ omega = jnp.arange(width // 4) / (width // 4 - 1)
+ omega = 1.0 / (temperature**omega)
+ y = jnp.einsum('m,d->md', y.flatten(), omega)
+ x = jnp.einsum('m,d->md', x.flatten(), omega)
+ pe = jnp.concatenate(
+ [jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1
+ )
+ return jnp.asarray(pe, dtype)[None, :, :]
+
+
+def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
+ if typ == 'learn':
+ return self.param(
+ name,
+ nn.initializers.normal(stddev=1 / np.sqrt(width)),
+ (1, np.prod(seqshape), width),
+ dtype,
+ )
+ if typ == 'sincos2d':
+ return posemb_sincos_2d(*seqshape, width, dtype=dtype)
+ raise ValueError(f'Unknown posemb type: {typ}')
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ dropout: float = 0.0
+ dtype_mm: str = 'float32'
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ """Applies Transformer MlpBlock module."""
+ inits = {
+ 'kernel_init': nn.initializers.xavier_uniform(),
+ 'bias_init': nn.initializers.normal(stddev=1e-6),
+ }
+
+ _, _, d = x.shape # n,l,d
+ x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
+ x = nn.gelu(x)
+ x = nn.Dropout(rate=self.dropout)(x, deterministic)
+ return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
+
+
+class Encoder1DBlock(nn.Module):
+ """Single transformer encoder block (MHSA + MLP)."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dropout: float = 0.0
+ dtype_mm: str = 'float32'
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ out = {}
+ x = sharding.activation_sharding_constraint(x)
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ y = out['sa'] = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=deterministic,
+ dtype=self.dtype_mm,
+ )(y, y)
+ y = sharding.activation_sharding_constraint(y)
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
+ x = out['+sa'] = x + y
+
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ y = out['mlp'] = MlpBlock(
+ mlp_dim=self.mlp_dim,
+ dropout=self.dropout,
+ dtype_mm=self.dtype_mm,
+ )(y, deterministic)
+ y = sharding.activation_sharding_constraint(y)
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
+ x = out['+mlp'] = x + y
+ x = sharding.activation_sharding_constraint(x)
+ return x, out
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation."""
+
+ depth: int
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dropout: float = 0.0
+ scan: bool = False
+ remat_policy: str = 'nothing_saveable'
+ dtype_mm: str = 'float32'
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ out = {}
+
+ if self.scan:
+ block = nn.remat(
+ Encoder1DBlock,
+ prevent_cse=False,
+ static_argnums=(2,), # 0=self, 2=deterministic
+ policy=getattr(
+ jax.checkpoint_policies, self.remat_policy, None
+ ),
+ )
+ x, scan_out = nn.scan(
+ block,
+ variable_axes={'params': 0},
+ split_rngs={'params': True, 'dropout': True},
+ in_axes=nn.broadcast,
+ length=self.depth,
+ )(
+ name='encoderblock',
+ dtype_mm=self.dtype_mm,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ )(
+ x, deterministic
+ )
+ for lyr in range(self.depth):
+ out[f'block{lyr:02d}'] = jax.tree.map(
+ lambda o, lyr=lyr: o[lyr], scan_out
+ )
+ else:
+ # Input Encoder
+ for lyr in range(self.depth):
+ block_cur = Encoder1DBlock(
+ name=f'encoderblock_{lyr}',
+ dtype_mm=self.dtype_mm,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ )
+ x, out[f'block{lyr:02d}'] = block_cur(x, deterministic)
+ out['pre_ln'] = (
+ x # Alias for last block, but without the number in it.
+ )
+
+ return nn.LayerNorm(name='encoder_norm', dtype=self.dtype_mm)(x), out
+
+
+class MAPHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dtype_mm: str = 'float32'
+
+ @nn.compact
+ def __call__(self, x):
+ n, _, d = x.shape # n,l,d
+ probe = self.param(
+ 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype
+ )
+ probe = jnp.tile(probe, [n, 1, 1])
+
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ dtype=self.dtype_mm,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(probe, x)
+
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
+ return x[:, 0]
+
+
+class _Module(nn.Module):
+ """ViT model."""
+
+ num_classes: int | None = None
+ patch_size: Sequence[int] = (16, 16)
+ width: int = 768
+ depth: int = 12
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ posemb: str = 'learn' # Can also be "sincos2d"
+ rep_size: int | bool = False
+ dropout: float = 0.0
+ pool_type: str = 'gap' # Can also be "map" or "tok"
+ head_zeroinit: bool = True
+ scan: bool = False
+ # or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
+ remat_policy: str = 'nothing_saveable'
+ dtype_mm: str = 'float32'
+
+ @nn.compact
+ def __call__(self, image, *, train=False):
+ out = {}
+
+ # Kevin edit: do patch extraction and posemb in float32,
+ # because I feel like it's a bit safer.
+ image = jnp.asarray(image, jnp.float32)
+
+ # Patch extraction
+ x = out['stem'] = nn.Conv(
+ self.width,
+ self.patch_size,
+ strides=self.patch_size,
+ padding='VALID',
+ name='embedding',
+ dtype=jnp.float32,
+ )(image)
+
+ n, h, w, c = x.shape
+ x = jnp.reshape(x, [n, h * w, c])
+
+ # Add posemb before adding extra token.
+ x = out['with_posemb'] = x + get_posemb(
+ self, self.posemb, (h, w), c, 'pos_embedding', jnp.float32
+ )
+
+ if self.pool_type == 'tok':
+ cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
+ x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
+
+ n, _, c = x.shape # n,l,d
+ x = nn.Dropout(rate=self.dropout)(x, not train)
+
+ # Kevin edit: now cast back to dtype_mm (potentially half precision)
+ x = x.astype(self.dtype_mm)
+
+ x, out['encoder'] = Encoder(
+ depth=self.depth,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scan=self.scan,
+ remat_policy=self.remat_policy,
+ dtype_mm=self.dtype_mm,
+ name='Transformer',
+ )(x, deterministic=not train)
+ encoded = out['encoded'] = x
+
+ if self.pool_type == 'map':
+ x = out['head_input'] = MAPHead(
+ num_heads=self.num_heads,
+ mlp_dim=self.mlp_dim,
+ dtype=self.dtype_mm,
+ )(x)
+ elif self.pool_type == 'gap':
+ x = out['head_input'] = jnp.mean(x, axis=1)
+ elif self.pool_type == '0':
+ x = out['head_input'] = x[:, 0]
+ elif self.pool_type == 'tok':
+ x = out['head_input'] = x[:, 0]
+ encoded = encoded[:, 1:]
+ elif self.pool_type == 'none':
+ pass
+ else:
+ raise ValueError(f"Unknown pool type: '{self.pool_type}'")
+
+ x_2d = jnp.reshape(encoded, [n, h, w, -1])
+
+ if self.rep_size:
+ rep_size = self.width if self.rep_size is True else self.rep_size
+ hid = nn.Dense(rep_size, dtype=self.dtype_mm, name='pre_logits')
+ # NOTE: In the past we did not include tanh in pre_logits.
+ # For few-shot, it should not matter much, as it whitens anyways.
+ x_2d = nn.tanh(hid(x_2d))
+ x = nn.tanh(hid(x))
+
+ out['pre_logits_2d'] = x_2d
+ out['pre_logits'] = x
+
+ if self.num_classes:
+ kw = (
+ {'kernel_init': nn.initializers.zeros}
+ if self.head_zeroinit
+ else {}
+ )
+ head = nn.Dense(
+ self.num_classes, dtype=self.dtype_mm, name='head', **kw
+ )
+ x_2d = out['logits_2d'] = head(x_2d)
+ x = out['logits'] = head(x)
+
+ return x, out
+
+
+def Module(
+ num_classes=None, *, variant=None, **kw
+): # pylint: disable=invalid-name # noqa: N802
+ """Factory function, because linen really don't like what I'm doing!"""
+ return _Module(num_classes, **{**decode_variant(variant), **kw})
+
+
+def decode_variant(variant):
+ """Converts a string like "B" or "B/32" into a params dict."""
+ if variant is None:
+ return {}
+
+ v, patch = variant, {}
+ if '/' in variant:
+ v, patch = variant.split('/')
+ patch = {'patch_size': (int(patch), int(patch))}
+
+ return {
+ # pylint:disable=line-too-long
+ # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
+ 'width': {
+ 'mu': 32,
+ 'Ti': 192,
+ 'S': 384,
+ 'M': 512,
+ 'B': 768,
+ 'L': 1024,
+ 'So400m': 1152,
+ 'H': 1280,
+ 'g': 1408,
+ 'g-opt': 1536,
+ 'G': 1664,
+ 'G-opt': 1536,
+ 'e': 1792,
+ }[v],
+ 'depth': {
+ 'mu': 1,
+ 'Ti': 12,
+ 'S': 12,
+ 'M': 12,
+ 'B': 12,
+ 'L': 24,
+ 'So400m': 27,
+ 'H': 32,
+ 'g': 40,
+ 'g-opt': 40,
+ 'G': 48,
+ 'G-opt': 48,
+ 'e': 56,
+ }[v],
+ 'mlp_dim': {
+ 'mu': 128,
+ 'Ti': 768,
+ 'S': 1536,
+ 'M': 2048,
+ 'B': 3072,
+ 'L': 4096,
+ 'So400m': 4304,
+ 'H': 5120,
+ 'g': 6144,
+ 'g-opt': 6144,
+ 'G': 8192,
+ 'G-opt': 8192,
+ 'e': 15360,
+ }[v],
+ 'num_heads': {
+ 'mu': 2,
+ 'Ti': 3,
+ 'S': 6,
+ 'M': 8,
+ 'B': 12,
+ 'L': 16,
+ 'So400m': 16,
+ 'H': 16,
+ 'g': 16,
+ 'g-opt': 16,
+ 'G': 16,
+ 'G-opt': 16,
+ 'e': 16,
+ }[v],
+ # pylint:enable=line-too-long
+ **patch,
+ }
diff --git a/vla_arena/models/openpi/src/openpi/models/tokenizer.py b/vla_arena/models/openpi/src/openpi/models/tokenizer.py
new file mode 100644
index 00000000..2d7b0029
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/tokenizer.py
@@ -0,0 +1,500 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+import jax
+import numpy as np
+import openpi.models.utils.fsq_tokenizer as fsq_tokenizer
+import openpi.shared.download as download
+import orbax.checkpoint as ocp
+import sentencepiece
+from transformers import AutoProcessor
+
+
+class PaligemmaTokenizer:
+ def __init__(self, max_len: int = 48):
+ self._max_len = max_len
+
+ path = download.maybe_download(
+ 'gs://big_vision/paligemma_tokenizer.model', gs={'token': 'anon'}
+ )
+ with path.open('rb') as f:
+ self._tokenizer = sentencepiece.SentencePieceProcessor(
+ model_proto=f.read()
+ )
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray | None = None
+ ) -> tuple[np.ndarray, np.ndarray]:
+ cleaned_text = prompt.strip().replace('_', ' ').replace('\n', ' ')
+ if state is not None:
+ # This is the Pi05 format, where the state is part of the discrete language input.
+ discretized_state = (
+ np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+ )
+ state_str = ' '.join(map(str, discretized_state))
+ full_prompt = (
+ f'Task: {cleaned_text}, State: {state_str};\nAction: '
+ )
+ tokens = self._tokenizer.encode(full_prompt, add_bos=True)
+ else:
+ # This is the Pi0 format, where the state is part of the continuous action expert input.
+ # tokenize "\n" separately as the "start of answer" token
+ tokens = self._tokenizer.encode(
+ cleaned_text, add_bos=True
+ ) + self._tokenizer.encode('\n')
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ mask = [True] * tokens_len + padding
+ tokens = tokens + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f'Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. '
+ 'Consider increasing the `max_token_len` in your model config if this happens frequently.'
+ )
+ tokens = tokens[: self._max_len]
+ mask = [True] * self._max_len
+
+ return np.asarray(tokens), np.asarray(mask)
+
+
+class FASTTokenizer:
+ def __init__(
+ self,
+ max_len: int = 256,
+ fast_tokenizer_path: str = 'physical-intelligence/fast',
+ ):
+ self._max_len = max_len
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download(
+ 'gs://big_vision/paligemma_tokenizer.model', gs={'token': 'anon'}
+ )
+ with path.open('rb') as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(
+ model_proto=f.read()
+ )
+
+ # Instantiate FAST tokenizer
+ self._fast_tokenizer = AutoProcessor.from_pretrained(
+ fast_tokenizer_path, trust_remote_code=True
+ )
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ cleaned_text = prompt.lower().strip().replace('_', ' ')
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = (
+ np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+ )
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = ' '.join(map(str, discretized_state))
+ prefix = f'Task: {cleaned_text}, State: {state_str};\n'
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
+ action_tokens = self._fast_tokenizer(actions[None])[0]
+ action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(
+ action_tokens
+ )
+
+ # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
+ postfix_tokens = (
+ self._paligemma_tokenizer.encode('Action: ')
+ + action_tokens_in_pg.tolist()
+ + self._paligemma_tokenizer.encode('|', add_eos=True)
+ )
+ else:
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(
+ postfix_tokens
+ ) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f'Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. '
+ 'Consider increasing the `max_token_len` in your model config if this happens frequently.'
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return (
+ np.asarray(tokens),
+ np.asarray(token_mask),
+ np.asarray(ar_mask),
+ np.asarray(loss_mask),
+ )
+
+ def extract_actions(
+ self, tokens: np.ndarray, action_horizon: int, action_dim: int
+ ) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if 'Action: ' not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(
+ decoded_tokens.split('Action: ')[1].split('|')[0].strip()
+ )
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ return self._fast_tokenizer.decode(
+ [action_tokens.tolist()],
+ time_horizon=action_horizon,
+ action_dim=action_dim,
+ )[0]
+
+ def _act_tokens_to_paligemma_tokens(
+ self, tokens: np.ndarray | list[int]
+ ) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return (
+ self._paligemma_tokenizer.vocab_size()
+ - 1
+ - self._fast_skip_tokens
+ - tokens
+ )
+
+
+###########################################################################
+## The tokenizers below are used for RoboArena baseline implementations. ##
+## They are *not* used for pi0-style models. ##
+###########################################################################
+
+
+class BinningTokenizer:
+ """
+ Standard RT-2 / OpenVLA style binning tokenizer.
+ """
+
+ def __init__(self, max_len: int = 256, n_bins: int = 256):
+ self._max_len = max_len
+ self._n_bins = n_bins
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download(
+ 'gs://big_vision/paligemma_tokenizer.model', gs={'token': 'anon'}
+ )
+ with path.open('rb') as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(
+ model_proto=f.read()
+ )
+
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Tokenize a prompt and state into a sequence of tokens.
+
+ Args:
+ prompt: The text prompt to tokenize.
+ state: The state array to discretize and tokenize.
+ actions: Must be None. Action encoding is not currently supported.
+
+ Returns:
+ A tuple of (tokens, token_mask, ar_mask, targets).
+
+ Raises:
+ NotImplementedError: If actions is not None.
+ """
+ cleaned_text = prompt.lower().strip().replace('_', ' ')
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = (
+ np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+ )
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = ' '.join(map(str, discretized_state))
+ prefix = f'Task: {cleaned_text}, State: {state_str};\n'
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ raise NotImplementedError(
+ 'BinningTokenizer does not support encoding actions atm (only for inference use)'
+ )
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(
+ postfix_tokens
+ ) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f'Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. '
+ 'Consider increasing the `max_token_len` in your model config if this happens frequently.'
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return (
+ np.asarray(tokens),
+ np.asarray(token_mask),
+ np.asarray(ar_mask),
+ np.asarray(loss_mask),
+ )
+
+ def extract_actions(
+ self, tokens: np.ndarray, action_horizon: int, action_dim: int
+ ) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if 'Action: ' not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(
+ decoded_tokens.split('Action: ')[1].split('|')[0].strip()
+ )
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ if len(action_tokens) < action_horizon * action_dim:
+ return np.zeros([action_horizon, action_dim], dtype=np.float32)
+ action_tokens = action_tokens[: (action_horizon * action_dim)].reshape(
+ [action_horizon, action_dim]
+ )
+ return action_tokens / self._n_bins * 2 - 1
+
+ def _act_tokens_to_paligemma_tokens(
+ self, tokens: np.ndarray | list[int]
+ ) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return (
+ self._paligemma_tokenizer.vocab_size()
+ - 1
+ - self._fast_skip_tokens
+ - tokens
+ )
+
+
+class FSQTokenizer:
+ """
+ FSQ tokenizer from the FAST paper baselines.
+ """
+
+ def __init__(
+ self, max_len: int = 256, fsq_tokenizer_path: str | None = None
+ ):
+ self._max_len = max_len
+
+ assert (
+ fsq_tokenizer_path is not None
+ ), 'fsq_tokenizer_path must be provided'
+ # Download tokenizer
+ path = download.maybe_download(fsq_tokenizer_path)
+ tok_path = os.path.join(path, os.listdir(path)[0])
+
+ # Split step from path
+ step = int(tok_path.split('/')[-1])
+ base_path = tok_path.rsplit('/', 1)[0]
+
+ mgr = ocp.CheckpointManager(
+ base_path,
+ item_handlers={
+ 'params': ocp.StandardCheckpointHandler(),
+ 'opt_state': ocp.StandardCheckpointHandler(),
+ 'config': ocp.JsonCheckpointHandler(),
+ },
+ options=ocp.CheckpointManagerOptions(max_to_keep=1),
+ )
+
+ try:
+ restored = mgr.restore(
+ step,
+ args=ocp.args.Composite(
+ config=ocp.args.JsonRestore(),
+ params=ocp.args.StandardRestore(),
+ ),
+ )
+ config = restored['config']
+ self._params = restored['params']
+ self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
+ except Exception as e:
+ raise RuntimeError(
+ f'Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}'
+ ) from e
+
+ # Compile tokenize and detokenize functions
+ self._tokenize_fn = jax.jit(
+ lambda params, x: self._fsq_tokenizer.apply(
+ {'params': params}, x, method=self._fsq_tokenizer.tokenize
+ )
+ )
+ self._detokenize_fn = jax.jit(
+ lambda params, x: self._fsq_tokenizer.apply(
+ {'params': params}, x, method=self._fsq_tokenizer.detokenize
+ )
+ )
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download(
+ 'gs://big_vision/paligemma_tokenizer.model', gs={'token': 'anon'}
+ )
+ with path.open('rb') as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(
+ model_proto=f.read()
+ )
+
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ cleaned_text = prompt.lower().strip().replace('_', ' ')
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = (
+ np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+ )
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = ' '.join(map(str, discretized_state))
+ prefix = f'Task: {cleaned_text}, State: {state_str};\n'
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ raise NotImplementedError(
+ 'FSQTokenizer does not support encoding actions atm (only for inference use)'
+ )
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(
+ postfix_tokens
+ ) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f'Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. '
+ 'Consider increasing the `max_token_len` in your model config if this happens frequently.'
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return (
+ np.asarray(tokens),
+ np.asarray(token_mask),
+ np.asarray(ar_mask),
+ np.asarray(loss_mask),
+ )
+
+ def extract_actions(
+ self, tokens: np.ndarray, action_horizon: int, action_dim: int
+ ) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if 'Action: ' not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(
+ decoded_tokens.split('Action: ')[1].split('|')[0].strip()
+ )
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ try:
+ # Move computation to CPU and compile on-demand
+ device = jax.devices('cpu')[0]
+ with jax.default_device(device):
+ detok_act = self._detokenize_fn(
+ self._params, action_tokens[None, ...]
+ )[0]
+ return detok_act[: action_horizon * action_dim].reshape(
+ [action_horizon, action_dim]
+ )
+ except Exception as e:
+ logging.warning(f'Error decoding FSQ: {e}')
+ return np.zeros((action_horizon, action_dim))
+
+ def _act_tokens_to_paligemma_tokens(
+ self, tokens: np.ndarray | list[int]
+ ) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return (
+ self._paligemma_tokenizer.vocab_size()
+ - 1
+ - self._fast_skip_tokens
+ - tokens
+ )
diff --git a/vla_arena/models/openpi/src/openpi/models/tokenizer_test.py b/vla_arena/models/openpi/src/openpi/models/tokenizer_test.py
new file mode 100644
index 00000000..49974b5f
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/tokenizer_test.py
@@ -0,0 +1,42 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from openpi.models import tokenizer as _tokenizer
+
+
+def test_tokenize():
+ tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
+ tokens, masks = tokenizer.tokenize('Hello, world!')
+
+ assert tokens.shape == (10,)
+ assert masks.shape == (10,)
+
+
+def test_fast_tokenizer():
+ prompt = 'Hello, world!'
+ state = np.random.rand(5).astype(np.float32)
+ action = np.random.rand(3, 2).astype(np.float32)
+ tokenizer = _tokenizer.FASTTokenizer(max_len=256)
+ tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(
+ prompt, state, action
+ )
+
+ assert tokens.shape == (256,)
+ assert token_masks.shape == (256,)
+ assert ar_masks.shape == (256,)
+ assert loss_masks.shape == (256,)
+
+ act = tokenizer.extract_actions(tokens, 3, 2)
+ assert act.shape == (3, 2)
diff --git a/vla_arena/models/openpi/src/openpi/models/utils/fsq_tokenizer.py b/vla_arena/models/openpi/src/openpi/models/utils/fsq_tokenizer.py
new file mode 100644
index 00000000..440badca
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/utils/fsq_tokenizer.py
@@ -0,0 +1,542 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Literal
+
+import chex
+import jax
+import jax.numpy as jnp
+from einops import einops
+from flax import linen as nn
+from flax.linen.module import Module, compact
+from flax.struct import dataclass
+from flax.typing import Array
+
+
+class FsqCodebook(nn.Module):
+ input_dim: int
+ target_codebook_size: int
+ codebook_type: Literal['fsq', 'lfq']
+
+ _bins_per_dim: tuple[int] | None = None
+
+ @property
+ def bins_per_dim(self) -> tuple[int]:
+ if self._bins_per_dim is not None:
+ return self._bins_per_dim
+
+ if self.codebook_type == 'fsq':
+ return self._get_bins_fsq(self.target_codebook_size)
+ elif self.codebook_type == 'lfq': # noqa: RET505
+ return self._get_bins_lfq(self.target_codebook_size)
+ elif self.codebook_type == 'custom':
+ return self._get_bins_custom(self.target_codebook_size)
+ else:
+ raise ValueError(
+ f'Codebook type {self.codebook_type} not supported.'
+ )
+
+ @property
+ def place_values(self) -> jnp.ndarray:
+ place_values = [1]
+ for b in self.bins_per_dim[:-1]:
+ place_values.append(place_values[-1] * b)
+ return jnp.array(place_values)
+
+ @staticmethod
+ def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
+ """
+ Get bins per dimension based on codebook size, from the original FSQ paper.
+ """
+ if target_codebook_size == 2**8:
+ return (8, 6, 5)
+ elif target_codebook_size == 2**10: # noqa: RET505
+ return (8, 5, 5, 5)
+ elif target_codebook_size == 2**12:
+ return (7, 5, 5, 5, 5)
+ elif target_codebook_size == 2**14:
+ return (8, 8, 8, 6, 5)
+ elif target_codebook_size == 2**16:
+ return (8, 8, 8, 5, 5, 5)
+ else:
+ raise ValueError(
+ f'Codebook size {target_codebook_size} not supported.'
+ )
+
+ @staticmethod
+ def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
+ if target_codebook_size == 2**8:
+ return (16, 16)
+ elif target_codebook_size == 2**10: # noqa: RET505
+ return (32, 32)
+ elif target_codebook_size == 2**12:
+ return (64, 64)
+ elif target_codebook_size == 2**14:
+ return (128, 128)
+ elif target_codebook_size == 2**16:
+ return (256, 256)
+ return None
+
+ @staticmethod
+ def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
+ """
+ Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
+ """
+ assert (
+ target_codebook_size & (target_codebook_size - 1) == 0
+ ), 'Codebook size should be a power of two for LFQ'
+
+ return (2,) * int(math.log2(target_codebook_size))
+
+ def setup(self):
+ self.proj_down = nn.Dense(len(self.bins_per_dim))
+ self.proj_up = nn.Dense(self.input_dim)
+
+ def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
+ tokens, z = self.encode(inputs)
+ output = self.decode(tokens, z_grad=z)
+ return tokens, output
+
+ def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
+ bases = jnp.array(self.bins_per_dim)
+
+ x = self.proj_down(inputs)
+ z = jnp.tanh(x)
+
+ # Quantize
+ digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
+ tokens = self.undigitize(digits)
+
+ return tokens, z
+
+ def decode(
+ self, tokens: jnp.ndarray, z_grad: jax.Array | None = None
+ ) -> jnp.ndarray:
+ bases = jnp.array(self.bins_per_dim)
+ digits = self.digitize(tokens)
+
+ z_q = digits / (bases - 1) * 2 - 1
+
+ if z_grad is not None:
+ chex.assert_equal_shape([z_q, z_grad])
+ z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
+
+ return self.proj_up(z_q)
+
+ def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
+ return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
+
+ def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
+ return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(
+ self.bins_per_dim
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return math.prod(self.bins_per_dim)
+
+
+class ResNetDownBlock(nn.Module):
+ stride: int = 1
+ n_filters: int = 64
+ dropout_rate: float = 0.0
+ group_size: int = 32
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
+ skip = x
+
+ if self.stride > 1 or x.shape[-1] != self.n_filters:
+ skip = nn.Conv(
+ self.n_filters, (self.stride,), (self.stride,), 'SAME'
+ )(skip)
+
+ x = nn.Conv(self.n_filters, (3,), (self.stride,), 'SAME')(x)
+ x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = nn.relu(x)
+ x = nn.Conv(self.n_filters, (3,), (1,), 'SAME')(x)
+
+ return skip + x
+
+
+class ResNetUpBlock(nn.Module):
+ stride: int = 1
+ n_filters: int = 64
+ dropout_rate: float = 0.0
+ group_size: int = 32
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
+ skip = x
+
+ if self.stride > 1:
+ skip = nn.ConvTranspose(
+ self.n_filters, (self.stride,), (self.stride,), 'SAME'
+ )(skip)
+
+ x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), 'SAME')(x)
+ x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = nn.relu(x)
+ x = nn.ConvTranspose(self.n_filters, (3,), (1,), 'SAME')(x)
+
+ return skip + x
+
+
+@dataclass
+class LfqCodebookOutput:
+ tokens: jnp.ndarray
+ z: jnp.ndarray
+ z_q: jnp.ndarray
+ token_log_probs: jnp.ndarray
+ commit_loss: jnp.ndarray
+
+
+class LookupFreeQuantization(nn.Module):
+ num_dims: int
+ latent_dim: int
+
+ def setup(self):
+ self.codebook = jnp.array([-1, 1])
+ self.activation = nn.tanh
+
+ self.project_down = nn.Dense(self.num_dims)
+ self.project_up = nn.Dense(self.latent_dim)
+
+ def encode(self, z: jnp.ndarray) -> jnp.ndarray:
+ z = self.project_down(z)
+ token_squared_distances = jnp.square(z[..., None] - self.codebook)
+ token_bits = jnp.argmin(token_squared_distances, axis=-1)
+ return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
+
+ def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
+ token_bits = (
+ tokens[..., None] & (2 ** jnp.arange(self.num_dims))
+ ).astype(jnp.int32)
+ return self.project_up(self.codebook[token_bits])
+
+ def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
+ z = self.project_down(x)
+ z = self.activation(z)
+
+ token_squared_distances = jnp.square(z[..., None] - self.codebook)
+ tokens = jnp.argmin(token_squared_distances, axis=-1)
+
+ token_bit_log_probs = -token_squared_distances
+ # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
+ token_bit_expansions = jnp.bitwise_and(
+ jnp.arange(2**self.num_dims)[None, :],
+ 2 ** jnp.arange(self.num_dims)[:, None],
+ ).astype(jnp.int32)
+ token_log_probs = (
+ token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ + token_bit_log_probs[..., 1] @ token_bit_expansions
+ ) # (batch_size, num_tokens, 2 ** num_dims)
+ token_log_probs = jax.lax.stop_gradient(
+ jax.nn.log_softmax(token_log_probs, axis=-1)
+ )
+ chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
+
+ z_q = self.codebook[tokens]
+ commit_loss = jnp.square(z - z_q).mean()
+ z_q = jax.lax.stop_gradient(z_q - z) + z
+
+ z_q = self.project_up(z_q)
+ z = self.project_up(z)
+
+ tokens = jnp.sum(
+ tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1
+ )
+ return LfqCodebookOutput(
+ tokens=tokens,
+ z=z,
+ z_q=z_q,
+ token_log_probs=jnp.zeros(()),
+ commit_loss=commit_loss,
+ )
+
+
+def make_block_causal_attention_matrix(
+ q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int
+) -> jnp.ndarray:
+ return nn.make_attention_mask(
+ q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)
+ )
+
+
+class GeGLU(Module):
+ """Gated Linear Unit with GELU (GeGLU) activation function.
+ GeGLU is a Flax layer that combines a linear transformation with a GELU
+ activation function in a gating mechanism. It is often used in Transformer models
+ to provide non-linear capabilities while preserving a strong linear component.
+
+ Attributes:
+ features: the number of output features (default: None).
+ """
+
+ output_dim: int = -1
+
+ @compact
+ def __call__(self, inputs: Array) -> Array:
+ """Applies the GeGLU activation to the inputs.
+ Args:
+ inputs: the nd-array to apply the GeGLU activation function to.
+ Returns:
+ The transformed input.
+ """
+ output_dim = (
+ inputs.shape[-1] if self.output_dim == -1 else self.output_dim
+ )
+
+ x = nn.Dense(output_dim * 2)(inputs)
+ x, gate = x[..., :output_dim], x[..., output_dim:]
+ return x * nn.gelu(gate)
+
+
+class CrossAttentionLayer(nn.Module):
+ dropout_rate: float = 0.0
+ num_heads: int = None
+ causal: bool = False
+ mlp_ratio: float = 4.0
+
+ @nn.compact
+ def __call__(
+ self,
+ x: jnp.ndarray,
+ y: jnp.ndarray,
+ *,
+ mask_self: jnp.ndarray | None = None,
+ mask_cross: jnp.ndarray | None = None,
+ train: bool = True,
+ ) -> jnp.ndarray:
+ d_embed = x.shape[-1]
+ seq_len_q = x.shape[-2]
+ seq_len_k = y.shape[-2]
+
+ if self.causal:
+ # One block size will be 1
+ bs_q = max(seq_len_q // seq_len_k, 1)
+ bs_k = max(seq_len_k // seq_len_q, 1)
+
+ mask_self = nn.make_causal_mask(x[..., 0])
+ mask_cross = make_block_causal_attention_matrix(
+ x[..., 0], y[..., 0], bs_q, bs_k
+ )
+
+ # Self-attention block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads or d_embed // 64,
+ dropout_rate=self.dropout_rate,
+ deterministic=not train,
+ )(x, x, x, mask=mask_self)
+ x = skip + x
+
+ # Cross-attention block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads or d_embed // 64,
+ dropout_rate=self.dropout_rate,
+ deterministic=not train,
+ )(x, y, y, mask=mask_cross)
+ x = skip + x
+
+ # MLP block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = GeGLU()(x)
+ x = nn.Dense(d_embed)(x)
+ return skip + x
+
+
+def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
+ seq_len, d_embed = shape
+
+ position = jnp.arange(0, seq_len, 1)
+ div_term = jnp.exp(
+ jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)
+ )
+ return jnp.concatenate(
+ [
+ jnp.sin(position[:, jnp.newaxis] * div_term),
+ jnp.cos(position[:, jnp.newaxis] * div_term),
+ ],
+ axis=-1,
+ )
+
+
+class TokenizerEncoderDecoder(nn.Module):
+ num_tokens: int
+ num_cross_tokens: int
+ num_layers: int
+ causal: bool
+
+ mlp_ratio: float = 4.0
+ use_state_conditioning: bool = False
+
+ @nn.compact
+ def __call__(
+ self,
+ y: jnp.ndarray,
+ *,
+ train: bool = True,
+ state_conditioning: jnp.ndarray | None = None,
+ mask: jnp.ndarray | None = None,
+ ) -> jnp.ndarray:
+ x = self.param(
+ 'q_embed', sinusoidal_pe_init, (self.num_tokens, y.shape[-1])
+ )
+ x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
+
+ if mask is not None:
+ # mask is (batch_dims..., num_cross_tokens)
+ chex.assert_equal_shape([y[..., 0], mask])
+ attn_mask = einops.repeat(
+ mask, '... kv -> ... 1 q kv', q=self.num_tokens
+ )
+ else:
+ attn_mask = jnp.ones(
+ (*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)
+ )
+
+ if self.use_state_conditioning:
+ assert (
+ state_conditioning is not None
+ ), 'State conditioning is required for this model.'
+ state_embed = nn.Dense(y.shape[-1], name='state_proj')(
+ state_conditioning
+ )[..., None, :]
+ y = jnp.concatenate([y, state_embed], axis=-2)
+ attn_mask = jnp.concatenate(
+ [attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1
+ )
+
+ y = y + self.param('y_pos_enc', sinusoidal_pe_init, y.shape[-2:])
+
+ for _ in range(self.num_layers):
+ x = CrossAttentionLayer(
+ causal=self.causal, mlp_ratio=self.mlp_ratio
+ )(x, y, train=train, mask_self=None, mask_cross=attn_mask)
+
+ return x
+
+
+class FsqAttentionTokenizer(nn.Module):
+ embed_dim: int
+ data_dim: int
+ data_horizon: int
+ num_tokens: int
+ num_layers: int
+ target_codebook_size: int
+ causal: bool = False
+ mlp_ratio: float = 2.0
+
+ bound: float | None = None
+
+ use_state_conditioning: bool = False
+
+ @property
+ def vocab_size(self) -> int:
+ return math.prod(
+ FsqCodebook._get_bins_fsq(self.target_codebook_size)
+ ) # noqa: SLF001
+
+ def setup(self):
+ self.proj = nn.Dense(self.embed_dim)
+ self.encoder = TokenizerEncoderDecoder(
+ num_tokens=self.num_tokens,
+ num_cross_tokens=self.data_horizon,
+ num_layers=self.num_layers,
+ causal=self.causal,
+ use_state_conditioning=self.use_state_conditioning,
+ mlp_ratio=self.mlp_ratio,
+ )
+ self.codebook = FsqCodebook(
+ input_dim=self.embed_dim,
+ target_codebook_size=self.target_codebook_size,
+ codebook_type='custom',
+ )
+ self.decoder = TokenizerEncoderDecoder(
+ num_tokens=self.data_horizon,
+ num_cross_tokens=self.num_tokens,
+ num_layers=self.num_layers,
+ causal=self.causal,
+ use_state_conditioning=self.use_state_conditioning,
+ mlp_ratio=self.mlp_ratio,
+ )
+
+ self.proj_mean = nn.Dense(self.data_dim)
+ self.out_scale = self.param('out_scale', lambda _: jnp.full((), 1.0))
+
+ def tokenize(
+ self,
+ action: jnp.ndarray,
+ *,
+ obs: jnp.ndarray | None = None,
+ train: bool = False,
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
+ if self.bound is not None:
+ action = jnp.clip(action, -self.bound, self.bound)
+
+ x = self.proj(action)
+ x = self.encoder(x, train=train, state_conditioning=obs)
+
+ return self.codebook.encode(x)
+
+ def detokenize(
+ self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None
+ ) -> jnp.ndarray:
+ x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
+ mean = self.proj_mean(x)
+ return mean * self.out_scale
+
+ def loss(
+ self,
+ action: jnp.ndarray,
+ *,
+ obs: jnp.ndarray | None = None,
+ train: bool = True,
+ ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
+ # Encode
+ x = self.proj(action)
+ z = self.encoder(x, train=train, state_conditioning=obs)
+
+ # Quantize
+ tokens, z = self.codebook(z)
+
+ # Decode
+ x = self.decoder(z, train=train, state_conditioning=obs)
+ mean = self.proj_mean(x) * self.out_scale
+
+ mse = jnp.mean(jnp.square(action - mean))
+ mae = jnp.mean(jnp.abs(action - mean))
+
+ return mse, {
+ 'mse': mse,
+ 'mae': mae,
+ }
+
+ def __call__(
+ self, *args: Any, **kwargs: Any
+ ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
+ """
+ Dummy for .init
+ """
+ return self.loss(*args, **kwargs)
diff --git a/vla_arena/models/openpi/src/openpi/models/vit.py b/vla_arena/models/openpi/src/openpi/models/vit.py
new file mode 100644
index 00000000..b6e6564b
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models/vit.py
@@ -0,0 +1,357 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2024 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
+
+from collections.abc import Callable
+from typing import Any
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from openpi.models import resnet as models_resnet
+
+
+Array = Any
+PRNGKey = Any
+Shape = tuple[int]
+Dtype = Any
+
+
+class IdentityLayer(nn.Module):
+ """Identity layer, convenient for giving a name to an array."""
+
+ @nn.compact
+ def __call__(self, x):
+ return x
+
+
+class AddPositionEmbs(nn.Module):
+ """Adds learned positional embeddings to the inputs.
+
+ Attributes:
+ posemb_init: positional embedding initializer.
+ """
+
+ posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
+ param_dtype: Dtype = jnp.float32
+
+ @nn.compact
+ def __call__(self, inputs):
+ """Applies the AddPositionEmbs module.
+
+ Args:
+ inputs: Inputs to the layer.
+
+ Returns:
+ Output tensor with shape `(bs, timesteps, in_dim)`.
+ """
+ # inputs.shape is (batch_size, seq_len, emb_dim).
+ assert (
+ inputs.ndim == 3
+ ), f'Number of dimensions should be 3, but it is: {inputs.ndim}'
+ pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
+ pe = self.param(
+ 'pos_embedding', self.posemb_init, pos_emb_shape, self.param_dtype
+ )
+ return inputs + pe
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+
+ mlp_dim: int
+ dtype: Dtype = jnp.float32
+ param_dtype: Dtype = jnp.float32
+ out_dim: int | None = None
+ dropout_rate: float = 0.1
+ kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
+ nn.initializers.xavier_uniform()
+ )
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
+ nn.initializers.normal(stddev=1e-6)
+ )
+
+ @nn.compact
+ def __call__(self, inputs, *, deterministic):
+ """Applies Transformer MlpBlock module."""
+ actual_out_dim = (
+ inputs.shape[-1] if self.out_dim is None else self.out_dim
+ )
+ x = nn.Dense(
+ features=self.mlp_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ kernel_init=self.kernel_init,
+ bias_init=self.bias_init,
+ )( # pytype: disable=wrong-arg-types
+ inputs
+ )
+ x = nn.gelu(x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
+ output = nn.Dense(
+ features=actual_out_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ kernel_init=self.kernel_init,
+ bias_init=self.bias_init,
+ )( # pytype: disable=wrong-arg-types
+ x
+ )
+ return nn.Dropout(rate=self.dropout_rate)(
+ output, deterministic=deterministic
+ )
+
+
+class Encoder1DBlock(nn.Module):
+ """Transformer encoder layer.
+
+ Attributes:
+ inputs: input data.
+ mlp_dim: dimension of the mlp on top of attention block.
+ dtype: the dtype of the computation (default: float32).
+ dropout_rate: dropout rate.
+ attention_dropout_rate: dropout for attention heads.
+ deterministic: bool, deterministic or not (to apply dropout).
+ num_heads: Number of heads in nn.MultiHeadDotProductAttention
+ """
+
+ mlp_dim: int
+ num_heads: int
+ dtype: Dtype = jnp.float32
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+
+ @nn.compact
+ def __call__(self, inputs, deterministic):
+ """Applies Encoder1DBlock module.
+
+ Args:
+ inputs: Inputs to the layer.
+ deterministic: Dropout will not be applied when set to true.
+
+ Returns:
+ output after transformer encoder block.
+ """
+
+ # Attention block.
+ assert (
+ inputs.ndim == 3
+ ), f'Expected (batch, seq, hidden) got {inputs.shape}'
+ x = nn.LayerNorm(dtype=self.dtype)(inputs)
+ x = nn.MultiHeadDotProductAttention(
+ dtype=self.dtype,
+ kernel_init=nn.initializers.xavier_uniform(),
+ broadcast_dropout=False,
+ deterministic=deterministic,
+ dropout_rate=self.attention_dropout_rate,
+ num_heads=self.num_heads,
+ # why isn't this true by default???
+ force_fp32_for_softmax=True,
+ )(x, x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
+ x = x + inputs
+
+ # MLP block.
+ y = nn.LayerNorm(dtype=self.dtype)(x)
+ y = MlpBlock(
+ mlp_dim=self.mlp_dim,
+ dtype=self.dtype,
+ dropout_rate=self.dropout_rate,
+ )(y, deterministic=deterministic)
+
+ return x + y, None
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation.
+
+ Attributes:
+ num_layers: number of layers
+ mlp_dim: dimension of the mlp on top of attention block
+ num_heads: Number of heads in nn.MultiHeadDotProductAttention
+ dropout_rate: dropout rate.
+ attention_dropout_rate: dropout rate in self attention.
+ """
+
+ dtype: jax.typing.DTypeLike
+ num_layers: int
+ mlp_dim: int
+ num_heads: int
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ add_position_embedding: bool = True
+
+ @nn.compact
+ def __call__(self, x, *, train):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ x: Inputs to the layer.
+ train: Set to `True` when training.
+
+ Returns:
+ output of a transformer encoder.
+ """
+ assert x.ndim == 3 # (batch, len, emb)
+
+ if self.add_position_embedding:
+ x = AddPositionEmbs(
+ posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
+ name='posembed_input',
+ )(x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
+
+ x = x.astype(self.dtype)
+ # Input Encoder
+ block = nn.remat(
+ Encoder1DBlock, prevent_cse=False, static_argnums=(2,)
+ )
+ x, _ = nn.scan(
+ block,
+ variable_axes={'params': 0},
+ split_rngs={'params': True, 'dropout': True},
+ in_axes=nn.broadcast,
+ length=self.num_layers,
+ )(
+ name='encoderblock',
+ mlp_dim=self.mlp_dim,
+ dropout_rate=self.dropout_rate,
+ attention_dropout_rate=self.attention_dropout_rate,
+ dtype=self.dtype,
+ num_heads=self.num_heads,
+ )(
+ x, not train
+ )
+ return nn.LayerNorm(name='encoder_norm', dtype=self.dtype)(x)
+
+
+class VisionTransformer(nn.Module):
+ """VisionTransformer."""
+
+ dtype: jax.typing.DTypeLike
+ num_classes: int
+ patches: Any
+ transformer: Any
+ hidden_size: int
+ resnet: Any | None = None
+ representation_size: int | None = None
+ classifier: str = 'token'
+ head_bias_init: float = 0.0
+ encoder: type[nn.Module] = Encoder
+ model_name: str | None = None
+
+ @nn.compact
+ def __call__(self, inputs, *, train):
+ x = inputs
+ # (Possibly partial) ResNet root.
+ if self.resnet is not None:
+ width = int(64 * self.resnet.width_factor)
+
+ # Root block.
+ x = models_resnet.StdConv(
+ features=width,
+ kernel_size=(7, 7),
+ strides=(2, 2),
+ use_bias=False,
+ name='conv_root',
+ )(x)
+ x = nn.GroupNorm(name='gn_root')(x)
+ x = nn.relu(x)
+ x = nn.max_pool(
+ x, window_shape=(3, 3), strides=(2, 2), padding='SAME'
+ )
+
+ # ResNet stages.
+ if self.resnet.num_layers:
+ x = models_resnet.ResNetStage(
+ block_size=self.resnet.num_layers[0],
+ nout=width,
+ first_stride=(1, 1),
+ name='block1',
+ )(x)
+ for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
+ x = models_resnet.ResNetStage(
+ block_size=block_size,
+ nout=width * 2**i,
+ first_stride=(2, 2),
+ name=f'block{i + 1}',
+ )(x)
+
+ n, h, w, c = x.shape
+
+ # We can merge s2d+emb into a single conv; it's the same.
+ x = nn.Conv(
+ features=self.hidden_size,
+ kernel_size=self.patches.size,
+ strides=self.patches.size,
+ padding='VALID',
+ name='embedding',
+ )(x)
+
+ # Here, x is a grid of embeddings.
+
+ # (Possibly partial) Transformer.
+ if self.transformer is not None:
+ n, h, w, c = x.shape
+ x = jnp.reshape(x, [n, h * w, c])
+
+ # If we want to add a class token, add it here.
+ if self.classifier in ['token', 'token_unpooled']:
+ cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
+ cls = jnp.tile(cls, [n, 1, 1])
+ x = jnp.concatenate([cls, x], axis=1)
+
+ x = self.encoder(
+ name='Transformer', **self.transformer, dtype=self.dtype
+ )(x, train=train)
+
+ if self.classifier == 'token':
+ x = x[:, 0]
+ elif self.classifier == 'gap':
+ x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
+ elif self.classifier in ['unpooled', 'token_unpooled']:
+ pass
+ else:
+ raise ValueError(f'Invalid classifier={self.classifier}')
+
+ if self.representation_size is not None:
+ x = nn.Dense(features=self.representation_size, name='pre_logits')(
+ x
+ )
+ x = nn.tanh(x)
+ else:
+ x = IdentityLayer(name='pre_logits')(x)
+
+ if self.num_classes:
+ x = nn.Dense(
+ features=self.num_classes,
+ name='head',
+ kernel_init=nn.initializers.zeros,
+ bias_init=nn.initializers.constant(self.head_bias_init),
+ )(x)
+ return x
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/gemma_pytorch.py b/vla_arena/models/openpi/src/openpi/models_pytorch/gemma_pytorch.py
new file mode 100644
index 00000000..8e76fbb6
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/gemma_pytorch.py
@@ -0,0 +1,372 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+import pytest
+import torch
+from torch import nn
+from transformers import GemmaForCausalLM, PaliGemmaForConditionalGeneration
+from transformers.models.auto import CONFIG_MAPPING
+from transformers.models.gemma import modeling_gemma
+
+
+class PaliGemmaWithExpertModel(nn.Module):
+ def __init__(
+ self,
+ vlm_config,
+ action_expert_config,
+ use_adarms=None,
+ precision: Literal['bfloat16', 'float32'] = 'bfloat16',
+ ):
+ if use_adarms is None:
+ use_adarms = [False, False]
+ super().__init__()
+
+ vlm_config_hf = CONFIG_MAPPING['paligemma']()
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
+ vlm_config_hf.image_token_index = 257152
+ vlm_config_hf.text_config.hidden_size = vlm_config.width
+ vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
+ vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
+ vlm_config_hf.text_config.head_dim = vlm_config.head_dim
+ vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
+ vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
+ vlm_config_hf.text_config.hidden_activation = 'gelu_pytorch_tanh'
+ vlm_config_hf.text_config.torch_dtype = 'float32'
+ vlm_config_hf.text_config.vocab_size = 257152
+ vlm_config_hf.text_config.use_adarms = use_adarms[0]
+ vlm_config_hf.text_config.adarms_cond_dim = (
+ vlm_config.width if use_adarms[0] else None
+ )
+ vlm_config_hf.vision_config.intermediate_size = 4304
+ vlm_config_hf.vision_config.projection_dim = 2048
+ vlm_config_hf.vision_config.projector_hidden_act = 'gelu_fast'
+ vlm_config_hf.vision_config.torch_dtype = 'float32'
+
+ action_expert_config_hf = CONFIG_MAPPING['gemma'](
+ head_dim=action_expert_config.head_dim,
+ hidden_size=action_expert_config.width,
+ intermediate_size=action_expert_config.mlp_dim,
+ num_attention_heads=action_expert_config.num_heads,
+ num_hidden_layers=action_expert_config.depth,
+ num_key_value_heads=action_expert_config.num_kv_heads,
+ vocab_size=257152,
+ hidden_activation='gelu_pytorch_tanh',
+ torch_dtype='float32',
+ use_adarms=use_adarms[1],
+ adarms_cond_dim=(
+ action_expert_config.width if use_adarms[1] else None
+ ),
+ )
+
+ self.paligemma = PaliGemmaForConditionalGeneration(
+ config=vlm_config_hf
+ )
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
+ self.gemma_expert.model.embed_tokens = None
+
+ self.to_bfloat16_for_selected_params(precision)
+
+ def to_bfloat16_for_selected_params(
+ self, precision: Literal['bfloat16', 'float32'] = 'bfloat16'
+ ):
+ if precision == 'bfloat16':
+ self.to(dtype=torch.bfloat16)
+ elif precision == 'float32':
+ self.to(dtype=torch.float32)
+ return
+ else:
+ raise ValueError(f'Invalid precision: {precision}')
+
+ params_to_keep_float32 = [
+ 'vision_tower.vision_model.embeddings.patch_embedding.weight',
+ 'vision_tower.vision_model.embeddings.patch_embedding.bias',
+ 'vision_tower.vision_model.embeddings.position_embedding.weight',
+ 'input_layernorm',
+ 'post_attention_layernorm',
+ 'model.norm',
+ ]
+
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_keep_float32):
+ param.data = param.data.to(dtype=torch.float32)
+
+ def embed_image(self, image: torch.Tensor):
+ return self.paligemma.model.get_image_features(image)
+
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.paligemma.language_model.embed_tokens(tokens)
+
+ def forward(
+ self,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
+ inputs_embeds: list[torch.FloatTensor] | None = None,
+ use_cache: bool | None = None,
+ adarms_cond: list[torch.Tensor] | None = None,
+ ):
+ if adarms_cond is None:
+ adarms_cond = [None, None]
+ if inputs_embeds[1] is None:
+ prefix_output = self.paligemma.language_model.forward(
+ inputs_embeds=inputs_embeds[0],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=(
+ adarms_cond[0] if adarms_cond is not None else None
+ ),
+ )
+ prefix_past_key_values = prefix_output.past_key_values
+ prefix_output = prefix_output.last_hidden_state
+ suffix_output = None
+ elif inputs_embeds[0] is None:
+ suffix_output = self.gemma_expert.model.forward(
+ inputs_embeds=inputs_embeds[1],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=(
+ adarms_cond[1] if adarms_cond is not None else None
+ ),
+ )
+ suffix_output = suffix_output.last_hidden_state
+ prefix_output = None
+ prefix_past_key_values = None
+ else:
+ models = [self.paligemma.language_model, self.gemma_expert.model]
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
+
+ # Check if gradient checkpointing is enabled for any of the models
+ use_gradient_checkpointing = (
+ hasattr(self.gemma_expert.model, 'gradient_checkpointing')
+ and self.gemma_expert.model.gradient_checkpointing
+ and self.training
+ ) or (
+ hasattr(self, 'gradient_checkpointing')
+ and self.gradient_checkpointing
+ and self.training
+ )
+
+ # Force enable gradient checkpointing if we're in training mode and the model supports it
+ if self.training and hasattr(
+ self.gemma_expert.model, 'gradient_checkpointing'
+ ):
+ if not self.gemma_expert.model.gradient_checkpointing:
+ print(
+ 'Forcing gradient checkpointing to be enabled for Gemma expert model'
+ )
+ self.gemma_expert.model.gradient_checkpointing = True
+ use_gradient_checkpointing = True
+
+ # Debug gradient checkpointing status
+ if (
+ hasattr(self, '_debug_gc_printed')
+ and not self._debug_gc_printed
+ ):
+ print(
+ f'Gemma expert model gradient checkpointing: {use_gradient_checkpointing}'
+ )
+ print(f'Model training mode: {self.training}')
+ print(
+ f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
+ )
+ if hasattr(self.gemma_expert.model, 'gradient_checkpointing'):
+ print(
+ f'Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}'
+ )
+ self._debug_gc_printed = True
+
+ # Define the complete layer computation function for gradient checkpointing
+ def compute_layer_complete(
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ ):
+ models = [
+ self.paligemma.language_model,
+ self.gemma_expert.model,
+ ]
+
+ query_states = []
+ key_states = []
+ value_states = []
+ gates = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ hidden_states, gate = layer.input_layernorm(
+ hidden_states, cond=adarms_cond[i]
+ )
+ gates.append(gate)
+
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
+ query_state = (
+ layer.self_attn.q_proj(hidden_states)
+ .view(hidden_shape)
+ .transpose(1, 2)
+ )
+ key_state = (
+ layer.self_attn.k_proj(hidden_states)
+ .view(hidden_shape)
+ .transpose(1, 2)
+ )
+ value_state = (
+ layer.self_attn.v_proj(hidden_states)
+ .view(hidden_shape)
+ .transpose(1, 2)
+ )
+
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+
+ # Concatenate and process attention
+ query_states = torch.cat(query_states, dim=2)
+ key_states = torch.cat(key_states, dim=2)
+ value_states = torch.cat(value_states, dim=2)
+
+ dummy_tensor = torch.zeros(
+ query_states.shape[0],
+ query_states.shape[2],
+ query_states.shape[-1],
+ device=query_states.device,
+ dtype=query_states.dtype,
+ )
+ cos, sin = self.paligemma.model.language_model.rotary_emb(
+ dummy_tensor, position_ids
+ )
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, unsqueeze_dim=1
+ )
+
+ batch_size = query_states.shape[0]
+ scaling = self.paligemma.language_model.layers[
+ layer_idx
+ ].self_attn.scaling
+
+ # Attention computation
+ att_output, _ = modeling_gemma.eager_attention_forward(
+ self.paligemma.language_model.layers[layer_idx].self_attn,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling,
+ )
+ # Get head_dim from the current layer, not from the model
+ head_dim = self.paligemma.language_model.layers[
+ layer_idx
+ ].self_attn.head_dim
+ att_output = att_output.reshape(
+ batch_size, -1, 1 * 8 * head_dim
+ )
+
+ # Process layer outputs
+ outputs_embeds = []
+ start_pos = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ end_pos = start_pos + hidden_states.shape[1]
+
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(
+ layer.self_attn.o_proj.weight.dtype
+ )
+ out_emb = layer.self_attn.o_proj(
+ att_output[:, start_pos:end_pos]
+ )
+
+ # first residual
+ out_emb = modeling_gemma._gated_residual(
+ hidden_states, out_emb, gates[i]
+ )
+ after_first_residual = out_emb.clone()
+ out_emb, gate = layer.post_attention_layernorm(
+ out_emb, cond=adarms_cond[i]
+ )
+ # Convert to bfloat16 if the next layer (mlp) uses bfloat16
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
+ out_emb = out_emb.to(dtype=torch.bfloat16)
+
+ out_emb = layer.mlp(out_emb)
+ # second residual
+ out_emb = modeling_gemma._gated_residual(
+ after_first_residual, out_emb, gate
+ )
+ outputs_embeds.append(out_emb)
+ start_pos = end_pos
+
+ return outputs_embeds
+
+ # Process all layers with gradient checkpointing if enabled
+ for layer_idx in range(num_layers):
+ if use_gradient_checkpointing:
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_layer_complete,
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ )
+ else:
+ inputs_embeds = compute_layer_complete(
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ )
+
+ # Old code removed - now using compute_layer_complete function above
+
+ # final norm
+ # Define final norm computation function for gradient checkpointing
+ def compute_final_norms(inputs_embeds, adarms_cond):
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ out_emb, _ = models[i].norm(
+ hidden_states, cond=adarms_cond[i]
+ )
+ outputs_embeds.append(out_emb)
+ return outputs_embeds
+
+ # Apply gradient checkpointing to final norm if enabled
+ if use_gradient_checkpointing:
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_final_norms,
+ inputs_embeds,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ )
+ else:
+ outputs_embeds = compute_final_norms(
+ inputs_embeds, adarms_cond
+ )
+
+ prefix_output = outputs_embeds[0]
+ suffix_output = outputs_embeds[1]
+ prefix_past_key_values = None
+
+ return [prefix_output, suffix_output], prefix_past_key_values
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/pi0_pytorch.py b/vla_arena/models/openpi/src/openpi/models_pytorch/pi0_pytorch.py
new file mode 100644
index 00000000..554af3d8
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/pi0_pytorch.py
@@ -0,0 +1,572 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import math
+
+import openpi.models.gemma as _gemma
+import openpi.models_pytorch.preprocessing_pytorch as _preprocessing
+import torch
+import torch.nn.functional as F # noqa: N812
+from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
+from torch import Tensor, nn
+
+
+def get_safe_dtype(target_dtype, device_type):
+ """Get a safe dtype for the given device type."""
+ if device_type == 'cpu':
+ # CPU doesn't support bfloat16, use float32 instead
+ if target_dtype == torch.bfloat16:
+ return torch.float32
+ if target_dtype == torch.float64:
+ return torch.float64
+ return target_dtype
+
+
+def create_sinusoidal_pos_embedding(
+ time: torch.tensor,
+ dimension: int,
+ min_period: float,
+ max_period: float,
+ device='cpu',
+) -> Tensor:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if dimension % 2 != 0:
+ raise ValueError(f'dimension ({dimension}) must be divisible by 2')
+
+ if time.ndim != 1:
+ raise ValueError(
+ 'The time tensor is expected to be of shape `(batch_size, )`.'
+ )
+
+ dtype = get_safe_dtype(torch.float64, device.type)
+ fraction = torch.linspace(
+ 0.0, 1.0, dimension // 2, dtype=dtype, device=device
+ )
+ period = min_period * (max_period / min_period) ** fraction
+
+ # Compute the outer product
+ scaling_factor = 1.0 / period * 2 * math.pi
+ sin_input = scaling_factor[None, :] * time[:, None]
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
+
+
+def sample_beta(alpha, beta, bsize, device):
+ alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
+ beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
+ dist = torch.distributions.Beta(alpha_t, beta_t)
+ return dist.sample((bsize,))
+
+
+def make_att_2d_masks(pad_masks, att_masks):
+ """Copied from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
+ it and 0 where it shares the same attention mask as the previous token.
+ """
+ if att_masks.ndim != 2:
+ raise ValueError(att_masks.ndim)
+ if pad_masks.ndim != 2:
+ raise ValueError(pad_masks.ndim)
+
+ cumsum = torch.cumsum(att_masks, dim=1)
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
+ return att_2d_masks & pad_2d_masks
+
+
+class PI0Pytorch(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pi05 = config.pi05
+
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
+
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
+ paligemma_config,
+ action_expert_config,
+ use_adarms=[False, True] if self.pi05 else [False, False],
+ precision=config.dtype,
+ )
+
+ self.action_in_proj = nn.Linear(32, action_expert_config.width)
+ self.action_out_proj = nn.Linear(action_expert_config.width, 32)
+
+ if self.pi05:
+ self.time_mlp_in = nn.Linear(
+ action_expert_config.width, action_expert_config.width
+ )
+ self.time_mlp_out = nn.Linear(
+ action_expert_config.width, action_expert_config.width
+ )
+ else:
+ self.state_proj = nn.Linear(32, action_expert_config.width)
+ self.action_time_mlp_in = nn.Linear(
+ 2 * action_expert_config.width, action_expert_config.width
+ )
+ self.action_time_mlp_out = nn.Linear(
+ action_expert_config.width, action_expert_config.width
+ )
+
+ torch.set_float32_matmul_precision('high')
+ self.sample_actions = torch.compile(
+ self.sample_actions, mode='max-autotune'
+ )
+
+ # Initialize gradient checkpointing flag
+ self.gradient_checkpointing_enabled = False
+
+ msg = 'transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.'
+ try:
+ from transformers.models.siglip import check
+
+ if (
+ not check.check_whether_transformers_replace_is_installed_correctly()
+ ):
+ raise ValueError(msg)
+ except ImportError:
+ raise ValueError(msg) from None
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.gradient_checkpointing_enabled = True
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = (
+ True
+ )
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = (
+ True
+ )
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = (
+ True
+ )
+
+ logging.info('Enabled gradient checkpointing for PI0Pytorch model')
+
+ def gradient_checkpointing_disable(self):
+ """Disable gradient checkpointing."""
+ self.gradient_checkpointing_enabled = False
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = (
+ False
+ )
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = (
+ False
+ )
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = (
+ False
+ )
+
+ logging.info('Disabled gradient checkpointing for PI0Pytorch model')
+
+ def is_gradient_checkpointing_enabled(self):
+ """Check if gradient checkpointing is enabled."""
+ return self.gradient_checkpointing_enabled
+
+ def _apply_checkpoint(self, func, *args, **kwargs):
+ """Helper method to apply gradient checkpointing if enabled."""
+ if self.gradient_checkpointing_enabled and self.training:
+ return torch.utils.checkpoint.checkpoint(
+ func,
+ *args,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ **kwargs,
+ )
+ return func(*args, **kwargs)
+
+ def _prepare_attention_masks_4d(self, att_2d_masks):
+ """Helper method to prepare 4D attention masks for transformer."""
+ att_2d_masks_4d = att_2d_masks[:, None, :, :]
+ return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
+
+ def _preprocess_observation(self, observation, *, train=True):
+ """Helper method to preprocess observation."""
+ observation = _preprocessing.preprocess_observation_pytorch(
+ observation, train=train
+ )
+ return (
+ list(observation.images.values()),
+ list(observation.image_masks.values()),
+ observation.tokenized_prompt,
+ observation.tokenized_prompt_mask,
+ observation.state,
+ )
+
+ def sample_noise(self, shape, device):
+ return torch.normal(
+ mean=0.0,
+ std=1.0,
+ size=shape,
+ dtype=torch.float32,
+ device=device,
+ )
+
+ def sample_time(self, bsize, device):
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
+ time = time_beta * 0.999 + 0.001
+ return time.to(dtype=torch.float32, device=device)
+
+ def embed_prefix(
+ self, images, img_masks, lang_tokens, lang_masks
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
+ for PaliGemma transformer processing.
+ """
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ # Process images
+ for img, img_mask in zip(images, img_masks, strict=True):
+
+ def image_embed_func(img):
+ return self.paligemma_with_expert.embed_image(img)
+
+ img_emb = self._apply_checkpoint(image_embed_func, img)
+
+ bsize, num_img_embs = img_emb.shape[:2]
+
+ embs.append(img_emb)
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
+
+ # Create attention masks so that image tokens attend to each other
+ att_masks += [0] * num_img_embs
+
+ # Process language tokens
+ def lang_embed_func(lang_tokens):
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(
+ lang_tokens
+ )
+ lang_emb_dim = lang_emb.shape[-1]
+ return lang_emb * math.sqrt(lang_emb_dim)
+
+ lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
+
+ embs.append(lang_emb)
+ pad_masks.append(lang_masks)
+
+ # full attention between image and language inputs
+ num_lang_embs = lang_emb.shape[1]
+ att_masks += [0] * num_lang_embs
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(
+ att_masks, dtype=torch.bool, device=pad_masks.device
+ )
+
+ # Get batch size from the first dimension of the concatenated tensors
+ bsize = pad_masks.shape[0]
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks
+
+ def embed_suffix(self, state, noisy_actions, timestep):
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ if not self.pi05:
+ if self.state_proj.weight.dtype == torch.float32:
+ state = state.to(torch.float32)
+
+ # Embed state
+ def state_proj_func(state):
+ return self.state_proj(state)
+
+ state_emb = self._apply_checkpoint(state_proj_func, state)
+
+ embs.append(state_emb[:, None, :])
+ bsize = state_emb.shape[0]
+ device = state_emb.device
+
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
+ pad_masks.append(state_mask)
+
+ # Set attention masks so that image and language inputs do not attend to state or actions
+ att_masks += [1]
+
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = create_sinusoidal_pos_embedding(
+ timestep,
+ self.action_in_proj.out_features,
+ min_period=4e-3,
+ max_period=4.0,
+ device=timestep.device,
+ )
+ time_emb = time_emb.type(dtype=timestep.dtype)
+
+ # Fuse timestep + action information using an MLP
+ def action_proj_func(noisy_actions):
+ return self.action_in_proj(noisy_actions)
+
+ action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
+
+ if not self.pi05:
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
+
+ # Apply MLP layers
+ def mlp_func(action_time_emb):
+ x = self.action_time_mlp_in(action_time_emb)
+ x = F.silu(x) # swish == silu
+ return self.action_time_mlp_out(x)
+
+ action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
+ adarms_cond = None
+ else:
+ # time MLP (for adaRMS)
+ def time_mlp_func(time_emb):
+ x = self.time_mlp_in(time_emb)
+ x = F.silu(x) # swish == silu
+ x = self.time_mlp_out(x)
+ return F.silu(x)
+
+ time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
+ action_time_emb = action_emb
+ adarms_cond = time_emb
+
+ # Add to input tokens
+ embs.append(action_time_emb)
+
+ bsize, action_time_dim = action_time_emb.shape[:2]
+ action_time_mask = torch.ones(
+ bsize, action_time_dim, dtype=torch.bool, device=timestep.device
+ )
+ pad_masks.append(action_time_mask)
+
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
+ att_masks += [1] + ([0] * (self.config.action_horizon - 1))
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(
+ att_masks, dtype=embs.dtype, device=embs.device
+ )
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks, adarms_cond
+
+ def forward(self, observation, actions, noise=None, time=None) -> Tensor:
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
+ images, img_masks, lang_tokens, lang_masks, state = (
+ self._preprocess_observation(observation, train=True)
+ )
+
+ if noise is None:
+ noise = self.sample_noise(actions.shape, actions.device)
+
+ if time is None:
+ time = self.sample_time(actions.shape[0], actions.device)
+
+ time_expanded = time[:, None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks
+ )
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = (
+ self.embed_suffix(state, x_t, time)
+ )
+ if (
+ self.paligemma_with_expert.paligemma.language_model.layers[
+ 0
+ ].self_attn.q_proj.weight.dtype
+ == torch.bfloat16
+ ):
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
+
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
+
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
+
+ # Prepare attention masks
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
+
+ # Apply gradient checkpointing if enabled
+ def forward_func(
+ prefix_embs,
+ suffix_embs,
+ att_2d_masks_4d,
+ position_ids,
+ adarms_cond,
+ ):
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
+ attention_mask=att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+ return suffix_out
+
+ suffix_out = self._apply_checkpoint(
+ forward_func,
+ prefix_embs,
+ suffix_embs,
+ att_2d_masks_4d,
+ position_ids,
+ adarms_cond,
+ )
+
+ suffix_out = suffix_out[:, -self.config.action_horizon :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+
+ # Apply gradient checkpointing to final action projection if enabled
+ def action_out_proj_func(suffix_out):
+ return self.action_out_proj(suffix_out)
+
+ v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
+
+ return F.mse_loss(u_t, v_t, reduction='none')
+
+ @torch.no_grad()
+ def sample_actions(
+ self, device, observation, noise=None, num_steps=10
+ ) -> Tensor:
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
+ bsize = observation.state.shape[0]
+ if noise is None:
+ actions_shape = (
+ bsize,
+ self.config.action_horizon,
+ self.config.action_dim,
+ )
+ noise = self.sample_noise(actions_shape, device)
+
+ images, img_masks, lang_tokens, lang_masks, state = (
+ self._preprocess_observation(observation, train=False)
+ )
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks
+ )
+ prefix_att_2d_masks = make_att_2d_masks(
+ prefix_pad_masks, prefix_att_masks
+ )
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ # Compute image and language key value cache
+ prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
+ prefix_att_2d_masks
+ )
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = (
+ 'eager'
+ )
+
+ _, past_key_values = self.paligemma_with_expert.forward(
+ attention_mask=prefix_att_2d_masks_4d,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, None],
+ use_cache=True,
+ )
+
+ dt = -1.0 / num_steps
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
+
+ x_t = noise
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
+ while time >= -dt / 2:
+ expanded_time = time.expand(bsize)
+ v_t = self.denoise_step(
+ state,
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ expanded_time,
+ )
+
+ # Euler step - use new tensor assignment instead of in-place operation
+ x_t = x_t + dt * v_t
+ time += dt
+ return x_t
+
+ def denoise_step(
+ self,
+ state,
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ timestep,
+ ):
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = (
+ self.embed_suffix(state, x_t, timestep)
+ )
+
+ suffix_len = suffix_pad_masks.shape[1]
+ batch_size = prefix_pad_masks.shape[0]
+ prefix_len = prefix_pad_masks.shape[1]
+
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
+ batch_size, suffix_len, prefix_len
+ )
+
+ suffix_att_2d_masks = make_att_2d_masks(
+ suffix_pad_masks, suffix_att_masks
+ )
+
+ full_att_2d_masks = torch.cat(
+ [prefix_pad_2d_masks, suffix_att_2d_masks], dim=2
+ )
+
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
+ position_ids = (
+ prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+ )
+
+ # Prepare attention masks
+ full_att_2d_masks_4d = self._prepare_attention_masks_4d(
+ full_att_2d_masks
+ )
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = (
+ 'eager'
+ )
+
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
+ attention_mask=full_att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=[None, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+
+ suffix_out = outputs_embeds[1]
+ suffix_out = suffix_out[:, -self.config.action_horizon :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+ return self.action_out_proj(suffix_out)
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py b/vla_arena/models/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py
new file mode 100644
index 00000000..6fe402fc
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py
@@ -0,0 +1,222 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections.abc import Sequence
+
+import torch
+from openpi.shared import image_tools
+
+
+logger = logging.getLogger('openpi')
+
+# Constants moved from model.py
+IMAGE_KEYS = (
+ 'base_0_rgb',
+ 'left_wrist_0_rgb',
+ 'right_wrist_0_rgb',
+)
+
+IMAGE_RESOLUTION = (224, 224)
+
+
+def preprocess_observation_pytorch(
+ observation,
+ *,
+ train: bool = False,
+ image_keys: Sequence[str] = IMAGE_KEYS,
+ image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
+):
+ """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
+
+ This function avoids complex type annotations that can cause torch.compile issues.
+ """
+ if not set(image_keys).issubset(observation.images):
+ raise ValueError(
+ f'images dict missing keys: expected {image_keys}, got {list(observation.images)}'
+ )
+
+ batch_shape = observation.state.shape[:-1]
+
+ out_images = {}
+ for key in image_keys:
+ image = observation.images[key]
+
+ # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
+ # Handle both [B, C, H, W] and [B, H, W, C] formats
+ is_channels_first = (
+ image.shape[1] == 3
+ ) # Check if channels are in dimension 1
+
+ if is_channels_first:
+ # Convert [B, C, H, W] to [B, H, W, C] for processing
+ image = image.permute(0, 2, 3, 1)
+
+ if image.shape[1:3] != image_resolution:
+ logger.info(
+ f'Resizing image {key} from {image.shape[1:3]} to {image_resolution}'
+ )
+ image = image_tools.resize_with_pad_torch(image, *image_resolution)
+
+ if train:
+ # Convert from [-1, 1] to [0, 1] for PyTorch augmentations
+ image = image / 2.0 + 0.5
+
+ # Apply PyTorch-based augmentations
+ if 'wrist' not in key:
+ # Geometric augmentations for non-wrist cameras
+ height, width = image.shape[1:3]
+
+ # Random crop and resize
+ crop_height = int(height * 0.95)
+ crop_width = int(width * 0.95)
+
+ # Random crop
+ max_h = height - crop_height
+ max_w = width - crop_width
+ if max_h > 0 and max_w > 0:
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ start_h = torch.randint(
+ 0, max_h + 1, (1,), device=image.device
+ )
+ start_w = torch.randint(
+ 0, max_w + 1, (1,), device=image.device
+ )
+ image = image[
+ :,
+ start_h : start_h + crop_height,
+ start_w : start_w + crop_width,
+ :,
+ ]
+
+ # Resize back to original size
+ image = torch.nn.functional.interpolate(
+ image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
+ size=(height, width),
+ mode='bilinear',
+ align_corners=False,
+ ).permute(
+ 0, 2, 3, 1
+ ) # [b, c, h, w] -> [b, h, w, c]
+
+ # Random rotation (small angles)
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ angle = (
+ torch.rand(1, device=image.device) * 10 - 5
+ ) # Random angle between -5 and 5 degrees
+ if (
+ torch.abs(angle) > 0.1
+ ): # Only rotate if angle is significant
+ # Convert to radians
+ angle_rad = angle * torch.pi / 180.0
+
+ # Create rotation matrix
+ cos_a = torch.cos(angle_rad)
+ sin_a = torch.sin(angle_rad)
+
+ # Apply rotation using grid_sample
+ grid_x = torch.linspace(-1, 1, width, device=image.device)
+ grid_y = torch.linspace(-1, 1, height, device=image.device)
+
+ # Create meshgrid
+ grid_y, grid_x = torch.meshgrid(
+ grid_y, grid_x, indexing='ij'
+ )
+
+ # Expand to batch dimension
+ grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
+ grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
+
+ # Apply rotation transformation
+ grid_x_rot = grid_x * cos_a - grid_y * sin_a
+ grid_y_rot = grid_x * sin_a + grid_y * cos_a
+
+ # Stack and reshape for grid_sample
+ grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
+
+ image = torch.nn.functional.grid_sample(
+ image.permute(
+ 0, 3, 1, 2
+ ), # [b, h, w, c] -> [b, c, h, w]
+ grid,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False,
+ ).permute(
+ 0, 2, 3, 1
+ ) # [b, c, h, w] -> [b, h, w, c]
+
+ # Color augmentations for all cameras
+ # Random brightness
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ brightness_factor = (
+ 0.7 + torch.rand(1, device=image.device) * 0.6
+ ) # Random factor between 0.7 and 1.3
+ image = image * brightness_factor
+
+ # Random contrast
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ contrast_factor = (
+ 0.6 + torch.rand(1, device=image.device) * 0.8
+ ) # Random factor between 0.6 and 1.4
+ mean = image.mean(dim=[1, 2, 3], keepdim=True)
+ image = (image - mean) * contrast_factor + mean
+
+ # Random saturation (convert to HSV, modify S, convert back)
+ # For simplicity, we'll just apply a random scaling to the color channels
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ saturation_factor = (
+ 0.5 + torch.rand(1, device=image.device) * 1.0
+ ) # Random factor between 0.5 and 1.5
+ gray = image.mean(dim=-1, keepdim=True)
+ image = gray + (image - gray) * saturation_factor
+
+ # Clamp values to [0, 1]
+ image = torch.clamp(image, 0, 1)
+
+ # Back to [-1, 1]
+ image = image * 2.0 - 1.0
+
+ # Convert back to [B, C, H, W] format if it was originally channels-first
+ if is_channels_first:
+ image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
+
+ out_images[key] = image
+
+ # obtain mask
+ out_masks = {}
+ for key in out_images:
+ if key not in observation.image_masks:
+ # do not mask by default
+ out_masks[key] = torch.ones(
+ batch_shape, dtype=torch.bool, device=observation.state.device
+ )
+ else:
+ out_masks[key] = observation.image_masks[key]
+
+ # Create a simple object with the required attributes instead of using the complex Observation class
+ class SimpleProcessedObservation:
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ return SimpleProcessedObservation(
+ images=out_images,
+ image_masks=out_masks,
+ state=observation.state,
+ tokenized_prompt=observation.tokenized_prompt,
+ tokenized_prompt_mask=observation.tokenized_prompt_mask,
+ token_ar_mask=observation.token_ar_mask,
+ token_loss_mask=observation.token_loss_mask,
+ )
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py
new file mode 100644
index 00000000..867cba9c
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py
@@ -0,0 +1,188 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+
+
+class GemmaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Gemma-7B.
+ e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GemmaModel`]
+ hidden_size (`int`, *optional*, defaults to 3072):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 24576):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The legacy activation function. It is overwritten by the `hidden_activation`.
+ hidden_activation (`str` or `function`, *optional*):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ use_adarms (`bool`, *optional*, defaults to `False`):
+ Whether to use ADARMS.
+ adarms_cond_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the ADARMS condition.
+ ```python
+ >>> from transformers import GemmaModel, GemmaConfig
+ >>> # Initializing a Gemma gemma-7b style configuration
+ >>> configuration = GemmaConfig()
+ >>> # Initializing a model from the gemma-7b style configuration
+ >>> model = GemmaModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = 'gemma'
+ keys_to_ignore_at_inference = ['past_key_values']
+ base_model_tp_plan = {
+ 'layers.*.self_attn.q_proj': 'colwise',
+ 'layers.*.self_attn.k_proj': 'colwise',
+ 'layers.*.self_attn.v_proj': 'colwise',
+ 'layers.*.self_attn.o_proj': 'rowwise',
+ 'layers.*.mlp.gate_proj': 'colwise',
+ 'layers.*.mlp.up_proj': 'colwise',
+ 'layers.*.mlp.down_proj': 'rowwise',
+ }
+ base_model_pp_plan = {
+ 'embed_tokens': (['input_ids'], ['inputs_embeds']),
+ 'layers': (['hidden_states', 'attention_mask'], ['hidden_states']),
+ 'norm': (['hidden_states'], ['hidden_states']),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act='gelu_pytorch_tanh',
+ hidden_activation=None,
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ use_adarms: bool = False,
+ adarms_cond_dim: int | None = None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.hidden_activation = hidden_activation
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.use_adarms = use_adarms
+ self.adarms_cond_dim = adarms_cond_dim
+
+ # Set default for adarms_cond_dim if use_adarms is True
+ if self.use_adarms and self.adarms_cond_dim is None:
+ self.adarms_cond_dim = self.hidden_size
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ['GemmaConfig']
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py
new file mode 100644
index 00000000..88d86a41
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py
@@ -0,0 +1,1030 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
+from .configuration_gemma import GemmaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GemmaRMSNorm(nn.Module):
+ def __init__(
+ self, dim: int, eps: float = 1e-6, cond_dim: int | None = None
+ ):
+ super().__init__()
+ self.eps = eps
+ self.dim = dim
+ self.cond_dim = cond_dim
+
+ # Dense layer for adaptive normalization (if cond_dim is provided)
+ if cond_dim is not None:
+ # self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
+ self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
+ # Initialize with zeros (matches source implementation)
+ nn.init.zeros_(self.dense.weight)
+ else:
+ self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
+ self.dense = None
+
+ def _norm(self, x):
+ # Compute variance in float32 (like the source implementation)
+ var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
+ # Compute normalization in float32
+ normed_inputs = x * torch.rsqrt(var + self.eps)
+ return normed_inputs
+
+ def forward(self, x, cond=None):
+ dtype = x.dtype # original dtype, could be half-precision
+ normed_inputs = self._norm(x)
+
+ if cond is None or self.dense is None:
+ # regular RMSNorm
+ # scale by learned parameter in float32 (matches source implementation)
+ normed_inputs = normed_inputs * (1.0 + self.weight.float())
+ return (
+ normed_inputs.to(dtype),
+ None,
+ ) # return in original dtype with None gate
+
+ # adaptive RMSNorm (if cond is provided and dense layer exists)
+ if cond.shape[-1] != self.cond_dim:
+ raise ValueError(
+ f'Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}'
+ )
+
+ # self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
+ modulation = self.dense(cond)
+ # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
+ if len(x.shape) == 3: # [batch, seq, features]
+ modulation = modulation.unsqueeze(1)
+
+ scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
+
+ # Apply adaptive normalization: use model weight dtype to ensure compatibility
+ # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
+ # scale = scale.to(model_dtype)
+ # shift = shift.to(model_dtype)
+ # gate = gate.to(model_dtype)
+ # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
+
+ normed_inputs = normed_inputs * (
+ 1 + scale.to(torch.float32)
+ ) + shift.to(torch.float32)
+
+ return normed_inputs.to(dtype), gate.to(dtype)
+
+ def extra_repr(self):
+ repr_str = f'{tuple(self.weight.shape)}, eps={self.eps}'
+ if self.dense is not None:
+ repr_str += f', adaptive=True, cond_dim={self.cond_dim}'
+ return repr_str
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False
+ )
+ self.up_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False
+ )
+ self.down_proj = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=False
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(
+ self.act_fn(self.gate_proj(x)) * self.up_proj(x)
+ )
+ return down_proj
+
+
+class GemmaRotaryEmbedding(nn.Module):
+ def __init__(self, config: GemmaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get(
+ 'rope_type', config.rope_scaling.get('type')
+ )
+ else:
+ self.rope_type = 'default'
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device
+ )
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None]
+ .float()
+ .expand(position_ids.shape[0], -1, 1)
+ .to(x.device)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = (
+ x.device.type
+ if isinstance(x.device.type, str) and x.device.type != 'mps'
+ else 'cpu'
+ )
+ with torch.autocast(
+ device_type=device_type, enabled=False
+ ): # Force float32
+ freqs = (
+ inv_freq_expanded.float() @ position_ids_expanded.float()
+ ).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(
+ batch, num_key_value_heads * n_rep, slen, head_dim
+ )
+
+
+def _gated_residual(x, y, gate):
+ """
+ Applies gated residual connection with optional gate parameter.
+
+ Args:
+ x: Input tensor (residual)
+ y: Output tensor to be added
+ gate: Optional gate tensor to modulate the addition
+
+ Returns:
+ x + y if gate is None, otherwise x + y * gate
+ """
+ if x is None and y is None:
+ return None
+ if x is None or y is None:
+ return x if x is not None else y
+ if gate is None:
+ return x + y
+ return x + y * gate
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=dropout, training=module.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GemmaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(
+ config,
+ 'head_dim',
+ config.hidden_size // config.num_attention_heads,
+ )
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size,
+ config.num_attention_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim,
+ config.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_value: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ use_cache: bool = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = (
+ self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ )
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ # Use cache if provided
+ if past_key_value is not None:
+ if use_cache:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position,
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ else:
+ key_states = torch.cat(
+ [past_key_value[self.layer_idx][0], key_states], dim=2
+ )
+ value_states = torch.cat(
+ [past_key_value[self.layer_idx][1], value_states], dim=2
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != 'eager':
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
+ self.config._attn_implementation
+ ]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GemmaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = GemmaMLP(config)
+ cond_dim = (
+ getattr(config, 'adarms_cond_dim', None)
+ if getattr(config, 'use_adarms', False)
+ else None
+ )
+ self.input_layernorm = GemmaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
+ )
+ self.post_attention_layernorm = GemmaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_value: Cache | None = None,
+ output_attentions: bool | None = False,
+ use_cache: bool | None = False,
+ cache_position: torch.LongTensor | None = None,
+ position_embeddings: (
+ tuple[torch.Tensor, torch.Tensor] | None
+ ) = None, # necessary, but kept here for BC
+ adarms_cond: torch.Tensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[
+ torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None
+ ]:
+ residual = hidden_states
+ hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = _gated_residual(residual, hidden_states, gate)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states, gate = self.post_attention_layernorm(
+ hidden_states, adarms_cond
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = _gated_residual(residual, hidden_states, gate)
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class GemmaPreTrainedModel(PreTrainedModel):
+ config_class = GemmaConfig
+ base_model_prefix = 'model'
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['GemmaDecoderLayer']
+ _skip_keys_device_placement = ['past_key_values']
+ _supports_flash_attn_3 = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GemmaRMSNorm):
+ if hasattr(module, 'weight'):
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class GemmaModel(GemmaPreTrainedModel):
+ def __init__(self, config: GemmaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [
+ GemmaDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+
+ cond_dim = (
+ getattr(config, 'adarms_cond_dim', None)
+ if getattr(config, 'use_adarms', False)
+ else None
+ )
+ self.norm = GemmaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
+ )
+ self.rotary_emb = GemmaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ adarms_cond: torch.Tensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ """
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = (
+ use_cache if use_cache is not None else self.config.use_cache
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ 'You must specify exactly one of input_ids or inputs_embeds'
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length()
+ if past_key_values is not None
+ else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+ # Convert to bfloat16 if the first layer uses bfloat16
+ if (
+ len(self.layers) > 0
+ and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16
+ ):
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # normalized
+ # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = torch.tensor(
+ self.config.hidden_size**0.5, dtype=hidden_states.dtype
+ )
+ # hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ adarms_cond=adarms_cond,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states, _ = self.norm(hidden_states, adarms_cond)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring
+class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ['lm_head.weight']
+ _tp_plan = {'lm_head': 'colwise_rep'}
+ _pp_plan = {'lm_head': (['hidden_states'], ['logits'])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GemmaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ adarms_cond: torch.Tensor | None = None,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
+
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ adarms_cond=adarms_cond,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = (
+ slice(-logits_to_keep, None)
+ if isinstance(logits_to_keep, int)
+ else logits_to_keep
+ )
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Gemma Model transformer with a sequence classification head on top (linear layer).
+
+ [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GemmaForSequenceClassification(GemmaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = GemmaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ adarms_cond: torch.Tensor | None = None,
+ ) -> SequenceClassifierOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+
+ transformer_outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ adarms_cond=adarms_cond,
+ )
+ hidden_states = transformer_outputs.last_hidden_state
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ 'Cannot handle batch sizes > 1 if no padding token is defined.'
+ )
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(
+ logits.device, torch.int32
+ )
+ token_indices = torch.arange(
+ input_ids.shape[-1], device=logits.device, dtype=torch.int32
+ )
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f'{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be '
+ 'unexpected if using padding tokens in conjunction with `inputs_embeds.`'
+ )
+
+ pooled_logits = logits[
+ torch.arange(batch_size, device=logits.device), last_non_pad_token
+ ]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ pooled_logits=pooled_logits,
+ config=self.config,
+ )
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GemmaForTokenClassification(GemmaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = GemmaModel(config)
+ if getattr(config, 'classifier_dropout', None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, 'hidden_dropout', None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ adarms_cond: torch.Tensor | None = None,
+ ) -> TokenClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ adarms_cond=adarms_cond,
+ )
+ sequence_output = outputs.last_hidden_state
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.config)
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ 'GemmaModel',
+ 'GemmaForCausalLM',
+ 'GemmaForSequenceClassification',
+ 'GemmaForTokenClassification',
+ 'GemmaPreTrainedModel',
+]
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py
new file mode 100644
index 00000000..e62aed00
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py
@@ -0,0 +1,766 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PaliGemmamodel."""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...cache_utils import Cache, HybridCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ LossKwargs,
+ ModelOutput,
+ auto_docstring,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+)
+from ..auto import AutoModel
+from .configuration_paligemma import PaliGemmaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Paligemma outputs, with hidden states and attentions.
+ """
+)
+class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: torch.FloatTensor | None = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PaliGemma causal language model (or autoregressive) outputs.
+ """
+)
+class PaliGemmaCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: list[torch.FloatTensor] | Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ image_hidden_states: torch.FloatTensor | None = None
+
+
+class PaliGemmaMultiModalProjector(nn.Module):
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__()
+ self.linear = nn.Linear(
+ config.vision_config.hidden_size,
+ config.vision_config.projection_dim,
+ bias=True,
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear(image_features)
+
+ return hidden_states
+
+
+@auto_docstring
+class PaliGemmaPreTrainedModel(PreTrainedModel):
+ config_class = PaliGemmaConfig
+ base_model_prefix = ''
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['PaliGemmaMultiModalProjector']
+ _skip_keys_device_placement = 'past_key_values'
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ # important: this ported version of PaliGemmaisn't meant for training from scratch - only
+ # inference and fine-tuning
+ std = getattr(
+ self.config,
+ 'initializer_range',
+ self.config.get_text_config().initializer_range,
+ )
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,
+ """
+)
+class PaliGemmaModel(PaliGemmaPreTrainedModel):
+ _checkpoint_conversion_mapping = {'language_model.model': 'language_model'}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = AutoModel.from_config(config=config.text_config)
+ self.language_model = language_model
+
+ self.pad_token_id = (
+ self.config.pad_token_id
+ if self.config.pad_token_id is not None
+ else -1
+ )
+ self.post_init()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def _update_causal_mask(
+ self,
+ attention_mask,
+ token_type_ids=None,
+ past_key_values=None,
+ cache_position=None,
+ input_tensor=None,
+ is_training: bool | None = None,
+ ):
+ if self.config.text_config._attn_implementation == 'flash_attention_2':
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ is_training = is_training if is_training is not None else self.training
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ min_dtype = torch.finfo(self.dtype).min
+ if input_tensor is None:
+ input_tensor = attention_mask
+
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ elif isinstance(past_key_values, HybridCache):
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[0] + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ return attention_mask
+
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=self.dtype,
+ device=cache_position.device,
+ )
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
+ if sequence_length != 1:
+ if is_training:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ else:
+ causal_mask[:, :sequence_length] = 0.0
+
+ causal_mask *= torch.arange(
+ target_length, device=cache_position.device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ inputs_lead_dim, 1, -1, -1
+ )
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+
+ # First unmask prefix tokens during training
+ if is_training:
+ if token_type_ids is None:
+ raise ValueError(
+ 'Token type ids must be provided during training'
+ )
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(
+ token_type_ids[:, None, None, :].to(causal_mask.device)
+ == 0,
+ 0,
+ )
+
+ # Then apply padding mask (will mask pad tokens)
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
+ :, None, None, :
+ ].to(causal_mask.device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+
+ return causal_mask
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values)
+ selected_image_feature = image_outputs.last_hidden_state
+ image_features = self.multi_modal_projector(selected_image_feature)
+ return image_features
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | Cache | None = None,
+ token_type_ids: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple | PaligemmaModelOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ 'You must specify exactly one of input_ids or inputs_embeds'
+ )
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict
+ if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if (
+ input_ids is not None
+ and self.config.image_token_id >= self.vocab_size
+ ):
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length()
+ if past_key_values is not None
+ else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = (
+ cache_position.unsqueeze(0) + 1
+ ) # Paligemma positions are 1-indexed
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(
+ self.config.image_token_id,
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ )
+ )
+ else:
+ special_image_mask = (
+ input_ids == self.config.image_token_id
+ ).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(
+ inputs_embeds
+ ).to(inputs_embeds.device)
+
+ if (
+ not is_torchdynamo_compiling()
+ and inputs_embeds[special_image_mask].numel()
+ != image_features.numel()
+ ):
+ image_tokens_in_text = (
+ (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ )
+ raise ValueError(
+ f'Number of images does not match number of special image tokens in the input text. '
+ f'Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} '
+ 'tokens from image embeddings.'
+ )
+ image_features = image_features.to(
+ inputs_embeds.device, inputs_embeds.dtype
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(
+ special_image_mask, image_features
+ )
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ token_type_ids,
+ past_key_values,
+ cache_position,
+ inputs_embeds,
+ is_training,
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return PaligemmaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=(
+ image_features if pixel_values is not None else None
+ ),
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
+ """
+)
+class PaliGemmaForConditionalGeneration(
+ PaliGemmaPreTrainedModel, GenerationMixin
+):
+ _checkpoint_conversion_mapping = {
+ '^language_model.model': 'model.language_model',
+ '^vision_tower': 'model.vision_tower',
+ '^multi_modal_projector': 'model.multi_modal_projector',
+ '^language_model.lm_head': 'lm_head',
+ }
+ _tied_weights_keys = ['lm_head.weight']
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.model = PaliGemmaModel(config)
+ self.lm_head = nn.Linear(
+ config.text_config.hidden_size,
+ config.text_config.vocab_size,
+ bias=False,
+ )
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | Cache | None = None,
+ token_type_ids: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> tuple | PaliGemmaCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict
+ if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = (
+ slice(-logits_to_keep, None)
+ if isinstance(logits_to_keep, int)
+ else logits_to_keep
+ )
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.text_config.vocab_size,
+ **kwargs,
+ )
+
+ return PaliGemmaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # position_ids in Paligemma are 1-indexed
+ if model_inputs.get('position_ids') is not None:
+ model_inputs['position_ids'] += 1
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
+ if cache_position[0] == 0:
+ model_inputs['pixel_values'] = pixel_values
+ is_training = token_type_ids is not None and labels is not None
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
+ input_tensor = (
+ inputs_embeds if inputs_embeds is not None else input_ids
+ )
+ causal_mask = self.model._update_causal_mask(
+ attention_mask,
+ token_type_ids,
+ past_key_values,
+ cache_position,
+ input_tensor,
+ is_training,
+ )
+ model_inputs['attention_mask'] = causal_mask
+
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=cache_position.device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=cache_position.device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ batch_size, 1, -1, -1
+ )
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[
+ :, :, :, :mask_length
+ ] + attention_mask[:, None, None, :].to(causal_mask.device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+
+ return causal_mask
+
+
+__all__ = [
+ 'PaliGemmaForConditionalGeneration',
+ 'PaliGemmaPreTrainedModel',
+ 'PaliGemmaModel',
+]
diff --git a/vla_arena/configs/task_suite/long_horizon.yaml b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
similarity index 66%
rename from vla_arena/configs/task_suite/long_horizon.yaml
rename to vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
index 33ad5353..22f15e81 100644
--- a/vla_arena/configs/task_suite/long_horizon.yaml
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2024-2025 VLA-Arena Team. All Rights Reserved.
+# Copyright 2025 The VLA-Arena Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-task_suite_name: LONG_HORIZON
-num_steps_wait: 10
-num_trials_per_task: 50
-initial_states_path: DEFAULT
-max_episode_length: 600
+import transformers
+
+
+def check_whether_transformers_replace_is_installed_correctly():
+ return transformers.__version__ == '4.53.2'
diff --git a/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py
new file mode 100644
index 00000000..3e052b79
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py
@@ -0,0 +1,1413 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding=utf-8
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Siglip model."""
+
+import math
+import warnings
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+ torch_int,
+)
+from .configuration_siglip import (
+ SiglipConfig,
+ SiglipTextConfig,
+ SiglipVisionConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor,
+ mean: float = 0.0,
+ std: float = 1.0,
+ a: float = -2.0,
+ b: float = 2.0,
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsequently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == 'fan_in':
+ denom = fan_in
+ elif mode == 'fan_out':
+ denom = fan_out
+ elif mode == 'fan_avg':
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == 'truncated_normal':
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == 'normal':
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == 'uniform':
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f'invalid distribution {distribution}')
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode='fan_in', distribution='normal')
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+ """
+)
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
+class SiglipVisionModelOutput(ModelOutput):
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ image_embeds: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
+ attentions: tuple[torch.FloatTensor, ...] | None = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
+class SiglipTextModelOutput(ModelOutput):
+ r"""
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ text_embeds: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
+ attentions: tuple[torch.FloatTensor, ...] | None = None
+
+
+@dataclass
+@auto_docstring
+# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
+class SiglipOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`SiglipTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`SiglipVisionModel`].
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits_per_image: torch.FloatTensor | None = None
+ logits_per_text: torch.FloatTensor | None = None
+ text_embeds: torch.FloatTensor | None = None
+ image_embeds: torch.FloatTensor | None = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ (
+ self[k]
+ if k not in ['text_model_output', 'vision_model_output']
+ else getattr(self, k).to_tuple()
+ )
+ for k in self.keys()
+ )
+
+
+class SiglipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding='valid',
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(
+ self.num_positions, self.embed_dim
+ )
+ self.register_buffer(
+ 'position_ids',
+ torch.arange(self.num_positions).expand((1, -1)),
+ persistent=False,
+ )
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embedding.weight.shape[0]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if (
+ not torch.jit.is_tracing()
+ and num_patches == num_positions
+ and height == width
+ ):
+ return self.position_embedding(self.position_ids)
+
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(
+ 1, sqrt_num_positions, sqrt_num_positions, dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode='bicubic',
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False
+ ) -> torch.Tensor:
+ _, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype)
+ ) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width
+ )
+ else:
+ embeddings = embeddings + self.position_embedding(
+ self.position_ids
+ )
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
+class SiglipTextEmbeddings(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(
+ config.max_position_embeddings, embed_dim
+ )
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ 'position_ids',
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ ) -> torch.Tensor:
+ seq_length = (
+ input_ids.shape[-1]
+ if input_ids is not None
+ else inputs_embeds.shape[-2]
+ )
+ max_position_embedding = self.position_embedding.weight.shape[0]
+
+ if seq_length > max_position_embedding:
+ raise ValueError(
+ f'Sequence length must be less than max_position_embeddings (got `sequence length`: '
+ f'{seq_length} and max_position_embeddings: {max_position_embedding}'
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=dropout, training=module.training
+ )
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
+ f' {self.num_heads}).'
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(
+ batch_size, seq_length, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ keys = keys.view(
+ batch_size, seq_length, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ values = values.view(
+ batch_size, seq_length, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != 'eager':
+ if (
+ self.config._attn_implementation == 'sdpa'
+ and output_attentions
+ ):
+ logger.warning_once(
+ '`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to '
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
+ self.config._attn_implementation
+ ]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, seq_length, embed_dim
+ ).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SiglipEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: SiglipVisionConfig | SiglipTextConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(
+ self.embed_dim, eps=config.layer_norm_eps
+ )
+ self.self_attn = SiglipAttention(config)
+ self.layer_norm2 = nn.LayerNorm(
+ self.embed_dim, eps=config.layer_norm_eps
+ )
+ self.mlp = SiglipMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: bool | None = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class SiglipPreTrainedModel(PreTrainedModel):
+ config_class = SiglipConfig
+ base_model_prefix = 'siglip'
+ supports_gradient_checkpointing = True
+
+ _no_split_modules = [
+ 'SiglipTextEmbeddings',
+ 'SiglipEncoderLayer',
+ 'SiglipVisionEmbeddings',
+ 'SiglipEncoderLayer',
+ 'SiglipMultiheadAttentionPoolingHead',
+ ]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, SiglipVisionEmbeddings):
+ width = (
+ self.config.vision_config.hidden_size
+ if isinstance(self.config, SiglipConfig)
+ else self.config.hidden_size
+ )
+ nn.init.normal_(
+ module.position_embedding.weight, std=1 / np.sqrt(width)
+ )
+ elif isinstance(module, nn.Embedding):
+ default_flax_embed_init(module.weight)
+ elif isinstance(module, SiglipAttention):
+ nn.init.xavier_uniform_(module.q_proj.weight)
+ nn.init.xavier_uniform_(module.k_proj.weight)
+ nn.init.xavier_uniform_(module.v_proj.weight)
+ nn.init.xavier_uniform_(module.out_proj.weight)
+ nn.init.zeros_(module.q_proj.bias)
+ nn.init.zeros_(module.k_proj.bias)
+ nn.init.zeros_(module.v_proj.bias)
+ nn.init.zeros_(module.out_proj.bias)
+ elif isinstance(module, SiglipMLP):
+ nn.init.xavier_uniform_(module.fc1.weight)
+ nn.init.xavier_uniform_(module.fc2.weight)
+ nn.init.normal_(module.fc1.bias, std=1e-6)
+ nn.init.normal_(module.fc2.bias, std=1e-6)
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
+ nn.init.xavier_uniform_(module.probe.data)
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
+ nn.init.zeros_(module.attention.in_proj_bias.data)
+ elif isinstance(module, SiglipModel):
+ logit_scale_init = torch.log(torch.tensor(1.0))
+ module.logit_scale.data.fill_(logit_scale_init)
+ module.logit_bias.data.zero_()
+ elif isinstance(module, SiglipForImageClassification):
+ nn.init.normal_(
+ module.classifier.weight,
+ std=self.config.vision_config.hidden_size**-0.5
+ * self.config.initializer_factor,
+ )
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ SiglipEncoderLayer(config)
+ for _ in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @can_return_tuple
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ ) -> BaseModelOutput:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+class SiglipTextTransformer(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = SiglipTextEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(
+ embed_dim, eps=config.layer_norm_eps
+ )
+
+ self.head = nn.Linear(embed_dim, config.projection_size)
+ self._use_flash_attention_2 = (
+ config._attn_implementation == 'flash_attention_2'
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ ) -> BaseModelOutputWithPooling:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ if input_ids is None:
+ raise ValueError('You have to specify input_ids')
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids, position_ids=position_ids
+ )
+
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
+ # expand attention_mask
+ if attention_mask is not None and not self._use_flash_attention_2:
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
+ pooled_output = last_hidden_state[:, -1, :]
+ pooled_output = self.head(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The text model from SigLIP without any head or projection on top.
+ """
+)
+class SiglipTextModel(SiglipPreTrainedModel):
+ config_class = SiglipTextConfig
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__(config)
+ self.text_model = SiglipTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, SiglipTextModel
+
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class SiglipVisionTransformer(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SiglipVisionEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(
+ embed_dim, eps=config.layer_norm_eps
+ )
+ self.use_head = (
+ True
+ if not hasattr(config, 'vision_use_head')
+ else config.vision_use_head
+ )
+ if self.use_head:
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ interpolate_pos_encoding: bool | None = False,
+ ) -> BaseModelOutputWithPooling:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+ # Convert to bfloat16 if the encoder uses bfloat16
+ if (
+ len(self.encoder.layers) > 0
+ and self.encoder.layers[0].self_attn.q_proj.weight.dtype
+ == torch.bfloat16
+ ):
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooler_output = self.head(last_hidden_state) if self.use_head else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(
+ config.hidden_size, config.num_attention_heads, batch_first=True
+ )
+ self.layernorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps
+ )
+ self.mlp = SiglipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from SigLIP without any head or projection on top.
+ """
+)
+class SiglipVisionModel(SiglipPreTrainedModel):
+ config_class = SiglipVisionConfig
+ main_input_name = 'pixel_values'
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SiglipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SiglipVisionModel
+
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+
+@auto_docstring
+class SiglipModel(SiglipPreTrainedModel):
+ config_class = SiglipConfig
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, SiglipTextConfig):
+ raise TypeError(
+ 'config.text_config is expected to be of type SiglipTextConfig but is of type'
+ f' {type(config.text_config)}.'
+ )
+
+ if not isinstance(config.vision_config, SiglipVisionConfig):
+ raise TypeError(
+ 'config.vision_config is expected to be of type SiglipVisionConfig but is of type'
+ f' {type(config.vision_config)}.'
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ # First, initialize the text and vision models with proper attention implementation
+ text_model = SiglipTextModel._from_config(text_config)
+ vision_model = SiglipVisionModel._from_config(vision_config)
+
+ # Second, get the text and vision submodules (for backward compatibility)
+ self.text_model = text_model.text_model
+ self.vision_model = vision_model.vision_model
+
+ self.logit_scale = nn.Parameter(torch.randn(1))
+ self.logit_bias = nn.Parameter(torch.randn(1))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = text_outputs.pooler_output
+
+ return pooled_output
+
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ pooled_output = vision_outputs.pooler_output
+
+ return pooled_output
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ return_loss: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> SiglipOutput:
+ r"""
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was trained with this
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ image_embeds = vision_outputs.pooler_output
+ text_embeds = text_outputs.pooler_output
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(
+ p=2, dim=-1, keepdim=True
+ )
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = torch.matmul(
+ text_embeds, image_embeds.t().to(text_embeds.device)
+ )
+
+ logit_scale, logit_bias = self.logit_scale.to(
+ text_embeds.device
+ ), self.logit_bias.to(text_embeds.device)
+ logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
+
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
+ eye = torch.eye(
+ logits_per_text.size(0), device=logits_per_text.device
+ )
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
+ nll = -torch.sum(loglik, dim=-1)
+ loss = nll.mean()
+
+ return SiglipOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
+ the patch tokens) e.g. for ImageNet.
+ """
+)
+class SiglipForImageClassification(SiglipPreTrainedModel):
+ main_input_name = 'pixel_values'
+
+ def __init__(self, config: SiglipConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+
+ # Create the vision model with proper attention
+ # and take only vision_model submodule (for backward compatibility)
+ vision_model = SiglipVisionModel._from_config(config.vision_config)
+ self.vision_model = vision_model.vision_model
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.vision_config.hidden_size, config.num_labels)
+ if config.num_labels > 0
+ else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor | None = None,
+ labels: torch.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, SiglipForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a `SiglipModel` from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
+ >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the two classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: LABEL_1
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # average pool the patch tokens
+ sequence_output = torch.mean(sequence_output, dim=1)
+ # apply classifier
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = 'regression'
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = 'single_label_classification'
+ else:
+ self.config.problem_type = 'multi_label_classification'
+
+ if self.config.problem_type == 'regression':
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == 'single_label_classification':
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1)
+ )
+ elif self.config.problem_type == 'multi_label_classification':
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ 'SiglipModel',
+ 'SiglipPreTrainedModel',
+ 'SiglipTextModel',
+ 'SiglipVisionModel',
+ 'SiglipForImageClassification',
+]
diff --git a/vla_arena/models/openpi/src/openpi/policies/aloha_policy.py b/vla_arena/models/openpi/src/openpi/policies/aloha_policy.py
new file mode 100644
index 00000000..2ebb8388
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/aloha_policy.py
@@ -0,0 +1,242 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from typing import ClassVar
+
+import einops
+import numpy as np
+from openpi import transforms
+
+
+def make_aloha_example() -> dict:
+ """Creates a random input example for the Aloha policy."""
+ return {
+ 'state': np.ones((14,)),
+ 'images': {
+ 'cam_high': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_low': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_left_wrist': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ 'cam_right_wrist': np.random.randint(
+ 256, size=(3, 224, 224), dtype=np.uint8
+ ),
+ },
+ 'prompt': 'do something',
+ }
+
+
+@dataclasses.dataclass(frozen=True)
+class AlohaInputs(transforms.DataTransformFn):
+ """Inputs for the Aloha policy.
+
+ Expected inputs:
+ - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
+ - state: [14]
+ - actions: [action_horizon, 14]
+ """
+
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
+ # the space used by the pi internal runtime which was used to train the base model.
+ adapt_to_pi: bool = True
+
+ # The expected cameras names. All input cameras must be in this set. Missing cameras will be
+ # replaced with black images and the corresponding `image_mask` will be set to False.
+ EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = (
+ 'cam_high',
+ 'cam_low',
+ 'cam_left_wrist',
+ 'cam_right_wrist',
+ )
+
+ def __call__(self, data: dict) -> dict:
+ data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
+
+ in_images = data['images']
+ if set(in_images) - set(self.EXPECTED_CAMERAS):
+ raise ValueError(
+ f'Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}'
+ )
+
+ # Assume that base image always exists.
+ base_image = in_images['cam_high']
+
+ images = {
+ 'base_0_rgb': base_image,
+ }
+ image_masks = {
+ 'base_0_rgb': np.True_,
+ }
+
+ # Add the extra images.
+ extra_image_names = {
+ 'left_wrist_0_rgb': 'cam_left_wrist',
+ 'right_wrist_0_rgb': 'cam_right_wrist',
+ }
+ for dest, source in extra_image_names.items():
+ if source in in_images:
+ images[dest] = in_images[source]
+ image_masks[dest] = np.True_
+ else:
+ images[dest] = np.zeros_like(base_image)
+ image_masks[dest] = np.False_
+
+ inputs = {
+ 'image': images,
+ 'image_mask': image_masks,
+ 'state': data['state'],
+ }
+
+ # Actions are only available during training.
+ if 'actions' in data:
+ actions = np.asarray(data['actions'])
+ actions = _encode_actions_inv(
+ actions, adapt_to_pi=self.adapt_to_pi
+ )
+ inputs['actions'] = actions
+
+ if 'prompt' in data:
+ inputs['prompt'] = data['prompt']
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class AlohaOutputs(transforms.DataTransformFn):
+ """Outputs for the Aloha policy."""
+
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
+ # the space used by the pi internal runtime which was used to train the base model.
+ adapt_to_pi: bool = True
+
+ def __call__(self, data: dict) -> dict:
+ # Only return the first 14 dims.
+ actions = np.asarray(data['actions'][:, :14])
+ return {
+ 'actions': _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)
+ }
+
+
+def _joint_flip_mask() -> np.ndarray:
+ """Used to convert between aloha and pi joint angles."""
+ return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
+
+
+def _normalize(x, min_val, max_val):
+ return (x - min_val) / (max_val - min_val)
+
+
+def _unnormalize(x, min_val, max_val):
+ return x * (max_val - min_val) + min_val
+
+
+def _gripper_to_angular(value):
+ # Aloha transforms the gripper positions into a linear space. The following code
+ # reverses this transformation to be consistent with pi0 which is pretrained in
+ # angular space.
+ #
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
+ value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
+
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
+ def linear_to_radian(linear_position, arm_length, horn_radius):
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
+ 2 * horn_radius * linear_position
+ )
+ return np.arcsin(np.clip(value, -1.0, 1.0))
+
+ # The constants are taken from the Interbotix code.
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
+
+ # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
+ # There are 4096 total encoder counts and aloha uses a zero of 2048.
+ # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
+ return _normalize(value, min_val=0.5476, max_val=1.6296)
+
+
+def _gripper_from_angular(value):
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
+ # Note that the units are still angular but the range is different.
+
+ # We do not scale the output since the trossen model predictions are already in radians.
+ # See the comment in _gripper_to_angular for a derivation of the constant
+ value = value + 0.5476
+
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
+ return _normalize(value, min_val=-0.6213, max_val=1.4910)
+
+
+def _gripper_from_angular_inv(value):
+ # Directly inverts the gripper_from_angular function.
+ value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
+ return value - 0.5476
+
+
+def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
+ # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
+ # dim sizes: [6, 1, 6, 1]
+ state = np.asarray(data['state'])
+ state = _decode_state(state, adapt_to_pi=adapt_to_pi)
+
+ def convert_image(img):
+ img = np.asarray(img)
+ # Convert to uint8 if using float images.
+ if np.issubdtype(img.dtype, np.floating):
+ img = (255 * img).astype(np.uint8)
+ # Convert from [channel, height, width] to [height, width, channel].
+ return einops.rearrange(img, 'c h w -> h w c')
+
+ images = data['images']
+ images_dict = {name: convert_image(img) for name, img in images.items()}
+
+ data['images'] = images_dict
+ data['state'] = state
+ return data
+
+
+def _decode_state(
+ state: np.ndarray, *, adapt_to_pi: bool = False
+) -> np.ndarray:
+ if adapt_to_pi:
+ # Flip the joints.
+ state = _joint_flip_mask() * state
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
+ state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
+ return state
+
+
+def _encode_actions(
+ actions: np.ndarray, *, adapt_to_pi: bool = False
+) -> np.ndarray:
+ if adapt_to_pi:
+ # Flip the joints.
+ actions = _joint_flip_mask() * actions
+ actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
+ return actions
+
+
+def _encode_actions_inv(
+ actions: np.ndarray, *, adapt_to_pi: bool = False
+) -> np.ndarray:
+ if adapt_to_pi:
+ actions = _joint_flip_mask() * actions
+ actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
+ return actions
diff --git a/vla_arena/models/openpi/src/openpi/policies/droid_policy.py b/vla_arena/models/openpi/src/openpi/policies/droid_policy.py
new file mode 100644
index 00000000..0e2bb956
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/droid_policy.py
@@ -0,0 +1,100 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+
+import einops
+import numpy as np
+from openpi import transforms
+from openpi.models import model as _model
+
+
+def make_droid_example() -> dict:
+ """Creates a random input example for the Droid policy."""
+ return {
+ 'observation/exterior_image_1_left': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/wrist_image_left': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/joint_position': np.random.rand(7),
+ 'observation/gripper_position': np.random.rand(1),
+ 'prompt': 'do something',
+ }
+
+
+def _parse_image(image) -> np.ndarray:
+ image = np.asarray(image)
+ if np.issubdtype(image.dtype, np.floating):
+ image = (255 * image).astype(np.uint8)
+ if image.shape[0] == 3:
+ image = einops.rearrange(image, 'c h w -> h w c')
+ return image
+
+
+@dataclasses.dataclass(frozen=True)
+class DroidInputs(transforms.DataTransformFn):
+ # Determines which model will be used.
+ model_type: _model.ModelType
+
+ def __call__(self, data: dict) -> dict:
+ gripper_pos = np.asarray(data['observation/gripper_position'])
+ if gripper_pos.ndim == 0:
+ # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions
+ gripper_pos = gripper_pos[np.newaxis]
+ state = np.concatenate(
+ [data['observation/joint_position'], gripper_pos]
+ )
+
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
+ # stores as float32 (C,H,W), gets skipped for policy inference
+ base_image = _parse_image(data['observation/exterior_image_1_left'])
+ wrist_image = _parse_image(data['observation/wrist_image_left'])
+
+ match self.model_type:
+ case _model.ModelType.PI0 | _model.ModelType.PI05:
+ names = ('base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb')
+ images = (base_image, wrist_image, np.zeros_like(base_image))
+ image_masks = (np.True_, np.True_, np.False_)
+ case _model.ModelType.PI0_FAST:
+ names = ('base_0_rgb', 'base_1_rgb', 'wrist_0_rgb')
+ # We don't mask out padding images for FAST models.
+ images = (base_image, np.zeros_like(base_image), wrist_image)
+ image_masks = (np.True_, np.True_, np.True_)
+ case _:
+ raise ValueError(f'Unsupported model type: {self.model_type}')
+
+ inputs = {
+ 'state': state,
+ 'image': dict(zip(names, images, strict=True)),
+ 'image_mask': dict(zip(names, image_masks, strict=True)),
+ }
+
+ if 'actions' in data:
+ inputs['actions'] = np.asarray(data['actions'])
+
+ if 'prompt' in data:
+ if isinstance(data['prompt'], bytes):
+ data['prompt'] = data['prompt'].decode('utf-8')
+ inputs['prompt'] = data['prompt']
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class DroidOutputs(transforms.DataTransformFn):
+ def __call__(self, data: dict) -> dict:
+ # Only return the first 8 dims.
+ return {'actions': np.asarray(data['actions'][:, :8])}
diff --git a/vla_arena/models/openpi/src/openpi/policies/libero_policy.py b/vla_arena/models/openpi/src/openpi/policies/libero_policy.py
new file mode 100644
index 00000000..4549cf16
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/libero_policy.py
@@ -0,0 +1,121 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+
+import einops
+import numpy as np
+from openpi import transforms
+from openpi.models import model as _model
+
+
+def make_libero_example() -> dict:
+ """Creates a random input example for the Libero policy."""
+ return {
+ 'observation/state': np.random.rand(8),
+ 'observation/image': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'observation/wrist_image': np.random.randint(
+ 256, size=(224, 224, 3), dtype=np.uint8
+ ),
+ 'prompt': 'do something',
+ }
+
+
+def _parse_image(image) -> np.ndarray:
+ image = np.asarray(image)
+ if np.issubdtype(image.dtype, np.floating):
+ image = (255 * image).astype(np.uint8)
+ if image.shape[0] == 3:
+ image = einops.rearrange(image, 'c h w -> h w c')
+ return image
+
+
+@dataclasses.dataclass(frozen=True)
+class LiberoInputs(transforms.DataTransformFn):
+ """
+ This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
+
+ For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
+ the correct elements of your dataset into the model.
+ """
+
+ # Determines which model will be used.
+ # Do not change this for your own dataset.
+ model_type: _model.ModelType
+
+ def __call__(self, data: dict) -> dict:
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
+ # stores as float32 (C,H,W), gets skipped for policy inference.
+ # Keep this for your own dataset, but if your dataset stores the images
+ # in a different key than "observation/image" or "observation/wrist_image",
+ # you should change it below.
+ # Pi0 models support three image inputs at the moment: one third-person view,
+ # and two wrist views (left and right). If your dataset does not have a particular type
+ # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
+ # right wrist image below.
+ base_image = _parse_image(data['observation/image'])
+ wrist_image = _parse_image(data['observation/wrist_image'])
+
+ # Create inputs dict. Do not change the keys in the dict below.
+ inputs = {
+ 'state': data['observation/state'],
+ 'image': {
+ 'base_0_rgb': base_image,
+ 'left_wrist_0_rgb': wrist_image,
+ # Pad any non-existent images with zero-arrays of the appropriate shape.
+ 'right_wrist_0_rgb': np.zeros_like(base_image),
+ },
+ 'image_mask': {
+ 'base_0_rgb': np.True_,
+ 'left_wrist_0_rgb': np.True_,
+ # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset.
+ 'right_wrist_0_rgb': (
+ np.True_
+ if self.model_type == _model.ModelType.PI0_FAST
+ else np.False_
+ ),
+ },
+ }
+
+ # Pad actions to the model action dimension. Keep this for your own dataset.
+ # Actions are only available during training.
+ if 'actions' in data:
+ inputs['actions'] = data['actions']
+
+ # Pass the prompt (aka language instruction) to the model.
+ # Keep this for your own dataset (but modify the key if the instruction is not
+ # stored in "prompt"; the output dict always needs to have the key "prompt").
+ if 'prompt' in data:
+ inputs['prompt'] = data['prompt']
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class LiberoOutputs(transforms.DataTransformFn):
+ """
+ This class is used to convert outputs from the model back the the dataset specific format. It is
+ used for inference only.
+
+ For your own dataset, you can copy this class and modify the action dimension based on the comments below.
+ """
+
+ def __call__(self, data: dict) -> dict:
+ # Only return the first N actions -- since we padded actions above to fit the model action
+ # dimension, we need to now parse out the correct number of actions in the return dict.
+ # For Libero, we only return the first 7 actions (since the rest is padding).
+ # For your own dataset, replace `7` with the action dimension of your dataset.
+ return {'actions': np.asarray(data['actions'][:, :7])}
diff --git a/vla_arena/models/openpi/src/openpi/policies/policy.py b/vla_arena/models/openpi/src/openpi/policies/policy.py
new file mode 100644
index 00000000..a8e26964
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/policy.py
@@ -0,0 +1,170 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import pathlib
+import time
+from collections.abc import Sequence
+from typing import Any, TypeAlias
+from typing_extensions import override
+
+import flax
+import flax.traverse_util
+import jax
+import jax.numpy as jnp
+import numpy as np
+import torch
+from openpi import transforms as _transforms
+from openpi.models import model as _model
+from openpi.shared import array_typing as at
+from openpi.shared import nnx_utils
+from openpi_client import base_policy as _base_policy
+
+
+BasePolicy: TypeAlias = _base_policy.BasePolicy
+
+
+class Policy(BasePolicy):
+ def __init__(
+ self,
+ model: _model.BaseModel,
+ *,
+ rng: at.KeyArrayLike | None = None,
+ transforms: Sequence[_transforms.DataTransformFn] = (),
+ output_transforms: Sequence[_transforms.DataTransformFn] = (),
+ sample_kwargs: dict[str, Any] | None = None,
+ metadata: dict[str, Any] | None = None,
+ pytorch_device: str = 'cpu',
+ is_pytorch: bool = False,
+ ):
+ """Initialize the Policy.
+
+ Args:
+ model: The model to use for action sampling.
+ rng: Random number generator key for JAX models. Ignored for PyTorch models.
+ transforms: Input data transformations to apply before inference.
+ output_transforms: Output data transformations to apply after inference.
+ sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
+ metadata: Additional metadata to store with the policy.
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
+ Only relevant when is_pytorch=True.
+ is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
+ """
+ self._model = model
+ self._input_transform = _transforms.compose(transforms)
+ self._output_transform = _transforms.compose(output_transforms)
+ self._sample_kwargs = sample_kwargs or {}
+ self._metadata = metadata or {}
+ self._is_pytorch_model = is_pytorch
+ self._pytorch_device = pytorch_device
+
+ if self._is_pytorch_model:
+ self._model = self._model.to(pytorch_device)
+ self._model.eval()
+ self._sample_actions = model.sample_actions
+ else:
+ # JAX model setup
+ self._sample_actions = nnx_utils.module_jit(model.sample_actions)
+ self._rng = rng or jax.random.key(0)
+
+ @override
+ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc]
+ # Make a copy since transformations may modify the inputs in place.
+ inputs = jax.tree.map(lambda x: x, obs)
+ inputs = self._input_transform(inputs)
+ if not self._is_pytorch_model:
+ # Make a batch and convert to jax.Array.
+ inputs = jax.tree.map(
+ lambda x: jnp.asarray(x)[np.newaxis, ...], inputs
+ )
+ self._rng, sample_rng_or_pytorch_device = jax.random.split(
+ self._rng
+ )
+ else:
+ # Convert inputs to PyTorch tensors and move to correct device
+ inputs = jax.tree.map(
+ lambda x: torch.from_numpy(np.array(x)).to(
+ self._pytorch_device
+ )[None, ...],
+ inputs,
+ )
+ sample_rng_or_pytorch_device = self._pytorch_device
+
+ # Prepare kwargs for sample_actions
+ sample_kwargs = dict(self._sample_kwargs)
+ if noise is not None:
+ noise = (
+ torch.from_numpy(noise).to(self._pytorch_device)
+ if self._is_pytorch_model
+ else jnp.asarray(noise)
+ )
+
+ if (
+ noise.ndim == 2
+ ): # If noise is (action_horizon, action_dim), add batch dimension
+ noise = noise[
+ None, ...
+ ] # Make it (1, action_horizon, action_dim)
+ sample_kwargs['noise'] = noise
+
+ observation = _model.Observation.from_dict(inputs)
+ start_time = time.monotonic()
+ outputs = {
+ 'state': inputs['state'],
+ 'actions': self._sample_actions(
+ sample_rng_or_pytorch_device, observation, **sample_kwargs
+ ),
+ }
+ model_time = time.monotonic() - start_time
+ if self._is_pytorch_model:
+ outputs = jax.tree.map(
+ lambda x: np.asarray(x[0, ...].detach().cpu()), outputs
+ )
+ else:
+ outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)
+
+ outputs = self._output_transform(outputs)
+ outputs['policy_timing'] = {
+ 'infer_ms': model_time * 1000,
+ }
+ return outputs
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ return self._metadata
+
+
+class PolicyRecorder(_base_policy.BasePolicy):
+ """Records the policy's behavior to disk."""
+
+ def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
+ self._policy = policy
+
+ logging.info(f'Dumping policy records to: {record_dir}')
+ self._record_dir = pathlib.Path(record_dir)
+ self._record_dir.mkdir(parents=True, exist_ok=True)
+ self._record_step = 0
+
+ @override
+ def infer(self, obs: dict) -> dict: # type: ignore[misc]
+ results = self._policy.infer(obs)
+
+ data = {'inputs': obs, 'outputs': results}
+ data = flax.traverse_util.flatten_dict(data, sep='/')
+
+ output_path = self._record_dir / f'step_{self._record_step}'
+ self._record_step += 1
+
+ np.save(output_path, np.asarray(data))
+ return results
diff --git a/vla_arena/models/openpi/src/openpi/policies/policy_config.py b/vla_arena/models/openpi/src/openpi/policies/policy_config.py
new file mode 100644
index 00000000..87578333
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/policy_config.py
@@ -0,0 +1,119 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import pathlib
+from typing import Any
+
+import jax.numpy as jnp
+import openpi.models.model as _model
+import openpi.policies.policy as _policy
+import openpi.shared.download as download
+import openpi.transforms as transforms
+from openpi.training import checkpoints as _checkpoints
+from openpi.training import config as _config
+
+
+def create_trained_policy(
+ train_config: _config.TrainConfig,
+ checkpoint_dir: pathlib.Path | str,
+ *,
+ repack_transforms: transforms.Group | None = None,
+ sample_kwargs: dict[str, Any] | None = None,
+ default_prompt: str | None = None,
+ norm_stats: dict[str, transforms.NormStats] | None = None,
+ pytorch_device: str | None = None,
+) -> _policy.Policy:
+ """Create a policy from a trained checkpoint.
+
+ Args:
+ train_config: The training config to use to create the model.
+ checkpoint_dir: The directory to load the model from.
+ repack_transforms: Optional transforms that will be applied before any other transforms.
+ sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
+ kwargs will be used.
+ default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
+ data if it doesn't already exist.
+ norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
+ from the checkpoint directory.
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
+ If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
+
+ Note:
+ The function automatically detects whether the model is PyTorch-based by checking for the
+ presence of "model.safensors" in the checkpoint directory.
+ """
+ repack_transforms = repack_transforms or transforms.Group()
+ checkpoint_dir = download.maybe_download(str(checkpoint_dir))
+
+ # Check if this is a PyTorch model by looking for model.safetensors
+ weight_path = os.path.join(checkpoint_dir, 'model.safetensors')
+ is_pytorch = os.path.exists(weight_path)
+
+ logging.info('Loading model...')
+ if is_pytorch:
+ model = train_config.model.load_pytorch(train_config, weight_path)
+ model.paligemma_with_expert.to_bfloat16_for_selected_params('bfloat16')
+ else:
+ model = train_config.model.load(
+ _model.restore_params(
+ checkpoint_dir / 'params', dtype=jnp.bfloat16
+ )
+ )
+ data_config = train_config.data.create(
+ train_config.assets_dirs, train_config.model
+ )
+ if norm_stats is None:
+ # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
+ # that the policy is using the same normalization stats as the original training process.
+ if data_config.asset_id is None:
+ raise ValueError('Asset id is required to load norm stats.')
+ norm_stats = _checkpoints.load_norm_stats(
+ checkpoint_dir / 'assets', data_config.asset_id
+ )
+
+ # Determine the device to use for PyTorch models
+ if is_pytorch and pytorch_device is None:
+ try:
+ import torch
+
+ pytorch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ except ImportError:
+ pytorch_device = 'cpu'
+
+ return _policy.Policy(
+ model,
+ transforms=[
+ *repack_transforms.inputs,
+ transforms.InjectDefaultPrompt(default_prompt),
+ *data_config.data_transforms.inputs,
+ transforms.Normalize(
+ norm_stats, use_quantiles=data_config.use_quantile_norm
+ ),
+ *data_config.model_transforms.inputs,
+ ],
+ output_transforms=[
+ *data_config.model_transforms.outputs,
+ transforms.Unnormalize(
+ norm_stats, use_quantiles=data_config.use_quantile_norm
+ ),
+ *data_config.data_transforms.outputs,
+ *repack_transforms.outputs,
+ ],
+ sample_kwargs=sample_kwargs,
+ metadata=train_config.policy_metadata,
+ is_pytorch=is_pytorch,
+ pytorch_device=pytorch_device if is_pytorch else None,
+ )
diff --git a/vla_arena/models/openpi/src/openpi/policies/policy_test.py b/vla_arena/models/openpi/src/openpi/policies/policy_test.py
new file mode 100644
index 00000000..38237d7b
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/policies/policy_test.py
@@ -0,0 +1,51 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from openpi.policies import aloha_policy
+from openpi.policies import policy_config as _policy_config
+from openpi.training import config as _config
+from openpi_client import action_chunk_broker
+
+
+@pytest.mark.manual
+def test_infer():
+ config = _config.get_config('pi0_aloha_sim')
+ policy = _policy_config.create_trained_policy(
+ config, 'gs://openpi-assets/checkpoints/pi0_aloha_sim'
+ )
+
+ example = aloha_policy.make_aloha_example()
+ result = policy.infer(example)
+
+ assert result['actions'].shape == (config.model.action_horizon, 14)
+
+
+@pytest.mark.manual
+def test_broker():
+ config = _config.get_config('pi0_aloha_sim')
+ policy = _policy_config.create_trained_policy(
+ config, 'gs://openpi-assets/checkpoints/pi0_aloha_sim'
+ )
+
+ broker = action_chunk_broker.ActionChunkBroker(
+ policy,
+ # Only execute the first half of the chunk.
+ action_horizon=config.model.action_horizon // 2,
+ )
+
+ example = aloha_policy.make_aloha_example()
+ for _ in range(config.model.action_horizon):
+ outputs = broker.infer(example)
+ assert outputs['actions'].shape == (14,)
diff --git a/vla_arena/models/openpi/src/openpi/py.typed b/vla_arena/models/openpi/src/openpi/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/vla_arena/models/openpi/src/openpi/serving/websocket_policy_server.py b/vla_arena/models/openpi/src/openpi/serving/websocket_policy_server.py
new file mode 100644
index 00000000..364068cc
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/serving/websocket_policy_server.py
@@ -0,0 +1,111 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import http
+import logging
+import time
+import traceback
+
+import websockets.asyncio.server as _server
+import websockets.frames
+from openpi_client import base_policy as _base_policy
+from openpi_client import msgpack_numpy
+
+
+logger = logging.getLogger(__name__)
+
+
+class WebsocketPolicyServer:
+ """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
+
+ Currently only implements the `load` and `infer` methods.
+ """
+
+ def __init__(
+ self,
+ policy: _base_policy.BasePolicy,
+ host: str = '0.0.0.0',
+ port: int | None = None,
+ metadata: dict | None = None,
+ ) -> None:
+ self._policy = policy
+ self._host = host
+ self._port = port
+ self._metadata = metadata or {}
+ logging.getLogger('websockets.server').setLevel(logging.INFO)
+
+ def serve_forever(self) -> None:
+ asyncio.run(self.run())
+
+ async def run(self):
+ async with _server.serve(
+ self._handler,
+ self._host,
+ self._port,
+ compression=None,
+ max_size=None,
+ process_request=_health_check,
+ ) as server:
+ await server.serve_forever()
+
+ async def _handler(self, websocket: _server.ServerConnection):
+ logger.info(f'Connection from {websocket.remote_address} opened')
+ packer = msgpack_numpy.Packer()
+
+ await websocket.send(packer.pack(self._metadata))
+
+ prev_total_time = None
+ while True:
+ try:
+ start_time = time.monotonic()
+ obs = msgpack_numpy.unpackb(await websocket.recv())
+
+ infer_time = time.monotonic()
+ action = self._policy.infer(obs)
+ infer_time = time.monotonic() - infer_time
+
+ action['server_timing'] = {
+ 'infer_ms': infer_time * 1000,
+ }
+ if prev_total_time is not None:
+ # We can only record the last total time since we also want to include the send time.
+ action['server_timing']['prev_total_ms'] = (
+ prev_total_time * 1000
+ )
+
+ await websocket.send(packer.pack(action))
+ prev_total_time = time.monotonic() - start_time
+
+ except websockets.ConnectionClosed:
+ logger.info(
+ f'Connection from {websocket.remote_address} closed'
+ )
+ break
+ except Exception:
+ await websocket.send(traceback.format_exc())
+ await websocket.close(
+ code=websockets.frames.CloseCode.INTERNAL_ERROR,
+ reason='Internal server error. Traceback included in previous frame.',
+ )
+ raise
+
+
+def _health_check(
+ connection: _server.ServerConnection, request: _server.Request
+) -> _server.Response | None:
+ if request.path == '/healthz':
+ return connection.respond(http.HTTPStatus.OK, 'OK\n')
+ # Continue with the normal request handling.
+ return None
diff --git a/vla_arena/models/openpi/src/openpi/shared/__init__.py b/vla_arena/models/openpi/src/openpi/shared/__init__.py
new file mode 100644
index 00000000..fb51bf31
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/shared/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/vla_arena/models/openpi/src/openpi/shared/array_typing.py b/vla_arena/models/openpi/src/openpi/shared/array_typing.py
new file mode 100644
index 00000000..2a1567e3
--- /dev/null
+++ b/vla_arena/models/openpi/src/openpi/shared/array_typing.py
@@ -0,0 +1,115 @@
+# Copyright 2025 The VLA-Arena Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import functools as ft
+import inspect
+from typing import TypeAlias, TypeVar, cast
+
+import beartype
+import jax
+import jax._src.tree_util as private_tree_util
+import jax.core
+import jaxtyping._decorator
+import torch
+from jaxtyping import Bool # noqa: F401
+from jaxtyping import DTypeLike # noqa: F401
+from jaxtyping import Int # noqa: F401
+from jaxtyping import Key # noqa: F401
+from jaxtyping import Num # noqa: F401
+from jaxtyping import Real # noqa: F401
+from jaxtyping import UInt8 # noqa: F401
+from jaxtyping import ArrayLike, Float, PyTree, config, jaxtyped
+
+
+# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.
+# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,
+# `jax.Sharding`, or even