diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index c5792ebc3..429656041 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -1,6 +1,7 @@ -using System; +using System; using System.Buffers; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using System.Text.Json; using LLama.Native; @@ -15,10 +16,19 @@ public sealed class Conversation { private ulong _requiredEpoch; private LLamaPos _end; - private int _batchSampleIndex; private bool _disposed; + + /// + /// Indicates if this conversation has been "forked" and may share logits with another conversation. + /// private bool _forked; + /// + /// Stores the indices to sample from. Contains valid items. + /// + private int[] _batchSampleIndices = new int[4]; + private int _batchSampleCount; + /// /// The executor which this conversation belongs to /// @@ -108,7 +118,8 @@ public Conversation Fork() // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures // they both copy the logits before the next sampling run, to fix this issue. _requiredEpoch = _requiredEpoch, - _batchSampleIndex = _batchSampleIndex, + _batchSampleIndices = _batchSampleIndices.ToArray(), + _batchSampleCount = _batchSampleCount, _forked = true, _end = _end, @@ -128,11 +139,12 @@ public Conversation Fork() /// /// Get the logits from this conversation, ready for sampling /// + /// How far from the end of the previous prompt should logits be sampled. Any value other than 0 requires allLogits to have been set during prompting /// /// /// Thrown if this conversation was not prompted before the previous call to infer /// Thrown if Infer() must be called on the executor - public Span Sample() + public Span Sample(int offset = 0) { AssertNotDisposed(); @@ -140,8 +152,11 @@ public Span Sample() throw new CannotSampleRequiresPromptException(); if (_requiredEpoch > Executor.Epoch) throw new CannotSampleRequiresInferenceException(); - - var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex); + if (offset >= _batchSampleCount) + throw new ArgumentException("Cannot sample offset more than the previous prompt count", nameof(offset)); + + var index = _batchSampleIndices[_batchSampleCount - offset - 1]; + var span = Executor.Context.NativeHandle.GetLogitsIth(index); // If necessary copy the span, to protect it from modification. This is only done when // this conversation has been forked in this epoch. @@ -161,33 +176,21 @@ private void AssertCanBePrompted() throw new AlreadyPromptedConversationException(); } - /// - /// Add tokens to this conversation - /// - /// - /// - [Obsolete("Tokenize the text and pass the tokens instead")] - public void Prompt(string input, bool addBos, bool special) - { - AssertCanBePrompted(); - - Prompt(Executor.Context.Tokenize(input, addBos, special)); - } - /// /// Add tokens to this conversation /// /// + /// If true, generate logits for all tokens. If false, only generate logits for the last token. /// /// /// - public void Prompt(List tokens) + public void Prompt(List tokens, bool allLogits = false) { AssertCanBePrompted(); #if NET6_0_OR_GREATER var span = CollectionsMarshal.AsSpan(tokens); - Prompt(span); + Prompt(span, allLogits); #else // Borrow an array and copy tokens into it var arr = ArrayPool.Shared.Rent(tokens.Count); @@ -204,15 +207,16 @@ public void Prompt(List tokens) } #endif } - + /// /// Add tokens to this conversation /// /// + /// If true, generate logits for all tokens. If false, only generate logits for the last token. /// /// /// - public void Prompt(ReadOnlySpan tokens) + public void Prompt(ReadOnlySpan tokens, bool allLogits = false) { AssertCanBePrompted(); @@ -221,8 +225,25 @@ public void Prompt(ReadOnlySpan tokens) return; // Add the prompt to the batch - for (var i = 0; i < tokens.Length; i++) - _batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); + if (allLogits) + { + if (_batchSampleIndices.Length < tokens.Length) + _batchSampleIndices = new int[tokens.Length]; + + _batchSampleCount = tokens.Length; + + for (var i = 0; i < tokens.Length; i++) + _batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true); + } + else + { + _batchSampleCount = 1; + + for (var i = 0; i < tokens.Length; i++) + _batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1); + } + + // Mark this conversation as needing inference/sampling _requiredEpoch = Executor.Epoch + 1;