Skip to content

Commit 0e9e482

Browse files
committed
Update environment configuration in YAML files and adjust dataset setup for vlm grpo
- Added `env_name` to `vlm_grpo_3B_megatron.yaml` and `vlm_grpo_3B.yaml` for environment specification. - Modified `setup_data` function in `run_vlm_grpo.py` to use `env_name` for environment configuration, enhancing flexibility in dataset processing. Signed-off-by: ruit <[email protected]>
1 parent 8a6f265 commit 0e9e482

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ data:
228228
prompt_file: "examples/prompts/clevr_cogent_cot.txt"
229229
system_prompt_file: null
230230
dataset_name: "clevr-cogent"
231+
env_name: "clevr-cogent"
231232
split: "trainA"
232233
shuffle: true
233234
num_workers: 1

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ data:
180180
prompt_file: examples/prompts/clevr_cogent_cot.txt
181181
system_prompt_file: null
182182
dataset_name: clevr-cogent
183+
env_name: "clevr-cogent"
183184
split: trainA
184185
shuffle: true
185186
num_workers: 1

examples/run_vlm_grpo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,15 @@ def setup_data(
264264
)
265265
task_data_processors[task_name] = (vlm_task_spec, hf_data_processor)
266266

267+
env_name = data_config["env_name"]
267268
vlm_env = VLMEnvironment.options( # type: ignore # it's wrapped with ray.remote
268269
runtime_env={
269270
"py_executable": get_actor_python_env(
270271
"nemo_rl.environments.vlm_environment.VLMEnvironment"
271272
),
272273
"env_vars": dict(os.environ), # Pass thru all user environment variables
273274
}
274-
).remote(env_configs[task_name])
275+
).remote(env_configs[env_name])
275276

276277
dataset = AllTaskProcessedDataset(
277278
data.formatted_ds["train"],

0 commit comments

Comments
 (0)