diff --git a/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py index d4265552bcf..d2cdbf1082a 100644 --- a/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py +++ b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py @@ -42,7 +42,7 @@ def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: metrics = {} old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs - teacher_log_probs = exps.batch["teacher_log_probs"] + teacher_log_probs = exps.batch["teacher_logprobs"] response_mask = exps.batch["response_mask"] # advantages = -(student - teacher) = teacher - student