diff --git a/examples/pi0/README.md b/examples/pi0/README.md index 13f02422dd..ffcdcba9d2 100644 --- a/examples/pi0/README.md +++ b/examples/pi0/README.md @@ -157,7 +157,7 @@ Configure the following fields: - `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) - `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) - `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) -- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}`) **Model settings**: - `model.model_name` - Model name: `"pi0"` or `"pi0.5"` @@ -186,7 +186,7 @@ python run.py --config-path ./examples/pi0/conf --config-name train action=run Training logs are saved to `outputs/pi0_train/logs/host_0_localhost.output` by default. -Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_train/ckpt`). +Checkpoints are saved to `${experiment.exp_dir}/checkpoints` (default: `outputs/pi0_train/checkpoints`). ### Stop Training ```sh diff --git a/examples/pi0/conf/train/pi0.yaml b/examples/pi0/conf/train/pi0.yaml index 88448f489c..75ad610dd9 100644 --- a/examples/pi0/conf/train/pi0.yaml +++ b/examples/pi0/conf/train/pi0.yaml @@ -20,7 +20,7 @@ system: decay_lr: 2.5e-6 checkpoint: - output_directory: ${experiment.exp_dir}/ckpt + output_directory: ${experiment.exp_dir} # Whether to save checkpoint save_checkpoint: true # Number of steps between checkpoints diff --git a/examples/pi0_5/README.md b/examples/pi0_5/README.md index e345d25e45..4c177e5ed1 100644 --- a/examples/pi0_5/README.md +++ b/examples/pi0_5/README.md @@ -164,7 +164,7 @@ Configure the following fields: - `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) - `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) - `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) -- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}`) **Model settings**: - `model.model_name` - Model name: `"pi0.5"` @@ -193,7 +193,7 @@ python run.py --config-path ./examples/pi0_5/conf --config-name train action=run Training logs are saved to `outputs/pi0_5_train/logs/host_0_localhost.output` by default. -Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_5_train/ckpt`). +Checkpoints are saved to `${experiment.exp_dir}/checkpoints` (default: `outputs/pi0_5_train/checkpoints`). ### Stop Training ```sh diff --git a/examples/pi0_5/conf/train/pi0_5.yaml b/examples/pi0_5/conf/train/pi0_5.yaml index 2027c1a6cb..4599321e6b 100644 --- a/examples/pi0_5/conf/train/pi0_5.yaml +++ b/examples/pi0_5/conf/train/pi0_5.yaml @@ -20,7 +20,7 @@ system: decay_lr: 2.5e-6 checkpoint: - output_directory: ${experiment.exp_dir}/ckpt + output_directory: ${experiment.exp_dir} # Whether to save checkpoint save_checkpoint: true # Number of steps between checkpoints diff --git a/examples/qwen_gr00t/README.md b/examples/qwen_gr00t/README.md new file mode 100644 index 0000000000..13f02422dd --- /dev/null +++ b/examples/qwen_gr00t/README.md @@ -0,0 +1,343 @@ +# PI0: Training, Inference, and Serving + +This guide covers how to train, run inference, and serve PI0 models using FlagScale. + +## Installation + +### Clone Repository + +```sh +git clone https://github.com/FlagOpen/FlagScale.git +cd FlagScale/ +``` + +### Setup Conda Environment + +Create a new conda environment for robotics training: + +```sh +conda create -n flagos-robo python=3.12 +conda activate flagos-robo +``` + +Install FlagScale and robotics dependencies: + +```sh +cd FlagScale/ +pip install . --verbose +pip install -r requirements/train/robotics/requirements.txt +``` + +Install additional dependencies for downloading models/datasets: + +```sh +# For HuggingFace Hub +pip install huggingface_hub + +# For ModelScope (optional) +pip install modelscope +``` + +## Download Models and Tokenizers + +Download models and tokenizers using the provided script. Choose either HuggingFace Hub or ModelScope based on your preference: + +**Using HuggingFace Hub:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/pi0_base \ + --output_dir /workspace/models \ + --source huggingface + +python examples/pi0/download.py \ + --repo_id google/paligemma-3b-pt-224 \ + --output_dir /workspace/models \ + --source huggingface +``` + +**Using ModelScope:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/pi0_base \ + --output_dir /workspace/models \ + --source modelscope + +python examples/pi0/download.py \ + --repo_id google/paligemma-3b-pt-224 \ + --output_dir /workspace/models \ + --source modelscope +``` + +The models will be downloaded to (example with `/workspace/models`): +- `/workspace/models/lerobot/pi0_base` +- `/workspace/models/google/paligemma-3b-pt-224` + + +## Training + +### Prepare Dataset + +FlagScale uses the **LeRobotDataset v3.0** format. For detailed information about the format structure, see the [LeRobotDataset v3.0 documentation](https://huggingface.co/docs/lerobot/en/lerobot-dataset-v3). + +For example, to download the `aloha_mobile_cabinet` dataset: + +**Using HuggingFace Hub:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/aloha_mobile_cabinet \ + --output_dir /workspace/datasets \ + --repo_type dataset \ + --source huggingface +``` + +**Using ModelScope:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/aloha_mobile_cabinet \ + --output_dir /workspace/datasets \ + --repo_type dataset \ + --source modelscope +``` + +The dataset will be downloaded to (example with `/workspace/datasets`): +- `/workspace/datasets/lerobot/aloha_mobile_cabinet` + +### Edit Config + +FlagScale uses a two-level configuration system: + +1. **Experiment-level config** (`examples/pi0/conf/train.yaml`): Defines experiment settings, environment variables, and resource allocation +2. **Task-level config** (`examples/pi0/conf/train/pi0.yaml`): Defines model, dataset, and training hyperparameters + +#### Experiment-Level Config + +Edit the experiment-level config for multi-GPU training: + +```sh +cd FlagScale/ +vim examples/pi0/conf/train.yaml +``` + +Configure the following fields: + +- `experiment.envs.CUDA_VISIBLE_DEVICES` - GPU devices to use (e.g., `"0,1,2,3"` for 4 GPUs, `"0,1"` for 2 GPUs) +- `experiment.envs.CUDA_DEVICE_MAX_CONNECTIONS` - Connection limit (typically `1`) +- `experiment.exp_name` - Experiment name +- `experiment.exp_dir` - Output directory for checkpoints and logs + +#### Task-Level Config + +Edit the task-level config for model and training settings: + +```sh +cd FlagScale/ +vim examples/pi0/conf/train/pi0.yaml +``` + +Configure the following fields: + +**System settings** (training hyperparameters): +- `system.batch_size` - Batch size per GPU +- `system.train_steps` - Total training steps +- `system.optimizer.name` - Optimizer name (default: `"AdamW"`) +- `system.optimizer.lr` - Learning rate (default: `2.5e-5`) +- `system.optimizer.betas` - Optimizer betas (default: `[0.9, 0.95]`) +- `system.optimizer.eps` - Optimizer epsilon (default: `1.0e-8`) +- `system.optimizer.weight_decay` - Weight decay (default: `0.01`) +- `system.scheduler.warmup_steps` - Warmup steps (default: `1000`) +- `system.scheduler.decay_steps` - Decay steps (default: `30000`) +- `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) +- `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) +- `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) + +**Model settings**: +- `model.model_name` - Model name: `"pi0"` or `"pi0.5"` +- `model.checkpoint_dir` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `model.tokenizer_path` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `model.tokenizer_max_length` - Maximum tokenizer sequence length +- `model.action_steps` - Number of action steps to predict + +**Data settings**: +- `data.data_path` - Path to LeRobot dataset root (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet`) +- `data.use_imagenet_stats` - Whether to use ImageNet normalization stats (default: `true`) +- `data.rename_map` - Dictionary mapping dataset keys to policy keys (optional). Check the `features` key in your dataset's `meta/info.json` file to determine the correct mapping: + ```yaml + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + ``` +- `data.use_quantiles` - Whether to use quantile normalization (for `pi0.5`, set to `false` to use MEAN_STD normalization) + +### Start Training +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name train action=run +``` + +Training logs are saved to `outputs/pi0_train/logs/host_0_localhost.output` by default. + +Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_train/ckpt`). + +### Stop Training +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name train action=stop +``` + +## Inference + +### Prepare Inference Inputs + +You can extract inference inputs (images, state, task) from a dataset using the provided script: + +```sh +cd FlagScale/ +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --frame_index 100 +``` + +This will create: +- `frame_100_observation_images_*.jpg` - Image files +- `frame_100_state.pt` - State tensor +- `frame_100_task.txt` - Task prompt +- `extraction_summary.json` - Summary of extracted files + +Alternatively, you can extract from a specific episode and frame: + +```sh +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --episode_index 0 \ + --frame_in_episode 50 +``` + +Or extract multiple samples at once: + +```sh +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --frame_indices 100 200 300 +``` + +### Edit Config + +```sh +cd FlagScale/ +vim examples/pi0/conf/inference/pi0.yaml +``` + +Configure the following fields: + +**Engine settings:** +- `engine.model` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `engine.tokenizer` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `engine.stat_path` - Path to dataset statistics (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet/meta/stats.json`) +- `engine.device` - Device to use (e.g., `"cuda"`) + +**Generate settings:** +- `generate.images` - Dictionary mapping image keys to file paths: + ```yaml + images: + observation.images.cam_high: /path/to/image1.jpg + observation.images.cam_left_wrist: /path/to/image2.jpg + observation.images.cam_right_wrist: /path/to/image3.jpg + ``` +- `generate.state_path` - Path to state tensor file (`.pt` file) +- `generate.task_path` - Path to task prompt file (`.txt` file) +- `generate.rename_map` (optional) - Map input keys to policy expected keys: + ```yaml + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + ``` + +### Run Inference + +```sh +cd FlagScale/ +python run.py \ + --config-path ./examples/pi0/conf \ + --config-name inference \ + action=run +``` + +Inference logs are saved to `outputs/pi0_inference/inference_logs/host_0_localhost.output` by default. + +The predicted action tensor is printed to the console and saved in the log file. + +## Serving + +### Edit Config + +```sh +cd FlagScale/ +vim examples/pi0/conf/serve/pi0.yaml +``` + +Configure the following fields: + +**Engine arguments:** +- `engine_args.host` - Server host (default: `"0.0.0.0"`) +- `engine_args.port` - Server port (default: `5000`) +- `engine_args.model` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `engine_args.tokenizer` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `engine_args.stat_path` - Path to dataset statistics (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet/meta/stats.json`) +- `engine_args.device` - Device to use (e.g., `"cuda"`) +- `engine_args.images_keys` - List of image keys expected by the model (do not change): + ```yaml + images_keys: + - observation.images.base_0_rgb + - observation.images.left_wrist_0_rgb + - observation.images.right_wrist_0_rgb + ``` +- `engine_args.images_shape` - Image shape `[C, H, W]` for warmup (e.g., `[3, 480, 640]`) +- `engine_args.state_key` - Key for state in the batch (e.g., `"observation.state"`) + +### Run Serving + +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name serve action=run +``` + +Serving logs are saved to `outputs/pi0_serve/logs/host_0_localhost.output` by default. + +### Stop Serving + +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name serve action=stop +``` + +### Test Server with Client + +The client should send images using keys that match the `images_keys` in the config. For example, if using the default config: + +```sh +cd FlagScale/ +python examples/pi0/client_pi0.py \ + --host 127.0.0.1 \ + --port 5000 \ + --img1 ./inference_inputs/frame_100_observation_images_cam_high.jpg \ + --img2 ./inference_inputs/frame_100_observation_images_cam_left_wrist.jpg \ + --img3 ./inference_inputs/frame_100_observation_images_cam_right_wrist.jpg \ + --state-path ./inference_inputs/frame_100_state.pt \ + --instruction "Grab the orange and put it into the basket." +``` + +**Note**: The client must send image keys that match the `engine_args.images_keys` in the config. diff --git a/examples/qwen_gr00t/client_pi0.py b/examples/qwen_gr00t/client_pi0.py new file mode 100644 index 0000000000..4074ad3839 --- /dev/null +++ b/examples/qwen_gr00t/client_pi0.py @@ -0,0 +1,129 @@ +import argparse +import base64 +import json +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np +import requests +import torch + + +def encode_image(path: str) -> str: + """Read image as base64 string.""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Image not found: {path.resolve()}") + return base64.b64encode(path.read_bytes()).decode("utf-8") + + +def check_health(base_url: str) -> None: + """Ping /health; raise RuntimeError if unhealthy.""" + try: + r = requests.get(f"{base_url}/health", timeout=5) + r.raise_for_status() + except Exception as e: + raise RuntimeError(f"Health-check request failed: {e}") from e + + data = r.json() + if not (data.get("status") == "healthy" and data.get("model_loaded")): + raise RuntimeError(f"Server not ready: {json.dumps(data, indent=2)}") + print(f"[√] Server healthy - GPU: {data['gpu_info']['device_name']}") + + +def load_state_from_file(state_path: str) -> np.ndarray: + """Load state tensor from file and convert to numpy array. + + Args: + state_path: Path to state file (.pt file) + + Returns: + State array with shape (1, state_dim) + """ + state = torch.load(state_path, map_location="cpu") + if isinstance(state, torch.Tensor): + state = state.numpy() + # Ensure shape is (1, state_dim) + if state.ndim == 1: + state = state[np.newaxis, :] + return state + + +def build_payload(args) -> dict[str, Any]: + """Construct JSON payload for /infer. + + The client must send images with keys matching the config's images_keys. + Default keys are: + - observation.images.base_0_rgb + - observation.images.left_wrist_0_rgb + - observation.images.right_wrist_0_rgb + """ + # Encode images with keys matching config images_keys + img_sample = { + "observation.images.base_0_rgb": encode_image(args.img1), + "observation.images.left_wrist_0_rgb": encode_image(args.img2), + "observation.images.right_wrist_0_rgb": encode_image(args.img3), + } + # Load state from file + state = load_state_from_file(args.state_path) + state = state.tolist() + + return {"instruction": args.instruction, "state": state, "images": [img_sample]} + + +def pretty_print_resp(resp: requests.Response) -> None: + """Nicely print JSON or raw content.""" + try: + print(json.dumps(resp.json(), indent=2, ensure_ascii=False)) + except ValueError: + print(resp.text) + + +def main(): + parser = argparse.ArgumentParser(description="Client for RoboBrain-Robotics inference API") + parser.add_argument( + "--host", default="127.0.0.1", help="Host of local SSH tunnel (default: 127.0.0.1)" + ) + parser.add_argument( + "--port", type=int, default=5000, help="Port of local SSH tunnel (default: 15000)" + ) + parser.add_argument("--img1", required=True, help="Path to first camera RGB image") + parser.add_argument("--img2", required=True, help="Path to second camera RGB image") + parser.add_argument("--img3", required=True, help="Path to third camera RGB image") + parser.add_argument( + "--state-path", + required=True, + help="Path to state tensor file (.pt file) with shape (1, state_dim)", + ) + parser.add_argument( + "--instruction", + default="Grab the orange and put it into the basket.", + help="Task instruction for the robot", + ) + args = parser.parse_args() + + base_url = f"http://{args.host}:{args.port}" + print(f"-> Using endpoint: {base_url}") + + payload = build_payload(args) + try: + t0 = time.time() + resp = requests.post( + f"{base_url}/infer", + headers={"Content-Type": "application/json"}, + data=json.dumps(payload), + timeout=300, + ) + elapsed = (time.time() - t0) * 1000 + resp.raise_for_status() + except requests.RequestException as e: + print(f"[Error] HTTP request failed: {e}") + sys.exit(1) + print(f"[√] Response OK ({resp.status_code}) - {elapsed:.1f}ms") + pretty_print_resp(resp) + + +if __name__ == "__main__": + main() diff --git a/examples/qwen_gr00t/conf/inference.yaml b/examples/qwen_gr00t/conf/inference.yaml new file mode 100644 index 0000000000..36d3686823 --- /dev/null +++ b/examples/qwen_gr00t/conf/inference.yaml @@ -0,0 +1,26 @@ +defaults: + - _self_ + - inference: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_inference + exp_dir: outputs/${experiment.exp_name} + model: /models/qwen_gr00t + task: + type: inference + backend: vllm # TODO: Remove this restriction + entrypoint: flagscale/inference/inference_qwen_gr00t.py + runner: + hostfile: null + cmds: + before_start: null + envs: + CUDA_VISIBLE_DEVICES: 2 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + # Optionally, set HF_HOME and HF_ENDPOINT + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml new file mode 100644 index 0000000000..43890a76f6 --- /dev/null +++ b/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml @@ -0,0 +1,11 @@ +engine: + model_variant: "QwenGr00t" + model: /share/project/fengyupu/github/FlagScale_2/outputs/qwen_gr00t_train/20260207_110505.701567_ckpt/last + device: "cuda" + +generate: + images: + observation.images.wrist_image: qwen_gr00t_inference_inputs/frame_100_observation_images_wrist_image.jpg + observation.images.image: qwen_gr00t_inference_inputs/frame_100_observation_images_image.jpg + state_path: qwen_gr00t_inference_inputs/frame_100_state.pt + task_path: qwen_gr00t_inference_inputs/frame_100_task.txt diff --git a/examples/qwen_gr00t/conf/serve.yaml b/examples/qwen_gr00t/conf/serve.yaml new file mode 100644 index 0000000000..d8f02ef771 --- /dev/null +++ b/examples/qwen_gr00t/conf/serve.yaml @@ -0,0 +1,23 @@ +defaults: +- _self_ +- serve: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_serve_2 + exp_dir: outputs/${experiment.exp_name} + task: + type: serve + entrypoint: flagscale/serve/run_serve_qwen_gr00t.py + runner: + hostfile: null + deploy: + use_fs_serve: false + envs: + CUDA_VISIBLE_DEVICES: 3 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml new file mode 100644 index 0000000000..2f2dd73173 --- /dev/null +++ b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml @@ -0,0 +1,14 @@ +- serve_id: vllm_model # Not in use + engine_args: + host: 0.0.0.0 + port: 6000 + model_variant: QwenGr00t + model: /share/project/fengyupu/github/FlagScale_2/outputs/qwen_gr00t_train/20260208_112711.741406_ckpt/last + device: "cuda" + images_keys: + - observation.images.base_0_rgb + - observation.images.left_wrist_0_rgb + - observation.images.right_wrist_0_rgb + # Only used for warmup + images_shape: [3, 480, 640] + state_key: observation.state diff --git a/examples/qwen_gr00t/conf/train.yaml b/examples/qwen_gr00t/conf/train.yaml new file mode 100644 index 0000000000..2f537a9a58 --- /dev/null +++ b/examples/qwen_gr00t/conf/train.yaml @@ -0,0 +1,35 @@ +defaults: + - _self_ + - train: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_train + seed: 42 + save_steps: 10000 + load: null + exp_dir: outputs/${experiment.exp_name} + ckpt_format: torch + task: + type: train + backend: native + entrypoint: flagscale/train/train_qwen_gr00t.py + runner: + per_node_task: false + no_shared_fs: false + rdzv_backend: static + hostfile: null + cmds: + before_start: echo "Starting Qwen-GR00T Training" + envs: + LOGLEVEL: "INFO" + # CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" + CUDA_VISIBLE_DEVICES: "2" + CUDA_DEVICE_MAX_CONNECTIONS: 1 + WANDB_MODE: offline + OTEL_SDK_DISABLED: true + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml new file mode 100644 index 0000000000..952885d8d2 --- /dev/null +++ b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml @@ -0,0 +1,168 @@ +system: + batch_size: 16 + train_steps: 30000 + log_freq: 1 + grad_clip_norm: 1.0 + use_amp: true + shuffle: true + num_workers: 4 + + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + param_groups: + vlm: + lr: 1.0e-05 + action_model: + lr: 1.0e-04 + + scheduler: + name: cosine_with_min_lr + warmup_steps: 5000 + scheduler_kwargs: + min_lr: 1.0e-06 + # Legacy fields kept for BC + decay_steps: 30000 + decay_lr: 2.5e-6 + + checkpoint: + output_directory: ${experiment.exp_dir} + # Whether to save checkpoint + save_checkpoint: true + # Number of steps between checkpoints + save_freq: 1000 + # TODO(yupu): Support resuming from checkpoint + +model: + # TODO: (yupu) the config layout is still a mess + model_name: qwen_gr00t + # Path to the checkpoint of the pretrained base VLM model, e.g. Qwen3-VL-4B-Instruct + checkpoint_dir: /share/project/fengyupu/models/Qwen/Qwen3-VL-4B-Instruct/ + # checkpoint_dir: /workspace/models/Qwen/Qwen2.5-VL-3B-Instruct/ + vlm: + type: qwen3-vl + # type: qwen2.5-vl + qwenvl: + base_vlm: /share/project/fengyupu/models/Qwen/Qwen3-VL-4B-Instruct/ + # base_vlm: /workspace/models/Qwen/Qwen2.5-VL-3B-Instruct/ + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + dino: + dino_backbone: dinov2_vits14 + action_model: + type: flow_matching + action_model_type: DiT-B + action_hidden_dim: 1024 + hidden_size: 1024 + add_pos_embed: True + max_seq_len: 1024 + action_dim: 7 + state_dim: 7 + future_action_window_size: 7 + action_horizon: 8 + past_action_window_size: 0 + repeated_diffusion_steps: 4 + noise_beta_alpha: 1.5 + noise_beta_beta: 1.0 + noise_s: 0.999 + num_timestep_buckets: 1000 + num_inference_timesteps: 4 + num_target_vision_tokens: 32 + diffusion_model_cfg: + cross_attention_dim: 2048 + dropout: 0.2 + final_dropout: True + # # FIXME: Debug only + # dropout: 0 + # final_dropout: False + interleave_self_attention: True + norm_type: ada_norm + num_layers: 16 + output_dim: 1024 + positional_embeddings: None + reduce_in_full_precision: True + + # ============================================================ + # Module Freezing Configuration + # ============================================================ + # Freezing logic: freeze_patterns are applied first, then keep_patterns override. + # Patterns are regex matched against full parameter names. + # + # Common patterns for QwenGR00T: + # - "qwen_vl_interface\\..*" # Entire VLM + # - "qwen_vl_interface\\.model\\.visual\\..*" # Vision encoder + # - "qwen_vl_interface\\.model\\.model\\..*" # Language model + # - "qwen_vl_interface\\.model\\.model\\.layers\\.[0-9]\\." # LLM layers 0-9 + # - "action_model\\..*" # Action head + # - "action_model\\.model\\.transformer_blocks\\.[0-7]\\." # DiT blocks 0-7 + # + # freeze: + # # SCENARIO A: Freeze VLM, train only action head + # freeze_patterns: + # - "qwen_vl_interface\\..*" + # + # # SCENARIO B: Freeze VLM but keep projector trainable + # # freeze_patterns: + # # - "qwen_vl_interface\\..*" + # # keep_patterns: + # # - "qwen_vl_interface\\.model\\.visual\\.merger\\..*" + # + # # SCENARIO C: Freeze everything except action decoder + # # freeze_patterns: + # # - ".*" + # # keep_patterns: + # # - "action_model\\.action_decoder\\..*" + +data: + # TODO: (yupu) Remove this once we have a proper dataset config + vla_data: + dataset_py: lerobot_datasets + data_root_dir: playground/Datasets/ + data_mix: libero_goal_old + action_type: delta_qpos + CoT_prompt: Your task is {instruction}. To identify the key objects for your task. Locate their bounding boxes in [x1,y1,x2,y2] format. + CoT_answer: bbox + default_image_resolution: [3, 224, 224] + load_all_data_for_training: True + obs: ["image_0"] + video_backend: torchvision_av + # Path to the training data + data_path: /share/project/fengyupu/datasets/IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot/ + tolerance_s: 0.0001 + use_imagenet_stats: False + # To match the input features naming from the dataset to the policy config + # For example, for the aloha_mobile_cabinet dataset, the rename_map is: + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + use_quantiles: false + # TODO: (yupu) I think these indices should belong to the policy config, maybe put it in the model config? + observation_delta_indices: [0] + action_delta_indices: [0,1,2,3,4,5,6,7] + preprocessor: + name: policy_preprocessor + steps: + - registry_name: rename_observations_processor + config: + rename_map: {} + - registry_name: to_batch_processor + config: {} + - registry_name: device_processor + config: + device: cuda + float_dtype: null + - registry_name: normalizer_processor + config: + eps: 1e-8 + features: {} + # Only normalize first 6 action dims (x,y,z,roll,pitch,yaw). + # Gripper (dim 6) is left raw, matching starVLA's Libero4in1DataConfig. + normalize_action_dims: 6 + norm_map: + VISUAL: IDENTITY + STATE: MIN_MAX + ACTION: MIN_MAX diff --git a/examples/qwen_gr00t/download.py b/examples/qwen_gr00t/download.py new file mode 100755 index 0000000000..48391c1e72 --- /dev/null +++ b/examples/qwen_gr00t/download.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +""" +Download models or datasets from HuggingFace Hub or ModelScope to a user-defined folder. + +Usage: + # Download model from HuggingFace + python download.py \ + --repo_id lerobot/pi0_base \ + --output_dir ~/models \ + --source huggingface + # Downloads to: ~/models/lerobot/pi0_base + + # Download dataset from HuggingFace + python download.py \ + --repo_id lerobot/aloha_mobile_cabinet \ + --output_dir ~/datasets \ + --repo_type dataset \ + --source huggingface + # Downloads to: ~/datasets/lerobot/aloha_mobile_cabinet +""" + +import argparse +import sys +from pathlib import Path + + +def _prepare_download(repo_id: str, output_dir: Path, repo_type: str, source_name: str) -> Path: + """Prepare download directory and print info. + + Returns: + Final output directory path + """ + final_output_dir = output_dir / repo_id + print(f"Downloading {repo_type} {repo_id} from {source_name}...") + print(f"Output directory: {final_output_dir}") + final_output_dir.mkdir(parents=True, exist_ok=True) + return final_output_dir + + +def _handle_download_error(e: Exception, repo_id: str, source: str) -> None: + """Handle download errors with helpful tips.""" + print(f"✗ Error downloading from {source}: {e}") + if "401" in str(e) or "authentication" in str(e).lower(): + if source == "HuggingFace": + print("\nTip: You may need to set a HuggingFace token:") + print(" export HF_TOKEN=your_token_here") + print(" or run: huggingface-cli login") + else: + print("\nTip: You may need to set ModelScope credentials:") + print(" export MODELSCOPE_API_TOKEN=your_token_here") + elif "404" in str(e) or "not found" in str(e).lower(): + print(f"\nTip: Repository '{repo_id}' not found. Check the repo ID.") + sys.exit(1) + + +def download_from_huggingface( + repo_id: str, + output_dir: Path, + repo_type: str = "model", + revision: str | None = None, + token: str | None = None, +) -> Path: + """Download model or dataset from HuggingFace Hub. + + Args: + repo_id: HuggingFace repository ID (e.g., "lerobot/pi0_base") + output_dir: Base directory to save the repository + (will be saved to output_dir/repo_id) + repo_type: Type of repository - "model" or "dataset" (default: "model") + revision: Git revision (branch, tag, or commit hash). Defaults to "main" + token: HuggingFace token for private repos. If None, uses cached token + + Returns: + Path to downloaded repository directory + """ + try: + from huggingface_hub import snapshot_download + except ImportError: + print("Error: huggingface_hub is not installed.") + print("Install it with: pip install huggingface_hub") + sys.exit(1) + + final_output_dir = _prepare_download(repo_id, output_dir, repo_type, "HuggingFace Hub") + + try: + downloaded_path = snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + local_dir=str(final_output_dir), + local_dir_use_symlinks=False, + token=token, + ) + downloaded_path = Path(downloaded_path) + print(f"✓ Successfully downloaded to: {downloaded_path}") + return downloaded_path + except Exception as e: + _handle_download_error(e, repo_id, "HuggingFace") + return Path() # Never reached, but satisfies type checker + + +def download_from_modelscope( + repo_id: str, output_dir: Path, repo_type: str = "model", revision: str | None = None +) -> Path: + """Download model or dataset from ModelScope. + + Args: + repo_id: ModelScope repository ID (e.g., "lerobot/pi0_base") + output_dir: Base directory to save the repository + (will be saved to output_dir/repo_id) + repo_type: Type of repository - "model" or "dataset" (default: "model") + revision: Git revision (branch, tag, or commit hash). Defaults to "master" + + Returns: + Path to downloaded repository directory + """ + try: + from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download + except ImportError: + try: + from modelscope import snapshot_download as ms_snapshot_download + except ImportError: + print("Error: modelscope is not installed.") + print("Install it with: pip install modelscope") + sys.exit(1) + + final_output_dir = _prepare_download(repo_id, output_dir, repo_type, "ModelScope") + + try: + downloaded_path = ms_snapshot_download( + model_id=repo_id, + repo_type=repo_type, + local_dir=str(final_output_dir), + revision=revision, + ) + downloaded_path = Path(downloaded_path) + print(f"✓ Successfully downloaded to: {downloaded_path}") + return downloaded_path + except Exception as e: + _handle_download_error(e, repo_id, "ModelScope") + return Path() # Never reached, but satisfies type checker + + +def main(): + parser = argparse.ArgumentParser( + description="Download models or datasets from HuggingFace Hub or ModelScope", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Download model from HuggingFace (saves to ~/models/lerobot/pi0_base) + python download.py --repo_id lerobot/pi0_base \\ + --output_dir ~/models --source huggingface + + # Download dataset from HuggingFace (saves to ~/datasets/lerobot/aloha_mobile_cabinet) + python download.py --repo_id lerobot/aloha_mobile_cabinet \\ + --output_dir ~/datasets --repo_type dataset --source huggingface + + # Download from ModelScope (China users, saves to ~/models/lerobot/pi0_base) + python download.py --repo_id lerobot/pi0_base \\ + --output_dir ~/models --source modelscope + + # Download tokenizer (saves to ~/models/google/paligemma-3b-pt-224) + python download.py --repo_id google/paligemma-3b-pt-224 \\ + --output_dir ~/models --source huggingface + +Note: For private repositories, set HF_TOKEN environment variable: + export HF_TOKEN=your_token_here + """, + ) + + parser.add_argument( + "--repo_id", + type=str, + required=True, + help="Repository ID (e.g., 'lerobot/pi0_base' or 'lerobot/aloha_mobile_cabinet')", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help=( + "Base output directory (repository will be saved to output_dir/repo_id, " + "e.g., '~/models' -> '~/models/lerobot/pi0_base')" + ), + ) + + parser.add_argument( + "--repo_type", + type=str, + choices=["model", "dataset"], + default="model", + help="Type of repository: 'model' or 'dataset' (default: model)", + ) + + parser.add_argument( + "--source", + type=str, + choices=["huggingface", "modelscope"], + default="huggingface", + help="Source to download from: 'huggingface' or 'modelscope' (default: huggingface)", + ) + + args = parser.parse_args() + + output_dir = Path(args.output_dir).expanduser().resolve() + + if args.source == "huggingface": + downloaded_path = download_from_huggingface( + repo_id=args.repo_id, + output_dir=output_dir, + repo_type=args.repo_type, + revision=None, + token=None, + ) + elif args.source == "modelscope": + downloaded_path = download_from_modelscope( + repo_id=args.repo_id, output_dir=output_dir, repo_type=args.repo_type, revision=None + ) + else: + raise ValueError(f"Unknown source: {args.source}") + + repo_type_name = "Dataset" if args.repo_type == "dataset" else "Model" + print(f"\n{repo_type_name} downloaded successfully to: {downloaded_path}") + print("You can now use this path in your config file:") + if args.repo_type == "dataset": + print(f" data_path: {downloaded_path}") + else: + print(f" checkpoint_dir: {downloaded_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/qwen_gr00t/dump_dataset_inputs.py b/examples/qwen_gr00t/dump_dataset_inputs.py new file mode 100644 index 0000000000..1cb9c61acb --- /dev/null +++ b/examples/qwen_gr00t/dump_dataset_inputs.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python +""" +Extract inference inputs (images, state, task) from a LeRobotDataset. + +This script extracts the required inputs from a dataset sample and saves them +in a format that can be used by the inference script. + +Usage: + # Extract from a specific frame index + python dump_dataset_inputs.py \ + --dataset_root /path/to/dataset \ + --output_dir ./inference_inputs \ + --frame_index 100 + + # Extract from a specific episode and frame + python dump_dataset_inputs.py \ + --dataset_root /path/to/dataset \ + --output_dir ./inference_inputs \ + --episode_index 0 \ + --frame_in_episode 50 + + # Extract multiple samples + python dump_dataset_inputs.py \ + --dataset_root /path/to/dataset \ + --output_dir ./inference_inputs \ + --frame_indices 100 200 300 +""" + +import argparse +import json +import os +import sys + +# Add FlagScale root to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from pathlib import Path + +import torch +from PIL import Image +from torchvision.transforms import ToPILImage + +from flagscale.train.datasets.lerobot_dataset import LeRobotDataset + + +def tensor_to_image(tensor: torch.Tensor) -> Image.Image: + """Convert tensor to PIL Image. + + Handles different tensor formats: + - (C, H, W) - single image + - (H, W, C) - single image (channel last) + - (B, C, H, W) - batch, takes first + - (B, H, W, C) - batch, takes first + """ + # Remove batch dimension if present + if tensor.dim() == 4: + tensor = tensor[0] + + # Handle channel-first vs channel-last + if tensor.dim() == 3: + if tensor.shape[0] == 3 or tensor.shape[0] == 1: + # (C, H, W) -> (H, W, C) + tensor = tensor.permute(1, 2, 0) + # Now should be (H, W, C) + + # Clamp values to [0, 1] if needed + if tensor.max() > 1.0: + tensor = tensor / 255.0 + + # Convert to [0, 255] uint8 + if tensor.dtype != torch.uint8: + tensor = (tensor.clamp(0, 1) * 255).byte() + + # Handle grayscale + if tensor.shape[2] == 1: + tensor = tensor.squeeze(2) + + # Convert to PIL Image + to_pil = ToPILImage() + if tensor.shape[2] == 3: + # RGB + img = to_pil(tensor.permute(2, 0, 1)) + else: + # Grayscale + img = Image.fromarray(tensor.numpy(), mode="L") + + return img + + +def extract_sample( + dataset: LeRobotDataset, + frame_index: int | None = None, + episode_index: int | None = None, + frame_in_episode: int | None = None, +) -> dict: + """Extract a sample from the dataset. + + Args: + dataset: LeRobotDataset instance + frame_index: Global frame index (takes precedence) + episode_index: Episode index (requires frame_in_episode) + frame_in_episode: Frame index within episode + + Returns: + Dictionary with sample data + """ + if frame_index is not None: + idx = frame_index + elif episode_index is not None and frame_in_episode is not None: + # Find the global index from episode and frame + episode_info = dataset.meta.episodes.iloc[episode_index] + idx = episode_info["dataset_from_index"] + frame_in_episode + else: + raise ValueError("Must provide either frame_index or (episode_index, frame_in_episode)") + + if idx >= len(dataset): + raise ValueError(f"Index {idx} out of range (dataset has {len(dataset)} frames)") + + sample = dataset[idx] + return sample + + +def dump_sample( + sample: dict, + output_dir: Path, + sample_name: str = "sample", + image_format: str = "jpg", + dataset=None, +) -> dict: + """Save sample data to files. + + Args: + sample: Sample dictionary from dataset + output_dir: Directory to save files + sample_name: Base name for output files + image_format: Image format ('jpg' or 'png') + + Returns: + Dictionary with paths to saved files + """ + saved_paths = {"images": {}, "state": None, "task": None} + + # TODO: A little bit hacky + image_keys = [k for k in sample.keys() if "images" in k] + print(f"Found {len(image_keys)} image key(s): {image_keys}") + + for img_key in image_keys: + img_tensor = sample[img_key] + img = tensor_to_image(img_tensor) + + filename = img_key.replace(".", "_") + img_path = output_dir / f"{sample_name}_{filename}.{image_format}" + + img.save(img_path) + print(f"Saved image: {img_path}") + saved_paths["images"][img_key] = str(img_path) + + # Extract and save state + state_keys = [k for k in sample.keys() if "state" in k and "images" not in k] + if state_keys: + state_key = state_keys[0] # Use first state key + state_tensor = sample[state_key] + + # Ensure it's 2D (batch, dim) + if state_tensor.dim() == 1: + state_tensor = state_tensor.unsqueeze(0) + + state_path = output_dir / f"{sample_name}_state.pt" + torch.save(state_tensor, state_path) + print(f"Saved state: {state_path} (shape: {state_tensor.shape})") + saved_paths["state"] = str(state_path) + else: + print("Warning: No state found in sample") + + # Extract and save task + if "task" in sample: + task = sample["task"] + if isinstance(task, torch.Tensor): + task = task.item() if task.numel() == 1 else str(task.tolist()) + elif isinstance(task, list) and len(task) > 0: + task = task[0] if isinstance(task[0], str) else str(task[0]) + + task_path = output_dir / f"{sample_name}_task.txt" + with open(task_path, "w", encoding="utf-8") as f: + f.write(str(task)) + print(f"Saved task: {task_path} (content: '{task}')") + saved_paths["task"] = str(task_path) + elif "task_index" in sample: + # Try to get task from task_index + task_idx = sample["task_index"] + if isinstance(task_idx, torch.Tensor): + task_idx = task_idx.item() + + # Get task from dataset metadata + if dataset is not None and hasattr(dataset, "meta") and hasattr(dataset.meta, "tasks"): + tasks_df = dataset.meta.tasks + if task_idx < len(tasks_df): + task = tasks_df.iloc[task_idx]["task"] + task_path = output_dir / f"{sample_name}_task.txt" + with open(task_path, "w", encoding="utf-8") as f: + f.write(str(task)) + print(f"Saved task: {task_path} (content: '{task}')") + saved_paths["task"] = str(task_path) + else: + print("Warning: No task found in sample") + + return saved_paths + + +def get_args(): + parser = argparse.ArgumentParser(description="Extract inference inputs from LeRobotDataset") + parser.add_argument( + "--dataset_root", type=str, default=None, help="Local dataset root directory" + ) + parser.add_argument( + "--output_dir", type=str, required=True, help="Output directory to save extracted files" + ) + parser.add_argument( + "--frame_index", type=int, default=None, help="Global frame index to extract" + ) + parser.add_argument( + "--episode_index", + type=int, + default=None, + help="Episode index (requires --frame_in_episode)", + ) + parser.add_argument( + "--frame_in_episode", + type=int, + default=None, + help="Frame index within episode (requires --episode_index)", + ) + parser.add_argument( + "--frame_indices", + type=int, + nargs="+", + default=None, + help="Multiple frame indices to extract", + ) + parser.add_argument( + "--image_format", + type=str, + default="jpg", + choices=["jpg", "png"], + help="Image format to save", + ) + parser.add_argument( + "--video_backend", + type=str, + default="pyav", + choices=["pyav", "torchcodec", "video_reader"], + help="Video backend to use (default: pyav, more reliable than torchcodec)", + ) + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + # Create output directory early + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load dataset + print(f"Loading dataset: {args.dataset_root}") + dataset = LeRobotDataset(root=args.dataset_root, video_backend=args.video_backend) + print(f"Dataset loaded: {len(dataset)} frames, {dataset.num_episodes} episodes") + + # Determine which samples to extract + if args.frame_indices: + indices = args.frame_indices + sample_names = [f"frame_{idx}" for idx in indices] + elif args.frame_index is not None: + indices = [args.frame_index] + sample_names = [f"frame_{args.frame_index}"] + elif args.episode_index is not None and args.frame_in_episode is not None: + # Calculate global index + episode_info = dataset.meta.episodes[args.episode_index] + global_idx = episode_info["dataset_from_index"] + args.frame_in_episode + indices = [global_idx] + sample_names = [f"episode_{args.episode_index}_frame_{args.frame_in_episode}"] + else: + raise ValueError( + "Must provide --frame_index, --frame_indices, or (--episode_index + --frame_in_episode)" + ) + + # Extract and save samples + all_paths = [] + + for idx, sample_name in zip(indices, sample_names, strict=False): + print(f"\n{'=' * 60}") + print(f"Extracting sample {idx} ({sample_name})") + print(f"{'=' * 60}") + + sample = extract_sample(dataset, frame_index=idx) + paths = dump_sample(sample, output_dir, sample_name, args.image_format, dataset=dataset) + all_paths.append({"index": idx, "sample_name": sample_name, "paths": paths}) + + summary_path = output_dir / "extraction_summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(all_paths, f, indent=2) + print(f"Extraction complete! Summary saved to: {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/qwen_gr00t/run_client.sh b/examples/qwen_gr00t/run_client.sh new file mode 100755 index 0000000000..eabf779a32 --- /dev/null +++ b/examples/qwen_gr00t/run_client.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Script to run the pi0 client using paths from examples/pi0/conf/inference/pi0.yaml + +set -e + +# Values from examples/pi0/conf/inference/pi0.yaml +BASE_IMG="/share/project/fengyupu/github/FlagScale/inference_inputs/frame_100_observation_images_cam_high.jpg" +LEFT_WRIST_IMG="/share/project/fengyupu/github/FlagScale/inference_inputs/frame_100_observation_images_cam_left_wrist.jpg" +RIGHT_WRIST_IMG="/share/project/fengyupu/github/FlagScale/inference_inputs/frame_100_observation_images_cam_right_wrist.jpg" +STATE_PATH="/share/project/fengyupu/github/FlagScale/inference_inputs/frame_100_state.pt" +TASK_PATH="/share/project/fengyupu/github/FlagScale/inference_inputs/frame_100_task.txt" + +# Server settings +HOST="${1:-127.0.0.1}" +PORT="${2:-5000}" + +# Read instruction from task file +INSTRUCTION=$(cat "$TASK_PATH") + +# Run the client +python examples/pi0/client_pi0.py \ + --host "$HOST" \ + --port "$PORT" \ + --img1 "$BASE_IMG" \ + --img2 "$LEFT_WRIST_IMG" \ + --img3 "$RIGHT_WRIST_IMG" \ + --state-path "$STATE_PATH" \ + --instruction "$INSTRUCTION" diff --git a/flagscale/logger.py b/flagscale/logger.py index 738dbcb8dd..61dcf26a1c 100644 --- a/flagscale/logger.py +++ b/flagscale/logger.py @@ -22,19 +22,19 @@ def __init__(self, name, level=logging.INFO): self.logger.addHandler(stream_handler) def info(self, message): - self.logger.info(message) + self.logger.info(message, stacklevel=2) def warning(self, message): - self.logger.warning(message) + self.logger.warning(message, stacklevel=2) def error(self, message): - self.logger.error(message) + self.logger.error(message, stacklevel=2) def critical(self, message): - self.logger.critical(message) + self.logger.critical(message, stacklevel=2) def debug(self, message): - self.logger.debug(message) + self.logger.debug(message, stacklevel=2) GLOBAL_LOGGER = None diff --git a/flagscale/models/action_model/__init__.py b/flagscale/models/action_model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flagscale/models/action_model/flow_matching_head/__init__.py b/flagscale/models/action_model/flow_matching_head/__init__.py new file mode 100644 index 0000000000..3159bfe656 --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flagscale/models/action_model/flow_matching_head/action_encoder.py b/flagscale/models/action_model/flow_matching_head/action_encoder.py new file mode 100644 index 0000000000..2b005c6be6 --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/action_encoder.py @@ -0,0 +1,105 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/flow_matching_head/action_encoder.py +# Below is the original copyright: + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def swish(x): + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Produces a sinusoidal encoding of shape (B, T, w) + given timesteps of shape (B, T). + """ + + def __init__(self, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps): + # timesteps: shape (B, T) + # We'll compute sin/cos frequencies across dim T + timesteps = timesteps.float() # ensure float + + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + # typical log space frequencies for sinusoidal encoding + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + # Expand timesteps to (B, T, 1) then multiply + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = nn.Linear(action_dim, hidden_size) # (d -> w) + self.W2 = nn.Linear(2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = nn.Linear(hidden_size, hidden_size) # (w -> w) + + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x) + + return x diff --git a/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py b/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py new file mode 100755 index 0000000000..3da618f5bd --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py @@ -0,0 +1,378 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/flow_matching_head/cross_attention_dit.py +# Below is the original copyright: + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import ( + SinusoidalPositionalEmbedding, + TimestepEmbedding, + Timesteps, +) +from torch import nn + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim, compute_dtype=torch.float32): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps): + dtype = next(self.parameters()).dtype + timesteps_proj = self.time_proj(timesteps).to(dtype) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + return timesteps_emb + + +class AdaLayerNorm(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + self.chunk_dim = chunk_dim + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.norm_type = norm_type + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positional_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding( + dim, max_seq_length=num_positional_embeddings + ) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if final_dropout: + self.final_dropout = nn.Dropout(dropout) + else: + self.final_dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + temb: torch.LongTensor | None = None, + ) -> torch.Tensor: + # 0. Self-Attention + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, temb) + else: + norm_hidden_states = self.norm1(hidden_states) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, # @JinhuiYE original attention_mask=attention_mask + ) + if self.final_dropout: + attn_output = self.final_dropout(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + # register_to_config 的作用是创建类的时候会自动把传入的参数注册到 config 中,这样后续调用的时候可以通过 self.config.xxx 调用 还不是 self.xxx + @register_to_config # 去看一下这个的作用 --> 将传入的参数注册到配置中 TODO 改为我们的单例模式, 写一个 能够merge 的 @merge_pram_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + cross_attention_dim: int | None = None, + **kwargs, + ): + super().__init__() + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + # Timestep encoder + # self.config.compute_dtype 可能不存在,要提前处理 + compute_dtype = getattr(self.config, "compute_dtype", torch.float32) + self.timestep_encoder = TimestepEncoder( # TODO BUG, train 的时候 self.config.compute_dtype 不会报错, 但是 eval 的时候会 + embedding_dim=self.inner_dim, compute_dtype=compute_dtype + ) + + all_blocks = [] + for idx in range(self.config.num_layers): + use_self_attn = idx % 2 == 1 and interleave_self_attention + curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None + + all_blocks += [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + cross_attention_dim=curr_cross_attention_dim, + ) + ] + self.transformer_blocks = nn.ModuleList(all_blocks) + + # Output blocks + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) + print( + "Total number of DiT parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + timestep: torch.LongTensor | None = None, + return_all_hidden_states: bool = False, + encoder_attention_mask=None, + ): + # Encode timesteps + temb = self.timestep_encoder(timestep) + + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1 and self.config.interleave_self_attention: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + # Output processing + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + else: + return self.proj_out_2(hidden_states) + + +class SelfAttentionTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + ) + for _ in range(self.config.num_layers) + ] + ) + print( + "Total number of SelfAttentionTransformer parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + return_all_hidden_states: bool = False, + ): + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + hidden_states = block(hidden_states) + all_hidden_states.append(hidden_states) + + if return_all_hidden_states: + return hidden_states, all_hidden_states + else: + return hidden_states diff --git a/flagscale/models/action_model/gr00t_action_header.py b/flagscale/models/action_model/gr00t_action_header.py new file mode 100644 index 0000000000..a0a6b3d395 --- /dev/null +++ b/flagscale/models/action_model/gr00t_action_header.py @@ -0,0 +1,443 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/GR00T_ActionHeader.py +# Below is the original copyright: + +# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved. +# Modified by [Junqiu YU/ Fudan University] in [2025]. +# Modification: [rm and add some connect adapter to match with starVLA, e.g., "rm "]. +# Action repeat is inspired by CogACT + + +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import Beta +from transformers import PretrainedConfig +from transformers.feature_extraction_utils import BatchFeature + +from flagscale.models.action_model.flow_matching_head.action_encoder import ( + SinusoidalPositionalEncoding, + swish, +) +from flagscale.models.action_model.flow_matching_head.cross_attention_dit import DiT + +# TODO try to merge DiT Modules with follow_match_head, they are just the same arch, but diff loss, use diffusers package will be simple + + +class CategorySpecificLinear(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim): + super().__init__() + self.num_categories = num_categories + # For each category, we have separate weights and biases. + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x, cat_ids): + selected_W = self.W[cat_ids] + selected_b = self.b[cat_ids] + # import ipdb; ipdb.set_trace() + return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) + + +class CategorySpecificMLP(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim, output_dim): + super().__init__() + self.num_categories = num_categories + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x, cat_ids): + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + return self.layer2(F.relu(self.layer1(x))) + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + self.action_dim = action_dim + self.layer1 = nn.Linear(action_dim, hidden_size) + self.layer2 = nn.Linear(2 * hidden_size, hidden_size) + self.layer3 = nn.Linear(hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.layer1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then layer2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.layer2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.layer3(x) + return x + + +class MultiEmbodimentActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size, num_embodiments): + super().__init__() + self.hidden_size = hidden_size + self.num_embodiments = num_embodiments + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps, cat_ids): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + cat_ids: shape (B,) + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions, cat_ids) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x, cat_ids) + return x + + +@dataclass +class FlowmatchingActionHeadConfig(PretrainedConfig): + """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" + + add_pos_embed: bool = field( + default=True, metadata={"help": "Whether to add positional embedding"} + ) + diffusion_model_cfg: dict = field( + default=None, metadata={"help": "Diffusion model configuration."} + ) + input_embedding_dim: int = field( + default=1536, metadata={"help": "Input embedding channel dimension."} + ) + + hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."}) + max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"}) + action_dim: int = field(default=None, metadata={"help": "Action dimension."}) + action_horizon: int = field(default=None, metadata={"help": "Action horizon."}) + noise_beta_alpha: float = field(default=1.5, metadata={"help": ""}) + noise_beta_beta: float = field(default=1.0, metadata={"help": ""}) + noise_s: float = field( + default=0.999, metadata={"help": "Flow matching noise Beta distribution s."} + ) + num_timestep_buckets: int = field( + default=1000, metadata={"help": "Number of timestep discretization buckets."} + ) + num_inference_timesteps: int = field( + default=None, + metadata={"help": "Number of inference steps for noise diffusion."}, + ) + max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."}) + tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."}) + tune_diffusion_model: bool = field( + default=True, metadata={"help": "Whether to tune the diffusion model."} + ) + load_pretrained_det_decode_layer_path: str = field( + default=None, metadata={"help": "Path to pretrained detection model."} + ) + detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."}) + + freeze_decode_layer: bool = field(default=False) + expand_batch: int = field(default=None) + use_vlln: bool = field(default=True) + + vl_self_attention_cfg: dict = field(default=None) + num_target_vision_tokens: int = field( + default=32, metadata={"help": "Number of target vision tokens."} + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) + + +DiTConfig = { + "DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12}, + "DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32}, +} + + +class FlowmatchingActionHead(nn.Module): + def __init__( + self, + full_config, + ): + super().__init__() + config = full_config.model.action_model + self.hidden_size = config.hidden_size # @JinhuiYE + self.full_config = full_config + action_model_type = config.action_model_type + action_model_cfg = DiTConfig[action_model_type] + + self.input_embedding_dim = action_model_cfg["input_embedding_dim"] + diffusion_model_cfg = config.diffusion_model_cfg + diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg} + print( + f"[DEBUG RNG ActionHead] Before DiT: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + self.model = DiT(**diffusion_model_cfg) + print( + f"[DEBUG RNG ActionHead] After DiT: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + self.action_dim = config.action_dim + self.action_horizon = config.future_action_window_size + 1 + self.num_inference_timesteps = config.num_inference_timesteps + + self.state_encoder = ( + MLP( + input_dim=config.state_dim, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + if config.state_dim + else None + ) + print( + f"[DEBUG RNG ActionHead] After state_encoder: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + + self.action_encoder = ActionEncoder( + action_dim=config.action_dim, + hidden_size=self.input_embedding_dim, + ) + print( + f"[DEBUG RNG ActionHead] After action_encoder: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + self.action_decoder = MLP( + input_dim=self.model.config.output_dim, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + print( + f"[DEBUG RNG ActionHead] After action_decoder: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim) + nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) + print( + f"[DEBUG RNG ActionHead] After future_tokens: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + print( + f"[DEBUG RNG ActionHead] After position_embedding: state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + + self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self.num_timestep_buckets = config.num_timestep_buckets + self.config = config + + def sample_time(self, batch_size, device, dtype): + sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + return (self.config.noise_s - sample) / self.config.noise_s + + def prepare_input(self, batch: dict) -> BatchFeature: + return BatchFeature(data=batch) + + def forward( + self, + vl_embs: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor = None, + encoder_attention_mask=None, + ): + """ + vl_embs: shape (B, seq_length, feature_dim) + actions: shape (B, future_action_window_size, D_action) + """ + device = vl_embs.device + + # Validate action dimension + if actions.shape[-1] != self.action_dim: + raise ValueError( + f"Action dimension mismatch: model expects {self.action_dim} dimensions " + f"(from config), but received actions with {actions.shape[-1]} dimensions. " + f"Please update config.model.action_model.action_dim to match your data." + ) + # # DEBUG: deterministic timesteps for alignment verification + # torch.manual_seed(42) + # torch.cuda.manual_seed(42) + + # DEBUG: Print input shapes and stats + print(f"[ACTION HEAD] vl_embs shape: {vl_embs.shape}, norm: {vl_embs.norm().item():.4f}") + print(f"[ACTION HEAD] actions shape: {actions.shape}, norm: {actions.norm().item():.4f}") + + # Embed noised action trajectory. + noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) + + t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) + t = t[:, None, None] # shape (B,1,1) for broadcast + + print(f"[ACTION HEAD] noise norm: {noise.norm().item():.4f}, t[0]: {t[0, 0, 0].item():.6f}") + print(f"[ACTION HEAD] noise[0,0,:3]: {noise[0, 0, :3].tolist()}") + print(f"[ACTION HEAD] t[:4]: {t[:4, 0, 0].tolist()}") + + noisy_trajectory = (1 - t) * noise + t * actions + velocity = actions - noise + + print(f"[ACTION HEAD] noisy_trajectory norm: {noisy_trajectory.norm().item():.4f}") + print(f"[ACTION HEAD] velocity norm: {velocity.norm().item():.4f}") + + # Convert (continuous) t -> discrete if needed + t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() + action_features = self.action_encoder(noisy_trajectory, t_discretized) + print(f"[ACTION HEAD] action_features norm: {action_features.norm().item():.4f}") + + # embed state + state_features = self.state_encoder(state) if state is not None else None + + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = ( + torch.cat((state_features, future_tokens, action_features), dim=1) + if state_features is not None + else torch.cat((future_tokens, action_features), dim=1) + ) + + # Join VLM features with state and action embedding along sequence dimension. + print(f"[ACTION HEAD] sa_embs shape: {sa_embs.shape}, norm: {sa_embs.norm().item():.4f}") + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + encoder_attention_mask=encoder_attention_mask, + timestep=t_discretized, + return_all_hidden_states=False, # NOTE (YL): not using flare now + ) + print(f"[ACTION HEAD] model_output norm: {model_output.norm().item():.4f}") + pred = self.action_decoder(model_output) + pred_actions = pred[:, -actions.shape[1] :] + + print(f"[ACTION HEAD] pred_actions norm: {pred_actions.norm().item():.4f}") + print(f"[ACTION HEAD] pred_actions[0,0,:5]: {pred_actions[0, 0, :5].tolist()}") + + # Slice out only the action portion of pred and target. + loss = ((pred_actions - velocity) ** 2).mean() + print(f"[ACTION HEAD] loss: {loss.item():.6f}") + return loss + + @torch.no_grad() + def predict_action(self, vl_embs: torch.Tensor, state: torch.Tensor = None) -> torch.Tensor: + # Set initial actions as the sampled noise. + batch_size = vl_embs.shape[0] + device = vl_embs.device + actions = torch.randn( # yes, here make sure action_horizon align with data loader? or share from client? + size=(batch_size, self.config.action_horizon, self.config.action_dim), + dtype=vl_embs.dtype, + device=device, + ) + + num_steps = self.num_inference_timesteps + dt = 1.0 / num_steps + + state_features = self.state_encoder(state) if state is not None else None + + # Run denoising steps. + for t in range(num_steps): + t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ... + t_discretized = int(t_cont * self.num_timestep_buckets) + + # Embed noised action trajectory. + timesteps_tensor = torch.full( + size=(batch_size,), fill_value=t_discretized, device=device + ) + action_features = self.action_encoder(actions, timesteps_tensor) + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = ( + torch.cat((state_features, future_tokens, action_features), dim=1) + if state_features is not None + else torch.cat((future_tokens, action_features), dim=1) + ) + + # Run model forward. + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output) + + pred_velocity = pred[:, -self.action_horizon :] + + # Update actions using euler integration. + actions = actions + dt * pred_velocity + return actions + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype diff --git a/flagscale/models/qwen_pi/__init__.py b/flagscale/models/qwen_pi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flagscale/models/robobrain_x/groot_action_header.py b/flagscale/models/robobrain_x/groot_action_header.py index 95895e65df..83dccc274e 100644 --- a/flagscale/models/robobrain_x/groot_action_header.py +++ b/flagscale/models/robobrain_x/groot_action_header.py @@ -213,7 +213,7 @@ def __init__(self, **kwargs): class FlowmatchingActionHead(nn.Module): def __init__(self, full_config): super().__init__() - config = full_config.framework.action_model + config = full_config.model.action_model self.no_random = config.get("no_random", True) self.hidden_size = config.hidden_size self.full_config = full_config @@ -402,7 +402,7 @@ def get_action_model(config=None): Factory: build FlowmatchingActionHead from global framework config. Args: - config: Global config (expects config.framework.action_model namespace). + config: Global config (expects config.model.action_model namespace). Returns: FlowmatchingActionHead: Initialized FlowMatchingActionHead. diff --git a/flagscale/models/robobrain_x/qwen2_5.py b/flagscale/models/robobrain_x/qwen2_5.py index c471c4fde4..d32215eec8 100644 --- a/flagscale/models/robobrain_x/qwen2_5.py +++ b/flagscale/models/robobrain_x/qwen2_5.py @@ -51,10 +51,10 @@ def __init__(self, config: dict | None = None, **kwargs): where: framework.qwenvl.base_vlm (str): HuggingFace model id or local path. Optional expected structure (illustrative): - config.framework.get("qwenvl", {}) -> { + config.model.qwenvl -> { "base_vlm": "Qwen/Qwen2.5-VL-3B-Instruct" } - config.datasets.vla_data.get("CoT_prompt", str) may be used later in build_qwenvl_inputs. + config.data.vla_data.get("CoT_prompt", str) may be used later in build_qwenvl_inputs. **kwargs: Ignored currently; placeholder for future extension (e.g., override device_map, dtype). @@ -74,7 +74,7 @@ def __init__(self, config: dict | None = None, **kwargs): """ super().__init__() - qwenvl_config = config.framework.get("qwenvl", {}) + qwenvl_config = config.model.qwenvl model_id = qwenvl_config.get("base_vlm", "Qwen/Qwen2.5-VL-3B-Instruct") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( @@ -191,7 +191,7 @@ def build_qwenvl_inputs(self, images, instructions, solutions=None, **kwargs): Reserved for future extensions (e.g., system prompts, style controls, additional metadata). Config Dependencies: - self.config.datasets.vla_data.get("CoT_prompt", str): + self.config.data.vla_data.get("CoT_prompt", str): If present, each instruction string is injected into the template by replacing "{instruction}". Returns: @@ -230,8 +230,8 @@ def build_qwenvl_inputs(self, images, instructions, solutions=None, **kwargs): for imgs, instruction in zip(images, instructions): content = [{"type": "image", "image": img} for img in imgs] - if "CoT_prompt" in self.config.datasets.vla_data: # If using a grounding prompt to task - CoT_prompt = self.config.datasets.vla_data.get("CoT_prompt", "") + if "CoT_prompt" in self.config.data.vla_data: # If using a grounding prompt to task + CoT_prompt = self.config.data.vla_data.get("CoT_prompt", "") prompt = CoT_prompt.replace("{instruction}", instruction) else: prompt = instruction diff --git a/flagscale/models/robobrain_x/qwen_groot.py b/flagscale/models/robobrain_x/qwen_groot.py index 5388aeb3e7..d5994ab2d9 100644 --- a/flagscale/models/robobrain_x/qwen_groot.py +++ b/flagscale/models/robobrain_x/qwen_groot.py @@ -55,14 +55,14 @@ def __init__(self, config: dict | None = None, **kwargs) -> None: self.config = config self.qwen_vl_interface = _QWen_VL_Interface(config=self.config) # align dims --> we should put them to config or no? - self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = ( + self.config.model.action_model.diffusion_model_cfg.cross_attention_dim = ( self.qwen_vl_interface.model.config.hidden_size ) self.action_model = FlowmatchingActionHead(full_config=self.config) - self.future_action_window_size = config.framework.action_model.future_action_window_size - self.past_action_window_size = config.framework.action_model.past_action_window_size + self.future_action_window_size = config.model.action_model.future_action_window_size + self.past_action_window_size = config.model.action_model.past_action_window_size self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size def forward(self, examples: list[dict] | None = None, **kwargs) -> tuple: @@ -101,8 +101,8 @@ def forward(self, examples: list[dict] | None = None, **kwargs) -> tuple: ] # (B, chunk_len, action_dim) repeated_diffusion_steps = ( - self.config.trainer.get("repeated_diffusion_steps", 4) - if self.config and self.config.trainer + self.config.system.get("repeated_diffusion_steps", 4) + if self.config and self.config.system else 4 ) actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) @@ -152,7 +152,7 @@ def predict_action( dict: normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. """ - train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) + train_obs_image_size = getattr(self.config.data.vla_data, "image_size", None) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) diff --git a/flagscale/models/vla/__init__.py b/flagscale/models/vla/__init__.py new file mode 100644 index 0000000000..fe5d035505 --- /dev/null +++ b/flagscale/models/vla/__init__.py @@ -0,0 +1,30 @@ +from .action_models.flow_matching import FlowMatchingHead +from .protocols import ActionModel, VLMBackbone +from .qwen_gr00t import QwenGr00t +from .registry import ( + ACTION_MODEL_REGISTRY, + VLM_REGISTRY, + build_action_model, + build_vlm, + register_action_model, + register_vlm, +) +from .utils import get_vlm_config + +# Explicit registration +from .vlm.qwen_vl import Qwen3VLBackbone, Qwen25VLBackbone + +VLM_REGISTRY["qwen2.5-vl"] = Qwen25VLBackbone +VLM_REGISTRY["qwen3-vl"] = Qwen3VLBackbone +ACTION_MODEL_REGISTRY["flow_matching"] = FlowMatchingHead + +__all__ = [ + "VLMBackbone", + "ActionModel", + "register_vlm", + "register_action_model", + "build_vlm", + "build_action_model", + "get_vlm_config", + "QwenGr00t", +] diff --git a/flagscale/models/vla/action_models/__init__.py b/flagscale/models/vla/action_models/__init__.py new file mode 100644 index 0000000000..91408b343c --- /dev/null +++ b/flagscale/models/vla/action_models/__init__.py @@ -0,0 +1,3 @@ +from .flow_matching import FlowMatchingHead + +__all__ = ["FlowMatchingHead"] diff --git a/flagscale/models/vla/action_models/flow_matching.py b/flagscale/models/vla/action_models/flow_matching.py new file mode 100644 index 0000000000..8b36721e61 --- /dev/null +++ b/flagscale/models/vla/action_models/flow_matching.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn + +from flagscale.models.action_model.gr00t_action_header import ( + FlowmatchingActionHead as _FlowmatchingActionHead, +) +from flagscale.models.utils.constants import ACTION +from flagscale.models.vla.utils import get_vlm_config +from flagscale.train.train_config import TrainConfig + + +class FlowMatchingHead(nn.Module): + """ + Flow matching action head wrapper for VLA framework. + + Args: + vlm_config: HF config object from VLM (used to get hidden_size). + action_config: dict with action model settings. + full_config: TrainConfig for initializing the underlying FlowmatchingActionHead. + """ + + def __init__(self, vlm_config, action_config: dict, full_config: TrainConfig = None): + super().__init__() + vlm_info = get_vlm_config(vlm_config) + self.hidden_size = vlm_info["hidden_size"] + + # TODO: pass cross_attention_dim directly to action head instead of mutating full_config + full_config.model.action_model.diffusion_model_cfg.cross_attention_dim = self.hidden_size + + self._head = _FlowmatchingActionHead(full_config=full_config) + + def forward( + self, vlm_output: dict[str, torch.Tensor], action_input: dict[str, torch.Tensor], **kwargs + ) -> dict[str, torch.Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch with 'actions', 'state', etc. + Returns: + dict with 'loss'. + """ + vl_embs = vlm_output["hidden_states"] + actions = action_input["actions"] + state = action_input.get("state") + encoder_attention_mask = action_input.get("attention_mask") + + loss = self._head.forward( + vl_embs=vl_embs, + actions=actions, + state=state, + encoder_attention_mask=encoder_attention_mask, + ) + return {"loss": loss} + + def predict_action( + self, vlm_output: dict[str, torch.Tensor], action_input: dict[str, torch.Tensor], **kwargs + ) -> dict[str, torch.Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch with 'state', etc. + Returns: + dict with 'actions': Tensor [B, horizon, action_dim]. + """ + vl_embs = vlm_output["hidden_states"] + state = action_input.get("state") + + actions = self._head.predict_action(vl_embs=vl_embs, state=state) + return {ACTION: actions} diff --git a/flagscale/models/vla/protocols.py b/flagscale/models/vla/protocols.py new file mode 100644 index 0000000000..581ea61197 --- /dev/null +++ b/flagscale/models/vla/protocols.py @@ -0,0 +1,55 @@ +from typing import Protocol + +from torch import Tensor + + +class VLMBackbone(Protocol): + @property + def config(self): + """HF config object (e.g., Qwen2VLConfig).""" + ... + + def prepare_input(self, batch: dict) -> dict[str, Tensor]: + """ + Args: + batch: Raw batch with 'image', 'lang', etc. + Returns: + Tokenized inputs ready for forward(). + """ + ... + + def forward(self, batch: dict[str, Tensor], **kwargs) -> dict[str, Tensor]: + """ + Args: + batch: Tokenized inputs from prepare_input(). + Returns: + dict with 'hidden_states': tuple of layer outputs. + """ + ... + + +# TODO: (yupu) This `ActionModel` assumes that the VLA model is a composite of a VLM and an ActionModel. +class ActionModel(Protocol): + def forward( + self, vlm_output: dict[str, Tensor], action_input: dict[str, Tensor], **kwargs + ) -> dict[str, Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch - pick what you need ('actions', 'state', etc.). + Returns: + dict with 'loss'. + """ + ... + + def predict( + self, vlm_output: dict[str, Tensor], action_input: dict[str, Tensor], **kwargs + ) -> dict[str, Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch - pick what you need ('state', etc.). + Returns: + dict with 'actions': Tensor [B, horizon, action_dim]. + """ + ... diff --git a/flagscale/models/vla/qwen_gr00t.py b/flagscale/models/vla/qwen_gr00t.py new file mode 100644 index 0000000000..6d822e94ae --- /dev/null +++ b/flagscale/models/vla/qwen_gr00t.py @@ -0,0 +1,212 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/framework/QwenGR00T.py +# Below is the original copyright: + +# Copyright 2025 starVLA community. All rights reserved. +# Licensed under the MIT License, Version 1.0 (the "License"); +# Implemented by [Junqiu YU / Fudan University] in [2025]. +# Design and Merged by [Jinhui YE / HKUST University] in [2025]. + +""" +Qwen-GR00T Framework +A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions +Flow-matching header is copyright from GR00T N1.5, +""" + +import torch +from transformers import PretrainedConfig, PreTrainedModel + +from flagscale.models.utils.constants import ACTION +from flagscale.models.vla.registry import build_action_model, build_vlm +from flagscale.train.train_config import TrainConfig + + +class QwenGr00t(PreTrainedModel): + """ + Multimodal vision-language-action model. + + Components: + - Qwen VL interface for fused language/vision token embeddings + - DiT diffusion head for future action sequence modeling + + Focus: Predict future continuous actions conditioned on images + instruction. + """ + + config_class = PretrainedConfig + + def __init__(self, config: TrainConfig, **kwargs): + super().__init__(PretrainedConfig()) + self._config = config + + # DEBUG: Track random state before VLM creation + print(f"[DEBUG RNG] Before VLM: torch state[:10] = {torch.get_rng_state()[:10].tolist()}") + + vlm_type = config.model.vlm.get("type", "qwen3-vl") + self.vlm = build_vlm(vlm_type, config=config) + + # DEBUG: Track random state after VLM creation + print(f"[DEBUG RNG] After VLM: torch state[:10] = {torch.get_rng_state()[:10].tolist()}") + + action_model_type = config.model.action_model.get("type", "flow_matching") + self.action_model = build_action_model( + action_model_type, + vlm_config=self.vlm.model_config, + action_config={}, + full_config=config, + ) + + self.future_action_window_size = config.model.action_model.future_action_window_size + + # DEBUG: Track random state after action model creation + print( + f"[DEBUG RNG] After action_model: torch state[:10] = {torch.get_rng_state()[:10].tolist()}" + ) + + # DEBUG: Print action encoder weights to verify initialization matches starVLA + if hasattr(self.action_model, "_head") and hasattr( + self.action_model._head, "action_encoder" + ): + ae = self.action_model._head.action_encoder + print( + f"[DEBUG INIT] action_encoder.layer1.weight[:3,:5]: {ae.layer1.weight[:3, :5].tolist()}" + ) + print( + f"[DEBUG INIT] action_encoder.layer1.weight sum: {ae.layer1.weight.sum().item():.6f}" + ) + + def forward(self, examples: dict, **kwargs): + """ """ + # actions = [example["action"] for example in examples] # [B, T, action_dim] + actions = examples[ACTION] + state = None # examples[OBS_STATE] + + # Step 1: QWenVL input format + # NOTE: (yupu) The order of the images differs from starVLA, which is [image, wrist_image] + qwen_inputs = self.vlm.prepare_input(examples) + + # DEBUG: Print qwen_inputs stats + # print(f"[DEBUG] qwen_inputs keys: {qwen_inputs.keys()}") + # print(f"[DEBUG] input_ids shape: {qwen_inputs['input_ids'].shape}") + # print(f"[DEBUG] input_ids sum: {qwen_inputs['input_ids'].sum().item()}") + + # qwen_inputs = torch.load("/share/project/fengyupu/github/starVLA/qwen_inputs_debug.pt", weights_only=False) + # torch.testing.assert_close(qwen_inputs, qwen_inputs_debug) + + # torch.save(qwen_inputs, "qwen_inputs.pt") + + # TODO: (yupu) Hard-coded autocast and dtype, matches starVLA + with torch.autocast("cuda", dtype=torch.bfloat16): + vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False) + # last_hidden_state: [B, seq_len, H] + last_hidden = vlm_output["hidden_states"][-1] # [B, L, H] + # print(f"[DEBUG] last_hidden shape: {last_hidden.shape}, dtype: {last_hidden.dtype}") + # print( + # f"[DEBUG] last_hidden norm: {last_hidden.norm().item():.4f}, mean: {last_hidden.mean().item():.6f}, std: {last_hidden.std().item():.6f}" + # ) + + # Step 4: Action Expert Forward and Loss + with torch.autocast("cuda", dtype=torch.float32): + # TODO: (yupu) Is this a bug or a feature? The action dtype would stay as bf16 under this autocast. + actions = actions.to(device=last_hidden.device, dtype=last_hidden.dtype) + # actions = torch.tensor( + # np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype + # ) # [B, T_full, action_dim] + + # TODO: does not match RoboBrainX, need to check + actions_target = actions[ + :, -(self.future_action_window_size + 1) :, : + ] # (B, chunk_len, action_dim) + + # TODO: (yupu) I believe there is a bug in starVLA, the + # `repeated_diffusion_steps` is not properly set in the config. + repeated_diffusion_steps = self._config.model.action_model.get( + "repeated_diffusion_steps", 4 + ) + + # print(f"[DEBUG] actions_target shape before repeat: {actions_target.shape}") + # print(f"[DEBUG] actions_target sum: {actions_target.sum().item():.4f}") + # print(f"[DEBUG] actions_target[0,0,:5]: {actions_target[0, 0, :5].tolist()}") + # print(f"[DEBUG] repeated_diffusion_steps: {repeated_diffusion_steps}") + + actions_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) + last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1) + + # print(f"[DEBUG] actions_repeated shape: {actions_repeated.shape}") + # print(f"[DEBUG] last_hidden_repeated shape: {last_hidden_repeated.shape}") + + state_repeated = None + if state is not None: + state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) + state_repeated = state.repeat(repeated_diffusion_steps, 1, 1) + + # Use action head forward API + vlm_output_repeated = {"hidden_states": last_hidden_repeated} + action_input = {"actions": actions_repeated, "state": state_repeated} + + # torch.save(vlm_output_repeated, "vlm_output_repeated.pt") + # torch.save(action_input, "action_input.pt") + + output = self.action_model.forward(vlm_output_repeated, action_input) + + # torch.save(output, "output.pt") + + # print(f"output: {output}") + # assert False + + return output["loss"] + + @torch.inference_mode() + def predict_action(self, examples: list[dict], **kwargs) -> dict: + """ + Steps: + 1. Resize images to training resolution (if specified) + 2. Encode with QwenVL (hidden states retained) + 6. Return normalized action trajectory + Returns: + dict: + normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. + """ + # TODO: (yupu) Fix inference input format to use constants (OBS_IMAGE, OBS_LANGUAGE, OBS_STATE) + # instead of hardcoded keys. The current keys are inconsistent with training batch format. + # batch_images = [[to_pil_preserve(example["image"])] for example in examples] # [B, [PLT]] + # instructions = [example["lang"] for example in examples] # [B, str] + + # We assume the images are already resized during preprocessing. + qwen_inputs = self.vlm.prepare_input(examples) + state = None # examples[OBS_STATE] + + # state = ( + # [example["state"] for example in examples] if "state" in examples[0] else None + # ) # [B, 1, state_dim] + + # train_obs_image_size = getattr(self._config.data.vla_data, "image_size", None) + # if train_obs_image_size: + # batch_images = resize_images(batch_images, target_size=train_obs_image_size) + + # # Step 1: QWenVL input format + # qwen_inputs = self.vlm.build_qwenvl_inputs( + # examples=None, images=batch_images, instructions=instructions + # ) + + with torch.autocast("cuda", dtype=torch.bfloat16): + vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False) + # last_hidden_state: [B, seq_len, H] + last_hidden = vlm_output["hidden_states"][-1] # [B, L, H] + + if state is not None: + state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) + + # state_tensor = ( + # torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) + # if state is not None + # else None + # ) + + # Step 4: Action Expert Forward + with torch.autocast("cuda", dtype=torch.float32): + vlm_output_for_action = {"hidden_states": last_hidden} + action_input = {"state": state} + output = self.action_model.predict_action(vlm_output_for_action, action_input) + + # Assume the output of the action moadel is dict mapps `ACTION` to the normalized actions + return output diff --git a/flagscale/models/vla/registry.py b/flagscale/models/vla/registry.py new file mode 100644 index 0000000000..180e6a1961 --- /dev/null +++ b/flagscale/models/vla/registry.py @@ -0,0 +1,32 @@ +VLM_REGISTRY: dict[str, type] = {} +ACTION_MODEL_REGISTRY: dict[str, type] = {} + + +def register_vlm(name: str): + def decorator(cls): + VLM_REGISTRY[name] = cls + return cls + + return decorator + + +def register_action_model(name: str): + def decorator(cls): + ACTION_MODEL_REGISTRY[name] = cls + return cls + + return decorator + + +def build_vlm(name: str, **kwargs): + if name not in VLM_REGISTRY: + raise ValueError(f"Unknown VLM: {name}. Available: {list(VLM_REGISTRY.keys())}") + return VLM_REGISTRY[name](**kwargs) + + +def build_action_model(name: str, vlm_config, action_config: dict, **kwargs): + if name not in ACTION_MODEL_REGISTRY: + raise ValueError( + f"Unknown ActionModel: {name}. Available: {list(ACTION_MODEL_REGISTRY.keys())}" + ) + return ACTION_MODEL_REGISTRY[name](vlm_config=vlm_config, action_config=action_config, **kwargs) diff --git a/flagscale/models/vla/utils.py b/flagscale/models/vla/utils.py new file mode 100644 index 0000000000..b0ee5fdd12 --- /dev/null +++ b/flagscale/models/vla/utils.py @@ -0,0 +1,29 @@ +def get_vlm_config(vlm_config) -> dict: + """ + Extract common fields from any VLM config, handling structural differences. + + Args: + vlm_config: HF config object (may have hidden_size directly or via text_config). + Returns: + dict with 'hidden_size' and 'num_hidden_layers'. + """ + return { + "hidden_size": _get_hidden_size(vlm_config), + "num_hidden_layers": _get_num_layers(vlm_config), + } + + +def _get_hidden_size(config) -> int: + if hasattr(config, "hidden_size"): + return config.hidden_size + if hasattr(config, "text_config"): + return config.text_config.hidden_size + raise ValueError(f"Cannot determine hidden_size from config: {type(config)}") + + +def _get_num_layers(config) -> int: + if hasattr(config, "num_hidden_layers"): + return config.num_hidden_layers + if hasattr(config, "text_config"): + return config.text_config.num_hidden_layers + raise ValueError(f"Cannot determine num_hidden_layers from config: {type(config)}") diff --git a/flagscale/models/vla/vlm/__init__.py b/flagscale/models/vla/vlm/__init__.py new file mode 100644 index 0000000000..aa73dba2e0 --- /dev/null +++ b/flagscale/models/vla/vlm/__init__.py @@ -0,0 +1,3 @@ +from .qwen_vl import Qwen3VLBackbone, Qwen25VLBackbone, QwenVLBackbone + +__all__ = ["QwenVLBackbone", "Qwen25VLBackbone", "Qwen3VLBackbone"] diff --git a/flagscale/models/vla/vlm/qwen_vl.py b/flagscale/models/vla/vlm/qwen_vl.py new file mode 100644 index 0000000000..b6d45e38db --- /dev/null +++ b/flagscale/models/vla/vlm/qwen_vl.py @@ -0,0 +1,349 @@ +import torch +import torch.nn as nn +from transformers import ( + AutoProcessor, + PretrainedConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen3VLForConditionalGeneration, +) + +from flagscale.train.train_config import TrainConfig +from flagscale.train.utils.image_tools import to_pil_preserve + +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = 151655 +VIDEO_TOKEN_INDEX = 151656 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "