Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 72 additions & 32 deletions rgym_exp/src/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from collections import defaultdict
from pathlib import Path

from genrl.blockchain import SwarmCoordinator
from genrl.communication import Communication
Expand Down Expand Up @@ -57,59 +58,87 @@ def __init__(
assert isinstance(self.communication, HivemindBackend)
self.train_timeout = 60 * 60 * 24 * 31 # 1 month

# Ensure log directory exists
if not os.path.exists(log_dir):
try:
os.makedirs(log_dir, exist_ok=True)
except OSError as e:
get_logger().error(f"Failed to create log directory {log_dir}: {e}")
raise

# Logging Setup
self.peer_id = self.communication.get_id()
self.state.peer_id = self.peer_id
self.animal_name = get_name_from_peer_id(self.peer_id, True)
format_msg = f"[{self.animal_name}] %(asctime)s %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=format_msg)
formatter = logging.Formatter(format_msg)
file_handler = logging.FileHandler(
os.path.join(log_dir, f"training_{self.animal_name}.log")
)
file_handler.setFormatter(formatter)
_LOG = get_logger()
_LOG.addHandler(file_handler)

try:
file_handler = logging.FileHandler(
os.path.join(log_dir, f"training_{self.animal_name}.log")
)
file_handler.setFormatter(formatter)
_LOG = get_logger()
_LOG.addHandler(file_handler)
except OSError as e:
get_logger().error(f"Failed to create log file: {e}")
# Continue without file logging if file creation fails

# Register peer_id and get current round from the chain
self.coordinator = coordinator
self.coordinator.register_peer(self.peer_id)
round, _ = self.coordinator.get_round_and_stage()
self.state.round = round
try:
self.coordinator.register_peer(self.peer_id)
round, _ = self.coordinator.get_round_and_stage()
self.state.round = round
except Exception as e:
get_logger().error(f"Failed to register peer or get round/stage: {e}")
raise

self.communication.step_ = (
self.state.round
) # initialize communication module to contract's round

# enable push to HF if token was provided
# Enable push to HF if token was provided and is valid
self.hf_token = hf_token
if self.hf_token not in [None, "None"]:
username = whoami(token=self.hf_token)["name"]
model_name = self.trainer.model.config.name_or_path.split("/")[-1]
model_name += "-Gensyn-Swarm"
model_name += f"-{self.animal_name}"
self.trainer.args.hub_model_id = f"{username}/{model_name}"
self.trainer.args.push_to_hub = True
self.trainer.args.hub_token = self.hf_token
self.hf_push_frequency = hf_push_frequency
get_logger().info("Logging into Hugging Face Hub...")

login(self.hf_token)
if self.hf_token and self.hf_token.strip() and self.hf_token not in ["None", ""]:
try:
username = whoami(token=self.hf_token)["name"]
model_name = self.trainer.model.config.name_or_path.split("/")[-1]
model_name += "-Gensyn-Swarm"
model_name += f"-{self.animal_name}"
self.trainer.args.hub_model_id = f"{username}/{model_name}"
self.trainer.args.push_to_hub = True
self.trainer.args.hub_token = self.hf_token
self.hf_push_frequency = hf_push_frequency
get_logger().info("Logging into Hugging Face Hub...")
login(self.hf_token)
except Exception as e:
get_logger().error(f"Failed to setup HuggingFace Hub: {e}")
self.hf_token = None # Disable HF functionality if setup fails
else:
self.hf_token = None

get_logger().info(
f"🐱 Hello 🐈 [{get_name_from_peer_id(self.peer_id)}] 🦮 [{self.peer_id}]!"
)
get_logger().info(f"bootnodes: {kwargs.get('bootnodes', [])}")
get_logger().info(f"Using Model: {self.trainer.model.config.name_or_path}")

with open(os.path.join(log_dir, f"system_info.txt"), "w") as f:
f.write(get_system_info())
# Write system info to file with proper error handling
try:
with open(os.path.join(log_dir, f"system_info.txt"), "w") as f:
f.write(get_system_info())
except OSError as e:
get_logger().error(f"Failed to write system info: {e}")

self.batched_signals = 0.0
self.time_since_submit = time.time() #seconds
self.submit_period = 3.0 #hours
self.time_since_submit = time.time() # seconds
self.submit_period = 3.0 # hours
self.submitted_this_round = False

def _get_total_rewards_by_agent(self):
"""Calculate total rewards for each agent across all stages."""
rewards_by_agent = defaultdict(int)
for stage in range(self.state.stage):
rewards = self.rewards[stage]
Expand All @@ -123,6 +152,7 @@ def _get_total_rewards_by_agent(self):
return rewards_by_agent

def _get_my_rewards(self, signal_by_agent):
"""Calculate rewards for this peer."""
if len(signal_by_agent) == 0:
return 0
if self.peer_id in signal_by_agent:
Expand All @@ -135,6 +165,7 @@ def _get_my_rewards(self, signal_by_agent):
return my_signal

def _try_submit_to_chain(self, signal_by_agent):
"""Attempt to submit rewards and winners to the blockchain."""
elapsed_time_hours = (time.time() - self.time_since_submit) / 3600
if elapsed_time_hours > self.submit_period:
try:
Expand All @@ -150,6 +181,7 @@ def _try_submit_to_chain(self, signal_by_agent):
self.coordinator.submit_winners(self.state.round, [max_agent], self.peer_id)
self.time_since_submit = time.time()
self.submitted_this_round = True
get_logger().info(f"Successfully submitted to chain for round {self.state.round}")
except Exception as e:
get_logger().exception(
"Failed to submit to chain.\n"
Expand All @@ -160,13 +192,14 @@ def _try_submit_to_chain(self, signal_by_agent):
"including the full stacktrace."
)


def _hook_after_rewards_updated(self):
"""Hook called after rewards are updated."""
signal_by_agent = self._get_total_rewards_by_agent()
self.batched_signals += self._get_my_rewards(signal_by_agent)
self._try_submit_to_chain(signal_by_agent)

def _hook_after_round_advanced(self):
"""Hook called after round is advanced."""
self._save_to_hf()

# Try to submit to chain again if necessary, but don't update our signal twice
Expand All @@ -181,11 +214,13 @@ def _hook_after_round_advanced(self):
self.agent_block()

def _hook_after_game(self):
"""Hook called after game ends."""
self._save_to_hf()

def _save_to_hf(self):
"""Save model to HuggingFace Hub if configured."""
if (
self.hf_token not in [None, "None"]
self.hf_token is not None
and self.state.round % self.hf_push_frequency == 0
):
get_logger().info(f"pushing model to huggingface")
Expand All @@ -206,23 +241,28 @@ def _save_to_hf(self):
f"I am {self.animal_name}",
],
)
except Exception:
get_logger().info(f"Successfully pushed model to HuggingFace Hub: {repo_id}")
except Exception as e:
get_logger().exception(
"Failed to push model to the Hugging Face Hub. When you conclude training please try manually pushing it yourself using the instructions here: https://huggingface.co/docs/hub/en/models-uploading",
stack_info=True,
"Failed to push model to the Hugging Face Hub. When you conclude training please try manually pushing it yourself using the instructions here: https://huggingface.co/docs/hub/en/models-uploading"
)

def agent_block(
self, check_interval=5.0, log_timeout=10.0, max_check_interval=60.0 * 15
):
"""Block until the swarm advances to a new round."""
start_time = time.monotonic()
fetch_log_time = start_time
check_backoff = (
check_interval # Exponential backoff for already finished rounds.
)
while time.monotonic() - start_time < self.train_timeout:
curr_time = time.monotonic()
_ = self.communication.dht.get_visible_maddrs(latest=True)

try:
_ = self.communication.dht.get_visible_maddrs(latest=True)
except Exception as e:
get_logger().debug(f"Failed to get visible maddrs: {e}")

# Retrieve current round and stage.
try:
Expand Down