@@ -140,12 +140,12 @@ def record_episode_sample(key: str, episode):
140140 "policy_version" : episode .policy_version ,
141141 "prompt" : episode .request ,
142142 "response" : episode .response ,
143- "target" : episode .target ,
143+ "target" : str ( episode .target ) ,
144144 ** (
145145 episode .reward_breakdown or {}
146146 ), # per-fn breakdown including the average reward
147147 "advantage" : episode .advantage ,
148- "ref_logprobs" : (
148+ "ref_logprobs" : float (
149149 episode .ref_logprobs .mean ().item ()
150150 if episode .ref_logprobs is not None
151151 else None
@@ -1059,25 +1059,23 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
10591059
10601060 if not self .run or not samples :
10611061 return
1062+
10621063 for key , rows in samples .items ():
10631064 if not rows :
10641065 continue
1065- # Create a WandB Table dynamically based on keys of first sample
1066- columns = list (rows [0 ].keys ())
1066+
1067+ # Use all keys to avoid dropped fields
1068+ columns = sorted ({k for s in rows for k in s .keys ()})
10671069 table = wandb .Table (columns = columns )
1068- for sample in rows :
1069- # table.add_data(*[sample.get(c) for c in columns])
1070- values = [sample .get (c ) for c in columns ]
1071- logger .info (f"Adding row to { key } _table: { values } " )
1070+
1071+ for s in rows :
1072+ values = [s .get (c ) for c in columns ]
10721073 table .add_data (* values )
1073- self .run .log (
1074- {
1075- f"{ key } _step_{ step } _table" : table ,
1076- "_sample_rows_logged" : len (rows ),
1077- "global_step" : step ,
1078- },
1079- commit = True ,
1080- )
1074+
1075+ # Unique table name avoids overwrite; commit forces sync
1076+ table_name = f"{ key } _table_step{ step } "
1077+ self .run .log ({table_name : table , "_num_rows" : len (rows )}, commit = True )
1078+
10811079 logger .info (
10821080 f"WandbBackend: Logged { len (rows )} samples for { key } at step { step } "
10831081 )
0 commit comments