diff --git a/rgym_exp/src/data.py b/rgym_exp/src/data.py
index 714a41d9..dc184c76 100644
--- a/rgym_exp/src/data.py
+++ b/rgym_exp/src/data.py
@@ -129,8 +129,6 @@ def load_reasoning_gym_dataset(
Returns:
A Dataset object containing the samples from the reseeding dataset
"""
- dataset_dict = {"question": [], "answer": [], "metadata": []}
-
if split in ("test", "validation"):
max_samples = self.num_samples["evaluation"]
else: # Default to train
@@ -139,22 +137,39 @@ def load_reasoning_gym_dataset(
if num_samples is not None:
max_samples = min(num_samples, max_samples)
- for i in range(max_samples):
- item = next(self.reseeding_dataset)
+ # Optimized batch processing - pre-allocate lists for better memory efficiency
+ dataset_dict = {
+ "question": [None] * max_samples,
+ "answer": [None] * max_samples,
+ "metadata": [None] * max_samples
+ }
+
+ # Process in smaller batches to reduce memory pressure
+ batch_size = min(1000, max_samples)
+
+ for batch_start in range(0, max_samples, batch_size):
+ batch_end = min(batch_start + batch_size, max_samples)
+
+ # Load batch items
+ batch_items = []
+ for i in range(batch_end - batch_start):
+ batch_items.append(next(self.reseeding_dataset))
- idx = i
+ # Process batch efficiently
+ for i, item in enumerate(batch_items):
+ idx = batch_start + i
- dataset_dict["question"].append(item["question"])
- dataset_dict["answer"].append(item["answer"])
+ dataset_dict["question"][idx] = item["question"]
+ dataset_dict["answer"][idx] = item["answer"]
- metadata = item.get("metadata", {})
- if not isinstance(metadata, dict):
- metadata = {"original_metadata": metadata}
+ metadata = item.get("metadata", {})
+ if not isinstance(metadata, dict):
+ metadata = {"original_metadata": metadata}
- metadata["dataset_index"] = idx
- metadata["split"] = split
+ metadata["dataset_index"] = idx
+ metadata["split"] = split
- dataset_dict["metadata"].append(metadata)
+ dataset_dict["metadata"][idx] = metadata
return Dataset.from_dict(dataset_dict)
diff --git a/rgym_exp/src/manager.py b/rgym_exp/src/manager.py
index 02c012a9..dca9ae0b 100644
--- a/rgym_exp/src/manager.py
+++ b/rgym_exp/src/manager.py
@@ -120,25 +120,42 @@ def _get_my_rewards(self, signal_by_agent):
def _try_submit_to_chain(self, signal_by_agent):
elapsed_time_hours = (time.time() - self.time_since_submit) / 3600
if elapsed_time_hours > self.submit_period:
- try:
- self.coordinator.submit_reward(
- self.state.round, 0, int(self.batched_signals), self.peer_id
- )
- self.batched_signals = 0.0
- if len(signal_by_agent) > 0:
- max_agent, max_signal = max(
- signal_by_agent.items(), key=lambda x: x[1]
+ # Exponential backoff for chain operations
+ max_retries = 3
+ base_delay = 1.0 # Start with 1 second
+
+ for attempt in range(max_retries):
+ try:
+ self.coordinator.submit_reward(
+ self.state.round, 0, int(self.batched_signals), self.peer_id
)
- else: # if we have no signal_by_agents, just submit ourselves.
- max_agent = self.peer_id
+ self.batched_signals = 0.0
- self.coordinator.submit_winners(
- self.state.round, [max_agent], self.peer_id
- )
- self.time_since_submit = time.time()
- self.submitted_this_round = True
- except Exception as e:
- get_logger().debug(str(e))
+ if len(signal_by_agent) > 0:
+ max_agent, max_signal = max(
+ signal_by_agent.items(), key=lambda x: x[1]
+ )
+ else: # if we have no signal_by_agents, just submit ourselves.
+ max_agent = self.peer_id
+
+ self.coordinator.submit_winners(
+ self.state.round, [max_agent], self.peer_id
+ )
+ self.time_since_submit = time.time()
+ self.submitted_this_round = True
+ break # Success - exit retry loop
+
+ except Exception as e:
+ if attempt < max_retries - 1:
+ # Calculate exponential backoff delay
+ delay = base_delay * (2 ** attempt)
+ get_logger().debug(
+ f"Chain submission attempt {attempt + 1} failed: {e}. "
+ f"Retrying in {delay} seconds."
+ )
+ time.sleep(delay)
+ else:
+ get_logger().debug(f"All chain submission attempts failed: {e}")
def _hook_after_rewards_updated(self):
signal_by_agent = self._get_total_rewards_by_agent()
diff --git a/rgym_exp/src/trainer.py b/rgym_exp/src/trainer.py
index 0d2f0cc5..b571b9cb 100644
--- a/rgym_exp/src/trainer.py
+++ b/rgym_exp/src/trainer.py
@@ -166,32 +166,52 @@ def _get_choice_logits(self, input_ids: torch.Tensor, choices: List[str]) -> tor
Returns a tensor of shape (len(choices),) giving, for each choice,
the sum of log-probabilities that the model assigns to generating
"{choice}" after the given input_ids.
- """
+ Optimized version that processes choices in batches and uses memory-efficient computation.
+ """
device = input_ids.device
batch_size, prompt_len = input_ids.shape
- logits_list = []
+ # Pre-tokenize all choices for efficiency
+ choice_token_cache = {}
for choice in choices:
- # 1) build the full token sequence: prompt + "…"
- # TODO: Make the dtype changes from genrl here?
answer_str = f"{choice}"
- choice_ids = self.processing_class(
+ choice_token_cache[choice] = self.processing_class(
answer_str,
return_tensors="pt",
add_special_tokens=False
- ).input_ids.to(device) # shape (1, L)
+ ).input_ids.to(device)
+
+ # Process choices in batches to optimize memory usage
+ max_batch_size = min(8, len(choices)) # Avoid OOM with large choice sets
+ logits_list = []
+
+ for i in range(0, len(choices), max_batch_size):
+ batch_choices = choices[i:i + max_batch_size]
+ batch_logits = []
+
+ # Process each choice in the current batch
+ for choice in batch_choices:
+ choice_ids = choice_token_cache[choice]
+
+ # Build sequence more efficiently
+ seq = torch.cat([input_ids, choice_ids], dim=1)
+
+ # Use no_grad for memory efficiency since we only need the loss
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
+ # Build labels efficiently
+ labels = seq.clone()
+ labels[:, :prompt_len] = -100 # ignore prompt positions
- seq = torch.cat([input_ids, choice_ids], dim=1) # (1, prompt_len + L)
+ outputs = self.model(input_ids=seq, labels=labels)
+ total_log_prob = -outputs.loss * choice_ids.size(1)
+ batch_logits.append(total_log_prob)
- # build labels that only include the answer positions
- labels = seq.clone()
- labels[:, :prompt_len] = -100 # ignore prompt positions in loss
- outputs = self.model(input_ids=seq, labels=labels)
- # outputs.loss is average negative log-likelihood over the L answer tokens
+ # Clear GPU cache after each choice to prevent OOM
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
- total_log_prob = -outputs.loss * choice_ids.size(1)
- logits_list.append(total_log_prob)
+ logits_list.extend(batch_logits)
- # stack into a single tensor of shape (num_choices,)
+ # Stack into a single tensor of shape (num_choices,)
return torch.stack(logits_list)
\ No newline at end of file
diff --git a/rgym_exp/src/utils/judge_client.py b/rgym_exp/src/utils/judge_client.py
index 22bc2270..1dbeb69b 100644
--- a/rgym_exp/src/utils/judge_client.py
+++ b/rgym_exp/src/utils/judge_client.py
@@ -1,33 +1,97 @@
import requests
+from requests.adapters import HTTPAdapter
+try:
+ # Try modern import path first (urllib3 2.0+)
+ from urllib3.util.retry import Retry
+except ImportError:
+ # Fallback to legacy import path (older versions)
+ from requests.packages.urllib3.util.retry import Retry
from typing import Dict, Any, Optional
+import time
+import socket
from genrl.logging_utils.global_defs import get_logger
+class SocketOptionsHTTPAdapter(HTTPAdapter):
+ """
+ Custom HTTPAdapter that supports socket options.
+
+ The socket_options parameter is not directly supported by requests.adapters.HTTPAdapter,
+ but can be implemented by overriding init_poolmanager to pass options to urllib3.
+ """
+
+ def __init__(self, socket_options=None, *args, **kwargs):
+ self.socket_options = socket_options
+ super().__init__(*args, **kwargs)
+
+ def init_poolmanager(self, *args, **kwargs):
+ if self.socket_options is not None:
+ kwargs["socket_options"] = self.socket_options
+ return super().init_poolmanager(*args, **kwargs)
+
+
class JudgeClient:
"""
Client for interacting with the judge API service.
- Handles question requests and answer submissions.
+ Handles question requests and answer submissions with connection pooling and retry logic.
"""
-
- def __init__(self, base_url: str):
+
+ def __init__(self, base_url: str, timeout: int = 30):
"""
- Initialize the judge client.
-
+ Initialize the judge client with performance optimizations.
+
Args:
base_url: Base URL for the judge API service
+ timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip('/')
self.logger = get_logger()
+ self.timeout = timeout
+
+ # Set up session with connection pooling and retry strategy
+ self.session = requests.Session()
+
+ # Configure retry strategy
+ retry_strategy = Retry(
+ total=3, # Total number of retries
+ backoff_factor=1, # Wait time multiplier between retries
+ status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry
+ allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "POST"]
+ )
+
+ # Configure connection adapter with pooling and socket options
+ adapter = SocketOptionsHTTPAdapter(
+ max_retries=retry_strategy,
+ pool_connections=10, # Number of connection pools
+ pool_maxsize=20, # Max connections per pool
+ socket_options=[(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)] # SO_REUSEADDR
+ )
+
+ # Mount adapter for both HTTP and HTTPS
+ self.session.mount("http://", adapter)
+ self.session.mount("https://", adapter)
+
+ # Set session headers for keep-alive
+ self.session.headers.update({
+ 'Connection': 'keep-alive',
+ 'Content-Type': 'application/json',
+ 'User-Agent': 'rl-swarm-judge-client/1.0'
+ })
+
+ def __del__(self):
+ """Clean up session resources."""
+ if hasattr(self, 'session'):
+ self.session.close()
def request_question(self, user_id: str, round_number: int, model_name: str) -> Optional[Dict[str, Any]]:
"""
Request a question from the judge service.
-
+
Args:
user_id: ID of the user/peer
round_number: Current round number
model_name: Name of the model being used
-
+
Returns:
Dictionary containing question data or None if request failed
"""
@@ -37,12 +101,13 @@ def request_question(self, user_id: str, round_number: int, model_name: str) ->
"round_number": round_number,
"model_name": model_name,
}
-
- response = requests.post(
- f"{self.base_url}/request-question/",
- json=request_data
+
+ response = self.session.post(
+ f"{self.base_url}/request-question/",
+ json=request_data,
+ timeout=self.timeout
)
-
+
if response.status_code == 200:
result = response.json()
self.logger.debug(f'Received question: {result["question"]}')
@@ -50,7 +115,7 @@ def request_question(self, user_id: str, round_number: int, model_name: str) ->
else:
self.logger.debug(f"Failed to receive question: {response.status_code}")
return None
-
+
except Exception as e:
self.logger.debug(f"Failed to request question: {e}")
return None
@@ -58,13 +123,16 @@ def request_question(self, user_id: str, round_number: int, model_name: str) ->
def get_current_clue(self) -> Optional[Dict[str, Any]]:
"""
Get the current clue from the judge service.
-
+
Returns:
Dictionary containing clue data or None if request failed
"""
try:
- response = requests.get(f"{self.base_url}/current_clue/")
-
+ response = self.session.get(
+ f"{self.base_url}/current_clue/",
+ timeout=self.timeout
+ )
+
if response.status_code == 200:
result = response.json()
self.logger.debug(f'Received clue: {result["clue"]}')
@@ -72,21 +140,20 @@ def get_current_clue(self) -> Optional[Dict[str, Any]]:
else:
self.logger.debug(f"Failed to receive clue: {response.status_code}")
return None
-
+
except Exception as e:
self.logger.debug(f"Failed to get current clue: {e}")
return None
-
def submit_answer(self, session_id: str, round_number: int, user_answer: str) -> Optional[Dict[str, Any]]:
"""
Submit an answer to the judge service.
-
+
Args:
session_id: Session ID from the question request
round_number: Current round number
user_answer: The user's answer to submit
-
+
Returns:
Dictionary containing score data or None if submission failed
"""
@@ -97,9 +164,10 @@ def submit_answer(self, session_id: str, round_number: int, user_answer: str) ->
"user_answer": user_answer,
}
- response = requests.post(
- f"{self.base_url}/submit-answer/",
- json=submission_data
+ response = self.session.post(
+ f"{self.base_url}/submit-answer/",
+ json=submission_data,
+ timeout=self.timeout
)
if response.status_code == 200:
diff --git a/web/api/dht_pub.py b/web/api/dht_pub.py
index 52e32801..43aa2725 100644
--- a/web/api/dht_pub.py
+++ b/web/api/dht_pub.py
@@ -6,7 +6,7 @@
import random
from abc import ABC, abstractmethod
from datetime import datetime, timezone
-from typing import Any, Optional
+from typing import Any, Optional, Dict, Tuple
from .game_tree import Payload, from_bytes
from hivemind.dht import DHT
@@ -63,6 +63,12 @@ def __init__(
self.last_polled = None
self.poll_id = None
+ # Performance optimization: DHT result caching
+ self._cache_ttl = 60 # Cache TTL in seconds
+ self._rewards_cache: Dict[str, Tuple[float, Any]] = {}
+ self._outputs_cache: Dict[str, Tuple[float, Any]] = {}
+ self._peer_name_cache: Dict[str, str] = {}
+
# Store the class name for use in logging
self.class_name = self.__class__.__name__
@@ -96,22 +102,76 @@ def get_last_polled(self):
"""Get the time of the last poll."""
return self.last_polled
+ def _is_cache_valid(self, cache_time: float) -> bool:
+ """Check if cache entry is still valid based on TTL."""
+ return (time.time() - cache_time) < self._cache_ttl
+
def _get_rewards_data(
self, round_num: int, stage_num: int
) -> dict[str, Any] | None:
rewards_key_str = rewards_key(round_num, stage_num)
+
+ # Check cache first
+ if rewards_key_str in self._rewards_cache:
+ cache_time, cached_data = self._rewards_cache[rewards_key_str]
+ if self._is_cache_valid(cache_time):
+ return cached_data
+
+ # Cache miss or expired - fetch from DHT
rewards_data = get_dht_value(self.dht, key=rewards_key_str, beam_size=500)
+
+ # Cache the result
+ self._rewards_cache[rewards_key_str] = (time.time(), rewards_data)
+
+ # Clean old cache entries periodically
+ if len(self._rewards_cache) > 100: # Arbitrary limit
+ self._cleanup_cache(self._rewards_cache)
+
return rewards_data
def _get_outputs_data(
self, node_key: str, round_num: int, stage_num: int
) -> dict[str, Any] | None:
outputs_key_str = outputs_key(node_key, round_num, stage_num)
+
+ # Check cache first
+ if outputs_key_str in self._outputs_cache:
+ cache_time, cached_data = self._outputs_cache[outputs_key_str]
+ if self._is_cache_valid(cache_time):
+ return cached_data
+
+ # Cache miss or expired - fetch from DHT
outputs_data = get_dht_value(self.dht, key=outputs_key_str)
+
+ # Cache the result
+ self._outputs_cache[outputs_key_str] = (time.time(), outputs_data)
+
+ # Clean old cache entries periodically
+ if len(self._outputs_cache) > 100: # Arbitrary limit
+ self._cleanup_cache(self._outputs_cache)
+
return outputs_data
def _get_peer_name_from_id(self, peer_id: str) -> str:
- return get_name_from_peer_id(peer_id) or peer_id
+ # Check cache first for peer names (these rarely change)
+ if peer_id in self._peer_name_cache:
+ return self._peer_name_cache[peer_id]
+
+ # Cache miss - fetch from utils
+ peer_name = get_name_from_peer_id(peer_id) or peer_id
+ self._peer_name_cache[peer_id] = peer_name
+
+ return peer_name
+
+ def _cleanup_cache(self, cache: Dict[str, Tuple[float, Any]]) -> None:
+ """Remove expired entries from cache."""
+ current_time = time.time()
+ expired_keys = [
+ key for key, (cache_time, _) in cache.items()
+ if (current_time - cache_time) > self._cache_ttl
+ ]
+ for key in expired_keys:
+ del cache[key]
def _poll_loop(self):
"""Main polling loop."""
@@ -203,34 +263,50 @@ def _poll_once(self):
bytes = value_with_expiration.value
payload_dict = from_bytes(bytes)
- # Flatten the payloads into a list of payloads.
+ # Optimized batch processing of payloads
all_payloads = []
for _, payload_list in payload_dict.items():
all_payloads.extend(payload_list)
- # For each payload, generate a gossip message.
- for payload in all_payloads:
- world_state_tuple = payload.world_state
- question = world_state_tuple.environment_states["question"]
- actions = payload.actions
- source_dataset = world_state_tuple.environment_states["metadata"]["source_dataset"]
- action = random.choice(actions) if actions else ""
-
- # Stamp the message with the current time.
- now_utc = datetime.now(timezone.utc)
- ts = int(now_utc.timestamp())
-
- # Generate a unique ID for the gossip message.
- gossip_id = hashlib.md5(f"{question}-{peer_id}-{self.current_round}-{action}-{source_dataset}".encode()).hexdigest()
- round_gossip.append((
- ts, {
- "id": gossip_id,
- "message": f"{question}...{action}",
- "node": get_name_from_peer_id(peer_id),
- "nodeId": peer_id,
- "dataset": source_dataset,
- }
- ))
+ # Process payloads in batches for better performance
+ batch_size = min(50, len(all_payloads)) # Process up to 50 at a time
+ now_utc = datetime.now(timezone.utc)
+ ts = int(now_utc.timestamp())
+
+ # Pre-compute peer name once
+ peer_name = self._get_peer_name_from_id(peer_id)
+
+ for i in range(0, len(all_payloads), batch_size):
+ batch_payloads = all_payloads[i:i + batch_size]
+
+ # Process batch efficiently
+ batch_gossip = []
+ for payload in batch_payloads:
+ try:
+ world_state_tuple = payload.world_state
+ question = world_state_tuple.environment_states["question"]
+ actions = payload.actions
+ source_dataset = world_state_tuple.environment_states["metadata"]["source_dataset"]
+ action = random.choice(actions) if actions else ""
+
+ # Generate unique ID more efficiently
+ id_string = f"{question}-{peer_id}-{self.current_round}-{action}-{source_dataset}"
+ gossip_id = hashlib.md5(id_string.encode()).hexdigest()
+
+ batch_gossip.append((
+ ts, {
+ "id": gossip_id,
+ "message": f"{question}...{action}",
+ "node": peer_name, # Use pre-computed name
+ "nodeId": peer_id,
+ "dataset": source_dataset,
+ }
+ ))
+ except (KeyError, AttributeError) as e:
+ self.logger.debug(f"Skipping malformed payload: {e}")
+ continue
+
+ round_gossip.extend(batch_gossip)
self.logger.info("Got gossip messages", extra={
"message_count": len(round_gossip),