Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions rgym_exp/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ def load_reasoning_gym_dataset(
Returns:
A Dataset object containing the samples from the reseeding dataset
"""
dataset_dict = {"question": [], "answer": [], "metadata": []}

if split in ("test", "validation"):
max_samples = self.num_samples["evaluation"]
else: # Default to train
Expand All @@ -139,22 +137,39 @@ def load_reasoning_gym_dataset(
if num_samples is not None:
max_samples = min(num_samples, max_samples)

for i in range(max_samples):
item = next(self.reseeding_dataset)
# Optimized batch processing - pre-allocate lists for better memory efficiency
dataset_dict = {
"question": [None] * max_samples,
"answer": [None] * max_samples,
"metadata": [None] * max_samples
}

# Process in smaller batches to reduce memory pressure
batch_size = min(1000, max_samples)

for batch_start in range(0, max_samples, batch_size):
batch_end = min(batch_start + batch_size, max_samples)

# Load batch items
batch_items = []
for i in range(batch_end - batch_start):
batch_items.append(next(self.reseeding_dataset))

idx = i
# Process batch efficiently
for i, item in enumerate(batch_items):
idx = batch_start + i

dataset_dict["question"].append(item["question"])
dataset_dict["answer"].append(item["answer"])
dataset_dict["question"][idx] = item["question"]
dataset_dict["answer"][idx] = item["answer"]

metadata = item.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {"original_metadata": metadata}
metadata = item.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {"original_metadata": metadata}

metadata["dataset_index"] = idx
metadata["split"] = split
metadata["dataset_index"] = idx
metadata["split"] = split

dataset_dict["metadata"].append(metadata)
dataset_dict["metadata"][idx] = metadata

return Dataset.from_dict(dataset_dict)

Expand Down
51 changes: 34 additions & 17 deletions rgym_exp/src/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,42 @@ def _get_my_rewards(self, signal_by_agent):
def _try_submit_to_chain(self, signal_by_agent):
elapsed_time_hours = (time.time() - self.time_since_submit) / 3600
if elapsed_time_hours > self.submit_period:
try:
self.coordinator.submit_reward(
self.state.round, 0, int(self.batched_signals), self.peer_id
)
self.batched_signals = 0.0
if len(signal_by_agent) > 0:
max_agent, max_signal = max(
signal_by_agent.items(), key=lambda x: x[1]
# Exponential backoff for chain operations
max_retries = 3
base_delay = 1.0 # Start with 1 second

for attempt in range(max_retries):
try:
self.coordinator.submit_reward(
self.state.round, 0, int(self.batched_signals), self.peer_id
)
else: # if we have no signal_by_agents, just submit ourselves.
max_agent = self.peer_id
self.batched_signals = 0.0

self.coordinator.submit_winners(
self.state.round, [max_agent], self.peer_id
)
self.time_since_submit = time.time()
self.submitted_this_round = True
except Exception as e:
get_logger().debug(str(e))
if len(signal_by_agent) > 0:
max_agent, max_signal = max(
signal_by_agent.items(), key=lambda x: x[1]
)
else: # if we have no signal_by_agents, just submit ourselves.
max_agent = self.peer_id

self.coordinator.submit_winners(
self.state.round, [max_agent], self.peer_id
)
self.time_since_submit = time.time()
self.submitted_this_round = True
break # Success - exit retry loop

except Exception as e:
if attempt < max_retries - 1:
# Calculate exponential backoff delay
delay = base_delay * (2 ** attempt)
get_logger().debug(
f"Chain submission attempt {attempt + 1} failed: {e}. "
f"Retrying in {delay} seconds."
)
time.sleep(delay)
else:
get_logger().debug(f"All chain submission attempts failed: {e}")

def _hook_after_rewards_updated(self):
signal_by_agent = self._get_total_rewards_by_agent()
Expand Down
50 changes: 35 additions & 15 deletions rgym_exp/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,52 @@ def _get_choice_logits(self, input_ids: torch.Tensor, choices: List[str]) -> tor
Returns a tensor of shape (len(choices),) giving, for each choice,
the sum of log-probabilities that the model assigns to generating
"<answer>{choice}</answer>" after the given input_ids.
"""

Optimized version that processes choices in batches and uses memory-efficient computation.
"""
device = input_ids.device
batch_size, prompt_len = input_ids.shape
logits_list = []

# Pre-tokenize all choices for efficiency
choice_token_cache = {}
for choice in choices:
# 1) build the full token sequence: prompt + "<answer>…</answer>"
# TODO: Make the dtype changes from genrl here?
answer_str = f"<answer>{choice}</answer>"
choice_ids = self.processing_class(
choice_token_cache[choice] = self.processing_class(
answer_str,
return_tensors="pt",
add_special_tokens=False
).input_ids.to(device) # shape (1, L)
).input_ids.to(device)

# Process choices in batches to optimize memory usage
max_batch_size = min(8, len(choices)) # Avoid OOM with large choice sets
logits_list = []

for i in range(0, len(choices), max_batch_size):
batch_choices = choices[i:i + max_batch_size]
batch_logits = []

# Process each choice in the current batch
for choice in batch_choices:
choice_ids = choice_token_cache[choice]

# Build sequence more efficiently
seq = torch.cat([input_ids, choice_ids], dim=1)

# Use no_grad for memory efficiency since we only need the loss
with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
# Build labels efficiently
labels = seq.clone()
labels[:, :prompt_len] = -100 # ignore prompt positions

seq = torch.cat([input_ids, choice_ids], dim=1) # (1, prompt_len + L)
outputs = self.model(input_ids=seq, labels=labels)
total_log_prob = -outputs.loss * choice_ids.size(1)
batch_logits.append(total_log_prob)

# build labels that only include the answer positions
labels = seq.clone()
labels[:, :prompt_len] = -100 # ignore prompt positions in loss
outputs = self.model(input_ids=seq, labels=labels)
# outputs.loss is average negative log-likelihood over the L answer tokens
# Clear GPU cache after each choice to prevent OOM
if torch.cuda.is_available():
torch.cuda.empty_cache()

total_log_prob = -outputs.loss * choice_ids.size(1)
logits_list.append(total_log_prob)
logits_list.extend(batch_logits)

# stack into a single tensor of shape (num_choices,)
# Stack into a single tensor of shape (num_choices,)
return torch.stack(logits_list)
114 changes: 91 additions & 23 deletions rgym_exp/src/utils/judge_client.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,97 @@
import requests
from requests.adapters import HTTPAdapter
try:
# Try modern import path first (urllib3 2.0+)
from urllib3.util.retry import Retry
except ImportError:
# Fallback to legacy import path (older versions)
from requests.packages.urllib3.util.retry import Retry
from typing import Dict, Any, Optional
import time
import socket
from genrl.logging_utils.global_defs import get_logger


class SocketOptionsHTTPAdapter(HTTPAdapter):
"""
Custom HTTPAdapter that supports socket options.

The socket_options parameter is not directly supported by requests.adapters.HTTPAdapter,
but can be implemented by overriding init_poolmanager to pass options to urllib3.
"""

def __init__(self, socket_options=None, *args, **kwargs):
self.socket_options = socket_options
super().__init__(*args, **kwargs)

def init_poolmanager(self, *args, **kwargs):
if self.socket_options is not None:
kwargs["socket_options"] = self.socket_options
return super().init_poolmanager(*args, **kwargs)


class JudgeClient:
"""
Client for interacting with the judge API service.
Handles question requests and answer submissions.
Handles question requests and answer submissions with connection pooling and retry logic.
"""
def __init__(self, base_url: str):

def __init__(self, base_url: str, timeout: int = 30):
"""
Initialize the judge client.
Initialize the judge client with performance optimizations.

Args:
base_url: Base URL for the judge API service
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip('/')
self.logger = get_logger()
self.timeout = timeout

# Set up session with connection pooling and retry strategy
self.session = requests.Session()

# Configure retry strategy
retry_strategy = Retry(
total=3, # Total number of retries
backoff_factor=1, # Wait time multiplier between retries
status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry
allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "POST"]
)

