Skip to content

Commit

Permalink
Merge pull request #743 from martindevans/Conversation_allLogits
Browse files Browse the repository at this point in the history
Conversation Generate All Logits
  • Loading branch information
martindevans authored May 19, 2024
2 parents 5db45c6 + 91cad8d commit 0834008
Showing 1 changed file with 46 additions and 25 deletions.
71 changes: 46 additions & 25 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using LLama.Native;
Expand All @@ -15,10 +16,19 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchSampleIndex;
private bool _disposed;

/// <summary>
/// Indicates if this conversation has been "forked" and may share logits with another conversation.
/// </summary>
private bool _forked;

/// <summary>
/// Stores the indices to sample from. Contains <see cref="_batchSampleCount"/> valid items.
/// </summary>
private int[] _batchSampleIndices = new int[4];
private int _batchSampleCount;

/// <summary>
/// The executor which this conversation belongs to
/// </summary>
Expand Down Expand Up @@ -108,7 +118,8 @@ public Conversation Fork()
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchSampleIndex = _batchSampleIndex,
_batchSampleIndices = _batchSampleIndices.ToArray(),
_batchSampleCount = _batchSampleCount,
_forked = true,

_end = _end,
Expand All @@ -128,20 +139,24 @@ public Conversation Fork()
/// <summary>
/// Get the logits from this conversation, ready for sampling
/// </summary>
/// <param name="offset">How far from the <b>end</b> of the previous prompt should logits be sampled. Any value other than 0 requires allLogits to have been set during prompting</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
/// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
public Span<float> Sample()
public Span<float> Sample(int offset = 0)
{
AssertNotDisposed();

if (_requiredEpoch < Executor.Epoch)
throw new CannotSampleRequiresPromptException();
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
if (offset >= _batchSampleCount)
throw new ArgumentException("Cannot sample offset more than the previous prompt count", nameof(offset));

var index = _batchSampleIndices[_batchSampleCount - offset - 1];
var span = Executor.Context.NativeHandle.GetLogitsIth(index);

// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
Expand All @@ -161,33 +176,21 @@ private void AssertCanBePrompted()
throw new AlreadyPromptedConversationException();
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
[Obsolete("Tokenize the text and pass the tokens instead")]
public void Prompt(string input, bool addBos, bool special)
{
AssertCanBePrompted();

Prompt(Executor.Context.Tokenize(input, addBos, special));
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <param name="allLogits">If true, generate logits for all tokens. If false, only generate logits for the last token.</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(List<LLamaToken> tokens)
public void Prompt(List<LLamaToken> tokens, bool allLogits = false)
{
AssertCanBePrompted();

#if NET6_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Prompt(span);
Prompt(span, allLogits);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
Expand All @@ -204,15 +207,16 @@ public void Prompt(List<LLamaToken> tokens)
}
#endif
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <param name="allLogits">If true, generate logits for all tokens. If false, only generate logits for the last token.</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)
{
AssertCanBePrompted();

Expand All @@ -221,8 +225,25 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens)
return;

// Add the prompt to the batch
for (var i = 0; i < tokens.Length; i++)
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
if (allLogits)
{
if (_batchSampleIndices.Length < tokens.Length)
_batchSampleIndices = new int[tokens.Length];

_batchSampleCount = tokens.Length;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true);
}
else
{
_batchSampleCount = 1;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
}



// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
Expand Down

0 comments on commit 0834008

Please sign in to comment.