Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions rgym_exp/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
batch_item_id_column=batch_item_id_column,
data_generator=self.load_reasoning_gym_dataset, # TODO: this was confusing, we should document or change the way this is done
)

self.logger = get_logger()
self.yaml_config_path = yaml_config_path
self.eval_split_ratio = eval_split_ratio
self.chunk_size = chunk_size
Expand Down Expand Up @@ -164,7 +164,7 @@ def state_to_system_prompt(self, state: WorldState) -> str:

def state_to_user_prompt(self, state: WorldState) -> str:
"""Convert the state to a user prompt."""
return state.environment_states["question"]
return state.environment_states.get("question", "")

def state_to_answer(self, state: WorldState) -> str:
"""Extract the answer from the state."""
Expand Down Expand Up @@ -259,13 +259,19 @@ def prepare_states(
if batch_id not in trees[agent]:
trees[agent][batch_id] = None
payload = transplants[pair]
received_states, received_actions, received_metadata = (
payload["world_state"],
payload["actions"],
payload["metadata"],
)
received_states = payload.get("world_state")
received_actions = payload.get("actions")
received_metadata = payload.get("metadata")

if received_states is None or received_actions is None or received_metadata is None:
self.logger.warning(f"Incomplete payload: {payload}")
continue
world_state = received_states.environment_states
payload_batch_id = generate_md5_hash_id(world_state["question"])
question = world_state.get("question")
if not question:
self.logger.warning(f"No 'question' found in world_state: {world_state}")
continue
payload_batch_id = generate_md5_hash_id(question)
assert payload_batch_id == batch_id
if (
trees[agent][batch_id] is None
Expand Down Expand Up @@ -297,11 +303,8 @@ def transplant_trees(
for batch_id in swarm_states[agent]:
for payload in swarm_states[agent][batch_id]:
if (
self.num_generations
and hasattr(payload, "actions")
and payload.actions is not None
and isinstance(payload.actions, list)
and len(payload.actions) == self.num_generations
payload.get("actions") is not None
and len(payload["actions"]) == self.num_generations
):
transplants[(agent, batch_id)] = payload
if len(transplants) >= num_transplants:
Expand Down