# Configure connection adapter with pooling and socket options
adapter = SocketOptionsHTTPAdapter(
max_retries=retry_strategy,
pool_connections=10, # Number of connection pools
pool_maxsize=20, # Max connections per pool
socket_options=[(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)] # SO_REUSEADDR
)

# Mount adapter for both HTTP and HTTPS
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)

# Set session headers for keep-alive
self.session.headers.update({
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'User-Agent': 'rl-swarm-judge-client/1.0'
})

def __del__(self):
"""Clean up session resources."""
if hasattr(self, 'session'):
self.session.close()

def request_question(self, user_id: str, round_number: int, model_name: str) -> Optional[Dict[str, Any]]:
"""
Request a question from the judge service.

Args:
user_id: ID of the user/peer
round_number: Current round number
model_name: Name of the model being used

Returns:
Dictionary containing question data or None if request failed
"""
Expand All @@ -37,56 +101,59 @@ def request_question(self, user_id: str, round_number: int, model_name: str) ->
"round_number": round_number,
"model_name": model_name,
}

response = requests.post(
f"{self.base_url}/request-question/",
json=request_data

response = self.session.post(
f"{self.base_url}/request-question/",
json=request_data,
timeout=self.timeout
)

if response.status_code == 200:
result = response.json()
self.logger.debug(f'Received question: {result["question"]}')
return result
else:
self.logger.debug(f"Failed to receive question: {response.status_code}")
return None

except Exception as e:
self.logger.debug(f"Failed to request question: {e}")
return None

def get_current_clue(self) -> Optional[Dict[str, Any]]:
"""
Get the current clue from the judge service.

Returns:
Dictionary containing clue data or None if request failed
"""
try:
response = requests.get(f"{self.base_url}/current_clue/")

response = self.session.get(
f"{self.base_url}/current_clue/",
timeout=self.timeout
)

if response.status_code == 200:
result = response.json()
self.logger.debug(f'Received clue: {result["clue"]}')
return result
else:
self.logger.debug(f"Failed to receive clue: {response.status_code}")
return None

except Exception as e:
self.logger.debug(f"Failed to get current clue: {e}")
return None


def submit_answer(self, session_id: str, round_number: int, user_answer: str) -> Optional[Dict[str, Any]]:
"""
Submit an answer to the judge service.

Args:
session_id: Session ID from the question request
round_number: Current round number
user_answer: The user's answer to submit

Returns:
Dictionary containing score data or None if submission failed
"""
Expand All @@ -97,9 +164,10 @@ def submit_answer(self, session_id: str, round_number: int, user_answer: str) ->
"user_answer": user_answer,
}

response = requests.post(
f"{self.base_url}/submit-answer/",
json=submission_data
response = self.session.post(
f"{self.base_url}/submit-answer/",
json=submission_data,
timeout=self.timeout
)

if response.status_code == 200:
Expand Down
Loading