diff --git a/examples/pi0/README.md b/examples/pi0/README.md index 13f02422d..166d32833 100644 --- a/examples/pi0/README.md +++ b/examples/pi0/README.md @@ -147,17 +147,9 @@ Configure the following fields: **System settings** (training hyperparameters): - `system.batch_size` - Batch size per GPU - `system.train_steps` - Total training steps -- `system.optimizer.name` - Optimizer name (default: `"AdamW"`) -- `system.optimizer.lr` - Learning rate (default: `2.5e-5`) -- `system.optimizer.betas` - Optimizer betas (default: `[0.9, 0.95]`) -- `system.optimizer.eps` - Optimizer epsilon (default: `1.0e-8`) -- `system.optimizer.weight_decay` - Weight decay (default: `0.01`) -- `system.scheduler.warmup_steps` - Warmup steps (default: `1000`) -- `system.scheduler.decay_steps` - Decay steps (default: `30000`) -- `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) - `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) - `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) -- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}`) **Model settings**: - `model.model_name` - Model name: `"pi0"` or `"pi0.5"` @@ -165,6 +157,14 @@ Configure the following fields: - `model.tokenizer_path` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) - `model.tokenizer_max_length` - Maximum tokenizer sequence length - `model.action_steps` - Number of action steps to predict +- `model.optimizer.name` - Optimizer name (for example: `"AdamW"`) +- `model.optimizer.lr` - Learning rate (for example: `2.5e-5`) +- `model.optimizer.betas` - Optimizer betas (for example: `[0.9, 0.95]`) +- `model.optimizer.eps` - Optimizer epsilon (for example: `1.0e-8`) +- `model.optimizer.weight_decay` - Weight decay (for example: `0.01`) +- `model.optimizer.scheduler.warmup_steps` - Warmup steps (for example: `1000`) +- `model.optimizer.scheduler.decay_steps` - Decay steps (for example: `30000`) +- `model.optimizer.scheduler.decay_lr` - Final learning rate after decay (for example: `2.5e-6`) **Data settings**: - `data.data_path` - Path to LeRobot dataset root (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet`) @@ -186,7 +186,7 @@ python run.py --config-path ./examples/pi0/conf --config-name train action=run Training logs are saved to `outputs/pi0_train/logs/host_0_localhost.output` by default. -Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_train/ckpt`). +Checkpoints are saved to `${experiment.exp_dir}/checkpoints` (default: `outputs/pi0_train/checkpoints`). ### Stop Training ```sh diff --git a/examples/pi0/conf/train/pi0.yaml b/examples/pi0/conf/train/pi0.yaml index 88448f489..4bbc2e5e7 100644 --- a/examples/pi0/conf/train/pi0.yaml +++ b/examples/pi0/conf/train/pi0.yaml @@ -7,20 +7,8 @@ system: shuffle: false num_workers: 4 - optimizer: - name: AdamW - lr: 2.5e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 0.01 - - scheduler: - warmup_steps: 1000 - decay_steps: 30000 - decay_lr: 2.5e-6 - checkpoint: - output_directory: ${experiment.exp_dir}/ckpt + output_directory: ${experiment.exp_dir} # Whether to save checkpoint save_checkpoint: true # Number of steps between checkpoints @@ -36,6 +24,17 @@ model: tokenizer_max_length: 48 action_steps: 50 + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 0.01 + scheduler: + warmup_steps: 1000 + decay_steps: 30000 + decay_lr: 2.5e-6 + data: # Path to the training data data_path: /workspace/datasets/lerobot/aloha_mobile_cabinet diff --git a/examples/pi0_5/README.md b/examples/pi0_5/README.md index e345d25e4..ca79eaf7e 100644 --- a/examples/pi0_5/README.md +++ b/examples/pi0_5/README.md @@ -154,17 +154,9 @@ Configure the following fields: **System settings** (training hyperparameters): - `system.batch_size` - Batch size per GPU - `system.train_steps` - Total training steps -- `system.optimizer.name` - Optimizer name (default: `"AdamW"`) -- `system.optimizer.lr` - Learning rate (default: `2.5e-5`) -- `system.optimizer.betas` - Optimizer betas (default: `[0.9, 0.95]`) -- `system.optimizer.eps` - Optimizer epsilon (default: `1.0e-8`) -- `system.optimizer.weight_decay` - Weight decay (default: `0.01`) -- `system.scheduler.warmup_steps` - Warmup steps (default: `1000`) -- `system.scheduler.decay_steps` - Decay steps (default: `30000`) -- `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) - `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) - `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) -- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}`) **Model settings**: - `model.model_name` - Model name: `"pi0.5"` @@ -172,6 +164,14 @@ Configure the following fields: - `model.tokenizer_path` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) - `model.tokenizer_max_length` - Maximum tokenizer sequence length (default: `200` for pi0.5) - `model.action_steps` - Number of action steps to predict +- `model.optimizer.name` - Optimizer name (for example: `"AdamW"`) +- `model.optimizer.lr` - Learning rate (for example: `2.5e-5`) +- `model.optimizer.betas` - Optimizer betas (for example: `[0.9, 0.95]`) +- `model.optimizer.eps` - Optimizer epsilon (for example: `1.0e-8`) +- `model.optimizer.weight_decay` - Weight decay (for example: `0.01`) +- `model.optimizer.scheduler.warmup_steps` - Warmup steps (for example: `1000`) +- `model.optimizer.scheduler.decay_steps` - Decay steps (for example: `30000`) +- `model.optimizer.scheduler.decay_lr` - Final learning rate after decay (for example: `2.5e-6`) **Data settings**: - `data.data_path` - Path to LeRobot dataset root (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet`) @@ -193,7 +193,7 @@ python run.py --config-path ./examples/pi0_5/conf --config-name train action=run Training logs are saved to `outputs/pi0_5_train/logs/host_0_localhost.output` by default. -Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_5_train/ckpt`). +Checkpoints are saved to `${experiment.exp_dir}/checkpoints` (default: `outputs/pi0_5_train/checkpoints`). ### Stop Training ```sh diff --git a/examples/pi0_5/conf/train/pi0_5.yaml b/examples/pi0_5/conf/train/pi0_5.yaml index 2027c1a6c..4827e981e 100644 --- a/examples/pi0_5/conf/train/pi0_5.yaml +++ b/examples/pi0_5/conf/train/pi0_5.yaml @@ -7,20 +7,8 @@ system: shuffle: false num_workers: 4 - optimizer: - name: AdamW - lr: 2.5e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 0.01 - - scheduler: - warmup_steps: 1000 - decay_steps: 30000 - decay_lr: 2.5e-6 - checkpoint: - output_directory: ${experiment.exp_dir}/ckpt + output_directory: ${experiment.exp_dir} # Whether to save checkpoint save_checkpoint: true # Number of steps between checkpoints @@ -36,6 +24,17 @@ model: tokenizer_max_length: 200 action_steps: 50 + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 0.01 + scheduler: + warmup_steps: 1000 + decay_steps: 30000 + decay_lr: 2.5e-6 + data: # Path to the training data data_path: /workspace/datasets/lerobot/aloha_mobile_cabinet diff --git a/examples/qwen_gr00t/README.md b/examples/qwen_gr00t/README.md new file mode 100644 index 000000000..13f02422d --- /dev/null +++ b/examples/qwen_gr00t/README.md @@ -0,0 +1,343 @@ +# PI0: Training, Inference, and Serving + +This guide covers how to train, run inference, and serve PI0 models using FlagScale. + +## Installation + +### Clone Repository + +```sh +git clone https://github.com/FlagOpen/FlagScale.git +cd FlagScale/ +``` + +### Setup Conda Environment + +Create a new conda environment for robotics training: + +```sh +conda create -n flagos-robo python=3.12 +conda activate flagos-robo +``` + +Install FlagScale and robotics dependencies: + +```sh +cd FlagScale/ +pip install . --verbose +pip install -r requirements/train/robotics/requirements.txt +``` + +Install additional dependencies for downloading models/datasets: + +```sh +# For HuggingFace Hub +pip install huggingface_hub + +# For ModelScope (optional) +pip install modelscope +``` + +## Download Models and Tokenizers + +Download models and tokenizers using the provided script. Choose either HuggingFace Hub or ModelScope based on your preference: + +**Using HuggingFace Hub:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/pi0_base \ + --output_dir /workspace/models \ + --source huggingface + +python examples/pi0/download.py \ + --repo_id google/paligemma-3b-pt-224 \ + --output_dir /workspace/models \ + --source huggingface +``` + +**Using ModelScope:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/pi0_base \ + --output_dir /workspace/models \ + --source modelscope + +python examples/pi0/download.py \ + --repo_id google/paligemma-3b-pt-224 \ + --output_dir /workspace/models \ + --source modelscope +``` + +The models will be downloaded to (example with `/workspace/models`): +- `/workspace/models/lerobot/pi0_base` +- `/workspace/models/google/paligemma-3b-pt-224` + + +## Training + +### Prepare Dataset + +FlagScale uses the **LeRobotDataset v3.0** format. For detailed information about the format structure, see the [LeRobotDataset v3.0 documentation](https://huggingface.co/docs/lerobot/en/lerobot-dataset-v3). + +For example, to download the `aloha_mobile_cabinet` dataset: + +**Using HuggingFace Hub:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/aloha_mobile_cabinet \ + --output_dir /workspace/datasets \ + --repo_type dataset \ + --source huggingface +``` + +**Using ModelScope:** + +```sh +cd FlagScale/ +python examples/pi0/download.py \ + --repo_id lerobot/aloha_mobile_cabinet \ + --output_dir /workspace/datasets \ + --repo_type dataset \ + --source modelscope +``` + +The dataset will be downloaded to (example with `/workspace/datasets`): +- `/workspace/datasets/lerobot/aloha_mobile_cabinet` + +### Edit Config + +FlagScale uses a two-level configuration system: + +1. **Experiment-level config** (`examples/pi0/conf/train.yaml`): Defines experiment settings, environment variables, and resource allocation +2. **Task-level config** (`examples/pi0/conf/train/pi0.yaml`): Defines model, dataset, and training hyperparameters + +#### Experiment-Level Config + +Edit the experiment-level config for multi-GPU training: + +```sh +cd FlagScale/ +vim examples/pi0/conf/train.yaml +``` + +Configure the following fields: + +- `experiment.envs.CUDA_VISIBLE_DEVICES` - GPU devices to use (e.g., `"0,1,2,3"` for 4 GPUs, `"0,1"` for 2 GPUs) +- `experiment.envs.CUDA_DEVICE_MAX_CONNECTIONS` - Connection limit (typically `1`) +- `experiment.exp_name` - Experiment name +- `experiment.exp_dir` - Output directory for checkpoints and logs + +#### Task-Level Config + +Edit the task-level config for model and training settings: + +```sh +cd FlagScale/ +vim examples/pi0/conf/train/pi0.yaml +``` + +Configure the following fields: + +**System settings** (training hyperparameters): +- `system.batch_size` - Batch size per GPU +- `system.train_steps` - Total training steps +- `system.optimizer.name` - Optimizer name (default: `"AdamW"`) +- `system.optimizer.lr` - Learning rate (default: `2.5e-5`) +- `system.optimizer.betas` - Optimizer betas (default: `[0.9, 0.95]`) +- `system.optimizer.eps` - Optimizer epsilon (default: `1.0e-8`) +- `system.optimizer.weight_decay` - Weight decay (default: `0.01`) +- `system.scheduler.warmup_steps` - Warmup steps (default: `1000`) +- `system.scheduler.decay_steps` - Decay steps (default: `30000`) +- `system.scheduler.decay_lr` - Final learning rate after decay (default: `2.5e-6`) +- `system.checkpoint.save_checkpoint` - Whether to save checkpoints (default: `true`) +- `system.checkpoint.save_freq` - Steps between checkpoints (default: `1000`) +- `system.checkpoint.output_directory` - Checkpoint output directory (default: `${experiment.exp_dir}/ckpt`) + +**Model settings**: +- `model.model_name` - Model name: `"pi0"` or `"pi0.5"` +- `model.checkpoint_dir` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `model.tokenizer_path` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `model.tokenizer_max_length` - Maximum tokenizer sequence length +- `model.action_steps` - Number of action steps to predict + +**Data settings**: +- `data.data_path` - Path to LeRobot dataset root (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet`) +- `data.use_imagenet_stats` - Whether to use ImageNet normalization stats (default: `true`) +- `data.rename_map` - Dictionary mapping dataset keys to policy keys (optional). Check the `features` key in your dataset's `meta/info.json` file to determine the correct mapping: + ```yaml + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + ``` +- `data.use_quantiles` - Whether to use quantile normalization (for `pi0.5`, set to `false` to use MEAN_STD normalization) + +### Start Training +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name train action=run +``` + +Training logs are saved to `outputs/pi0_train/logs/host_0_localhost.output` by default. + +Checkpoints are saved to `${experiment.exp_dir}/ckpt` (default: `outputs/pi0_train/ckpt`). + +### Stop Training +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name train action=stop +``` + +## Inference + +### Prepare Inference Inputs + +You can extract inference inputs (images, state, task) from a dataset using the provided script: + +```sh +cd FlagScale/ +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --frame_index 100 +``` + +This will create: +- `frame_100_observation_images_*.jpg` - Image files +- `frame_100_state.pt` - State tensor +- `frame_100_task.txt` - Task prompt +- `extraction_summary.json` - Summary of extracted files + +Alternatively, you can extract from a specific episode and frame: + +```sh +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --episode_index 0 \ + --frame_in_episode 50 +``` + +Or extract multiple samples at once: + +```sh +python examples/pi0/dump_dataset_inputs.py \ + --dataset_root /workspace/datasets/lerobot/aloha_mobile_cabinet \ + --output_dir ./inference_inputs \ + --frame_indices 100 200 300 +``` + +### Edit Config + +```sh +cd FlagScale/ +vim examples/pi0/conf/inference/pi0.yaml +``` + +Configure the following fields: + +**Engine settings:** +- `engine.model` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `engine.tokenizer` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `engine.stat_path` - Path to dataset statistics (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet/meta/stats.json`) +- `engine.device` - Device to use (e.g., `"cuda"`) + +**Generate settings:** +- `generate.images` - Dictionary mapping image keys to file paths: + ```yaml + images: + observation.images.cam_high: /path/to/image1.jpg + observation.images.cam_left_wrist: /path/to/image2.jpg + observation.images.cam_right_wrist: /path/to/image3.jpg + ``` +- `generate.state_path` - Path to state tensor file (`.pt` file) +- `generate.task_path` - Path to task prompt file (`.txt` file) +- `generate.rename_map` (optional) - Map input keys to policy expected keys: + ```yaml + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + ``` + +### Run Inference + +```sh +cd FlagScale/ +python run.py \ + --config-path ./examples/pi0/conf \ + --config-name inference \ + action=run +``` + +Inference logs are saved to `outputs/pi0_inference/inference_logs/host_0_localhost.output` by default. + +The predicted action tensor is printed to the console and saved in the log file. + +## Serving + +### Edit Config + +```sh +cd FlagScale/ +vim examples/pi0/conf/serve/pi0.yaml +``` + +Configure the following fields: + +**Engine arguments:** +- `engine_args.host` - Server host (default: `"0.0.0.0"`) +- `engine_args.port` - Server port (default: `5000`) +- `engine_args.model` - Path to pretrained model (e.g., `/workspace/models/lerobot/pi0_base`) +- `engine_args.tokenizer` - Path to tokenizer (e.g., `/workspace/models/google/paligemma-3b-pt-224`) +- `engine_args.stat_path` - Path to dataset statistics (e.g., `/workspace/datasets/lerobot/aloha_mobile_cabinet/meta/stats.json`) +- `engine_args.device` - Device to use (e.g., `"cuda"`) +- `engine_args.images_keys` - List of image keys expected by the model (do not change): + ```yaml + images_keys: + - observation.images.base_0_rgb + - observation.images.left_wrist_0_rgb + - observation.images.right_wrist_0_rgb + ``` +- `engine_args.images_shape` - Image shape `[C, H, W]` for warmup (e.g., `[3, 480, 640]`) +- `engine_args.state_key` - Key for state in the batch (e.g., `"observation.state"`) + +### Run Serving + +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name serve action=run +``` + +Serving logs are saved to `outputs/pi0_serve/logs/host_0_localhost.output` by default. + +### Stop Serving + +```sh +cd FlagScale/ +python run.py --config-path ./examples/pi0/conf --config-name serve action=stop +``` + +### Test Server with Client + +The client should send images using keys that match the `images_keys` in the config. For example, if using the default config: + +```sh +cd FlagScale/ +python examples/pi0/client_pi0.py \ + --host 127.0.0.1 \ + --port 5000 \ + --img1 ./inference_inputs/frame_100_observation_images_cam_high.jpg \ + --img2 ./inference_inputs/frame_100_observation_images_cam_left_wrist.jpg \ + --img3 ./inference_inputs/frame_100_observation_images_cam_right_wrist.jpg \ + --state-path ./inference_inputs/frame_100_state.pt \ + --instruction "Grab the orange and put it into the basket." +``` + +**Note**: The client must send image keys that match the `engine_args.images_keys` in the config. diff --git a/examples/qwen_gr00t/conf/inference.yaml b/examples/qwen_gr00t/conf/inference.yaml new file mode 100644 index 000000000..36d368682 --- /dev/null +++ b/examples/qwen_gr00t/conf/inference.yaml @@ -0,0 +1,26 @@ +defaults: + - _self_ + - inference: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_inference + exp_dir: outputs/${experiment.exp_name} + model: /models/qwen_gr00t + task: + type: inference + backend: vllm # TODO: Remove this restriction + entrypoint: flagscale/inference/inference_qwen_gr00t.py + runner: + hostfile: null + cmds: + before_start: null + envs: + CUDA_VISIBLE_DEVICES: 2 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + # Optionally, set HF_HOME and HF_ENDPOINT + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml new file mode 100644 index 000000000..43890a76f --- /dev/null +++ b/examples/qwen_gr00t/conf/inference/qwen_gr00t.yaml @@ -0,0 +1,11 @@ +engine: + model_variant: "QwenGr00t" + model: /share/project/fengyupu/github/FlagScale_2/outputs/qwen_gr00t_train/20260207_110505.701567_ckpt/last + device: "cuda" + +generate: + images: + observation.images.wrist_image: qwen_gr00t_inference_inputs/frame_100_observation_images_wrist_image.jpg + observation.images.image: qwen_gr00t_inference_inputs/frame_100_observation_images_image.jpg + state_path: qwen_gr00t_inference_inputs/frame_100_state.pt + task_path: qwen_gr00t_inference_inputs/frame_100_task.txt diff --git a/examples/qwen_gr00t/conf/serve.yaml b/examples/qwen_gr00t/conf/serve.yaml new file mode 100644 index 000000000..3c3422c99 --- /dev/null +++ b/examples/qwen_gr00t/conf/serve.yaml @@ -0,0 +1,23 @@ +defaults: +- _self_ +- serve: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_serve + exp_dir: outputs/${experiment.exp_name} + task: + type: serve + entrypoint: flagscale/serve/run_serve_qwen_gr00t.py + runner: + hostfile: null + deploy: + use_fs_serve: false + envs: + CUDA_VISIBLE_DEVICES: 3 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml new file mode 100644 index 000000000..00b64b1e5 --- /dev/null +++ b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml @@ -0,0 +1,14 @@ +- serve_id: vllm_model # Not in use + engine_args: + host: 0.0.0.0 + port: 6000 + model_variant: QwenGr00t + model: /share/project/fengyupu/github/FlagScale_2/outputs/qwen_gr00t_train_fsdp_fp32_grad/checkpoints/last + device: "cuda" + images_keys: + - observation.images.base_0_rgb + - observation.images.left_wrist_0_rgb + - observation.images.right_wrist_0_rgb + # Only used for warmup + images_shape: [3, 480, 640] + state_key: observation.state diff --git a/examples/qwen_gr00t/conf/train.yaml b/examples/qwen_gr00t/conf/train.yaml new file mode 100644 index 000000000..c2c1c4212 --- /dev/null +++ b/examples/qwen_gr00t/conf/train.yaml @@ -0,0 +1,35 @@ +defaults: + - _self_ + - train: qwen_gr00t + +experiment: + exp_name: qwen_gr00t_train_fsdp_fp32_grad_test_refactor + seed: 42 + save_steps: 10000 + load: null + exp_dir: outputs/${experiment.exp_name} + ckpt_format: torch + task: + type: train + backend: native + entrypoint: flagscale/train/train_qwen_gr00t.py + runner: + per_node_task: false + no_shared_fs: false + rdzv_backend: static + hostfile: null + cmds: + before_start: echo "Starting Qwen-GR00T Training" + envs: + LOGLEVEL: "INFO" + CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" + # CUDA_VISIBLE_DEVICES: "2" + CUDA_DEVICE_MAX_CONNECTIONS: 1 + WANDB_MODE: offline + OTEL_SDK_DISABLED: true + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml new file mode 100644 index 000000000..192383c89 --- /dev/null +++ b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml @@ -0,0 +1,175 @@ +system: + batch_size: 1 + train_steps: 30000 + log_freq: 1 + grad_clip_norm: 1.0 + use_amp: true + shuffle: true + num_workers: 4 + + checkpoint: + output_directory: ${experiment.exp_dir} + # Whether to save checkpoint + save_checkpoint: true + # Number of steps between checkpoints + save_freq: 10000 + # TODO(yupu): Support resuming from checkpoint + +model: + # TODO: (yupu) the config layout is still a mess + model_name: qwen_gr00t + # Path to the checkpoint of the pretrained base VLM model, e.g. Qwen3-VL-4B-Instruct + checkpoint_dir: /share/project/fengyupu/models/Qwen/Qwen3-VL-4B-Instruct/ + # checkpoint_dir: /workspace/models/Qwen/Qwen2.5-VL-3B-Instruct/ + vlm: + type: qwen3-vl + # type: qwen2.5-vl + qwenvl: + base_vlm: /share/project/fengyupu/models/Qwen/Qwen3-VL-4B-Instruct/ + # base_vlm: /workspace/models/Qwen/Qwen2.5-VL-3B-Instruct/ + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + dino: + dino_backbone: dinov2_vits14 + action_model: + type: flow_matching + action_model_type: DiT-B + action_hidden_dim: 1024 + hidden_size: 1024 + add_pos_embed: True + max_seq_len: 1024 + action_dim: 7 + state_dim: 7 + future_action_window_size: 7 + action_horizon: 8 + past_action_window_size: 0 + repeated_diffusion_steps: 4 + noise_beta_alpha: 1.5 + noise_beta_beta: 1.0 + noise_s: 0.999 + num_timestep_buckets: 1000 + num_inference_timesteps: 4 + num_target_vision_tokens: 32 + diffusion_model_cfg: + cross_attention_dim: 2048 + dropout: 0.2 + final_dropout: True + # # FIXME: Debug only + # dropout: 0 + # final_dropout: False + interleave_self_attention: True + norm_type: ada_norm + num_layers: 16 + output_dim: 1024 + positional_embeddings: None + reduce_in_full_precision: True + + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + param_groups: + vlm: + lr: 1.0e-05 + action_model: + lr: 1.0e-04 + scheduler: + name: cosine_with_min_lr + warmup_steps: 5000 + scheduler_kwargs: + min_lr: 1.0e-06 + # Legacy fields kept for BC + decay_steps: 30000 + decay_lr: 2.5e-6 + + # ============================================================ + # Module Freezing Configuration + # ============================================================ + # Freezing logic: freeze_patterns are applied first, then keep_patterns override. + # Patterns are regex matched against full parameter names. + # + # Common patterns for QwenGR00T: + # - "qwen_vl_interface\\..*" # Entire VLM + # - "qwen_vl_interface\\.model\\.visual\\..*" # Vision encoder + # - "qwen_vl_interface\\.model\\.model\\..*" # Language model + # - "qwen_vl_interface\\.model\\.model\\.layers\\.[0-9]\\." # LLM layers 0-9 + # - "action_model\\..*" # Action head + # - "action_model\\.model\\.transformer_blocks\\.[0-7]\\." # DiT blocks 0-7 + # + # freeze: + # # SCENARIO A: Freeze VLM, train only action head + # freeze_patterns: + # - "qwen_vl_interface\\..*" + # + # # SCENARIO B: Freeze VLM but keep projector trainable + # # freeze_patterns: + # # - "qwen_vl_interface\\..*" + # # keep_patterns: + # # - "qwen_vl_interface\\.model\\.visual\\.merger\\..*" + # + # # SCENARIO C: Freeze everything except action decoder + # # freeze_patterns: + # # - ".*" + # # keep_patterns: + # # - "action_model\\.action_decoder\\..*" + +data: + # TODO: (yupu) Remove this once we have a proper dataset config + vla_data: + dataset_py: lerobot_datasets + data_root_dir: playground/Datasets/ + data_mix: libero_goal_old + action_type: delta_qpos + CoT_prompt: Your task is {instruction}. To identify the key objects for your task. Locate their bounding boxes in [x1,y1,x2,y2] format. + CoT_answer: bbox + default_image_resolution: [3, 224, 224] + load_all_data_for_training: True + obs: ["image_0"] + video_backend: torchvision_av + # Path to the training data + data_path: /share/project/fengyupu/datasets/IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot/ + tolerance_s: 0.0001 + use_imagenet_stats: False + # To match the input features naming from the dataset to the policy config + # For example, for the aloha_mobile_cabinet dataset, the rename_map is: + rename_map: + observation.images.cam_high: observation.images.base_0_rgb + observation.images.cam_left_wrist: observation.images.left_wrist_0_rgb + observation.images.cam_right_wrist: observation.images.right_wrist_0_rgb + use_quantiles: false + # TODO: (yupu) I think these indices should belong to the policy config, maybe put it in the model config? + observation_delta_indices: [0] + action_delta_indices: [0,1,2,3,4,5,6,7] + preprocessor: + name: policy_preprocessor + steps: + - registry_name: rename_observations_processor + config: + rename_map: {} + - registry_name: to_batch_processor + config: {} + - registry_name: device_processor + config: + device: cuda + float_dtype: null + - registry_name: normalizer_processor + config: + eps: 1e-8 + features: {} + norm_map: + VISUAL: IDENTITY + STATE: MIN_MAX + ACTION: MIN_MAX + postprocessor: + name: policy_postprocessor + steps: + - registry_name: unnormalizer_processor + config: + eps: 1e-8 + features: {} + norm_map: + VISUAL: IDENTITY + STATE: MIN_MAX + ACTION: MIN_MAX diff --git a/flagscale/logger.py b/flagscale/logger.py index 738dbcb8d..61dcf26a1 100644 --- a/flagscale/logger.py +++ b/flagscale/logger.py @@ -22,19 +22,19 @@ def __init__(self, name, level=logging.INFO): self.logger.addHandler(stream_handler) def info(self, message): - self.logger.info(message) + self.logger.info(message, stacklevel=2) def warning(self, message): - self.logger.warning(message) + self.logger.warning(message, stacklevel=2) def error(self, message): - self.logger.error(message) + self.logger.error(message, stacklevel=2) def critical(self, message): - self.logger.critical(message) + self.logger.critical(message, stacklevel=2) def debug(self, message): - self.logger.debug(message) + self.logger.debug(message, stacklevel=2) GLOBAL_LOGGER = None diff --git a/flagscale/models/action_model/__init__.py b/flagscale/models/action_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flagscale/models/action_model/flow_matching_head/__init__.py b/flagscale/models/action_model/flow_matching_head/__init__.py new file mode 100644 index 000000000..3159bfe65 --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flagscale/models/action_model/flow_matching_head/action_encoder.py b/flagscale/models/action_model/flow_matching_head/action_encoder.py new file mode 100644 index 000000000..2b005c6be --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/action_encoder.py @@ -0,0 +1,105 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/flow_matching_head/action_encoder.py +# Below is the original copyright: + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def swish(x): + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Produces a sinusoidal encoding of shape (B, T, w) + given timesteps of shape (B, T). + """ + + def __init__(self, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps): + # timesteps: shape (B, T) + # We'll compute sin/cos frequencies across dim T + timesteps = timesteps.float() # ensure float + + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + # typical log space frequencies for sinusoidal encoding + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + # Expand timesteps to (B, T, 1) then multiply + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = nn.Linear(action_dim, hidden_size) # (d -> w) + self.W2 = nn.Linear(2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = nn.Linear(hidden_size, hidden_size) # (w -> w) + + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x) + + return x diff --git a/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py b/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py new file mode 100755 index 000000000..1e194a092 --- /dev/null +++ b/flagscale/models/action_model/flow_matching_head/cross_attention_dit.py @@ -0,0 +1,380 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/flow_matching_head/cross_attention_dit.py +# Below is the original copyright: + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import ( + SinusoidalPositionalEmbedding, + TimestepEmbedding, + Timesteps, +) +from torch import nn + +from flagscale.logger import logger + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim, compute_dtype=torch.float32): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps): + dtype = next(self.parameters()).dtype + timesteps_proj = self.time_proj(timesteps).to(dtype) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + return timesteps_emb + + +class AdaLayerNorm(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + self.chunk_dim = chunk_dim + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.norm_type = norm_type + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positional_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding( + dim, max_seq_length=num_positional_embeddings + ) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if final_dropout: + self.final_dropout = nn.Dropout(dropout) + else: + self.final_dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + temb: torch.LongTensor | None = None, + ) -> torch.Tensor: + # 0. Self-Attention + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, temb) + else: + norm_hidden_states = self.norm1(hidden_states) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, # @JinhuiYE original attention_mask=attention_mask + ) + if self.final_dropout: + attn_output = self.final_dropout(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + # register_to_config 的作用是创建类的时候会自动把传入的参数注册到 config 中,这样后续调用的时候可以通过 self.config.xxx 调用 还不是 self.xxx + @register_to_config # 去看一下这个的作用 --> 将传入的参数注册到配置中 TODO 改为我们的单例模式, 写一个 能够merge 的 @merge_pram_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + cross_attention_dim: int | None = None, + **kwargs, + ): + super().__init__() + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + # Timestep encoder + # self.config.compute_dtype 可能不存在,要提前处理 + compute_dtype = getattr(self.config, "compute_dtype", torch.float32) + self.timestep_encoder = TimestepEncoder( # TODO BUG, train 的时候 self.config.compute_dtype 不会报错, 但是 eval 的时候会 + embedding_dim=self.inner_dim, compute_dtype=compute_dtype + ) + + all_blocks = [] + for idx in range(self.config.num_layers): + use_self_attn = idx % 2 == 1 and interleave_self_attention + curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None + + all_blocks += [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + cross_attention_dim=curr_cross_attention_dim, + ) + ] + self.transformer_blocks = nn.ModuleList(all_blocks) + + # Output blocks + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) + logger.info( + f"Total number of DiT parameters: " + f"{sum(p.numel() for p in self.parameters() if p.requires_grad)}" + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + timestep: torch.LongTensor | None = None, + return_all_hidden_states: bool = False, + encoder_attention_mask=None, + ): + # Encode timesteps + temb = self.timestep_encoder(timestep) + + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1 and self.config.interleave_self_attention: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + # Output processing + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + else: + return self.proj_out_2(hidden_states) + + +class SelfAttentionTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + ) + for _ in range(self.config.num_layers) + ] + ) + logger.info( + f"Total number of SelfAttentionTransformer parameters: " + f"{sum(p.numel() for p in self.parameters() if p.requires_grad)}" + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + return_all_hidden_states: bool = False, + ): + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + hidden_states = block(hidden_states) + all_hidden_states.append(hidden_states) + + if return_all_hidden_states: + return hidden_states, all_hidden_states + else: + return hidden_states diff --git a/flagscale/models/action_model/gr00t_action_header.py b/flagscale/models/action_model/gr00t_action_header.py new file mode 100644 index 000000000..66b4ba399 --- /dev/null +++ b/flagscale/models/action_model/gr00t_action_header.py @@ -0,0 +1,399 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/action_model/GR00T_ActionHeader.py +# Below is the original copyright: + +# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved. +# Modified by [Junqiu YU/ Fudan University] in [2025]. +# Modification: [rm and add some connect adapter to match with starVLA, e.g., "rm "]. +# Action repeat is inspired by CogACT + + +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import Beta +from transformers import PretrainedConfig +from transformers.feature_extraction_utils import BatchFeature + +from flagscale.models.action_model.flow_matching_head.action_encoder import ( + SinusoidalPositionalEncoding, + swish, +) +from flagscale.models.action_model.flow_matching_head.cross_attention_dit import DiT + +# TODO try to merge DiT Modules with follow_match_head, they are just the same arch, but diff loss, use diffusers package will be simple + + +class CategorySpecificLinear(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim): + super().__init__() + self.num_categories = num_categories + # For each category, we have separate weights and biases. + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x, cat_ids): + selected_W = self.W[cat_ids] + selected_b = self.b[cat_ids] + # import ipdb; ipdb.set_trace() + return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) + + +class CategorySpecificMLP(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim, output_dim): + super().__init__() + self.num_categories = num_categories + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x, cat_ids): + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + return self.layer2(F.relu(self.layer1(x))) + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + self.action_dim = action_dim + self.layer1 = nn.Linear(action_dim, hidden_size) + self.layer2 = nn.Linear(2 * hidden_size, hidden_size) + self.layer3 = nn.Linear(hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.layer1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then layer2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.layer2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.layer3(x) + return x + + +class MultiEmbodimentActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size, num_embodiments): + super().__init__() + self.hidden_size = hidden_size + self.num_embodiments = num_embodiments + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps, cat_ids): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + cat_ids: shape (B,) + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions, cat_ids) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x, cat_ids) + return x + + +@dataclass +class FlowmatchingActionHeadConfig(PretrainedConfig): + """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" + + add_pos_embed: bool = field( + default=True, metadata={"help": "Whether to add positional embedding"} + ) + diffusion_model_cfg: dict = field( + default=None, metadata={"help": "Diffusion model configuration."} + ) + input_embedding_dim: int = field( + default=1536, metadata={"help": "Input embedding channel dimension."} + ) + + hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."}) + max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"}) + action_dim: int = field(default=None, metadata={"help": "Action dimension."}) + action_horizon: int = field(default=None, metadata={"help": "Action horizon."}) + noise_beta_alpha: float = field(default=1.5, metadata={"help": ""}) + noise_beta_beta: float = field(default=1.0, metadata={"help": ""}) + noise_s: float = field( + default=0.999, metadata={"help": "Flow matching noise Beta distribution s."} + ) + num_timestep_buckets: int = field( + default=1000, metadata={"help": "Number of timestep discretization buckets."} + ) + num_inference_timesteps: int = field( + default=None, + metadata={"help": "Number of inference steps for noise diffusion."}, + ) + max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."}) + tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."}) + tune_diffusion_model: bool = field( + default=True, metadata={"help": "Whether to tune the diffusion model."} + ) + load_pretrained_det_decode_layer_path: str = field( + default=None, metadata={"help": "Path to pretrained detection model."} + ) + detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."}) + + freeze_decode_layer: bool = field(default=False) + expand_batch: int = field(default=None) + use_vlln: bool = field(default=True) + + vl_self_attention_cfg: dict = field(default=None) + num_target_vision_tokens: int = field( + default=32, metadata={"help": "Number of target vision tokens."} + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) + + +DiTConfig = { + "DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12}, + "DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32}, +} + + +class FlowmatchingActionHead(nn.Module): + def __init__( + self, + full_config, + ): + super().__init__() + config = full_config.model.action_model + self.hidden_size = config.hidden_size # @JinhuiYE + self.full_config = full_config + action_model_type = config.action_model_type + action_model_cfg = DiTConfig[action_model_type] + + self.input_embedding_dim = action_model_cfg["input_embedding_dim"] + diffusion_model_cfg = config.diffusion_model_cfg + diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg} + self.model = DiT(**diffusion_model_cfg) + self.action_dim = config.action_dim + self.action_horizon = config.future_action_window_size + 1 + self.num_inference_timesteps = config.num_inference_timesteps + + self.state_encoder = ( + MLP( + input_dim=config.state_dim, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + if config.state_dim + else None + ) + + self.action_encoder = ActionEncoder( + action_dim=config.action_dim, + hidden_size=self.input_embedding_dim, + ) + self.action_decoder = MLP( + input_dim=self.model.config.output_dim, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim) + nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) + + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + + self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self.num_timestep_buckets = config.num_timestep_buckets + self.config = config + + def sample_time(self, batch_size, device, dtype): + sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + return (self.config.noise_s - sample) / self.config.noise_s + + def prepare_input(self, batch: dict) -> BatchFeature: + return BatchFeature(data=batch) + + def forward( + self, + vl_embs: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor = None, + encoder_attention_mask=None, + ): + """ + vl_embs: shape (B, seq_length, feature_dim) + actions: shape (B, future_action_window_size, D_action) + """ + device = vl_embs.device + + # Validate action dimension + if actions.shape[-1] != self.action_dim: + raise ValueError( + f"Action dimension mismatch: model expects {self.action_dim} dimensions " + f"(from config), but received actions with {actions.shape[-1]} dimensions. " + f"Please update config.model.action_model.action_dim to match your data." + ) + # Embed noised action trajectory. + noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) + + t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) + t = t[:, None, None] # shape (B,1,1) for broadcast + + noisy_trajectory = (1 - t) * noise + t * actions + velocity = actions - noise + + # Convert (continuous) t -> discrete if needed + t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() + action_features = self.action_encoder(noisy_trajectory, t_discretized) + + # embed state + state_features = self.state_encoder(state) if state is not None else None + + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = ( + torch.cat((state_features, future_tokens, action_features), dim=1) + if state_features is not None + else torch.cat((future_tokens, action_features), dim=1) + ) + + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + encoder_attention_mask=encoder_attention_mask, + timestep=t_discretized, + return_all_hidden_states=False, # NOTE (YL): not using flare now + ) + pred = self.action_decoder(model_output) + pred_actions = pred[:, -actions.shape[1] :] + + # Slice out only the action portion of pred and target. + loss = ((pred_actions - velocity) ** 2).mean() + return loss + + @torch.no_grad() + def predict_action(self, vl_embs: torch.Tensor, state: torch.Tensor = None) -> torch.Tensor: + # Set initial actions as the sampled noise. + batch_size = vl_embs.shape[0] + device = vl_embs.device + actions = torch.randn( # yes, here make sure action_horizon align with data loader? or share from client? + size=(batch_size, self.config.action_horizon, self.config.action_dim), + dtype=vl_embs.dtype, + device=device, + ) + + num_steps = self.num_inference_timesteps + dt = 1.0 / num_steps + + state_features = self.state_encoder(state) if state is not None else None + + # Run denoising steps. + for t in range(num_steps): + t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ... + t_discretized = int(t_cont * self.num_timestep_buckets) + + # Embed noised action trajectory. + timesteps_tensor = torch.full( + size=(batch_size,), fill_value=t_discretized, device=device + ) + action_features = self.action_encoder(actions, timesteps_tensor) + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = ( + torch.cat((state_features, future_tokens, action_features), dim=1) + if state_features is not None + else torch.cat((future_tokens, action_features), dim=1) + ) + + # Run model forward. + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output) + + pred_velocity = pred[:, -self.action_horizon :] + + # Update actions using euler integration. + actions = actions + dt * pred_velocity + return actions + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype diff --git a/flagscale/models/vla/__init__.py b/flagscale/models/vla/__init__.py new file mode 100644 index 000000000..fe5d03550 --- /dev/null +++ b/flagscale/models/vla/__init__.py @@ -0,0 +1,30 @@ +from .action_models.flow_matching import FlowMatchingHead +from .protocols import ActionModel, VLMBackbone +from .qwen_gr00t import QwenGr00t +from .registry import ( + ACTION_MODEL_REGISTRY, + VLM_REGISTRY, + build_action_model, + build_vlm, + register_action_model, + register_vlm, +) +from .utils import get_vlm_config + +# Explicit registration +from .vlm.qwen_vl import Qwen3VLBackbone, Qwen25VLBackbone + +VLM_REGISTRY["qwen2.5-vl"] = Qwen25VLBackbone +VLM_REGISTRY["qwen3-vl"] = Qwen3VLBackbone +ACTION_MODEL_REGISTRY["flow_matching"] = FlowMatchingHead + +__all__ = [ + "VLMBackbone", + "ActionModel", + "register_vlm", + "register_action_model", + "build_vlm", + "build_action_model", + "get_vlm_config", + "QwenGr00t", +] diff --git a/flagscale/models/vla/action_models/__init__.py b/flagscale/models/vla/action_models/__init__.py new file mode 100644 index 000000000..91408b343 --- /dev/null +++ b/flagscale/models/vla/action_models/__init__.py @@ -0,0 +1,3 @@ +from .flow_matching import FlowMatchingHead + +__all__ = ["FlowMatchingHead"] diff --git a/flagscale/models/vla/action_models/flow_matching.py b/flagscale/models/vla/action_models/flow_matching.py new file mode 100644 index 000000000..c65f2aca6 --- /dev/null +++ b/flagscale/models/vla/action_models/flow_matching.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + +from flagscale.models.action_model.gr00t_action_header import ( + FlowmatchingActionHead as _FlowmatchingActionHead, +) +from flagscale.models.utils.constants import ACTION +from flagscale.models.vla.utils import get_vlm_config +from flagscale.train.train_config import TrainConfig + + +class FlowMatchingHead(nn.Module): + """ + Flow matching action head wrapper for VLA framework. + + Args: + vlm_config: HF config object from VLM (used to get hidden_size). + action_config: dict with action model settings. + full_config: TrainConfig for initializing the underlying FlowmatchingActionHead. + """ + + def __init__(self, vlm_config, action_config: dict, full_config: TrainConfig = None): + super().__init__() + vlm_info = get_vlm_config(vlm_config) + self.hidden_size = vlm_info["hidden_size"] + + # TODO: pass cross_attention_dim directly to action head instead of mutating full_config + full_config.model.action_model.diffusion_model_cfg.cross_attention_dim = self.hidden_size + + self._head = _FlowmatchingActionHead(full_config=full_config) + + def forward( + self, vlm_output: dict[str, torch.Tensor], action_input: dict[str, torch.Tensor], **kwargs + ) -> dict[str, torch.Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch with 'actions', 'state', etc. + Returns: + dict with 'loss'. + """ + vl_embs = vlm_output["hidden_states"] + actions = action_input["actions"] + state = action_input.get("state") + encoder_attention_mask = action_input.get("attention_mask") + + loss = self._head.forward( + vl_embs=vl_embs, + actions=actions, + state=state, + encoder_attention_mask=encoder_attention_mask, + ) + return {"loss": loss} + + def predict_action( + self, vlm_output: dict[str, torch.Tensor], action_input: dict[str, torch.Tensor], **kwargs + ) -> dict[str, torch.Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch with 'state', etc. + Returns: + dict with 'actions': Tensor [B, horizon, action_dim]. + """ + vl_embs = vlm_output["hidden_states"] + state = action_input.get("state") + + actions = self._head.predict_action(vl_embs=vl_embs, state=state) + return {ACTION: actions} + + def fsdp_units(self) -> list[nn.Module]: + return list(self._head.model.transformer_blocks) diff --git a/flagscale/models/vla/protocols.py b/flagscale/models/vla/protocols.py new file mode 100644 index 000000000..70dcec6d2 --- /dev/null +++ b/flagscale/models/vla/protocols.py @@ -0,0 +1,63 @@ +from typing import Protocol + +from torch import Tensor +from torch.nn import Module + + +class VLMBackbone(Protocol): + @property + def config(self): + """HF config object (e.g., Qwen2VLConfig).""" + ... + + def prepare_input(self, batch: dict) -> dict[str, Tensor]: + """ + Args: + batch: Raw batch with 'image', 'lang', etc. + Returns: + Tokenized inputs ready for forward(). + """ + ... + + def forward(self, batch: dict[str, Tensor], **kwargs) -> dict[str, Tensor]: + """ + Args: + batch: Tokenized inputs from prepare_input(). + Returns: + dict with 'hidden_states': tuple of layer outputs. + """ + ... + + def fsdp_units(self) -> list[Module]: + """Return submodules that should each be individually sharded by FSDP.""" + ... + + +class ActionModel(Protocol): + def forward( + self, vlm_output: dict[str, Tensor], action_input: dict[str, Tensor], **kwargs + ) -> dict[str, Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch - pick what you need ('actions', 'state', etc.). + Returns: + dict with 'loss'. + """ + ... + + def predict( + self, vlm_output: dict[str, Tensor], action_input: dict[str, Tensor], **kwargs + ) -> dict[str, Tensor]: + """ + Args: + vlm_output: From VLM, contains 'hidden_states'. + action_input: Raw batch - pick what you need ('state', etc.). + Returns: + dict with 'actions': Tensor [B, horizon, action_dim]. + """ + ... + + def fsdp_units(self) -> list[Module]: + """Return submodules that should each be individually sharded by FSDP.""" + ... diff --git a/flagscale/models/vla/qwen_gr00t.py b/flagscale/models/vla/qwen_gr00t.py new file mode 100644 index 000000000..01589c71f --- /dev/null +++ b/flagscale/models/vla/qwen_gr00t.py @@ -0,0 +1,159 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/framework/QwenGR00T.py +# Below is the original copyright: + +# Copyright 2025 starVLA community. All rights reserved. +# Licensed under the MIT License, Version 1.0 (the "License"); +# Implemented by [Junqiu YU / Fudan University] in [2025]. +# Design and Merged by [Jinhui YE / HKUST University] in [2025]. + +""" +Qwen-GR00T Framework +A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions +Flow-matching header is copyright from GR00T N1.5, +""" + +import torch +from transformers import PretrainedConfig, PreTrainedModel + +from flagscale.models.utils.constants import ACTION +from flagscale.models.vla.registry import build_action_model, build_vlm +from flagscale.train.train_config import TrainConfig + + +class QwenGr00t(PreTrainedModel): + """ + Multimodal vision-language-action model. + + Components: + - Qwen VL interface for fused language/vision token embeddings + - DiT diffusion head for future action sequence modeling + + Focus: Predict future continuous actions conditioned on images + instruction. + """ + + config_class = PretrainedConfig + + def __init__(self, config: TrainConfig, **kwargs): + super().__init__(PretrainedConfig()) + self._config = config + + vlm_type = config.model.vlm.get("type", "qwen3-vl") + self.vlm = build_vlm(vlm_type, config=config) + + action_model_type = config.model.action_model.get("type", "flow_matching") + self.action_model = build_action_model( + action_model_type, + vlm_config=self.vlm.model_config, + action_config={}, + full_config=config, + ) + + self.future_action_window_size = config.model.action_model.future_action_window_size + + def forward(self, examples: dict, **kwargs): + """ """ + # actions = [example["action"] for example in examples] # [B, T, action_dim] + actions = examples[ACTION] + state = None # examples[OBS_STATE] + + # Step 1: QWenVL input format + # NOTE: (yupu) The order of the images differs from starVLA, which is [image, wrist_image] + qwen_inputs = self.vlm.prepare_input(examples) + + # TODO: (yupu) Hard-coded autocast and dtype, matches starVLA + with torch.autocast("cuda", dtype=torch.bfloat16): + vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False) + # last_hidden_state: [B, seq_len, H] + last_hidden = vlm_output["hidden_states"][-1] # [B, L, H] + + # Step 4: Action Expert Forward and Loss + with torch.autocast("cuda", dtype=torch.float32): + # TODO: (yupu) Is this a bug or a feature? The action dtype would stay as bf16 under this autocast. + actions = actions.to(device=last_hidden.device, dtype=last_hidden.dtype) + # actions = torch.tensor( + # np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype + # ) # [B, T_full, action_dim] + + # TODO: does not match RoboBrainX, need to check + actions_target = actions[ + :, -(self.future_action_window_size + 1) :, : + ] # (B, chunk_len, action_dim) + + # TODO: (yupu) I believe there is a bug in starVLA, the + # `repeated_diffusion_steps` is not properly set in the config. + repeated_diffusion_steps = self._config.model.action_model.get( + "repeated_diffusion_steps", 4 + ) + + actions_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) + last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1) + + state_repeated = None + if state is not None: + state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) + state_repeated = state.repeat(repeated_diffusion_steps, 1, 1) + + # Use action head forward API + vlm_output_repeated = {"hidden_states": last_hidden_repeated} + action_input = {"actions": actions_repeated, "state": state_repeated} + + output = self.action_model.forward(vlm_output_repeated, action_input) + + return output["loss"] + + @torch.inference_mode() + def predict_action(self, examples: list[dict], **kwargs) -> dict: + """ + Steps: + 1. Resize images to training resolution (if specified) + 2. Encode with QwenVL (hidden states retained) + 6. Return normalized action trajectory + Returns: + dict: + normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. + """ + # TODO: (yupu) Fix inference input format to use constants (OBS_IMAGE, OBS_LANGUAGE, OBS_STATE) + # instead of hardcoded keys. The current keys are inconsistent with training batch format. + # batch_images = [[to_pil_preserve(example["image"])] for example in examples] # [B, [PLT]] + # instructions = [example["lang"] for example in examples] # [B, str] + + # We assume the images are already resized during preprocessing. + qwen_inputs = self.vlm.prepare_input(examples) + state = None # examples[OBS_STATE] + + # state = ( + # [example["state"] for example in examples] if "state" in examples[0] else None + # ) # [B, 1, state_dim] + + # train_obs_image_size = getattr(self._config.data.vla_data, "image_size", None) + # if train_obs_image_size: + # batch_images = resize_images(batch_images, target_size=train_obs_image_size) + + # # Step 1: QWenVL input format + # qwen_inputs = self.vlm.build_qwenvl_inputs( + # examples=None, images=batch_images, instructions=instructions + # ) + + with torch.autocast("cuda", dtype=torch.bfloat16): + vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False) + # last_hidden_state: [B, seq_len, H] + last_hidden = vlm_output["hidden_states"][-1] # [B, L, H] + + if state is not None: + state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) + + # state_tensor = ( + # torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) + # if state is not None + # else None + # ) + + # Step 4: Action Expert Forward + with torch.autocast("cuda", dtype=torch.float32): + vlm_output_for_action = {"hidden_states": last_hidden} + action_input = {"state": state} + output = self.action_model.predict_action(vlm_output_for_action, action_input) + + # Assume the output of the action moadel is dict mapps `ACTION` to the normalized actions + return output diff --git a/flagscale/models/vla/registry.py b/flagscale/models/vla/registry.py new file mode 100644 index 000000000..180e6a196 --- /dev/null +++ b/flagscale/models/vla/registry.py @@ -0,0 +1,32 @@ +VLM_REGISTRY: dict[str, type] = {} +ACTION_MODEL_REGISTRY: dict[str, type] = {} + + +def register_vlm(name: str): + def decorator(cls): + VLM_REGISTRY[name] = cls + return cls + + return decorator + + +def register_action_model(name: str): + def decorator(cls): + ACTION_MODEL_REGISTRY[name] = cls + return cls + + return decorator + + +def build_vlm(name: str, **kwargs): + if name not in VLM_REGISTRY: + raise ValueError(f"Unknown VLM: {name}. Available: {list(VLM_REGISTRY.keys())}") + return VLM_REGISTRY[name](**kwargs) + + +def build_action_model(name: str, vlm_config, action_config: dict, **kwargs): + if name not in ACTION_MODEL_REGISTRY: + raise ValueError( + f"Unknown ActionModel: {name}. Available: {list(ACTION_MODEL_REGISTRY.keys())}" + ) + return ACTION_MODEL_REGISTRY[name](vlm_config=vlm_config, action_config=action_config, **kwargs) diff --git a/flagscale/models/vla/utils.py b/flagscale/models/vla/utils.py new file mode 100644 index 000000000..b0ee5fdd1 --- /dev/null +++ b/flagscale/models/vla/utils.py @@ -0,0 +1,29 @@ +def get_vlm_config(vlm_config) -> dict: + """ + Extract common fields from any VLM config, handling structural differences. + + Args: + vlm_config: HF config object (may have hidden_size directly or via text_config). + Returns: + dict with 'hidden_size' and 'num_hidden_layers'. + """ + return { + "hidden_size": _get_hidden_size(vlm_config), + "num_hidden_layers": _get_num_layers(vlm_config), + } + + +def _get_hidden_size(config) -> int: + if hasattr(config, "hidden_size"): + return config.hidden_size + if hasattr(config, "text_config"): + return config.text_config.hidden_size + raise ValueError(f"Cannot determine hidden_size from config: {type(config)}") + + +def _get_num_layers(config) -> int: + if hasattr(config, "num_hidden_layers"): + return config.num_hidden_layers + if hasattr(config, "text_config"): + return config.text_config.num_hidden_layers + raise ValueError(f"Cannot determine num_hidden_layers from config: {type(config)}") diff --git a/flagscale/models/vla/vlm/__init__.py b/flagscale/models/vla/vlm/__init__.py new file mode 100644 index 000000000..aa73dba2e --- /dev/null +++ b/flagscale/models/vla/vlm/__init__.py @@ -0,0 +1,3 @@ +from .qwen_vl import Qwen3VLBackbone, Qwen25VLBackbone, QwenVLBackbone + +__all__ = ["QwenVLBackbone", "Qwen25VLBackbone", "Qwen3VLBackbone"] diff --git a/flagscale/models/vla/vlm/qwen_vl.py b/flagscale/models/vla/vlm/qwen_vl.py new file mode 100644 index 000000000..94d63a01c --- /dev/null +++ b/flagscale/models/vla/vlm/qwen_vl.py @@ -0,0 +1,258 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/modules/vlm/QWen3.py + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from transformers import ( + AutoProcessor, + PretrainedConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen3VLForConditionalGeneration, +) + +from flagscale.train.train_config import TrainConfig + + +def _to_pil(img): + """Convert a single image (tensor, numpy, or PIL) to PIL.Image.""" + if isinstance(img, Image.Image): + return img + if isinstance(img, torch.Tensor): + img = img.detach().cpu().numpy() + if isinstance(img, np.ndarray): + if img.dtype == np.uint8: + return Image.fromarray(img) + # float [0,1] → uint8 + return Image.fromarray((img * 255).clip(0, 255).astype(np.uint8)) + return img + + +class QwenVLBackbone(nn.Module): + """ + Base class for Qwen VL backends. + + Args: + config: TrainConfig object with config.model.qwenvl namespace. + """ + + def __init__(self, config: TrainConfig, **kwargs): + super().__init__() + qwenvl_config = config.model.qwenvl + self.model_id = qwenvl_config.base_vlm + + # TODO: (yupu) The model loaded by `from_pretrained` is eval mode by default, is this expected? I removed `policy.train()` in train_qwen_gr00t.py to match starVLA, but not sure if this is the right way to do this. + self.model = self._load_model(self.model_id) + self.processor = AutoProcessor.from_pretrained(self.model_id) + # FIXME: Hard-coded padding side + self.processor.tokenizer.padding_side = "left" + self._config: TrainConfig = config + + def _load_model(self, model_id: str): + raise NotImplementedError + + @property + def model_config(self) -> PretrainedConfig: + """HF config object (e.g., Qwen2VLConfig).""" + return self.model.config + + def prepare_input(self, batch: dict) -> dict[str, torch.Tensor]: + raise NotImplementedError + + def forward(self, batch: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: + with torch.autocast("cuda", dtype=torch.bfloat16): + outputs = self.model( + **batch, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + # TODO: (yupu) We should output the original outputs, not just the hidden states. + return {"hidden_states": outputs.hidden_states} + + def fsdp_units(self) -> list[nn.Module]: + return list(self.model.model.visual.blocks) + list(self.model.model.language_model.layers) + + +class Qwen25VLBackbone(QwenVLBackbone): + """Qwen2.5-VL backend.""" + + def _load_model(self, model_id: str): + # WARNING: hard-coded attn_implementation and torch_dtype + return Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, + attn_implementation="flash_attention_2", + torch_dtype="auto", + ) + + def prepare_input(self, batch: dict) -> dict[str, torch.Tensor]: + # TODO: (yupu) This is a hack, we should find a better way to handle this. + image_keys = self._config.data.vla_data.image_features + return self.build_qwenvl_inputs(examples=batch, image_keys=image_keys) + + def build_qwenvl_inputs( + self, + examples, + images=None, + instructions=None, + image_keys=None, + **kwargs, + ): + if examples is not None and (images is None or instructions is None): + # TODO: (yupu) hard-code task key to "task" + instructions = examples["task"] + if isinstance(instructions, torch.Tensor): + instructions = instructions.detach().cpu().tolist() + if isinstance(instructions, str): + instructions = [instructions] + + batch_images = None + for key in image_keys: + imgs = examples[key] + if isinstance(imgs, torch.Tensor) and imgs.ndim == 3: + imgs = [imgs] + key_images = [_to_pil(img) for img in imgs] + if batch_images is None: + batch_images = [[img] for img in key_images] + else: + for sample_images, img in zip(batch_images, key_images): + sample_images.append(img) + + for idx, sample_images in enumerate(batch_images): + batch_images[idx] = [img for img in sample_images if img is not None] + + images = batch_images + + from qwen_vl_utils import process_vision_info + + # Create messages: one message per sample + messages = [] + assert len(images) == len(instructions) + for imgs, instruction in zip(images, instructions): + content = [{"type": "image", "image": img} for img in imgs] + + if "CoT_prompt" in self._config.data.vla_data: + CoT_prompt = self._config.data.vla_data.get("CoT_prompt", "") + prompt = CoT_prompt.replace("{instruction}", instruction) + else: + prompt = instruction + + content.append({"type": "text", "text": prompt}) + msg = [{"role": "user", "content": content}] + messages.append(msg) + + # Prepare text prompts using processor + # default process is json --> message --> texts --> input_ids + texts = [ + self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) + for m in messages + ] + + # image_inputs = list of PIL + image_inputs, video_inputs = process_vision_info(messages) + batch_input = self.processor( + text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" + ) + + # Use current CUDA device instead of self.model.device, which returns + # a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors. + return batch_input.to(f"cuda:{torch.cuda.current_device()}") + + +class Qwen3VLBackbone(QwenVLBackbone): + """Qwen3-VL backend.""" + + def _load_model(self, model_id: str) -> Qwen3VLForConditionalGeneration: + # FIXME: hard-coded attn_implementation and torch_dtype matches starVLA + # TODO: (yupu): During inference/serving, it's required to load model twice, not only that, the original qwen model has to be loaded! + model = Qwen3VLForConditionalGeneration.from_pretrained( + model_id, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + ) + # Align dims qwen3 with qwen2.5, actually it's not needed in our case + model.config.hidden_size = model.config.text_config.hidden_size + return model + + def prepare_input(self, batch: dict) -> dict[str, torch.Tensor]: + # TODO: (yupu) This is a hack, we should find a better way to handle this. + # image_keys = self._config.data.vla_data.image_features.keys() + image_keys = ["observation.images.image", "observation.images.wrist_image"] + + # Extract data in starVLA format (list of dicts) + # examples = batch + # batch_images = [example["image"] for example in examples] # [B, [PIL]] + # instructions = [example["lang"] for example in examples] # [B, str] + # actions = [example["action"] for example in examples] # [B, T, action_dim] + # state = None + + # return self.build_qwenvl_inputs( + # examples=None, images=batch_images, instructions=instructions + # ) + + return self.build_qwenvl_inputs(examples=batch, image_keys=image_keys) + + # TODO: (yupu) Refactor this args + def build_qwenvl_inputs( + self, + examples, + images=None, + instructions=None, + image_keys=None, + **kwargs, + ): + if examples is not None and (images is None or instructions is None): + # TODO: (yupu) hard-code task key to "task" + instructions = examples["task"] + if isinstance(instructions, torch.Tensor): + instructions = instructions.detach().cpu().tolist() + if isinstance(instructions, str): + instructions = [instructions] + + batch_images = None + for key in image_keys: + imgs = examples[key] + if isinstance(imgs, torch.Tensor) and imgs.ndim == 3: + imgs = [imgs] + key_images = [_to_pil(img) for img in imgs] + if batch_images is None: + batch_images = [[img] for img in key_images] + else: + for sample_images, img in zip(batch_images, key_images): + sample_images.append(img) + + for idx, sample_images in enumerate(batch_images): + batch_images[idx] = [img for img in sample_images if img is not None] + + images = batch_images + + # Create messages: one message per sample + messages = [] + assert len(images) == len(instructions) + for imgs, instruction in zip(images, instructions): + content = [{"type": "image", "image": img} for img in imgs] + + if "CoT_prompt" in self._config.data.vla_data: + CoT_prompt = self._config.data.vla_data.get("CoT_prompt", "") + prompt = CoT_prompt.replace("{instruction}", instruction) + else: + prompt = instruction + + content.append({"type": "text", "text": prompt}) + msg = [{"role": "user", "content": content}] + messages.append(msg) + + # Preparation for inference + batch_inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + padding=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + + # Use current CUDA device instead of self.model.device, which returns + # a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors. + return batch_inputs.to(f"cuda:{torch.cuda.current_device()}") diff --git a/flagscale/serve/msgpack_numpy.py b/flagscale/serve/msgpack_numpy.py new file mode 100644 index 000000000..b8dc50379 --- /dev/null +++ b/flagscale/serve/msgpack_numpy.py @@ -0,0 +1,60 @@ +# Copied from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/deployment/model_server/tools/msgpack_numpy.py + +"""Adds NumPy array support to msgpack. + +msgpack is good for (de)serializing data over a network for multiple reasons: +- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) +- msgpack is widely used and has good cross-language support +- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed + languages like Python and JavaScript +- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster + than pickle for serializing large arrays using the below strategy + +The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is +that it falls back to pickle for object arrays. +""" + +import functools + +import msgpack +import numpy as np + + +def pack_array(obj): + if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): + raise ValueError(f"Unsupported dtype: {obj.dtype}") + + if isinstance(obj, np.ndarray): + return { + b"__ndarray__": True, + b"data": obj.tobytes(), + b"dtype": obj.dtype.str, + b"shape": obj.shape, + } + + if isinstance(obj, np.generic): + return { + b"__npgeneric__": True, + b"data": obj.item(), + b"dtype": obj.dtype.str, + } + + return obj + + +def unpack_array(obj): + if b"__ndarray__" in obj: + return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) + + if b"__npgeneric__" in obj: + return np.dtype(obj[b"dtype"]).type(obj[b"data"]) + + return obj + + +Packer = functools.partial(msgpack.Packer, default=pack_array) +packb = functools.partial(msgpack.packb, default=pack_array) + +Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) +unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) diff --git a/flagscale/serve/run_serve_qwen_gr00t.py b/flagscale/serve/run_serve_qwen_gr00t.py new file mode 100644 index 000000000..bdae8b5fc --- /dev/null +++ b/flagscale/serve/run_serve_qwen_gr00t.py @@ -0,0 +1,92 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/deployment/model_server/server_policy.py + +import argparse +import importlib +import time + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf + +from flagscale.logger import logger +from flagscale.models.utils.constants import ACTION +from flagscale.serve.websocket_policy_server import WebsocketPolicyServer +from flagscale.train.utils.train_utils import load_checkpoint + + +class Policy: + def __init__(self, config: DictConfig | ListConfig): + self.config_engine = config["engine_args"] + + self.host = self.config_engine.get("host", "0.0.0.0") + self.port = self.config_engine.get("port", 5000) + self.model = None + self.preprocessor = None + self.postprocessor = None + + self.load_model() + + def load_model(self): + t_s = time.perf_counter() + model_variant = self.config_engine.model_variant + policy = getattr(importlib.import_module("flagscale.models.vla"), model_variant) + self.model, self.preprocessor, self.postprocessor = load_checkpoint( + self.config_engine.model, policy, self.config_engine.device + ) + # TODO: (yupu): model.to(dtype)? + logger.info(f"Policy model loading latency: {time.perf_counter() - t_s:.2f}s") + + def infer(self, batch): + # FIXME: image reisze + logger.info("Start to inference") + print(f"batch: {batch}") + # TODO: (yupu) remove hard-code + batch = batch["examples"][0] + for k, v in batch.items(): + if "image" in k: + print(f"{k}: type {type(v)} shape {v.shape}") + batch = self.preprocessor(batch) + + with torch.no_grad(): + action = self.model.predict_action(batch) + logger.info(f"action before postprocessor: {action}") + + logger.info("Applying postprocessor...") + action = self.postprocessor(action) + + # Convert to numpy for msgpack serialization + action[ACTION] = action[ACTION].detach().cpu().numpy() + + return action + + +def parse_config() -> DictConfig | ListConfig: + """Parse the configuration file""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-path", type=str, required=True, help="Path to the configuration YAML file" + ) + parser.add_argument("--log-dir", type=str, required=True, help="Path to the log") + args = parser.parse_args() + config = OmegaConf.load(args.config_path) + return config + + +def main(config): + policy = Policy(config) + logger.info("Done") + # start websocket server + server = WebsocketPolicyServer( + policy=policy, + host=policy.host, + port=policy.port, + metadata={"env": "simpler_env"}, + ) + logger.info("server running ...") + server.serve_forever() + + +if __name__ == "__main__": + parsed_cfg = parse_config() + main(parsed_cfg["serve"][0]) diff --git a/flagscale/serve/websocket_policy_server.py b/flagscale/serve/websocket_policy_server.py new file mode 100644 index 000000000..68f116140 --- /dev/null +++ b/flagscale/serve/websocket_policy_server.py @@ -0,0 +1,96 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/deployment/model_server/tools/websocket_policy_server.py + +import asyncio +import http +import time +import traceback +from typing import Protocol, runtime_checkable + +import websockets.asyncio.server as _server +import websockets.frames +from websockets.http11 import Request, Response + +from . import msgpack_numpy +from flagscale.logger import logger + + +@runtime_checkable +class Policy(Protocol): + def infer(self, obs: dict) -> dict: ... + + +class WebsocketPolicyServer: + """Serves a policy over websocket for evaluation inference. + + Protocol: + 1. On connect, server sends metadata dict to client. + 2. Client sends msgpack-encoded obs dict, server returns msgpack-encoded action dict. + 3. Each response includes a "server_timing" key with latency info. + """ + + def __init__( + self, + policy: Policy, + host: str = "0.0.0.0", + port: int = 10093, + metadata: dict | None = None, + ) -> None: + self._policy = policy + self._host = host + self._port = port + self._metadata = metadata or {} + + def serve_forever(self) -> None: + asyncio.run(self.run()) + + async def run(self) -> None: + async with _server.serve( + self._handler, + self._host, + self._port, + compression=None, + max_size=None, + process_request=_health_check, + ) as server: + await server.serve_forever() + + async def _handler(self, websocket: _server.ServerConnection) -> None: + logger.info(f"Connection from {websocket.remote_address} opened") + packer = msgpack_numpy.Packer() + + await websocket.send(packer.pack(self._metadata)) + + prev_total_time: float | None = None + while True: + try: + start_time = time.monotonic() + obs: dict = msgpack_numpy.unpackb(await websocket.recv()) + + infer_time = time.monotonic() + action: dict = self._policy.infer(obs) + infer_time = time.monotonic() - infer_time + + action["server_timing"] = {"infer_ms": infer_time * 1000} + if prev_total_time is not None: + action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 + + await websocket.send(packer.pack(action)) + prev_total_time = time.monotonic() - start_time + + except websockets.ConnectionClosed: + logger.info(f"Connection from {websocket.remote_address} closed") + break + except Exception: + await websocket.send(traceback.format_exc()) + await websocket.close( + code=websockets.frames.CloseCode.INTERNAL_ERROR, + reason="Internal server error. Traceback included in previous frame.", + ) + raise + + +def _health_check(connection: _server.ServerConnection, request: Request) -> Response | None: + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + return None diff --git a/flagscale/train/datasets/lerobot_dataset.py b/flagscale/train/datasets/lerobot_dataset.py index f3ed6611d..15b371279 100644 --- a/flagscale/train/datasets/lerobot_dataset.py +++ b/flagscale/train/datasets/lerobot_dataset.py @@ -509,7 +509,7 @@ def __repr__(self): feature_keys = list(self.features) return ( f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" + f" Root: '{self.root}',\n" f" Total episodes: '{self.total_episodes}',\n" f" Total frames: '{self.total_frames}',\n" f" Features: '{feature_keys}',\n" @@ -1086,6 +1086,7 @@ def __getitem__(self, idx) -> dict: # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + return item def __repr__(self): diff --git a/flagscale/train/datasets/video_utils.py b/flagscale/train/datasets/video_utils.py index 0cd45e8ef..42272e421 100644 --- a/flagscale/train/datasets/video_utils.py +++ b/flagscale/train/datasets/video_utils.py @@ -27,7 +27,6 @@ from typing import Any, ClassVar import av -import fsspec import pyarrow as pa import torch import torchvision @@ -188,9 +187,9 @@ def get_decoder(self, video_path: str): with self._lock: if video_path not in self._cache: - file_handle = fsspec.open(video_path).__enter__() - decoder = VideoDecoder(file_handle, seek_mode="approximate") - self._cache[video_path] = (decoder, file_handle) + # Pass path directly instead of fsspec file handle — only local files are supported. + decoder = VideoDecoder(video_path, seek_mode="approximate") + self._cache[video_path] = (decoder, None) return self._cache[video_path][0] @@ -198,7 +197,8 @@ def clear(self): """Clear the cache and close file handles.""" with self._lock: for _, file_handle in self._cache.values(): - file_handle.close() + if file_handle is not None: + file_handle.close() self._cache.clear() def size(self) -> int: diff --git a/flagscale/train/train_config.py b/flagscale/train/train_config.py index ec35b4b14..f3ab82ce0 100644 --- a/flagscale/train/train_config.py +++ b/flagscale/train/train_config.py @@ -4,28 +4,131 @@ from typing import Any -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field, field_validator -class OptimizerConfig(BaseModel): - """Optimizer configuration""" +class FreezeConfig(BaseModel): + """Pattern-based module freezing configuration (NeMo-style). - name: str = "AdamW" - lr: float = 2.5e-5 - betas: tuple[float, float] = (0.9, 0.95) - eps: float = 1e-8 - weight_decay: float = 0.01 + Freezing logic: + 1. For each parameter, check if name matches any `freeze_patterns` + 2. If matched, check if name also matches any `keep_patterns` + 3. If matched by freeze but NOT by keep → freeze (requires_grad=False) + + `keep_patterns` overrides `freeze_patterns` - this allows freezing a module + but keeping specific sub-components trainable. + + Patterns are regex patterns matched against full parameter names. + """ + + model_config = {"extra": "allow"} + + freeze_patterns: list[str] | None = None + keep_patterns: list[str] | None = None class SchedulerConfig(BaseModel): - """Learning rate scheduler configuration""" + """Learning rate scheduler configuration. + Uses transformers scheduler types when `name` is set. See transformers.SchedulerType for options: + linear, cosine, cosine_with_restarts, polynomial, constant, + constant_with_warmup, inverse_sqrt, cosine_with_min_lr, etc. + + Example: + scheduler: + name: cosine + warmup_steps: 1000 + scheduler_kwargs: + min_lr: 1e-6 + + For backward compatibility with pi0/pi0.5, the legacy fields (decay_steps, decay_lr) are kept. + """ + + name: str | None = None warmup_steps: int = 1000 + scheduler_kwargs: dict[str, Any] | None = None + + # Legacy fields for pi0/pi0.5 backward compatibility decay_steps: int = 30000 decay_lr: float = 2.5e-6 +class OptimizerConfig(BaseModel): + """Optimizer configuration. + + Attributes: + name: Optimizer class name. Currently supported: "AdamW". + lr: Learning rate (default for all param groups). + betas: Adam beta coefficients (beta1, beta2). + eps: Adam epsilon for numerical stability. + weight_decay: Weight decay (L2 penalty). + param_groups: Per-module optimizer overrides. Maps module paths to optimizer kwargs. + Example: {"encoder": {"lr": 1e-5}, "decoder": {"lr": 1e-3}} + scheduler: LR scheduler config. + + Example config (YAML): + optimizer: + name: AdamW + lr: 1e-4 + weight_decay: 0.01 + param_groups: + vision_encoder: + lr: 1e-5 + action_head: + lr: 2e-4 + scheduler: + name: cosine + warmup_steps: 1000 + """ + + name: str = "AdamW" + lr: float | None = None + betas: tuple[float, float] | None = None + eps: float | None = None + weight_decay: float | None = None + param_groups: dict[str, dict[str, Any]] | None = Field( + default=None, + description="Per-module optimizer settings. Maps module paths to optimizer kwargs.", + ) + scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig) + + @field_validator("betas", mode="before") + @classmethod + def normalize_betas(cls, v): + """Convert list to tuple for betas if provided. + + Accepts both list and tuple inputs, but always stores as tuple. + Also validates that betas has exactly two elements. + """ + if v is None: + return None + if isinstance(v, list): + if len(v) != 2: + raise ValueError(f"betas must have exactly 2 elements, got {len(v)}") + return tuple(v) + if isinstance(v, tuple) and len(v) != 2: + raise ValueError(f"betas must have exactly 2 elements, got {len(v)}") + return v + + def get_optimizer_kwargs(self) -> dict[str, Any]: + """Get non-None optimizer kwargs for passing to optimizer. + + Returns: + Dict of optimizer kwargs, excluding None values. + """ + return { + k: v + for k, v in { + "lr": self.lr, + "betas": self.betas, + "eps": self.eps, + "weight_decay": self.weight_decay, + }.items() + if v is not None + } + + class CheckpointConfig(BaseModel): """Checkpoint saving configuration""" @@ -37,6 +140,8 @@ class CheckpointConfig(BaseModel): class SystemConfig(BaseModel): """Training loop configuration""" + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + batch_size: int = 1 train_steps: int = 100000 log_freq: int = 10 @@ -45,19 +150,37 @@ class SystemConfig(BaseModel): shuffle: bool = False num_workers: int = 4 - optimizer: OptimizerConfig - scheduler: SchedulerConfig checkpoint: CheckpointConfig + raw: DictConfig | None = Field(default=None, exclude=True) + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + raw = self.__dict__.get("raw") + if raw is not None and hasattr(raw, name): + return getattr(raw, name) + raise AttributeError(name) class DataConfig(BaseModel): """Dataset configuration""" + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + data_path: str = Field(..., description="Path to training dataset") tolerance_s: float = 0.0001 use_imagenet_stats: bool = True rename_map: dict[str, str] | None = None use_quantiles: bool = False + raw: DictConfig | None = Field(default=None, exclude=True) + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + raw = self.__dict__.get("raw") + if raw is not None and hasattr(raw, name): + return getattr(raw, name) + raise AttributeError(name) class ModelConfig(BaseModel): @@ -72,22 +195,39 @@ class ModelConfig(BaseModel): All other fields are passed through to the model's config class. """ - model_config = {"extra": "allow"} # Allow extra fields for model-specific config + model_config = { + "extra": "allow", + "arbitrary_types_allowed": True, + } # Allow extra fields for model-specific config # Required fields to identify which model and checkpoint to use model_name: str = Field(..., description="Model name: 'pi0' or 'pi0.5'") checkpoint_dir: str = Field(..., description="Path to pretrained model checkpoint") + freeze: FreezeConfig | None = None + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + raw: DictConfig | None = Field(default=None, exclude=True) + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + raw = self.__dict__.get("raw") + if raw is not None and hasattr(raw, name): + return getattr(raw, name) + raise AttributeError(name) @field_validator("model_name") @classmethod def validate_model_name(cls, v): - if v not in ["pi0", "pi0.5"]: - raise ValueError(f"Invalid model_name: {v}. Must be 'pi0' or 'pi0.5'") + valid_names = {"pi0", "pi0.5", "qwen_gr00t"} + if v not in valid_names: + raise ValueError(f"Invalid model_name: {v}. Must be one of {valid_names}") return v def get_model_config_dict(self) -> dict[str, Any]: """Get all model-specific config fields (excluding train-level fields).""" - return self.model_dump(exclude={"model_name", "checkpoint_dir"}) + return self.model_dump( + exclude={"model_name", "checkpoint_dir", "freeze", "optimizer"} + ) class TrainConfig(BaseModel): @@ -100,10 +240,13 @@ class TrainConfig(BaseModel): @classmethod def from_hydra_config(cls, hydra_config) -> "TrainConfig": """Convert Hydra DictConfig to Pydantic TrainConfig""" - train_dict = OmegaConf.to_container(hydra_config.train, resolve=True) + train = hydra_config.train + train_dict = OmegaConf.to_container(train, resolve=True) + train_dict["system"] = SystemConfig(**train_dict["system"], raw=train.system) + train_dict["data"] = DataConfig(**train_dict["data"], raw=train.data) + train_dict["model"] = ModelConfig(**train_dict["model"], raw=train.model) return cls(**train_dict) class Config: # Allow arbitrary types for complex objects arbitrary_types_allowed = True - diff --git a/flagscale/train/train_pi.py b/flagscale/train/train_pi.py index ab03c747b..47c5bcb41 100644 --- a/flagscale/train/train_pi.py +++ b/flagscale/train/train_pi.py @@ -619,23 +619,23 @@ def main(config: TrainConfig, seed: int): ) # Convert optimizer_betas to tuple if it's a list - optimizer_betas = config.system.optimizer.betas + optimizer_betas = config.model.optimizer.betas if isinstance(optimizer_betas, list): optimizer_betas = tuple(optimizer_betas) # TODO: (yupu) Should we let the user choose between config and policy preset? optimizer = torch.optim.AdamW( policy.parameters(), - lr=config.system.optimizer.lr, + lr=config.model.optimizer.lr, betas=optimizer_betas, - eps=config.system.optimizer.eps, - weight_decay=config.system.optimizer.weight_decay, + eps=config.model.optimizer.eps, + weight_decay=config.model.optimizer.weight_decay, ) scheduler_config = CosineDecayWithWarmupSchedulerConfig( - num_warmup_steps=config.system.scheduler.warmup_steps, - num_decay_steps=config.system.scheduler.decay_steps, - peak_lr=config.system.optimizer.lr, - decay_lr=config.system.scheduler.decay_lr, + num_warmup_steps=config.model.optimizer.scheduler.warmup_steps, + num_decay_steps=config.model.optimizer.scheduler.decay_steps, + peak_lr=config.model.optimizer.lr, + decay_lr=config.model.optimizer.scheduler.decay_lr, ) lr_scheduler = scheduler_config.build(optimizer, config.system.train_steps) diff --git a/flagscale/train/train_qwen_gr00t.py b/flagscale/train/train_qwen_gr00t.py new file mode 100644 index 000000000..837e15c35 --- /dev/null +++ b/flagscale/train/train_qwen_gr00t.py @@ -0,0 +1,625 @@ +# Mainly adopted from +# https://github.com/huggingface/lerobot/blob/2b304eeb841ae6c371e3dd341bbbb9dd254b07cb/src/lerobot/scripts/lerobot_train.py + +import argparse +import os +import random +import time +from collections.abc import Iterator +from contextlib import nullcontext +from typing import Any + +from omegaconf import OmegaConf, DictConfig +import numpy as np +from PIL import Image +import torch +import torch.distributed as dist +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions +from torch.optim import Optimizer + +from flagscale.logger import logger +from flagscale.train.train_config import TrainConfig, DataConfig +from flagscale.train.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) +from flagscale.train.datasets.utils import dataset_to_policy_features +from flagscale.train.processor import PolicyProcessorPipeline +from flagscale.models.utils.constants import ACTION, OBS_PREFIX, REWARD +from flagscale.models.configs.types import FeatureType +from flagscale.train.utils.logging_utils import ( + AverageMeter, + MetricsTracker, + format_big_number, +) +from flagscale.train.utils.train_utils import ( + save_checkpoint, + get_step_checkpoint_dir, + update_last_checkpoint, +) +from flagscale.train.utils.optim_setup import setup_optimizer_and_scheduler +from flagscale.models.vla.qwen_gr00t import QwenGr00t + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = False + torch.backends.cuda.matmul.allow_tf32 = False + + +def apply_fsdp2(policy, device_mesh): + """Apply FSDP2 sharding to QwenGr00t. + + Uses a MixedPrecisionPolicy that matches DeepSpeed bf16 behavior: + bf16.enabled=true + ZeRO-2 → param_dtype=bf16, reduce_dtype=bf16, reshard=False + """ + # Cast everything to fp32 first so the root param group has uniform dtype. + policy = policy.float() + + # `reduce_dtype=torch.float32` would make evaluation on libero_goal drop to 94.8% (from 97.0%) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + ) + fsdp_config = {"mesh": device_mesh, "mp_policy": mp_policy} + + # reshard_after_forward=False keeps params unsharded during forward+backward + reshard = False + + for unit in policy.vlm.fsdp_units(): + fully_shard(unit, **fsdp_config, reshard_after_forward=reshard) + + for unit in policy.action_model.fsdp_units(): + fully_shard(unit, **fsdp_config, reshard_after_forward=reshard) + + fully_shard(policy, **fsdp_config) + + +def make_dataset(cfg: DataConfig): + # TODO: (yupu) Remove hard-coded video backend + # After not much testing, It feels like that `torchcodec` is more robust than `pyav` + # `pyav` crashes sometimes + video_backend = "torchcodec" + + def _resize_to_uint8_hwc(frame: torch.Tensor) -> torch.Tensor: + """float32 CHW [0,1] from torchcodec → uint8 HWC 224x224 via PIL resize.""" + + frame_uint8 = (frame.permute(1, 2, 0) * 255).round().clamp(0, 255).to(torch.uint8) + # PIL default is BICUBIC, matching starVLA's Image.fromarray(image).resize((224, 224)) + pil = Image.fromarray(frame_uint8.cpu().numpy()).resize((224, 224)) + return torch.from_numpy(np.array(pil)) + + image_transforms = _resize_to_uint8_hwc + # Leave the revision to None + ds_meta = LeRobotDatasetMetadata(root=cfg.data_path, revision=None) + delta_timestamps = resolve_delta_timestamps(cfg, ds_meta) + + dataset = LeRobotDataset( + root=cfg.data_path, + episodes=None, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=None, + video_backend=video_backend, + tolerance_s=cfg.tolerance_s, + ) + + return dataset + + +def resolve_delta_timestamps( + cfg: DataConfig, ds_meta: LeRobotDatasetMetadata +) -> dict[str, list] | None: + """Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig. + + Args: + cfg: The policy config (PI0Config or PI05Config) to read delta_indices from. + ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build + delta_timestamps against. + + Returns: + dict[str, list] | None: A dictionary of delta_timestamps, e.g.: + { + "observation.state": [-0.04, -0.02, 0] + "observation.action": [-0.02, 0, 0.02] + } + returns `None` if the resulting dict is empty. + """ + delta_timestamps = {} + for key in ds_meta.features: + if key == REWARD and cfg.reward_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] + if key == ACTION and cfg.action_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] + if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] + + if len(delta_timestamps) == 0: + delta_timestamps = None + + return delta_timestamps + + +# datasets/utils.py +def cycle(iterable: Any) -> Iterator[Any]: + """Create a dataloader-safe cyclical iterator. + + This is an equivalent of `itertools.cycle` but is safe for use with + PyTorch DataLoaders with multiple workers. + See https://github.com/pytorch/pytorch/issues/23900 for details. + + Args: + iterable: The iterable to cycle over. + + Yields: + Items from the iterable, restarting from the beginning when exhausted. + """ + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) + + +def format_train_tracker_step(train_tracker: MetricsTracker) -> str: + def _format_meter_val(meter: AverageMeter) -> str: + fmt = meter.fmt[1:] if meter.fmt.startswith(":") else meter.fmt + return f"{meter.name}:{format(meter.val, fmt)}" + + display_list = [ + f"step:{format_big_number(train_tracker.steps)}", + f"smpl:{format_big_number(train_tracker.samples)}", + f"ep:{format_big_number(train_tracker.episodes)}", + f"epch:{train_tracker.epochs:.2f}", + *[_format_meter_val(m) for m in train_tracker.metrics.values()], + ] + return " ".join(display_list) + + + +def make_policy( + config: TrainConfig, + ds_meta: LeRobotDatasetMetadata | None = None, +): + features = dataset_to_policy_features(ds_meta.features) + + # Use == instead of `is` for FeatureType.ACTION comparison + # because flagscale.FeatureType and lerobot.FeatureType are different enum classes + output_features = { + key: ft + for key, ft in features.items() + if ft.type == FeatureType.ACTION + } + input_features = {key: ft for key, ft in features.items() if key not in output_features} + + # TODO: (yupu) This is a hack, we should find a better way to handle this. LeRobot does this in the policy config. + # The order of the images is defined in the dataset config.json + image_features = { + key: ft for key, ft in input_features.items() if ft.type is FeatureType.VISUAL + } + config.data.vla_data.image_features = image_features + + policy = QwenGr00t(config=config) + policy.to("cuda") + + return policy, input_features, output_features + + +def make_preprocessor_from_config( + config: dict[str, Any] | list[str | dict[str, Any]], + overrides: dict[str, Any] | None = None, +) -> PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]: + """ + Create a preprocessor pipeline from step configurations with optional overrides. + + This function creates a PolicyProcessorPipeline directly from step configurations, + without requiring a pretrained path. It supports overriding step configurations + similar to PolicyProcessorPipeline.from_pretrained(). + + Args: + config: Can be either: + - A dict with "name" and "steps" fields (JSON format): + {"name": "policy_preprocessor", "steps": [...]} + - A list of step configurations (concise format): + ["step_name", {"step_name": {...}}] + overrides: Optional dictionary to override step configurations. Keys should + match the step's registry_name. Example: + {"device_processor": {"device": "cuda"}, + "normalizer_processor": {"stats": dataset.meta.stats}} + + Returns: + A PolicyProcessorPipeline instance with the configured steps. + + Example (JSON format with overrides): + ```python + config = { + "name": "policy_preprocessor", + "steps": [ + {"registry_name": "device_processor", "config": {"device": "cpu"}}, + {"registry_name": "normalizer_processor", "config": {"eps": 1e-8}}, + ], + } + overrides = { + "device_processor": {"device": "cuda"}, + "normalizer_processor": {"stats": dataset.meta.stats, "features": {...}}, + } + preprocessor = make_preprocessor_from_config(config, overrides=overrides) + # device_processor will use device="cuda" (overridden) + # normalizer_processor will use eps=1e-8 (from config) and stats from overrides + ``` + + Example (concise list format): + ```python + steps = [ + "rename_observations_processor", + "device_processor", + {"normalizer_processor": {"eps": 1e-8}}, + ] + preprocessor = make_preprocessor_from_config(steps) + ``` + + Raises: + ValueError: If a step configuration is invalid or step cannot be instantiated. + KeyError: If a registry name is not found. + """ + from flagscale.train.processor.pipeline import ProcessorStepRegistry + + overrides = overrides or {} + + # Determine format and extract step configs + if isinstance(config, (dict, DictConfig)) and "steps" in config: + # JSON format: {"name": "...", "steps": [...]} + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + step_configs = config["steps"] + pipeline_name = config.get("name", "policy_preprocessor") + elif isinstance(config, list): + # Concise list format + step_configs = config + pipeline_name = "policy_preprocessor" + else: + raise ValueError(f"Config must be a dict with 'steps' key or a list, got {type(config)}") + + steps = [] + for step_entry in step_configs: + # Determine step format and normalize to standard dict + if isinstance(step_entry, str): + # Concise format: "step_name" + step_dict = {"registry_name": step_entry, "config": {}} + elif isinstance(step_entry, (dict, DictConfig)): + if "registry_name" in step_entry: + # JSON format: {"registry_name": "...", "config": {...}} + if isinstance(step_entry, DictConfig): + step_entry = OmegaConf.to_container(step_entry, resolve=True) + step_dict = step_entry + elif len(step_entry) == 1: + # Concise format: {"step_name": {...}} + step_name = next(iter(step_entry.keys())) + step_config = step_entry[step_name] + if isinstance(step_config, DictConfig): + step_config = OmegaConf.to_container(step_config, resolve=True) + step_dict = {"registry_name": step_name, "config": step_config} + else: + raise ValueError( + f"Step config dict must have either 'registry_name' or exactly one key, " + f"got {list(step_entry.keys())}" + ) + else: + raise ValueError( + f"Step config must be str or dict, got {type(step_entry)}: {step_entry}" + ) + + # Get step class + registry_name = step_dict["registry_name"] + step_class = ProcessorStepRegistry.get(registry_name) + + # Merge config with overrides (overrides take precedence) + try: + base_config = step_dict.get("config", {}) + step_overrides = overrides.get(registry_name, {}) + merged_config = {**base_config, **step_overrides} + + step_instance = step_class(**merged_config) + steps.append(step_instance) + except Exception as e: + raise ValueError( + f"Failed to instantiate processor step '{registry_name}' " + f"with config {merged_config}. Error: {e!s}" + ) from e + + return PolicyProcessorPipeline( + steps=steps, + name=pipeline_name, + ) + + +def has_method(cls: object, method_name: str) -> bool: + return hasattr(cls, method_name) and callable(getattr(cls, method_name)) + + +def update_policy( + train_metrics: MetricsTracker, + policy, + batch: Any, + optimizer: Optimizer, + use_amp: bool, + grad_clip_norm: float, + lr_scheduler=None, + lock=None, +) -> MetricsTracker: + """ + Performs a single training step to update the policy's weights. + + This function executes the forward and backward passes, clips gradients, and steps the optimizer and + learning rate scheduler. + + Args: + train_metrics: A MetricsTracker instance to record training statistics. + policy: The policy model to be trained (FSDP2-sharded). + batch: A batch of training data. + optimizer: The optimizer used to update the policy's parameters. + use_amp: Whether to use automatic mixed precision. + grad_clip_norm: The maximum norm for gradient clipping. + lr_scheduler: An optional learning rate scheduler. + lock: An optional lock for thread-safe optimizer updates. + + Returns: + The updated MetricsTracker with new statistics for this step. + """ + start_time = time.perf_counter() + + optimizer.zero_grad() + + autocast_context = ( + torch.amp.autocast("cuda", dtype=torch.bfloat16) if use_amp else nullcontext() + ) + with autocast_context: + loss = policy(batch) + + loss.backward() + + # Clip gradients (torch.nn.utils.clip_grad_norm_ works with DTensors in PyTorch ≥2.6) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), grad_clip_norm if grad_clip_norm > 0 else float("inf") + ) + + with lock if lock is not None else nullcontext(): + optimizer.step() + + # Step through pytorch scheduler at every batch instead of epoch + if lr_scheduler is not None: + lr_scheduler.step() + + # Update internal buffers if policy has update method + if has_method(policy, "update"): + policy.update() + + train_metrics.loss = loss.item() + train_metrics.grad_norm = grad_norm.full_tensor().item() if hasattr(grad_norm, 'full_tensor') else grad_norm.item() + train_metrics.lr = optimizer.param_groups[0]["lr"] + train_metrics.update_s = time.perf_counter() - start_time + + return train_metrics + + +def main(config: TrainConfig, seed: int): + set_seed(seed) + + # --- Distributed init --- + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + is_main_process = rank == 0 + + dataset = make_dataset(config.data) + dist.barrier() + + policy, input_features, output_features = make_policy(config=config, ds_meta=dataset.meta) + dist.barrier() + + # --- Apply FSDP2 --- + device_mesh = init_device_mesh("cuda", (world_size,)) + apply_fsdp2(policy, device_mesh) + + # Create processors - only provide dataset_stats if not resuming from saved processors + preprocessor_overrides = { + "device_processor": {"device": device.type}, + "normalizer_processor": { + "stats": dataset.meta.stats, + "features": { + **input_features, + **output_features, + }, + }, + } + + num_workers = 0 # config.system.num_workers + shuffle = config.system.shuffle + + # DistributedSampler ensures each rank gets different data + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + drop_last=False, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=config.system.batch_size, + shuffle=False, # Must be False when using sampler + sampler=sampler, + pin_memory=True, + drop_last=False, + prefetch_factor=2 if num_workers > 0 else None, + ) + + # Setup preprocessor + preprocessor = None + if config.data.preprocessor is not None: + preprocessor = make_preprocessor_from_config( + config.data.preprocessor, overrides=preprocessor_overrides + ) + + # Setup postprocessor (unnormalization for inference) + postprocessor = None + postprocessor_config = getattr(config.data, "postprocessor", None) + if postprocessor_config is not None: + postprocessor_overrides = { + "unnormalizer_processor": { + "stats": dataset.meta.stats, + "features": { + **input_features, + **output_features, + }, + }, + } + postprocessor = make_preprocessor_from_config( + postprocessor_config, overrides=postprocessor_overrides + ) + + # Setup optimizer and scheduler (applies freeze config internally) + optimizer, lr_scheduler = setup_optimizer_and_scheduler(policy, config) + + dist.barrier() + + dl_iter = cycle(dataloader) + + train_metrics = { + "loss": AverageMeter("loss", ":.3f"), + "grad_norm": AverageMeter("grdn", ":.3f"), + "lr": AverageMeter("lr", ":0.1e"), + "update_s": AverageMeter("updt_s", ":.3f"), + "dataloading_s": AverageMeter("data_s", ":.3f"), + } + + effective_batch_size = config.system.batch_size * world_size + + step = 0 + + train_tracker = MetricsTracker( + effective_batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, + ) + + # Ensures proper data shuffling across epochs in distributed training + epoch = 0 + samples_per_epoch = len(dataset) // effective_batch_size + sampler.set_epoch(epoch) + + for _ in range(step, config.system.train_steps): + start_time = time.perf_counter() + batch = next(dl_iter) + batch = { + k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + + if preprocessor is not None: + batch = preprocessor(batch) + train_tracker.dataloading_s = time.perf_counter() - start_time + + train_tracker = update_policy( + train_tracker, + policy, + batch, + optimizer, + use_amp=config.system.use_amp, + grad_clip_norm=config.system.grad_clip_norm, + lr_scheduler=lr_scheduler, + ) + + step += 1 + train_tracker.step() + + # Update epoch counter for sampler.set_epoch() when we've processed one epoch worth of samples + # This ensures proper data shuffling across epochs in distributed training + if samples_per_epoch > 0 and step % samples_per_epoch == 0: + epoch += 1 + sampler.set_epoch(epoch) + + if step % config.system.log_freq == 0 and is_main_process: + logger.info(f"step: {step} {format_train_tracker_step(train_tracker)}") + train_tracker.reset_averages() + + if ( + config.system.checkpoint.save_checkpoint + and step % config.system.checkpoint.save_freq == 0 + ): + dist.barrier() + + # get_model_state_dict is a collective — all ranks must call it + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + state_dict = get_model_state_dict(policy, options=options) + + if is_main_process: + from pathlib import Path + + logger.info(f"Saving checkpoint at step {step}") + output_dir = Path(config.system.checkpoint.output_directory) + checkpoint_dir = get_step_checkpoint_dir( + output_dir, config.system.train_steps, step + ) + save_checkpoint( + checkpoint_dir=checkpoint_dir, + policy=state_dict, + config=config, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + update_last_checkpoint(checkpoint_dir) + + dist.barrier() + + if is_main_process: + logger.info("Training completed") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Train QwenGr00t model. This script is typically called by the flagscale runner, not directly." + ) + parser.add_argument( + "--config-file", type=str, required=True, help="Path to the configuration YAML file" + ) + args = parser.parse_args() + + config_file_path = args.config_file + + # Load config from YAML file (Hydra-generated config.yaml contains both train and experiment) + config = OmegaConf.load(config_file_path) + + logger.info(f"full config: {config}") + + # Extract train config and convert to Pydantic TrainConfig (preserves raw configs) + train_config = TrainConfig.from_hydra_config(config) + + # Extract experiment config (seed, exp_dir, etc.) + experiment_config = OmegaConf.to_container(config.experiment, resolve=True) + seed = experiment_config.get("seed", 42) + + logger.info("=" * 100) + logger.info(f"Experiment: {experiment_config}") + logger.info(f"Train config: {train_config}") + + main(train_config, seed) diff --git a/flagscale/train/utils/optim_setup.py b/flagscale/train/utils/optim_setup.py new file mode 100644 index 000000000..47cf7bf1e --- /dev/null +++ b/flagscale/train/utils/optim_setup.py @@ -0,0 +1,383 @@ +"""Optimizer setup utilities: parameter freezing and per-module optimizer config. + +Supports: +- Freezing parameters via regex patterns +- Per-module optimizer settings (lr, weight_decay, betas, etc.) via config +- LR scheduler nested under optimizer + +Example config (YAML): + model: + optimizer: + lr: 1e-4 + weight_decay: 0.01 + param_groups: + qwen_backbone: + lr: 1e-5 + action_head: + lr: 2e-4 + weight_decay: 0.0 + scheduler: + name: cosine_with_min_lr + warmup_steps: 5000 + scheduler_kwargs: + min_lr: 1.0e-06 + freeze: + freeze_patterns: ["backbone.*"] +""" + +import re +from collections import defaultdict +from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn +from transformers import get_scheduler + +from flagscale.logger import logger + +if TYPE_CHECKING: + from flagscale.train.train_config import ( + FreezeConfig, + OptimizerConfig, + SchedulerConfig, + TrainConfig, + ) + + +class PatternMatcher: + """Helper for matching parameter names against regex patterns with usage tracking.""" + + def __init__(self, patterns: list[str]): + self.patterns = patterns + self.compiled = [re.compile(p) for p in patterns] + self.match_counts = {p: 0 for p in patterns} + + def matches(self, name: str) -> bool: + for i, pattern in enumerate(self.compiled): + if pattern.search(name): + self.match_counts[self.patterns[i]] += 1 + return True + return False + + def get_unused_patterns(self) -> list[str]: + return [p for p, count in self.match_counts.items() if count == 0] + + +def freeze_and_get_trainable_params( + named_parameters: Iterable[tuple[str, torch.nn.Parameter]], + freeze_patterns: list[str] | None = None, + keep_patterns: list[str] | None = None, +) -> Generator[torch.nn.Parameter, None, None]: + """ + Freeze parameters matching patterns and yield only trainable parameters. + + Args: + named_parameters: Output of model.named_parameters() + freeze_patterns: Regex patterns for params to freeze + keep_patterns: Regex patterns for params to keep trainable (overrides freeze_patterns) + + Yields: + Only parameters that should be trained (for optimizer). + """ + freeze_matcher = PatternMatcher(freeze_patterns or []) + keep_matcher = PatternMatcher(keep_patterns or []) + + trainable_count, frozen_count = 0, 0 + previously_frozen_now_trainable = [] + + for name, param in named_parameters: + should_freeze = freeze_matcher.matches(name) and not keep_matcher.matches(name) + + if should_freeze: + param.requires_grad = False + frozen_count += param.numel() + else: + # Only force parameters to be trainable if freeze patterns are provided. + # Otherwise, preserve the original requires_grad state. + if freeze_patterns: + if not param.requires_grad: + previously_frozen_now_trainable.append(name) + param.requires_grad = True + if param.requires_grad: + trainable_count += param.numel() + yield param + else: + frozen_count += param.numel() + + # Log summary + total = trainable_count + frozen_count + pct = trainable_count / total if total > 0 else 0 + logger.info( + f"Parameters: trainable={trainable_count:,} ({pct:.2%}) | " + f"frozen={frozen_count:,} | total={total:,}" + ) + + if previously_frozen_now_trainable: + logger.warning( + f"{len(previously_frozen_now_trainable)} parameter(s) were already frozen " + f"(requires_grad=False) but don't match any freeze pattern and are being " + f"made trainable. Add them to freeze_patterns if they should stay frozen:" + ) + for name in previously_frozen_now_trainable: + logger.warning(f" unfrozen: {name}") + + # Warn about unused patterns + unused_freeze = freeze_matcher.get_unused_patterns() + if unused_freeze: + logger.warning(f"Freeze patterns matched nothing: {unused_freeze}") + + unused_keep = keep_matcher.get_unused_patterns() + if unused_keep: + logger.warning(f"Keep patterns matched nothing: {unused_keep}") + + +def apply_freeze_config(model: nn.Module, freeze_config) -> list: + """ + Apply freeze config and return list of trainable parameters for optimizer. + + Args: + model: The model to freeze + freeze_config: FreezeConfig with freeze_patterns and keep_patterns + + Returns: + List of trainable parameters (pass directly to optimizer) + """ + if freeze_config is None: + return list(model.parameters()) + + return list( + freeze_and_get_trainable_params( + model.named_parameters(), + freeze_patterns=freeze_config.freeze_patterns, + keep_patterns=freeze_config.keep_patterns, + ) + ) + + +def log_trainable_params(model: nn.Module) -> dict: + """Log trainable/frozen parameter statistics by module.""" + trainable_by_module = defaultdict(int) + frozen_by_module = defaultdict(int) + + for name, param in model.named_parameters(): + module_name = name.split(".")[0] + if param.requires_grad: + trainable_by_module[module_name] += param.numel() + else: + frozen_by_module[module_name] += param.numel() + + logger.info("=" * 60) + logger.info("Parameter status by top-level module:") + all_modules = set(trainable_by_module.keys()) | set(frozen_by_module.keys()) + for mod in sorted(all_modules): + t = trainable_by_module.get(mod, 0) + f = frozen_by_module.get(mod, 0) + logger.info(f" {mod}: {t:,} trainable, {f:,} frozen") + logger.info("=" * 60) + + return {"trainable": dict(trainable_by_module), "frozen": dict(frozen_by_module)} + + +def print_param_names(model: nn.Module, pattern: str | None = None): + """Debug helper: print parameter names (optionally filtered by pattern).""" + for name, param in model.named_parameters(): + if pattern is None or re.search(pattern, name): + status = "trainable" if param.requires_grad else "FROZEN" + print(f"[{status}] {name}: {param.numel():,} params") + + +# TODO: (yupu) Freeze supports regex patterns, but param groups uses exact module paths. See if this is reasonable. +def build_optim_param_groups( + model: nn.Module, + optim_param_groups_config: dict[str, dict[str, Any]] | None = None, +) -> list[dict]: + """ + Build optimizer param groups with per-module settings. + + Each module can have its own optimizer hyperparameters (lr, weight_decay, betas, etc.). + Parameters not belonging to any specified module go into a default group. + + Args: + model: The model to create param groups for. + optim_param_groups_config: Dict mapping module names to optimizer kwargs. + Example: {"encoder": {"lr": 1e-5}, "decoder": {"lr": 1e-3, "weight_decay": 0.01}} + Supports nested paths like "action_head.mlp". + + Returns: + List of param group dicts for optimizer. + """ + if optim_param_groups_config is None: + return [{"params": [p for p in model.parameters() if p.requires_grad]}] + + param_groups = [] + used_param_ids = set() + + for module_name, group_config in optim_param_groups_config.items(): + try: + module = model.get_submodule(module_name) + except AttributeError: + logger.warning( + f"build_optim_param_groups: Module '{module_name}' not found in model, skipping." + ) + continue + + # All trainable params for this module (including descendants) + module_params = [p for p in module.parameters() if p.requires_grad] + if not module_params: + logger.warning( + f"build_optim_param_groups: Module '{module_name}' has no trainable parameters." + ) + continue + # Avoid assigning the same parameter to multiple param groups by + # filtering out parameters that are already used by previous groups. + params = [p for p in module_params if id(p) not in used_param_ids] + if not params: + # All trainable params for this module were already included in + # previous param groups. This usually indicates overlapping + # module paths in the optimizer config (e.g., both "encoder" + # and "encoder.layer1"). + logger.warning( + "build_optim_param_groups: All trainable parameters for module " + f"'{module_name}' are already assigned to previous param groups. " + "This suggests overlapping module paths in the optimizer " + "configuration; this group will be skipped." + ) + continue + if len(params) < len(module_params): + # Some, but not all, parameters were already assigned to previous + # groups. Warn the user so they are aware of the partial overlap. + logger.warning( + "build_optim_param_groups: Some trainable parameters for module " + f"'{module_name}' are already assigned to previous param groups " + "(overlapping module paths). Only unassigned parameters will be " + "included in this group." + ) + + used_param_ids.update(id(p) for p in params) + param_groups.append({"params": params, "name": module_name, **group_config}) + + param_count = sum(p.numel() for p in params) + logger.info(f"Param group '{module_name}': {param_count:,} params, {group_config}") + + # Remaining params go to default group + other_params = [ + p for p in model.parameters() if p.requires_grad and id(p) not in used_param_ids + ] + if other_params: + param_groups.insert(0, {"params": other_params, "name": "default"}) + logger.info(f"Param group 'default': {sum(p.numel() for p in other_params):,} params") + + return param_groups + + +def setup_optimizer( + model: nn.Module, + optimizer_config: "OptimizerConfig", + freeze_config: "FreezeConfig | None" = None, +) -> torch.optim.Optimizer: + """ + One-stop setup: apply freeze config, build param groups, create optimizer. + + Args: + model: The model to optimize. + optimizer_config: OptimizerConfig with name, lr, betas, eps, weight_decay, param_groups. + freeze_config: FreezeConfig with freeze_patterns and keep_patterns. + + Returns: + Configured optimizer instance. + """ + if freeze_config is not None: + apply_freeze_config(model, freeze_config) + log_trainable_params(model) + + param_groups = build_optim_param_groups(model, optimizer_config.param_groups) + total_params = sum(len(g["params"]) for g in param_groups) + if not total_params: + raise ValueError( + "No trainable parameters found. All parameters may be frozen, " + "or configured param groups have no trainable parameters." + ) + + optimizer_kwargs = {"params": param_groups, **optimizer_config.get_optimizer_kwargs()} + + # Get optimizer class by name + optimizer_cls = _get_optimizer_class(optimizer_config.name) + return optimizer_cls(**optimizer_kwargs) + + +# Supported optimizers +_OPTIMIZER_REGISTRY: dict[str, type[torch.optim.Optimizer]] = { + "AdamW": torch.optim.AdamW, +} + + +def _get_optimizer_class(name: str) -> type[torch.optim.Optimizer]: + """Get optimizer class by name.""" + if name not in _OPTIMIZER_REGISTRY: + supported = list(_OPTIMIZER_REGISTRY.keys()) + raise ValueError(f"Unsupported optimizer: {name}. Supported: {supported}") + return _OPTIMIZER_REGISTRY[name] + + +def setup_scheduler( + optimizer: torch.optim.Optimizer, + scheduler_config: "SchedulerConfig", + num_training_steps: int, +) -> torch.optim.lr_scheduler.LRScheduler: + """ + Create LR scheduler using transformers' get_scheduler. + + Args: + optimizer: The optimizer to schedule. + scheduler_config: Config with name, warmup_steps, scheduler_kwargs. + num_training_steps: Total training steps. + + Returns: + A learning rate scheduler instance. + + Raises: + ValueError: If scheduler_config.name is None. + """ + + if scheduler_config.name is None: + raise ValueError("scheduler_config.name must be specified to use setup_scheduler") + + return get_scheduler( + name=scheduler_config.name, + optimizer=optimizer, + num_warmup_steps=scheduler_config.warmup_steps, + num_training_steps=num_training_steps, + scheduler_specific_kwargs=scheduler_config.scheduler_kwargs, + ) + + +def setup_optimizer_and_scheduler( + model: nn.Module, + train_config: "TrainConfig", +) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """ + One-stop setup for both optimizer and scheduler from TrainConfig. + + Args: + model: The model to optimize. + train_config: TrainConfig containing model (optimizer, scheduler, + freeze config) and system (train_steps). + + Returns: + Tuple of (optimizer, lr_scheduler). + + Raises: + ValueError: If scheduler_config.name is None. + """ + optimizer = setup_optimizer( + model, + train_config.model.optimizer, + freeze_config=train_config.model.freeze, + ) + scheduler = setup_scheduler( + optimizer, + train_config.model.optimizer.scheduler, + num_training_steps=train_config.system.train_steps, + ) + return optimizer, scheduler diff --git a/flagscale/train/utils/train_utils.py b/flagscale/train/utils/train_utils.py index 6a0b6e1c8..45ee857d7 100644 --- a/flagscale/train/utils/train_utils.py +++ b/flagscale/train/utils/train_utils.py @@ -17,24 +17,17 @@ # limitations under the License. from pathlib import Path -# from lerobot.configs.train import TrainPipelineConfig -from flagscale.models.pi0.modeling_pi0 import PI0Policy +from omegaconf import OmegaConf +from safetensors.torch import load_model, save_file -# from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state -# from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state -# from lerobot.policies.pretrained import PreTrainedPolicy -# from lerobot.processor import PolicyProcessorPipeline from flagscale.models.utils.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, - # TRAINING_STATE_DIR, TRAINING_STEP, ) from flagscale.train.datasets.utils import load_json, write_json -# from lerobot.utils.random_utils import load_rng_state, save_rng_state - def get_step_identifier(step: int, total_steps: int) -> str: num_digits = max(6, len(str(total_steps))) @@ -66,102 +59,130 @@ def update_last_checkpoint(checkpoint_dir: Path) -> Path: def save_checkpoint( checkpoint_dir: Path, - # step: int, - # cfg: PI0Config, - policy: PI0Policy, - # optimizer: Optimizer, - # scheduler: LRScheduler | None = None, - # preprocessor: PolicyProcessorPipeline | None = None, - # postprocessor: PolicyProcessorPipeline | None = None, + policy, + config, + preprocessor=None, + postprocessor=None, ) -> None: - """This function creates the following directory structure: - - 005000/ # training step at checkpoint - ├── pretrained_model/ - │ ├── config.json # policy config - │ ├── model.safetensors # policy weights - │ ├── train_config.json # train config - │ ├── processor.json # processor config (if preprocessor provided) - │ └── step_*.safetensors # processor state files (if any) - └── training_state/ - ├── optimizer_param_groups.json # optimizer param groups - ├── optimizer_state.safetensors # optimizer state - ├── rng_state.safetensors # rng states - ├── scheduler_state.json # scheduler state - └── training_step.json # training step + """Save model weights, config, and preprocessor state. + + Creates the following directory structure: + 005000/ + └── pretrained_model/ + ├── train_config.yaml # train config (OmegaConf) + ├── model.safetensors # All weights (VLM + action head) + ├── policy_preprocessor.json # Preprocessor pipeline config + └── policy_preprocessor_step_*.safetensors # Norm stats + + Args: + checkpoint_dir: Directory to save checkpoint (e.g., checkpoints/005000) + policy: The model + config: Training config (OmegaConf, Pydantic, or dict) + preprocessor: Optional PolicyProcessorPipeline + """ + pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR + pretrained_dir.mkdir(parents=True, exist_ok=True) + + # Save train config as YAML + # Handle OmegaConf, Pydantic, and dict configs + if hasattr(config, "model_dump"): + config = OmegaConf.create(config.model_dump()) + elif not OmegaConf.is_config(config): + config = OmegaConf.create(config) + OmegaConf.save(config, pretrained_dir / "train_config.yaml") + + # Accept either a model or a pre-gathered state_dict (e.g. from FSDP2). + # Clone tensors to avoid safetensors errors with non-contiguous views. + if isinstance(policy, dict): + state_dict = {k: v.clone().contiguous() for k, v in policy.items()} + else: + state_dict = {k: v.clone().contiguous() for k, v in policy.state_dict().items()} + save_file(state_dict, pretrained_dir / "model.safetensors") + + if preprocessor is not None: + preprocessor.save_pretrained(pretrained_dir) + if postprocessor is not None: + postprocessor.save_pretrained(pretrained_dir) + + +def load_checkpoint( + checkpoint_dir: str | Path, + model_cls, + device: str = "cpu", +): + """Load config, model weights, and preprocessor from checkpoint. Args: - cfg (TrainPipelineConfig): The training config used for this run. - step (int): The training step at that checkpoint. - policy (PreTrainedPolicy): The policy to save. - optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. - scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. - preprocessor: The preprocessor/pipeline to save. Defaults to None. + checkpoint_dir: Checkpoint directory (e.g., checkpoints/005000) + model_cls: Model class. + device: Device to load weights to + + Returns: + If model_cls provided: tuple of (model, preprocessor) + If model_cls is None: tuple of (config, state_dict, preprocessor) + + Raises: + FileNotFoundError: If checkpoint directory or required files don't exist """ + from flagscale.train.processor import PolicyProcessorPipeline + + print(f"Loading checkpoint from {checkpoint_dir}") + + if isinstance(checkpoint_dir, str): + checkpoint_dir = Path(checkpoint_dir) + pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR - policy.save_pretrained(pretrained_dir) - # cfg.save_pretrained(pretrained_dir) - # if preprocessor is not None: - # preprocessor.save_pretrained(pretrained_dir) - # if postprocessor is not None: - # postprocessor.save_pretrained(pretrained_dir) - # save_training_state(checkpoint_dir, step, optimizer, scheduler) - - -# def save_training_state( -# checkpoint_dir: Path, -# train_step: int, -# optimizer: Optimizer | None = None, -# scheduler: LRScheduler | None = None, -# ) -> None: -# """ -# Saves the training step, optimizer state, scheduler state, and rng state. - -# Args: -# save_dir (Path): The directory to save artifacts to. -# train_step (int): Current training step. -# optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict. -# Defaults to None. -# scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. -# Defaults to None. -# """ -# save_dir = checkpoint_dir / TRAINING_STATE_DIR -# save_dir.mkdir(parents=True, exist_ok=True) -# save_training_step(train_step, save_dir) -# save_rng_state(save_dir) -# if optimizer is not None: -# save_optimizer_state(optimizer, save_dir) -# if scheduler is not None: -# save_scheduler_state(scheduler, save_dir) - - -# def load_training_state( -# checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None -# ) -> tuple[int, Optimizer, LRScheduler | None]: -# """ -# Loads the training step, optimizer state, scheduler state, and rng state. -# This is used to resume a training run. - -# Args: -# checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. -# optimizer (Optimizer): The optimizer to load the state_dict to. -# scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). - -# Raises: -# NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir - -# Returns: -# tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their -# state_dict loaded. -# """ -# training_state_dir = checkpoint_dir / TRAINING_STATE_DIR -# if not training_state_dir.is_dir(): -# raise NotADirectoryError(training_state_dir) - -# load_rng_state(training_state_dir) -# step = load_training_step(training_state_dir) -# optimizer = load_optimizer_state(optimizer, training_state_dir) -# if scheduler is not None: -# scheduler = load_scheduler_state(scheduler, training_state_dir) - -# return step, optimizer, scheduler + + if not pretrained_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint directory not found: {pretrained_dir}") + + config_path = pretrained_dir / "train_config.yaml" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + config = OmegaConf.load(config_path) + + model = model_cls(config) + + weights_path = pretrained_dir / "model.safetensors" + if not weights_path.exists(): + raise FileNotFoundError(f"Weights file not found: {weights_path}") + # strict=False to handle tied weights saved as separate entries + missing_keys, unexpected_keys = load_model(model, weights_path, device=device, strict=False) + if missing_keys: + print(f"Warning: Missing keys when loading checkpoint: {len(missing_keys)} keys") + if len(missing_keys) <= 10: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:10]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 10} more") + if unexpected_keys: + print(f"Warning: Unexpected keys in checkpoint: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 10: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:10]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 10} more") + + model.to(device) + + preprocessor = None + preprocessor_config_path = pretrained_dir / "policy_preprocessor.json" + if preprocessor_config_path.exists(): + preprocessor = PolicyProcessorPipeline.from_pretrained( + pretrained_dir, + config_filename="policy_preprocessor.json", + ) + + postprocessor = None + postprocessor_config_path = pretrained_dir / "policy_postprocessor.json" + if postprocessor_config_path.exists(): + postprocessor = PolicyProcessorPipeline.from_pretrained( + pretrained_dir, + config_filename="policy_postprocessor.json", + ) + + return model, preprocessor, postprocessor diff --git a/tests/unit_tests/inference/test_qwen3_vl_apply_chat_template.py b/tests/unit_tests/inference/test_qwen3_vl_apply_chat_template.py new file mode 100644 index 000000000..d54ea2023 --- /dev/null +++ b/tests/unit_tests/inference/test_qwen3_vl_apply_chat_template.py @@ -0,0 +1,67 @@ +import os + +import pytest +import torch + +from flagscale.models.vlm.qwen3_vl import DEFAULT_IMAGE_TOKEN +from flagscale.train.utils.image_tools import to_pil_preserve + + +def _load_processor(): + pytest.importorskip("transformers") + from transformers import AutoProcessor + + model_id = os.environ.get("QWEN3_VL_TEST_MODEL", "Qwen/Qwen3-VL-4B-Instruct") + try: + return AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + except Exception as exc: + pytest.skip(f"Unable to load processor for {model_id}: {exc}") + + +def test_apply_chat_template_batched_images_match_per_sample_messages(): + processor = _load_processor() + batch_size = 2 + num_images = 2 + height = 32 + width = 32 + images = torch.rand(batch_size, num_images, 3, height, width) + pil_images = [ + [to_pil_preserve(img.permute(1, 2, 0).numpy()) for img in sample] for sample in images + ] + + instruction = "Describe." + per_sample_messages = [] + for sample_images in pil_images: + content = [{"type": "image", "image": img} for img in sample_images] + content.append({"type": "text", "text": instruction}) + per_sample_messages.append({"role": "user", "content": content}) + + rendered_from_messages = processor.apply_chat_template( + per_sample_messages, + tokenize=False, + add_generation_prompt=True, + ) + + prompt = f"{DEFAULT_IMAGE_TOKEN}\n" * num_images + instruction + batched_messages = [ + {"role": "user", "content": [{"type": "text", "text": prompt}]} + ] * batch_size + rendered_from_prefix = processor.apply_chat_template( + batched_messages, + tokenize=False, + add_generation_prompt=True, + ) + + list_inputs = processor( + text=rendered_from_messages, + images=[img for sample in pil_images for img in sample], + padding=True, + return_tensors="pt", + ) + batched_inputs = processor( + text=rendered_from_prefix, + images=images.view(-1, 3, height, width), + padding=True, + return_tensors="pt", + ) + assert torch.equal(list_inputs["input_ids"], batched_inputs["input_ids"]) diff --git a/tests/unit_tests/models/vla/qwen_gr00t_ref.py b/tests/unit_tests/models/vla/qwen_gr00t_ref.py new file mode 100644 index 000000000..7aa818ca7 --- /dev/null +++ b/tests/unit_tests/models/vla/qwen_gr00t_ref.py @@ -0,0 +1,206 @@ +# Mainly adopted from: +# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/starVLA/model/framework/QwenGR00T.py +# Below is the original copyright: + +# Copyright 2025 starVLA community. All rights reserved. +# Licensed under the MIT License, Version 1.0 (the "License"); +# Implemented by [Junqiu YU / Fudan University] in [2025]. +# Design and Merged by [Jinhui YE / HKUST University] in [2025]. + +""" +Qwen-GR00T Framework +A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions +Flow-matching header is copyright from GR00T N1.5, +""" + +import numpy as np +import torch +from transformers import PretrainedConfig, PreTrainedModel + +from flagscale.models.action_model.gr00t_action_header import FlowmatchingActionHead +from flagscale.models.utils.constants import ACTION, OBS_STATE + +# from flagscale.models.vlm.qwen2_5_vl import _QWen_VL_Interface +from flagscale.models.vlm.qwen3_vl import _QWen3_VL_Interface +from flagscale.train.utils.image_tools import to_pil_preserve +from flagscale.train.utils.trainer_tools import resize_images + + +class QwenGR00T(PreTrainedModel): + """ + Multimodal vision-language-action model. + + Components: + - Qwen2.5 VL interface for fused language/vision token embeddings + - Layer-wise QFormer for multi-layer feature aggregation + - DINO encoder for dense multi-view spatial tokens + - DiT diffusion head for future action sequence modeling + + Focus: Predict future continuous actions conditioned on images + instruction. + """ + + config_class = PretrainedConfig + + def __init__( + self, + config: dict | None = None, + **kwargs, + ) -> None: + """ + Construct all submodules and cache key configuration values. + + Args: + config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. + **kwargs: Reserved for future overrides (unused). + """ + super().__init__(PretrainedConfig()) + self.config = config + # self.qwen_vl_interface = _QWen_VL_Interface(config=self.config) + self.qwen_vl_interface = _QWen3_VL_Interface(config=self.config) + # align dims --> we should put them to config or no? + self.config.model.action_model.diffusion_model_cfg.cross_attention_dim = ( + self.qwen_vl_interface.model.config.hidden_size + ) + + self.action_model: FlowmatchingActionHead = FlowmatchingActionHead( + full_config=self.config + ) # 修复后续引用 + + self.future_action_window_size = config.model.action_model.future_action_window_size + self.past_action_window_size = config.model.action_model.past_action_window_size + self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size + + def forward( + self, + examples: list[dict] | None = None, + **kwargs, + ) -> tuple: + """ """ + # FIXME: state is None + # from torchvision import transforms + # image_transform = transforms.ToPILImage() + + # batch_images = [example["image"] for example in examples] # [B,[PLT]] + # instructions = [example["lang"] for example in examples] # [B, str] + # actions = [example["action"] for example in examples] # label [B, len, 7] + + actions = examples[ACTION] + state = examples[OBS_STATE] + + # state = ( + # [example["state"] for example in examples] if "state" in examples[0] else None + # ) # [B, 1, state_dim] + + # Step 1: QWenVL input format + qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( + examples=examples, + image_keys=self.config.data.vla_data.image_features, + # images=batch_images, instructions=instructions + ) + + # print(f"qwen_inputs: {qwen_inputs}") + + with torch.autocast("cuda", dtype=torch.bfloat16): + qwenvl_outputs = self.qwen_vl_interface( + **qwen_inputs, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + # last_hidden_state: [B, seq_len, H] + last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H] + + # Step 4: Action Expert Forward and Loss + with torch.autocast("cuda", dtype=torch.float32): + # TODO: (yupu) Is this a bug or a feature? The action dtype would stay as bf16 under this autocast. + if isinstance(actions, torch.Tensor): + actions = actions.to(device=last_hidden.device, dtype=last_hidden.dtype) + else: + actions = torch.tensor( + np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype + ) + # TODO: does not match RoboBrainX, need to check + # actions = torch.tensor( + # np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype + # ) # [B, T_full, action_dim] + actions_target = actions[ + :, -(self.future_action_window_size + 1) :, : + ] # (B, chunk_len, action_dim) + + # TODO: (yupu) I believe there is a bug in starVLA, the + # `repeated_diffusion_steps` is not properly set in the config. + repeated_diffusion_steps = self.config.model.action_model.get( + "repeated_diffusion_steps", 4 + ) + + actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) + last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1) + + state_repeated = None + if state is not None: + state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) + state_repeated = state.repeat(repeated_diffusion_steps, 1, 1) + + action_loss = self.action_model( + last_hidden_repeated, actions_target_repeated, state_repeated + ) # (B, chunk_len, action_dim) + + return action_loss + + @torch.inference_mode() + def predict_action( + self, + examples: list[dict], + **kwargs: str, + ) -> np.ndarray: + """ + Steps: + 1. Resize images to training resolution (if specified) + 2. Encode with QwenVL (hidden states retained) + 6. Return normalized action trajectory + Returns: + dict: + normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. + """ + if type(examples) is not list: + examples = [examples] + batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B, [PLT]] + instructions = [example["lang"] for example in examples] # [B, str] + + state = ( + [example["state"] for example in examples] if "state" in examples[0] else None + ) # [B, 1, state_dim] + + train_obs_image_size = getattr(self.config.data.vla_data, "image_size", None) + if train_obs_image_size: + batch_images = resize_images(batch_images, target_size=train_obs_image_size) + + # Step 1: QWenVL input format + qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( + images=batch_images, instructions=instructions + ) + with torch.autocast("cuda", dtype=torch.bfloat16): + qwenvl_outputs = self.qwen_vl_interface( + **qwen_inputs, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + + # last_hidden_state: [B, seq_len, H] + last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H] + + state = ( + torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) + if state is not None + else None + ) + + # Step 4: Action Expert Forward + with torch.autocast("cuda", dtype=torch.float32): + pred_actions = self.action_model.predict_action( + last_hidden, state + ) # (B, chunk_len, action_dim) + + normalized_actions = pred_actions.detach().cpu().numpy() + return {"normalized_actions": normalized_actions} diff --git a/tests/unit_tests/models/vla/test_protocols.py b/tests/unit_tests/models/vla/test_protocols.py new file mode 100644 index 000000000..edcd3a51c --- /dev/null +++ b/tests/unit_tests/models/vla/test_protocols.py @@ -0,0 +1,57 @@ +import unittest + +import torch + + +class MockVLM: + @property + def config(self): + return {"hidden_size": 1024} + + def prepare_input(self, batch): + return batch + + def forward(self, batch, **kwargs): + return {"hidden_states": (torch.randn(1, 10, 1024),)} + + +class MockActionModel: + def forward(self, vlm_output, action_input, **kwargs): + return {"loss": torch.tensor(0.5)} + + def predict(self, vlm_output, action_input, **kwargs): + return {"actions": torch.randn(1, 16, 7)} + + +class TestVLMBackboneProtocol(unittest.TestCase): + def test_mock_vlm_has_protocol_methods(self): + vlm = MockVLM() + self.assertTrue(hasattr(vlm, "config")) + self.assertTrue(hasattr(vlm, "prepare_input")) + self.assertTrue(hasattr(vlm, "forward")) + + output = vlm.forward({}) + self.assertIn("hidden_states", output) + + +class TestActionModelProtocol(unittest.TestCase): + def test_mock_action_model_has_protocol_methods(self): + model = MockActionModel() + self.assertTrue(hasattr(model, "forward")) + self.assertTrue(hasattr(model, "predict")) + + def test_forward_returns_loss(self): + model = MockActionModel() + vlm_output = {"hidden_states": (torch.randn(1, 10, 1024),)} + action_input = {"actions": torch.randn(1, 16, 7)} + + output = model.forward(vlm_output, action_input) + self.assertIn("loss", output) + + def test_predict_returns_actions(self): + model = MockActionModel() + vlm_output = {"hidden_states": (torch.randn(1, 10, 1024),)} + action_input = {} + + pred = model.predict(vlm_output, action_input) + self.assertIn("actions", pred) diff --git a/tests/unit_tests/models/vla/test_qwen_gr00t_parity.py b/tests/unit_tests/models/vla/test_qwen_gr00t_parity.py new file mode 100644 index 000000000..0ba0bfddd --- /dev/null +++ b/tests/unit_tests/models/vla/test_qwen_gr00t_parity.py @@ -0,0 +1,153 @@ +import unittest + +import torch +from omegaconf import OmegaConf + +from flagscale.models.utils.constants import ACTION, OBS_STATE + + +class TestQwenGR00TParity(unittest.TestCase): + """ + End-to-end parity test between QwenGR00T and QwenGr00t. + + Note: This test requires GPU and the actual model weights. + Skip in CI environments without GPU. + """ + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU available") + def test_forward_parity(self): + """Test that QwenGr00t produces same loss as QwenGR00T.""" + from tests.unit_tests.models.vla.qwen_gr00t_ref import QwenGR00T + + from flagscale.models.vla.qwen_gr00t import QwenGr00t + + # Create config + config = self._create_test_config() + + # Create both models + model_v1 = QwenGR00T(config=config).cuda() + model_v2 = QwenGr00t(config=config).cuda() + + # Copy action model weights from v1 to v2 + model_v2.action_model._head.load_state_dict(model_v1.action_model.state_dict()) + + # Create test batch + batch = self._create_test_batch() + + # Set same random seed for both + torch.manual_seed(42) + loss_v1 = model_v1.forward(batch) + + torch.manual_seed(42) + loss_v2 = model_v2.forward(batch) + + # Compare losses + self.assertTrue( + torch.allclose(loss_v1, loss_v2, atol=1e-5), + f"Loss mismatch: v1={loss_v1.item()}, v2={loss_v2.item()}", + ) + + def _create_test_config(self): + """Create config matching examples/qwen_gr00t/conf/train/qwen_gr00t.yaml.""" + config_dict = { + "model": { + "model_name": "qwen_gr00t", + "checkpoint_dir": "/workspace/models/Qwen/Qwen3-VL-4B-Instruct/", + "vlm": { + "type": "qwen3-vl", + }, + "qwenvl": { + "base_vlm": "/workspace/models/Qwen/Qwen3-VL-4B-Instruct/", + "attn_implementation": "flash_attention_2", + "vl_hidden_dim": 2048, + }, + "action_model": { + "type": "flow_matching", + "action_model_type": "DiT-B", + "action_hidden_dim": 1024, + "hidden_size": 1024, + "add_pos_embed": True, + "max_seq_len": 1024, + "action_dim": 7, + "state_dim": 8, + "future_action_window_size": 7, + "action_horizon": 8, + "past_action_window_size": 0, + "repeated_diffusion_steps": 4, + "noise_beta_alpha": 1.5, + "noise_beta_beta": 1.0, + "noise_s": 0.999, + "num_timestep_buckets": 1000, + "num_inference_timesteps": 4, + "num_target_vision_tokens": 32, + "diffusion_model_cfg": { + "cross_attention_dim": 2048, + "dropout": 0.2, + "final_dropout": True, + "interleave_self_attention": True, + "norm_type": "ada_norm", + "num_layers": 16, + "output_dim": 1024, + "positional_embeddings": None, + }, + }, + "reduce_in_full_precision": True, + }, + "data": { + "data_path": "", + "vla_data": { + "image_features": [ + "observation.images.image", + "observation.images.wrist_image", + ], + }, + }, + "system": { + "batch_size": 16, + "train_steps": 80000, + "log_freq": 10, + "grad_clip_norm": 1.0, + "optimizer": {"name": "AdamW", "lr": 2.5e-5}, + "scheduler": {"warmup_steps": 5000}, + "checkpoint": { + "save_checkpoint": False, + "save_freq": 1000, + "output_directory": "/tmp", + }, + }, + } + return OmegaConf.create(config_dict) + + def _create_test_batch(self): + """ + Create test batch matching actual training data format. + + Actual batch structure: + - action: [16, 8, 7] float32 + - task: list of 16 strings + - observation.images.wrist_image: [16, 3, 224, 224] float32 + - observation.images.image: [16, 3, 224, 224] float32 + - observation.state: [16, 1, 8] float32 + """ + batch_size = 16 + action_horizon = 8 + action_dim = 7 + state_dim = 8 + img_channels = 3 + img_size = 224 + + return { + ACTION: torch.randn(batch_size, action_horizon, action_dim, dtype=torch.float32), + "task": ["put the bowl on the plate"] * batch_size, + "observation.images.image": torch.randn( + batch_size, img_channels, img_size, img_size, dtype=torch.float32 + ), + "observation.images.wrist_image": torch.randn( + batch_size, img_channels, img_size, img_size, dtype=torch.float32 + ), + OBS_STATE: torch.randn(batch_size, 1, state_dim, dtype=torch.float32), + } + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/models/vla/test_registry.py b/tests/unit_tests/models/vla/test_registry.py new file mode 100644 index 000000000..4a4a91bae --- /dev/null +++ b/tests/unit_tests/models/vla/test_registry.py @@ -0,0 +1,41 @@ +import unittest + +from flagscale.models.vla.registry import ( + ACTION_MODEL_REGISTRY, + VLM_REGISTRY, + build_action_model, + build_vlm, + register_action_model, + register_vlm, +) + + +class TestRegistry(unittest.TestCase): + def test_register_vlm(self): + @register_vlm("test-vlm") + class TestVLM: + def __init__(self, **kwargs): + self.kwargs = kwargs + + self.assertIn("test-vlm", VLM_REGISTRY) + vlm = build_vlm("test-vlm", model_id="test") + self.assertEqual(vlm.kwargs["model_id"], "test") + + def test_register_action_model(self): + @register_action_model("test-model") + class TestModel: + def __init__(self, vlm_config, action_config): + self.vlm_config = vlm_config + self.action_config = action_config + + self.assertIn("test-model", ACTION_MODEL_REGISTRY) + model = build_action_model("test-model", vlm_config={}, action_config={"action_dim": 7}) + self.assertEqual(model.action_config["action_dim"], 7) + + def test_build_unknown_vlm_raises(self): + with self.assertRaises(ValueError): + build_vlm("nonexistent-vlm-xyz") + + def test_build_unknown_action_model_raises(self): + with self.assertRaises(ValueError): + build_action_model("nonexistent-model-xyz", vlm_config={}, action_config={}) diff --git a/tests/unit_tests/models/vla/test_utils.py b/tests/unit_tests/models/vla/test_utils.py new file mode 100644 index 000000000..0ff384d48 --- /dev/null +++ b/tests/unit_tests/models/vla/test_utils.py @@ -0,0 +1,34 @@ +import unittest + +from flagscale.models.vla.utils import get_vlm_config + + +class MockConfigDirect: + hidden_size = 2048 + num_hidden_layers = 28 + + +class MockConfigNested: + class text_config: + hidden_size = 1536 + num_hidden_layers = 24 + + +class MockConfigInvalid: + pass + + +class TestGetVlmConfig(unittest.TestCase): + def test_direct_config(self): + info = get_vlm_config(MockConfigDirect()) + self.assertEqual(info["hidden_size"], 2048) + self.assertEqual(info["num_hidden_layers"], 28) + + def test_nested_config(self): + info = get_vlm_config(MockConfigNested()) + self.assertEqual(info["hidden_size"], 1536) + self.assertEqual(info["num_hidden_layers"], 24) + + def test_invalid_config_raises(self): + with self.assertRaises(ValueError): + get_vlm_config(MockConfigInvalid()) diff --git a/tests/unit_tests/models/vla/vlm/test_qwen_vl.py b/tests/unit_tests/models/vla/vlm/test_qwen_vl.py new file mode 100644 index 000000000..13c56e718 --- /dev/null +++ b/tests/unit_tests/models/vla/vlm/test_qwen_vl.py @@ -0,0 +1,29 @@ +import unittest + +from flagscale.models.vla.registry import VLM_REGISTRY + + +class TestQwenVLRegistration(unittest.TestCase): + def test_qwen25_vl_registered(self): + from flagscale.models.vla.vlm import qwen_vl # noqa: F401 + + self.assertIn("qwen2.5-vl", VLM_REGISTRY) + + def test_qwen3_vl_registered(self): + from flagscale.models.vla.vlm import qwen_vl # noqa: F401 + + self.assertIn("qwen3-vl", VLM_REGISTRY) + + def test_qwen25_has_required_methods(self): + from flagscale.models.vla.vlm.qwen_vl import Qwen25VLBackbone + + self.assertTrue(hasattr(Qwen25VLBackbone, "model_config")) + self.assertTrue(hasattr(Qwen25VLBackbone, "prepare_input")) + self.assertTrue(hasattr(Qwen25VLBackbone, "forward")) + + def test_qwen3_has_required_methods(self): + from flagscale.models.vla.vlm.qwen_vl import Qwen3VLBackbone + + self.assertTrue(hasattr(Qwen3VLBackbone, "model_config")) + self.assertTrue(hasattr(Qwen3VLBackbone, "prepare_input")) + self.assertTrue(hasattr(Qwen3VLBackbone, "forward")) diff --git a/tests/unit_tests/models/vla/vlm/test_vlm_init.py b/tests/unit_tests/models/vla/vlm/test_vlm_init.py new file mode 100644 index 000000000..61087f33f --- /dev/null +++ b/tests/unit_tests/models/vla/vlm/test_vlm_init.py @@ -0,0 +1,9 @@ +import unittest + + +class TestVLMInit(unittest.TestCase): + def test_imports(self): + from flagscale.models.vla.vlm import Qwen3VLBackbone, Qwen25VLBackbone + + self.assertIsNotNone(Qwen25VLBackbone) + self.assertIsNotNone(Qwen3VLBackbone) diff --git a/tests/unit_tests/train/configs/test_train_config.py b/tests/unit_tests/train/configs/test_train_config.py index 6a4ebadae..42994cf30 100644 --- a/tests/unit_tests/train/configs/test_train_config.py +++ b/tests/unit_tests/train/configs/test_train_config.py @@ -20,10 +20,10 @@ class TestOptimizerConfig(unittest.TestCase): def test_default_values(self): config = OptimizerConfig() self.assertEqual(config.name, "AdamW") - self.assertEqual(config.lr, 2.5e-5) - self.assertEqual(config.betas, (0.9, 0.95)) - self.assertEqual(config.eps, 1e-8) - self.assertEqual(config.weight_decay, 0.01) + self.assertIsNone(config.lr) + self.assertIsNone(config.betas) + self.assertIsNone(config.eps) + self.assertIsNone(config.weight_decay) def test_custom_values(self): config = OptimizerConfig( @@ -84,32 +84,22 @@ class TestSystemConfig(unittest.TestCase): def test_hierarchical_structure(self): config = SystemConfig( batch_size=8, - optimizer=OptimizerConfig(lr=1e-4), - scheduler=SchedulerConfig(warmup_steps=100), checkpoint=CheckpointConfig(output_directory="/tmp"), ) - # Test hierarchical access self.assertEqual(config.batch_size, 8) - self.assertEqual(config.optimizer.lr, 1e-4) - self.assertEqual(config.scheduler.warmup_steps, 100) self.assertEqual(config.checkpoint.output_directory, "/tmp") def test_from_dict(self): config_dict = { "batch_size": 16, "train_steps": 5000, - "optimizer": {"lr": 5e-5, "betas": (0.9, 0.999)}, - "scheduler": {"warmup_steps": 200}, "checkpoint": {"output_directory": "/output", "save_freq": 100}, } config = SystemConfig(**config_dict) self.assertEqual(config.batch_size, 16) self.assertEqual(config.train_steps, 5000) - self.assertEqual(config.optimizer.lr, 5e-5) - self.assertEqual(config.optimizer.betas, (0.9, 0.999)) - self.assertEqual(config.scheduler.warmup_steps, 200) self.assertEqual(config.checkpoint.save_freq, 100) @@ -200,8 +190,6 @@ def test_full_config_creation(self): "system": { "batch_size": 4, "train_steps": 10000, - "optimizer": {"lr": 1e-4}, - "scheduler": {"warmup_steps": 500}, "checkpoint": {"output_directory": "/tmp/ckpt"}, }, "model": { @@ -209,21 +197,21 @@ def test_full_config_creation(self): "checkpoint_dir": "/model", "tokenizer_path": "/tokenizer", "action_steps": 50, + "optimizer": {"lr": 1e-4, "scheduler": {"warmup_steps": 500}}, }, "data": {"data_path": "/data", "use_imagenet_stats": True}, } config = TrainConfig(**config_dict) - # Test hierarchical access self.assertEqual(config.system.batch_size, 4) self.assertEqual(config.system.train_steps, 10000) - self.assertEqual(config.system.optimizer.lr, 1e-4) - self.assertEqual(config.system.scheduler.warmup_steps, 500) self.assertEqual(config.system.checkpoint.output_directory, "/tmp/ckpt") self.assertEqual(config.model.model_name, "pi0") self.assertEqual(config.model.checkpoint_dir, "/model") + self.assertEqual(config.model.optimizer.lr, 1e-4) + self.assertEqual(config.model.optimizer.scheduler.warmup_steps, 500) self.assertEqual(config.data.data_path, "/data") self.assertEqual(config.data.use_imagenet_stats, True) @@ -234,11 +222,13 @@ def test_from_hydra_config(self): "train": { "system": { "batch_size": 8, - "optimizer": {"lr": 2e-5}, - "scheduler": {}, "checkpoint": {"output_directory": "/out"}, }, - "model": {"model_name": "pi0.5", "checkpoint_dir": "/ckpt"}, + "model": { + "model_name": "pi0.5", + "checkpoint_dir": "/ckpt", + "optimizer": {"lr": 2e-5, "scheduler": {}}, + }, "data": {"data_path": "/dataset"}, } } @@ -247,7 +237,7 @@ def test_from_hydra_config(self): config = TrainConfig.from_hydra_config(hydra_config) self.assertEqual(config.system.batch_size, 8) - self.assertEqual(config.system.optimizer.lr, 2e-5) + self.assertEqual(config.model.optimizer.lr, 2e-5) self.assertEqual(config.model.model_name, "pi0.5") self.assertEqual(config.data.data_path, "/dataset") @@ -255,8 +245,6 @@ def test_type_validation_error(self): config_dict = { "system": { "batch_size": "invalid", # Should be int - "optimizer": {}, - "scheduler": {}, "checkpoint": {"output_directory": "/tmp"}, }, "model": {"model_name": "pi0", "checkpoint_dir": "/model"}, @@ -269,8 +257,6 @@ def test_type_validation_error(self): def test_missing_required_field(self): config_dict = { "system": { - "optimizer": {}, - "scheduler": {}, "checkpoint": {"output_directory": "/tmp"}, }, "model": { @@ -291,11 +277,14 @@ def test_dict_roundtrip(self): config = TrainConfig( system=SystemConfig( batch_size=16, - optimizer=OptimizerConfig(), - scheduler=SchedulerConfig(), checkpoint=CheckpointConfig(output_directory="/tmp"), ), - model=ModelConfig(model_name="pi0", checkpoint_dir="/model", action_steps=50), + model=ModelConfig( + model_name="pi0", + checkpoint_dir="/model", + action_steps=50, + optimizer=OptimizerConfig(), + ), data=DataConfig(data_path="/data"), ) diff --git a/tests/unit_tests/train/utils/test_optim_setup.py b/tests/unit_tests/train/utils/test_optim_setup.py new file mode 100644 index 000000000..437f1a1f5 --- /dev/null +++ b/tests/unit_tests/train/utils/test_optim_setup.py @@ -0,0 +1,909 @@ +"""Unit tests for optimizer setup utilities.""" + +import unittest +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from flagscale.train.utils.optim_setup import ( + apply_freeze_config, + build_optim_param_groups, + freeze_and_get_trainable_params, + log_trainable_params, + print_param_names, + setup_optimizer, + setup_optimizer_and_scheduler, + setup_scheduler, +) + + +class SimpleModel(nn.Module): + """Simple model for testing freeze patterns.""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 10), + ) + self.decoder = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 10), + ) + self.head = nn.Linear(10, 5) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return self.head(x) + + +class NestedModel(nn.Module): + """Model with nested structure similar to QwenGR00T.""" + + def __init__(self): + super().__init__() + self.vlm = nn.ModuleDict( + { + "visual": nn.Sequential( + nn.Linear(10, 20), + nn.Linear(20, 10), + ), + "language": nn.ModuleDict( + { + "layers": nn.ModuleList([nn.Linear(10, 10) for _ in range(5)]), + "embed": nn.Embedding(100, 10), + } + ), + } + ) + self.action_model = nn.ModuleDict( + { + "encoder": nn.Linear(10, 20), + "decoder": nn.Linear(20, 10), + "transformer_blocks": nn.ModuleList([nn.Linear(10, 10) for _ in range(4)]), + } + ) + + def forward(self, x): + return x + + +class TestFreezeAndGetTrainableParams(unittest.TestCase): + """Test freeze_and_get_trainable_params function.""" + + def setUp(self): + self.model = SimpleModel() + + def test_no_patterns_all_trainable(self): + """Without patterns, all params should be trainable.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=None, + keep_patterns=None, + ) + ) + + all_params = list(self.model.parameters()) + self.assertEqual(len(params), len(all_params)) + + for param in self.model.parameters(): + self.assertTrue(param.requires_grad) + + def test_freeze_single_module(self): + """Test freezing a single module by pattern.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*"], + keep_patterns=None, + ) + ) + + # Check encoder is frozen + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + self.assertFalse(param.requires_grad, f"{name} should be frozen") + else: + self.assertTrue(param.requires_grad, f"{name} should be trainable") + + # Returned params should only be trainable ones + for param in params: + self.assertTrue(param.requires_grad) + + def test_freeze_multiple_modules(self): + """Test freezing multiple modules.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*", "decoder\\..*"], + keep_patterns=None, + ) + ) + + # Only head should be trainable + for name, param in self.model.named_parameters(): + if name.startswith("head"): + self.assertTrue(param.requires_grad) + else: + self.assertFalse(param.requires_grad) + + # Returned params should only be head params + head_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("head") + ) + self.assertEqual(len(params), head_param_count) + + def test_freeze_all_pattern(self): + """Test freezing everything with '.*' pattern.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=[".*"], + keep_patterns=None, + ) + ) + + self.assertEqual(len(params), 0) + for param in self.model.parameters(): + self.assertFalse(param.requires_grad) + + def test_keep_patterns_override_freeze(self): + """Test that keep_patterns override freeze_patterns.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=[".*"], # Freeze everything + keep_patterns=["head\\..*"], # But keep head trainable + ) + ) + + # Only head should be trainable + for name, param in self.model.named_parameters(): + if name.startswith("head"): + self.assertTrue(param.requires_grad, f"{name} should be trainable") + else: + self.assertFalse(param.requires_grad, f"{name} should be frozen") + + # Should only return head params + self.assertEqual(len(params), 2) # head.weight and head.bias + + def test_partial_pattern_match(self): + """Test that patterns use search (partial match).""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["weight"], # Matches all weights + keep_patterns=None, + ) + ) + + # Only biases should be trainable + for name, param in self.model.named_parameters(): + if "weight" in name: + self.assertFalse(param.requires_grad) + else: + self.assertTrue(param.requires_grad) + + # Returned params should only be biases + bias_param_count = sum( + 1 for name, _ in self.model.named_parameters() if "weight" not in name + ) + self.assertEqual(len(params), bias_param_count) + + +class TestFreezeWithNestedModel(unittest.TestCase): + """Test freeze patterns with nested model structure.""" + + def setUp(self): + self.model = NestedModel() + + def test_freeze_vlm_module(self): + """Test freezing entire VLM module.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["vlm\\..*"], + keep_patterns=None, + ) + ) + + for name, param in self.model.named_parameters(): + if name.startswith("vlm"): + self.assertFalse(param.requires_grad, f"{name} should be frozen") + else: + self.assertTrue(param.requires_grad, f"{name} should be trainable") + + # Returned params should only be action_model params + action_model_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("action_model") + ) + self.assertEqual(len(params), action_model_param_count) + + def test_freeze_specific_layers(self): + """Test freezing specific layers by index.""" + # Freeze layers 0-2 + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["vlm\\.language\\.layers\\.[0-2]\\..*"], + keep_patterns=None, + ) + ) + + for name, param in self.model.named_parameters(): + if ( + "vlm.language.layers.0" in name + or "vlm.language.layers.1" in name + or "vlm.language.layers.2" in name + ): + self.assertFalse(param.requires_grad, f"{name} should be frozen") + + # Layers 3-4 should still be trainable + for name, param in self.model.named_parameters(): + if "vlm.language.layers.3" in name or "vlm.language.layers.4" in name: + self.assertTrue(param.requires_grad, f"{name} should be trainable") + + # Returned params should exclude frozen layers + trainable_param_count = sum( + 1 for name, param in self.model.named_parameters() if param.requires_grad + ) + self.assertEqual(len(params), trainable_param_count) + + def test_freeze_vlm_keep_visual(self): + """Test freezing VLM but keeping visual encoder trainable.""" + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["vlm\\..*"], + keep_patterns=["vlm\\.visual\\..*"], + ) + ) + + for name, param in self.model.named_parameters(): + if name.startswith("vlm.visual"): + self.assertTrue(param.requires_grad, f"{name} should be trainable") + elif name.startswith("vlm"): + self.assertFalse(param.requires_grad, f"{name} should be frozen") + + # Returned params should include visual and action_model params + trainable_param_count = sum( + 1 for name, param in self.model.named_parameters() if param.requires_grad + ) + self.assertEqual(len(params), trainable_param_count) + + +class TestApplyFreezeConfig(unittest.TestCase): + """Test apply_freeze_config function.""" + + def setUp(self): + self.model = SimpleModel() + + def test_none_config_returns_all_params(self): + """With None config, should return all parameters.""" + params = apply_freeze_config(self.model, None) + + all_params = list(self.model.parameters()) + self.assertEqual(len(params), len(all_params)) + + def test_with_freeze_config(self): + """Test with a FreezeConfig-like object.""" + freeze_config = MagicMock() + freeze_config.freeze_patterns = ["encoder\\..*"] + freeze_config.keep_patterns = None + + params = apply_freeze_config(self.model, freeze_config) + + # Should only return non-encoder params + encoder_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("encoder") + ) + total_param_count = sum(1 for _ in self.model.parameters()) + + self.assertEqual(len(params), total_param_count - encoder_param_count) + + +class TestLogTrainableParams(unittest.TestCase): + """Test log_trainable_params function.""" + + def setUp(self): + self.model = SimpleModel() + + def test_all_trainable(self): + """Test logging when all params are trainable.""" + result = log_trainable_params(self.model) + + self.assertIn("trainable", result) + self.assertIn("frozen", result) + self.assertIn("encoder", result["trainable"]) + self.assertIn("decoder", result["trainable"]) + self.assertIn("head", result["trainable"]) + + def test_partial_frozen(self): + """Test logging with some frozen params.""" + # Freeze encoder + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + param.requires_grad = False + + result = log_trainable_params(self.model) + + self.assertIn("encoder", result["frozen"]) + self.assertIn("decoder", result["trainable"]) + self.assertIn("head", result["trainable"]) + self.assertGreater(result["frozen"]["encoder"], 0) + + +class TestUnusedPatternWarnings(unittest.TestCase): + """Test that unused patterns trigger warnings.""" + + def setUp(self): + self.model = SimpleModel() + + @patch("flagscale.train.utils.optim_setup.logger") + def test_warns_on_unused_freeze_pattern(self, mock_logger): + """Should warn when freeze pattern matches nothing.""" + list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["nonexistent_module\\..*"], + keep_patterns=None, + ) + ) + + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args[0][0] + self.assertIn("Freeze patterns matched nothing", warning_call) + + @patch("flagscale.train.utils.optim_setup.logger") + def test_warns_on_unused_keep_pattern(self, mock_logger): + """Should warn when keep pattern matches nothing.""" + list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*"], + keep_patterns=["nonexistent_module\\..*"], + ) + ) + + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args[0][0] + self.assertIn("Keep patterns matched nothing", warning_call) + + +class TestPrintParamNames(unittest.TestCase): + """Test print_param_names debug helper.""" + + def setUp(self): + self.model = SimpleModel() + + @patch("builtins.print") + def test_prints_all_params(self, mock_print): + """Should print all params when no pattern given.""" + print_param_names(self.model) + + self.assertGreater(mock_print.call_count, 0) + + @patch("builtins.print") + def test_filters_by_pattern(self, mock_print): + """Should only print params matching pattern.""" + print_param_names(self.model, pattern="encoder") + + # Should only print encoder params + for call in mock_print.call_args_list: + self.assertIn("encoder", call[0][0]) + + +class TestParameterCounts(unittest.TestCase): + """Test that parameter counts are correctly reported.""" + + def setUp(self): + self.model = SimpleModel() + + @patch("flagscale.train.utils.optim_setup.logger") + def test_parameter_count_logging(self, mock_logger): + """Verify correct parameter counts are logged.""" + # Count total params + total_params = sum(p.numel() for p in self.model.parameters()) + + # Count encoder params + encoder_params = sum( + p.numel() for name, p in self.model.named_parameters() if name.startswith("encoder") + ) + + # Freeze encoder + list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*"], + keep_patterns=None, + ) + ) + + # Check that info was logged with correct counts + mock_logger.info.assert_called() + info_call = mock_logger.info.call_args[0][0] + self.assertIn(f"trainable={total_params - encoder_params:,}", info_call) + self.assertIn(f"frozen={encoder_params:,}", info_call) + + +class TestBuildOptimParamGroups(unittest.TestCase): + """Test build_optim_param_groups function (NeMo-style per-module config).""" + + def setUp(self): + self.model = SimpleModel() + + def test_none_config_returns_single_group(self): + """With None config, should return single group with all params.""" + param_groups = build_optim_param_groups(self.model, None) + + self.assertEqual(len(param_groups), 1) + all_params = list(self.model.parameters()) + self.assertEqual(len(param_groups[0]["params"]), len(all_params)) + + def test_single_module_config(self): + """Test with config for single module.""" + config = {"encoder": {"lr": 1e-5}} + param_groups = build_optim_param_groups(self.model, config) + + # Should have 2 groups: default + encoder + self.assertEqual(len(param_groups), 2) + + # Find encoder group + encoder_group = next(g for g in param_groups if g.get("name") == "encoder") + self.assertEqual(encoder_group["lr"], 1e-5) + + # Encoder params count + encoder_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("encoder") + ) + self.assertEqual(len(encoder_group["params"]), encoder_param_count) + + def test_multiple_module_config(self): + """Test with config for multiple modules.""" + config = { + "encoder": {"lr": 1e-5, "weight_decay": 0.01}, + "decoder": {"lr": 2e-5}, + } + param_groups = build_optim_param_groups(self.model, config) + + # Should have 3 groups: default + encoder + decoder + self.assertEqual(len(param_groups), 3) + + encoder_group = next(g for g in param_groups if g.get("name") == "encoder") + decoder_group = next(g for g in param_groups if g.get("name") == "decoder") + + self.assertEqual(encoder_group["lr"], 1e-5) + self.assertEqual(encoder_group["weight_decay"], 0.01) + self.assertEqual(decoder_group["lr"], 2e-5) + + def test_default_group_contains_remaining_params(self): + """Default group should contain params not in other groups.""" + config = {"encoder": {"lr": 1e-5}} + param_groups = build_optim_param_groups(self.model, config) + + default_group = next(g for g in param_groups if g.get("name") == "default") + + # Default should contain decoder + head params + non_encoder_count = sum( + 1 for name, _ in self.model.named_parameters() if not name.startswith("encoder") + ) + self.assertEqual(len(default_group["params"]), non_encoder_count) + + def test_respects_requires_grad(self): + """Should only include trainable params.""" + # Freeze encoder + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + param.requires_grad = False + + config = {"encoder": {"lr": 1e-5}} + param_groups = build_optim_param_groups(self.model, config) + + # Encoder group should not be added when it has no trainable params + encoder_groups = [g for g in param_groups if g.get("name") == "encoder"] + self.assertEqual( + len(encoder_groups), + 0, + "Encoder group should not exist when all params are frozen", + ) + + @patch("flagscale.train.utils.optim_setup.logger") + def test_warns_on_nonexistent_module(self, mock_logger): + """Should warn when module doesn't exist.""" + config = {"nonexistent": {"lr": 1e-5}} + build_optim_param_groups(self.model, config) + + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args[0][0] + self.assertIn("nonexistent", warning_call) + + +class TestBuildOptimParamGroupsNested(unittest.TestCase): + """Test build_optim_param_groups with nested model structure.""" + + def setUp(self): + self.model = NestedModel() + + def test_nested_module_path(self): + """Test accessing nested modules via dot path.""" + config = {"vlm.visual": {"lr": 1e-5}} + param_groups = build_optim_param_groups(self.model, config) + + visual_group = next(g for g in param_groups if g.get("name") == "vlm.visual") + self.assertEqual(visual_group["lr"], 1e-5) + + # Count visual params + visual_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("vlm.visual") + ) + self.assertEqual(len(visual_group["params"]), visual_param_count) + + def test_multiple_nested_paths(self): + """Test multiple nested module configs.""" + config = { + "vlm.visual": {"lr": 1e-5}, + "vlm.language": {"lr": 2e-5}, + "action_model": {"lr": 1e-4}, + } + param_groups = build_optim_param_groups(self.model, config) + + # 3 configured groups + default (though default may be empty) + groups_with_params = [g for g in param_groups if len(g["params"]) > 0] + self.assertGreaterEqual(len(groups_with_params), 3) + + +class TestSetupScheduler(unittest.TestCase): + """Test setup_scheduler function.""" + + def setUp(self): + self.model = SimpleModel() + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4) + + def test_cosine_scheduler(self): + """Test creating a cosine scheduler.""" + scheduler_config = MagicMock() + scheduler_config.name = "cosine" + scheduler_config.warmup_steps = 100 + scheduler_config.scheduler_kwargs = None + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=1000) + + self.assertIsNotNone(scheduler) + self.assertTrue(hasattr(scheduler, "step")) + + def test_linear_scheduler(self): + """Test creating a linear scheduler.""" + scheduler_config = MagicMock() + scheduler_config.name = "linear" + scheduler_config.warmup_steps = 50 + scheduler_config.scheduler_kwargs = None + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=500) + + self.assertIsNotNone(scheduler) + + def test_constant_with_warmup_scheduler(self): + """Test creating a constant_with_warmup scheduler.""" + scheduler_config = MagicMock() + scheduler_config.name = "constant_with_warmup" + scheduler_config.warmup_steps = 100 + scheduler_config.scheduler_kwargs = None + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=1000) + + self.assertIsNotNone(scheduler) + + def test_cosine_with_min_lr(self): + """Test creating a cosine scheduler with min_lr.""" + scheduler_config = MagicMock() + scheduler_config.name = "cosine_with_min_lr" + scheduler_config.warmup_steps = 100 + scheduler_config.scheduler_kwargs = {"min_lr": 1e-6} + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=1000) + + self.assertIsNotNone(scheduler) + + def test_raises_error_when_name_is_none(self): + """Should raise ValueError when scheduler name is None.""" + scheduler_config = MagicMock() + scheduler_config.name = None + scheduler_config.warmup_steps = 100 + scheduler_config.scheduler_kwargs = None + + with self.assertRaises(ValueError) as context: + setup_scheduler(self.optimizer, scheduler_config, num_training_steps=1000) + + self.assertIn("name must be specified", str(context.exception)) + + def test_scheduler_step_updates_lr(self): + """Test that scheduler step updates learning rate.""" + scheduler_config = MagicMock() + scheduler_config.name = "linear" + scheduler_config.warmup_steps = 10 + scheduler_config.scheduler_kwargs = None + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=100) + + initial_lr = self.optimizer.param_groups[0]["lr"] + for _ in range(50): + scheduler.step() + final_lr = self.optimizer.param_groups[0]["lr"] + + self.assertNotEqual(initial_lr, final_lr) + + def test_warmup_phase(self): + """Test that warmup phase increases lr.""" + scheduler_config = MagicMock() + scheduler_config.name = "linear" + scheduler_config.warmup_steps = 100 + scheduler_config.scheduler_kwargs = None + + scheduler = setup_scheduler(self.optimizer, scheduler_config, num_training_steps=1000) + + lrs = [] + for _ in range(50): + lrs.append(self.optimizer.param_groups[0]["lr"]) + scheduler.step() + + # During warmup, LR should generally increase + self.assertLess(lrs[0], lrs[-1]) + + +class TestSetupOptimizerAndScheduler(unittest.TestCase): + """Test setup_optimizer_and_scheduler function.""" + + def setUp(self): + self.model = SimpleModel() + + def _make_train_config(self, freeze_patterns=None, keep_patterns=None): + """Helper to create a mock TrainConfig.""" + train_config = MagicMock() + # System config + train_config.system = MagicMock() + train_config.system.train_steps = 1000 + # Model config with optimizer, scheduler, and freeze + train_config.model = MagicMock() + train_config.model.optimizer = MagicMock() + train_config.model.optimizer.name = "AdamW" + train_config.model.optimizer.lr = 1e-4 + train_config.model.optimizer.param_groups = None + train_config.model.optimizer.get_optimizer_kwargs.return_value = {"lr": 1e-4} + train_config.model.optimizer.scheduler = MagicMock() + train_config.model.optimizer.scheduler.name = "cosine" + train_config.model.optimizer.scheduler.warmup_steps = 100 + train_config.model.optimizer.scheduler.scheduler_kwargs = None + if freeze_patterns is not None: + train_config.model.freeze = MagicMock() + train_config.model.freeze.freeze_patterns = freeze_patterns + train_config.model.freeze.keep_patterns = keep_patterns + else: + train_config.model.freeze = None + return train_config + + def test_returns_optimizer_and_scheduler(self): + """Test that function returns both optimizer and scheduler.""" + train_config = self._make_train_config() + + optimizer, scheduler = setup_optimizer_and_scheduler(self.model, train_config) + + self.assertIsInstance(optimizer, torch.optim.AdamW) + self.assertIsNotNone(scheduler) + self.assertTrue(hasattr(scheduler, "step")) + + def test_with_freeze_config(self): + """Test with freeze config applied.""" + train_config = self._make_train_config(freeze_patterns=["encoder\\..*"]) + train_config.model.optimizer.scheduler.name = "linear" + train_config.model.optimizer.scheduler.warmup_steps = 50 + train_config.system.train_steps = 500 + + optimizer, scheduler = setup_optimizer_and_scheduler(self.model, train_config) + + # Encoder should be frozen + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + self.assertFalse(param.requires_grad) + else: + self.assertTrue(param.requires_grad) + + self.assertIsInstance(optimizer, torch.optim.AdamW) + self.assertIsNotNone(scheduler) + + def test_scheduler_uses_train_steps(self): + """Test that scheduler uses train_steps from TrainConfig.""" + train_config = self._make_train_config() + train_config.model.optimizer.scheduler.name = "linear" + train_config.model.optimizer.scheduler.warmup_steps = 10 + train_config.system.train_steps = 100 + + optimizer, scheduler = setup_optimizer_and_scheduler(self.model, train_config) + + # Step through warmup first + for _ in range(15): + optimizer.step() + scheduler.step() + peak_lr = optimizer.param_groups[0]["lr"] + + # Step through decay phase + for _ in range(80): + optimizer.step() + scheduler.step() + final_lr = optimizer.param_groups[0]["lr"] + + # After decay, LR should be less than peak + self.assertLess(final_lr, peak_lr) + + +class TestFreezeRequiresGradPreservation(unittest.TestCase): + """Test that freeze logic correctly preserves or overrides requires_grad.""" + + def setUp(self): + self.model = SimpleModel() + + @patch("flagscale.train.utils.optim_setup.logger") + def test_no_freeze_patterns_preserves_requires_grad(self, mock_logger): + """Params with requires_grad=False should stay frozen when no freeze patterns provided.""" + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + param.requires_grad = False + + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=None, + keep_patterns=None, + ) + ) + + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + self.assertFalse(param.requires_grad, f"{name} should remain frozen") + + encoder_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("encoder") + ) + total_count = sum(1 for _ in self.model.parameters()) + self.assertEqual(len(params), total_count - encoder_count) + + @patch("flagscale.train.utils.optim_setup.logger") + def test_freeze_patterns_forces_unmatched_trainable(self, mock_logger): + """Params not matching freeze patterns become trainable even if originally frozen.""" + for param in self.model.parameters(): + param.requires_grad = False + + params = list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*"], + keep_patterns=None, + ) + ) + + for name, param in self.model.named_parameters(): + if name.startswith("encoder"): + self.assertFalse(param.requires_grad, f"{name} should be frozen") + else: + self.assertTrue(param.requires_grad, f"{name} should be forced trainable") + + encoder_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("encoder") + ) + total_count = sum(1 for _ in self.model.parameters()) + self.assertEqual(len(params), total_count - encoder_count) + + @patch("flagscale.train.utils.optim_setup.logger") + def test_warns_when_unfreezing_previously_frozen(self, mock_logger): + """Should warn about params that were frozen but are being made trainable.""" + for name, param in self.model.named_parameters(): + if name.startswith("decoder"): + param.requires_grad = False + + list( + freeze_and_get_trainable_params( + self.model.named_parameters(), + freeze_patterns=["encoder\\..*"], + keep_patterns=None, + ) + ) + + warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] + summary_warnings = [w for w in warning_calls if "already frozen" in w] + self.assertEqual(len(summary_warnings), 1) + + per_param_warnings = [w for w in warning_calls if "unfrozen:" in w] + decoder_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("decoder") + ) + self.assertEqual(len(per_param_warnings), decoder_param_count) + for w in per_param_warnings: + self.assertIn("decoder", w) + + +class TestBuildOptimParamGroupsOverlap(unittest.TestCase): + """Test build_optim_param_groups with overlapping module paths.""" + + def setUp(self): + self.model = NestedModel() + + @patch("flagscale.train.utils.optim_setup.logger") + def test_parent_child_overlap_dedup(self, mock_logger): + """Parent module listed before child: child group gets skipped.""" + config = { + "vlm": {"lr": 1e-5}, + "vlm.visual": {"lr": 2e-5}, + } + param_groups = build_optim_param_groups(self.model, config) + + vlm_group = next(g for g in param_groups if g.get("name") == "vlm") + visual_groups = [g for g in param_groups if g.get("name") == "vlm.visual"] + + vlm_params = [p for p in self.model.vlm.parameters() if p.requires_grad] + self.assertEqual(len(vlm_group["params"]), len(vlm_params)) + self.assertEqual(len(visual_groups), 0) + + warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] + overlap_warnings = [w for w in warning_calls if "already assigned" in w] + self.assertGreater(len(overlap_warnings), 0) + + @patch("flagscale.train.utils.optim_setup.logger") + def test_child_parent_overlap_partial(self, mock_logger): + """Child module listed before parent: parent group excludes child's params.""" + config = { + "vlm.visual": {"lr": 2e-5}, + "vlm": {"lr": 1e-5}, + } + param_groups = build_optim_param_groups(self.model, config) + + visual_group = next(g for g in param_groups if g.get("name") == "vlm.visual") + vlm_group = next(g for g in param_groups if g.get("name") == "vlm") + + visual_param_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("vlm.visual") + ) + all_vlm_count = sum( + 1 for name, _ in self.model.named_parameters() if name.startswith("vlm") + ) + self.assertEqual(len(visual_group["params"]), visual_param_count) + self.assertEqual(len(vlm_group["params"]), all_vlm_count - visual_param_count) + + def test_no_duplicate_params_across_groups(self): + """No parameter should appear in more than one group.""" + config = { + "vlm.visual": {"lr": 2e-5}, + "vlm": {"lr": 1e-5}, + "action_model": {"lr": 1e-4}, + } + param_groups = build_optim_param_groups(self.model, config) + + all_param_ids = [] + for group in param_groups: + all_param_ids.extend(id(p) for p in group["params"]) + self.assertEqual(len(all_param_ids), len(set(all_param_ids))) + + +class TestSetupOptimizerEmptyParamGroups(unittest.TestCase): + """Test setup_optimizer raises ValueError when all params are frozen.""" + + def setUp(self): + self.model = SimpleModel() + + def test_all_frozen_raises_value_error(self): + freeze_config = MagicMock() + freeze_config.freeze_patterns = [".*"] + freeze_config.keep_patterns = None + + optimizer_config = MagicMock() + optimizer_config.name = "AdamW" + optimizer_config.param_groups = None + optimizer_config.get_optimizer_kwargs.return_value = {"lr": 1e-4} + + with self.assertRaises(ValueError) as context: + setup_optimizer(self.model, optimizer_config, freeze_config=freeze_config) + + self.assertIn("No trainable parameters found", str(context.exception))