diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 4405121e..6d4d1c59 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -85,15 +85,18 @@ def generate( self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs ): self._fix_generate_kwargs(kwargs) + logger.debug("Entered generate method with kwargs: %s", kwargs) # Added logging if inputs is None: inputs = kwargs.pop("input_ids", None) if session is not None: # If a session specified explicitly, use it context_manager = self.use_session(session) + logger.debug("Using specified session: %s", session) # Added logging elif self.active_session is not None: # If there's an active session, don't do anything context_manager = contextlib.nullcontext(self.active_session) + logger.debug("Using active session: %s", self.active_session) # Added logging else: # If there's no active session, create a new one @@ -109,6 +112,7 @@ def generate( else: session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens context_manager = self.inference_session(max_length=session_max_length) + logger.debug("Created new session with max length: %d", session_max_length) # Added logging with context_manager as session: # Prepend the tokens from the previous .generate() call @@ -134,7 +138,9 @@ def generate( past_key_values.update_seen(session.position) kwargs["past_key_values"] = past_key_values + logger.debug("Starting generation with input ids: %s", inputs) # Added logging result = super().generate(inputs, *args, **kwargs) + logger.debug("Generated result: %s", result) # Added logging sequences = result.sequences if isinstance(result, ModelOutput) else result # Save tokens from this .generate() call @@ -162,3 +168,5 @@ def _fix_generate_kwargs(kwargs: dict): @staticmethod def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues: return dataclasses.replace(past_key_values, hypo_ids=beam_idx) + +