diff --git a/rgym_exp/src/data.py b/rgym_exp/src/data.py index 714a41d9..dc184c76 100644 --- a/rgym_exp/src/data.py +++ b/rgym_exp/src/data.py @@ -129,8 +129,6 @@ def load_reasoning_gym_dataset( Returns: A Dataset object containing the samples from the reseeding dataset """ - dataset_dict = {"question": [], "answer": [], "metadata": []} - if split in ("test", "validation"): max_samples = self.num_samples["evaluation"] else: # Default to train @@ -139,22 +137,39 @@ def load_reasoning_gym_dataset( if num_samples is not None: max_samples = min(num_samples, max_samples) - for i in range(max_samples): - item = next(self.reseeding_dataset) + # Optimized batch processing - pre-allocate lists for better memory efficiency + dataset_dict = { + "question": [None] * max_samples, + "answer": [None] * max_samples, + "metadata": [None] * max_samples + } + + # Process in smaller batches to reduce memory pressure + batch_size = min(1000, max_samples) + + for batch_start in range(0, max_samples, batch_size): + batch_end = min(batch_start + batch_size, max_samples) + + # Load batch items + batch_items = [] + for i in range(batch_end - batch_start): + batch_items.append(next(self.reseeding_dataset)) - idx = i + # Process batch efficiently + for i, item in enumerate(batch_items): + idx = batch_start + i - dataset_dict["question"].append(item["question"]) - dataset_dict["answer"].append(item["answer"]) + dataset_dict["question"][idx] = item["question"] + dataset_dict["answer"][idx] = item["answer"] - metadata = item.get("metadata", {}) - if not isinstance(metadata, dict): - metadata = {"original_metadata": metadata} + metadata = item.get("metadata", {}) + if not isinstance(metadata, dict): + metadata = {"original_metadata": metadata} - metadata["dataset_index"] = idx - metadata["split"] = split + metadata["dataset_index"] = idx + metadata["split"] = split - dataset_dict["metadata"].append(metadata) + dataset_dict["metadata"][idx] = metadata return Dataset.from_dict(dataset_dict) diff --git a/rgym_exp/src/manager.py b/rgym_exp/src/manager.py index 02c012a9..dca9ae0b 100644 --- a/rgym_exp/src/manager.py +++ b/rgym_exp/src/manager.py @@ -120,25 +120,42 @@ def _get_my_rewards(self, signal_by_agent): def _try_submit_to_chain(self, signal_by_agent): elapsed_time_hours = (time.time() - self.time_since_submit) / 3600 if elapsed_time_hours > self.submit_period: - try: - self.coordinator.submit_reward( - self.state.round, 0, int(self.batched_signals), self.peer_id - ) - self.batched_signals = 0.0 - if len(signal_by_agent) > 0: - max_agent, max_signal = max( - signal_by_agent.items(), key=lambda x: x[1] + # Exponential backoff for chain operations + max_retries = 3 + base_delay = 1.0 # Start with 1 second + + for attempt in range(max_retries): + try: + self.coordinator.submit_reward( + self.state.round, 0, int(self.batched_signals), self.peer_id ) - else: # if we have no signal_by_agents, just submit ourselves. - max_agent = self.peer_id + self.batched_signals = 0.0 - self.coordinator.submit_winners( - self.state.round, [max_agent], self.peer_id - ) - self.time_since_submit = time.time() - self.submitted_this_round = True - except Exception as e: - get_logger().debug(str(e)) + if len(signal_by_agent) > 0: + max_agent, max_signal = max( + signal_by_agent.items(), key=lambda x: x[1] + ) + else: # if we have no signal_by_agents, just submit ourselves. + max_agent = self.peer_id + + self.coordinator.submit_winners( + self.state.round, [max_agent], self.peer_id + ) + self.time_since_submit = time.time() + self.submitted_this_round = True + break # Success - exit retry loop + + except Exception as e: + if attempt < max_retries - 1: + # Calculate exponential backoff delay + delay = base_delay * (2 ** attempt) + get_logger().debug( + f"Chain submission attempt {attempt + 1} failed: {e}. " + f"Retrying in {delay} seconds." + ) + time.sleep(delay) + else: + get_logger().debug(f"All chain submission attempts failed: {e}") def _hook_after_rewards_updated(self): signal_by_agent = self._get_total_rewards_by_agent() diff --git a/rgym_exp/src/trainer.py b/rgym_exp/src/trainer.py index 0d2f0cc5..b571b9cb 100644 --- a/rgym_exp/src/trainer.py +++ b/rgym_exp/src/trainer.py @@ -166,32 +166,52 @@ def _get_choice_logits(self, input_ids: torch.Tensor, choices: List[str]) -> tor Returns a tensor of shape (len(choices),) giving, for each choice, the sum of log-probabilities that the model assigns to generating "{choice}" after the given input_ids. - """ + Optimized version that processes choices in batches and uses memory-efficient computation. + """ device = input_ids.device batch_size, prompt_len = input_ids.shape - logits_list = [] + # Pre-tokenize all choices for efficiency + choice_token_cache = {} for choice in choices: - # 1) build the full token sequence: prompt + "" - # TODO: Make the dtype changes from genrl here? answer_str = f"{choice}" - choice_ids = self.processing_class( + choice_token_cache[choice] = self.processing_class( answer_str, return_tensors="pt", add_special_tokens=False - ).input_ids.to(device) # shape (1, L) + ).input_ids.to(device) + + # Process choices in batches to optimize memory usage + max_batch_size = min(8, len(choices)) # Avoid OOM with large choice sets + logits_list = [] + + for i in range(0, len(choices), max_batch_size): + batch_choices = choices[i:i + max_batch_size] + batch_logits = [] + + # Process each choice in the current batch + for choice in batch_choices: + choice_ids = choice_token_cache[choice] + + # Build sequence more efficiently + seq = torch.cat([input_ids, choice_ids], dim=1) + + # Use no_grad for memory efficiency since we only need the loss + with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): + # Build labels efficiently + labels = seq.clone() + labels[:, :prompt_len] = -100 # ignore prompt positions - seq = torch.cat([input_ids, choice_ids], dim=1) # (1, prompt_len + L) + outputs = self.model(input_ids=seq, labels=labels) + total_log_prob = -outputs.loss * choice_ids.size(1) + batch_logits.append(total_log_prob) - # build labels that only include the answer positions - labels = seq.clone() - labels[:, :prompt_len] = -100 # ignore prompt positions in loss - outputs = self.model(input_ids=seq, labels=labels) - # outputs.loss is average negative log-likelihood over the L answer tokens + # Clear GPU cache after each choice to prevent OOM + if torch.cuda.is_available(): + torch.cuda.empty_cache() - total_log_prob = -outputs.loss * choice_ids.size(1) - logits_list.append(total_log_prob) + logits_list.extend(batch_logits) - # stack into a single tensor of shape (num_choices,) + # Stack into a single tensor of shape (num_choices,) return torch.stack(logits_list) \ No newline at end of file diff --git a/rgym_exp/src/utils/judge_client.py b/rgym_exp/src/utils/judge_client.py index 22bc2270..1dbeb69b 100644 --- a/rgym_exp/src/utils/judge_client.py +++ b/rgym_exp/src/utils/judge_client.py @@ -1,33 +1,97 @@ import requests +from requests.adapters import HTTPAdapter +try: + # Try modern import path first (urllib3 2.0+) + from urllib3.util.retry import Retry +except ImportError: + # Fallback to legacy import path (older versions) + from requests.packages.urllib3.util.retry import Retry from typing import Dict, Any, Optional +import time +import socket from genrl.logging_utils.global_defs import get_logger +class SocketOptionsHTTPAdapter(HTTPAdapter): + """ + Custom HTTPAdapter that supports socket options. + + The socket_options parameter is not directly supported by requests.adapters.HTTPAdapter, + but can be implemented by overriding init_poolmanager to pass options to urllib3. + """ + + def __init__(self, socket_options=None, *args, **kwargs): + self.socket_options = socket_options + super().__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + if self.socket_options is not None: + kwargs["socket_options"] = self.socket_options + return super().init_poolmanager(*args, **kwargs) + + class JudgeClient: """ Client for interacting with the judge API service. - Handles question requests and answer submissions. + Handles question requests and answer submissions with connection pooling and retry logic. """ - - def __init__(self, base_url: str): + + def __init__(self, base_url: str, timeout: int = 30): """ - Initialize the judge client. - + Initialize the judge client with performance optimizations. + Args: base_url: Base URL for the judge API service + timeout: Request timeout in seconds """ self.base_url = base_url.rstrip('/') self.logger = get_logger() + self.timeout = timeout + + # Set up session with connection pooling and retry strategy + self.session = requests.Session() + + # Configure retry strategy + retry_strategy = Retry( + total=3, # Total number of retries + backoff_factor=1, # Wait time multiplier between retries + status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry + allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "POST"] + ) + + # Configure connection adapter with pooling and socket options + adapter = SocketOptionsHTTPAdapter( + max_retries=retry_strategy, + pool_connections=10, # Number of connection pools + pool_maxsize=20, # Max connections per pool + socket_options=[(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)] # SO_REUSEADDR + ) + + # Mount adapter for both HTTP and HTTPS + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + # Set session headers for keep-alive + self.session.headers.update({ + 'Connection': 'keep-alive', + 'Content-Type': 'application/json', + 'User-Agent': 'rl-swarm-judge-client/1.0' + }) + + def __del__(self): + """Clean up session resources.""" + if hasattr(self, 'session'): + self.session.close() def request_question(self, user_id: str, round_number: int, model_name: str) -> Optional[Dict[str, Any]]: """ Request a question from the judge service. - + Args: user_id: ID of the user/peer round_number: Current round number model_name: Name of the model being used - + Returns: Dictionary containing question data or None if request failed """ @@ -37,12 +101,13 @@ def request_question(self, user_id: str, round_number: int, model_name: str) -> "round_number": round_number, "model_name": model_name, } - - response = requests.post( - f"{self.base_url}/request-question/", - json=request_data + + response = self.session.post( + f"{self.base_url}/request-question/", + json=request_data, + timeout=self.timeout ) - + if response.status_code == 200: result = response.json() self.logger.debug(f'Received question: {result["question"]}') @@ -50,7 +115,7 @@ def request_question(self, user_id: str, round_number: int, model_name: str) -> else: self.logger.debug(f"Failed to receive question: {response.status_code}") return None - + except Exception as e: self.logger.debug(f"Failed to request question: {e}") return None @@ -58,13 +123,16 @@ def request_question(self, user_id: str, round_number: int, model_name: str) -> def get_current_clue(self) -> Optional[Dict[str, Any]]: """ Get the current clue from the judge service. - + Returns: Dictionary containing clue data or None if request failed """ try: - response = requests.get(f"{self.base_url}/current_clue/") - + response = self.session.get( + f"{self.base_url}/current_clue/", + timeout=self.timeout + ) + if response.status_code == 200: result = response.json() self.logger.debug(f'Received clue: {result["clue"]}') @@ -72,21 +140,20 @@ def get_current_clue(self) -> Optional[Dict[str, Any]]: else: self.logger.debug(f"Failed to receive clue: {response.status_code}") return None - + except Exception as e: self.logger.debug(f"Failed to get current clue: {e}") return None - def submit_answer(self, session_id: str, round_number: int, user_answer: str) -> Optional[Dict[str, Any]]: """ Submit an answer to the judge service. - + Args: session_id: Session ID from the question request round_number: Current round number user_answer: The user's answer to submit - + Returns: Dictionary containing score data or None if submission failed """ @@ -97,9 +164,10 @@ def submit_answer(self, session_id: str, round_number: int, user_answer: str) -> "user_answer": user_answer, } - response = requests.post( - f"{self.base_url}/submit-answer/", - json=submission_data + response = self.session.post( + f"{self.base_url}/submit-answer/", + json=submission_data, + timeout=self.timeout ) if response.status_code == 200: diff --git a/web/api/dht_pub.py b/web/api/dht_pub.py index 52e32801..43aa2725 100644 --- a/web/api/dht_pub.py +++ b/web/api/dht_pub.py @@ -6,7 +6,7 @@ import random from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Any, Optional +from typing import Any, Optional, Dict, Tuple from .game_tree import Payload, from_bytes from hivemind.dht import DHT @@ -63,6 +63,12 @@ def __init__( self.last_polled = None self.poll_id = None + # Performance optimization: DHT result caching + self._cache_ttl = 60 # Cache TTL in seconds + self._rewards_cache: Dict[str, Tuple[float, Any]] = {} + self._outputs_cache: Dict[str, Tuple[float, Any]] = {} + self._peer_name_cache: Dict[str, str] = {} + # Store the class name for use in logging self.class_name = self.__class__.__name__ @@ -96,22 +102,76 @@ def get_last_polled(self): """Get the time of the last poll.""" return self.last_polled + def _is_cache_valid(self, cache_time: float) -> bool: + """Check if cache entry is still valid based on TTL.""" + return (time.time() - cache_time) < self._cache_ttl + def _get_rewards_data( self, round_num: int, stage_num: int ) -> dict[str, Any] | None: rewards_key_str = rewards_key(round_num, stage_num) + + # Check cache first + if rewards_key_str in self._rewards_cache: + cache_time, cached_data = self._rewards_cache[rewards_key_str] + if self._is_cache_valid(cache_time): + return cached_data + + # Cache miss or expired - fetch from DHT rewards_data = get_dht_value(self.dht, key=rewards_key_str, beam_size=500) + + # Cache the result + self._rewards_cache[rewards_key_str] = (time.time(), rewards_data) + + # Clean old cache entries periodically + if len(self._rewards_cache) > 100: # Arbitrary limit + self._cleanup_cache(self._rewards_cache) + return rewards_data def _get_outputs_data( self, node_key: str, round_num: int, stage_num: int ) -> dict[str, Any] | None: outputs_key_str = outputs_key(node_key, round_num, stage_num) + + # Check cache first + if outputs_key_str in self._outputs_cache: + cache_time, cached_data = self._outputs_cache[outputs_key_str] + if self._is_cache_valid(cache_time): + return cached_data + + # Cache miss or expired - fetch from DHT outputs_data = get_dht_value(self.dht, key=outputs_key_str) + + # Cache the result + self._outputs_cache[outputs_key_str] = (time.time(), outputs_data) + + # Clean old cache entries periodically + if len(self._outputs_cache) > 100: # Arbitrary limit + self._cleanup_cache(self._outputs_cache) + return outputs_data def _get_peer_name_from_id(self, peer_id: str) -> str: - return get_name_from_peer_id(peer_id) or peer_id + # Check cache first for peer names (these rarely change) + if peer_id in self._peer_name_cache: + return self._peer_name_cache[peer_id] + + # Cache miss - fetch from utils + peer_name = get_name_from_peer_id(peer_id) or peer_id + self._peer_name_cache[peer_id] = peer_name + + return peer_name + + def _cleanup_cache(self, cache: Dict[str, Tuple[float, Any]]) -> None: + """Remove expired entries from cache.""" + current_time = time.time() + expired_keys = [ + key for key, (cache_time, _) in cache.items() + if (current_time - cache_time) > self._cache_ttl + ] + for key in expired_keys: + del cache[key] def _poll_loop(self): """Main polling loop.""" @@ -203,34 +263,50 @@ def _poll_once(self): bytes = value_with_expiration.value payload_dict = from_bytes(bytes) - # Flatten the payloads into a list of payloads. + # Optimized batch processing of payloads all_payloads = [] for _, payload_list in payload_dict.items(): all_payloads.extend(payload_list) - # For each payload, generate a gossip message. - for payload in all_payloads: - world_state_tuple = payload.world_state - question = world_state_tuple.environment_states["question"] - actions = payload.actions - source_dataset = world_state_tuple.environment_states["metadata"]["source_dataset"] - action = random.choice(actions) if actions else "" - - # Stamp the message with the current time. - now_utc = datetime.now(timezone.utc) - ts = int(now_utc.timestamp()) - - # Generate a unique ID for the gossip message. - gossip_id = hashlib.md5(f"{question}-{peer_id}-{self.current_round}-{action}-{source_dataset}".encode()).hexdigest() - round_gossip.append(( - ts, { - "id": gossip_id, - "message": f"{question}...{action}", - "node": get_name_from_peer_id(peer_id), - "nodeId": peer_id, - "dataset": source_dataset, - } - )) + # Process payloads in batches for better performance + batch_size = min(50, len(all_payloads)) # Process up to 50 at a time + now_utc = datetime.now(timezone.utc) + ts = int(now_utc.timestamp()) + + # Pre-compute peer name once + peer_name = self._get_peer_name_from_id(peer_id) + + for i in range(0, len(all_payloads), batch_size): + batch_payloads = all_payloads[i:i + batch_size] + + # Process batch efficiently + batch_gossip = [] + for payload in batch_payloads: + try: + world_state_tuple = payload.world_state + question = world_state_tuple.environment_states["question"] + actions = payload.actions + source_dataset = world_state_tuple.environment_states["metadata"]["source_dataset"] + action = random.choice(actions) if actions else "" + + # Generate unique ID more efficiently + id_string = f"{question}-{peer_id}-{self.current_round}-{action}-{source_dataset}" + gossip_id = hashlib.md5(id_string.encode()).hexdigest() + + batch_gossip.append(( + ts, { + "id": gossip_id, + "message": f"{question}...{action}", + "node": peer_name, # Use pre-computed name + "nodeId": peer_id, + "dataset": source_dataset, + } + )) + except (KeyError, AttributeError) as e: + self.logger.debug(f"Skipping malformed payload: {e}") + continue + + round_gossip.extend(batch_gossip) self.logger.info("Got gossip messages", extra={ "message_count": len(round_gossip),