Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/marin/src/marin/rl/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,13 +629,26 @@ def run(self):

# For inflight weight updates, wait for first weights before generating rollouts
if self.config.inflight_weight_updates:
logger.info("Waiting for first weight transfer before starting inference...")
while not self._first_weights_received.wait(timeout=10.0):
max_wait_time = self.config.weight_transfer.max_weight_transfer_wait_time
logger.info(
"Waiting for first weight transfer before starting inference (timeout %.1fs)...",
max_wait_time,
)
start_time = time.time()
while True:
if self._first_weights_received.wait(timeout=10.0):
break

if not self._running:
logger.info("Shutdown requested while waiting for first weights")
self._shutdown_complete.set()
return
logger.info("Still waiting for first weight transfer...")

elapsed = time.time() - start_time
if max_wait_time - elapsed <= 0:
raise RuntimeError("Timed out waiting for initial weight transfer.")

logger.info("Still waiting for first weight transfer (elapsed: %.1fs)", elapsed)
logger.info("First weights received, starting inference loop")

step = 0
Expand Down
3 changes: 2 additions & 1 deletion lib/marin/src/marin/rl/train_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def _loss_function(model, batch, key):
self.replay_buffer.set_current_step(-1)

# Wait for initial rollouts to ensure we have baseline measurements
self._wait_for_initial_rollouts()
if not self._wait_for_initial_rollouts():
raise RuntimeError("Timed out waiting for initial rollouts; aborting training.")

self._configure_training_hooks(trainer)

Expand Down
5 changes: 3 additions & 2 deletions lib/marin/src/marin/rl/weight_transfer/arrow_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,10 @@ def serve_weights(self, weight_id: int, model: PyTree) -> None:

barrier_sync()

except Exception as e:
except Exception:
self.metrics.failed_transfers += 1
logger.error(f"Failed to serve weights {weight_id} via Arrow Flight: {e}")
logger.exception(f"Failed to serve weights {weight_id} via Arrow Flight")
raise

def cleanup(self) -> None:
"""Cleanup Flight server resources."""
Expand Down