diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs
index c5792ebc3..429656041 100644
--- a/LLama/Batched/Conversation.cs
+++ b/LLama/Batched/Conversation.cs
@@ -1,6 +1,7 @@
-using System;
+using System;
using System.Buffers;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using LLama.Native;
@@ -15,10 +16,19 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
- private int _batchSampleIndex;
private bool _disposed;
+
+ ///
+ /// Indicates if this conversation has been "forked" and may share logits with another conversation.
+ ///
private bool _forked;
+ ///
+ /// Stores the indices to sample from. Contains valid items.
+ ///
+ private int[] _batchSampleIndices = new int[4];
+ private int _batchSampleCount;
+
///
/// The executor which this conversation belongs to
///
@@ -108,7 +118,8 @@ public Conversation Fork()
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
- _batchSampleIndex = _batchSampleIndex,
+ _batchSampleIndices = _batchSampleIndices.ToArray(),
+ _batchSampleCount = _batchSampleCount,
_forked = true,
_end = _end,
@@ -128,11 +139,12 @@ public Conversation Fork()
///
/// Get the logits from this conversation, ready for sampling
///
+ /// How far from the end of the previous prompt should logits be sampled. Any value other than 0 requires allLogits to have been set during prompting
///
///
/// Thrown if this conversation was not prompted before the previous call to infer
/// Thrown if Infer() must be called on the executor
- public Span Sample()
+ public Span Sample(int offset = 0)
{
AssertNotDisposed();
@@ -140,8 +152,11 @@ public Span Sample()
throw new CannotSampleRequiresPromptException();
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();
-
- var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
+ if (offset >= _batchSampleCount)
+ throw new ArgumentException("Cannot sample offset more than the previous prompt count", nameof(offset));
+
+ var index = _batchSampleIndices[_batchSampleCount - offset - 1];
+ var span = Executor.Context.NativeHandle.GetLogitsIth(index);
// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
@@ -161,33 +176,21 @@ private void AssertCanBePrompted()
throw new AlreadyPromptedConversationException();
}
- ///
- /// Add tokens to this conversation
- ///
- ///
- ///
- [Obsolete("Tokenize the text and pass the tokens instead")]
- public void Prompt(string input, bool addBos, bool special)
- {
- AssertCanBePrompted();
-
- Prompt(Executor.Context.Tokenize(input, addBos, special));
- }
-
///
/// Add tokens to this conversation
///
///
+ /// If true, generate logits for all tokens. If false, only generate logits for the last token.
///
///
///
- public void Prompt(List tokens)
+ public void Prompt(List tokens, bool allLogits = false)
{
AssertCanBePrompted();
#if NET6_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
- Prompt(span);
+ Prompt(span, allLogits);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool.Shared.Rent(tokens.Count);
@@ -204,15 +207,16 @@ public void Prompt(List tokens)
}
#endif
}
-
+
///
/// Add tokens to this conversation
///
///
+ /// If true, generate logits for all tokens. If false, only generate logits for the last token.
///
///
///
- public void Prompt(ReadOnlySpan tokens)
+ public void Prompt(ReadOnlySpan tokens, bool allLogits = false)
{
AssertCanBePrompted();
@@ -221,8 +225,25 @@ public void Prompt(ReadOnlySpan tokens)
return;
// Add the prompt to the batch
- for (var i = 0; i < tokens.Length; i++)
- _batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
+ if (allLogits)
+ {
+ if (_batchSampleIndices.Length < tokens.Length)
+ _batchSampleIndices = new int[tokens.Length];
+
+ _batchSampleCount = tokens.Length;
+
+ for (var i = 0; i < tokens.Length; i++)
+ _batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true);
+ }
+ else
+ {
+ _batchSampleCount = 1;
+
+ for (var i = 0; i < tokens.Length; i++)
+ _batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
+ }
+
+
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;