Skip to content

Commit

Permalink
OAI: support disabling parallel tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
lofcz committed Jun 8, 2024
1 parent 0c0ebba commit af6b83a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
44 changes: 44 additions & 0 deletions LlmTornado.Demo/ChatDemo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,50 @@ public static async Task OpenAiFunctions()
await chat.StreamResponseRich(handler);
}

public static async Task OpenAiDisableParallelFunctions()
{
Conversation chat = Program.Connect().Chat.CreateConversation(new ChatRequest
{
Model = ChatModel.OpenAi.Gpt4.O,
Tools =
[
new Tool(new ToolFunction("get_weather", "gets the current weather", new
{
type = "object",
properties = new
{
location = new
{
type = "string",
description = "The location for which the weather information is required."
}
},
required = new List<string> { "location" }
}))
],
ParallelToolCalls = false
})
.AppendSystemMessage("You are a helpful assistant")
.AppendUserInput("What is the weather like today in Prague and Paris?");

ChatStreamEventHandler handler = new ChatStreamEventHandler
{
MessageTokenHandler = (x) =>
{
Console.Write(x);
return Task.CompletedTask;
},
FunctionCallHandler = (calls) =>
{
calls.ForEach(x => x.Result = new FunctionResult(x, "A mild rain is expected around noon.", null));
return Task.CompletedTask;
},
AfterFunctionCallsResolvedHandler = async (results, handler) => { await chat.StreamResponseRich(handler); }
};

await chat.StreamResponseRich(handler);
}

public static async Task AnthropicFunctionsParallel()
{
StringBuilder sb = new StringBuilder();
Expand Down
2 changes: 2 additions & 0 deletions LlmTornado.Demo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public enum Demos
CohereFunctionsStreamingInteractive,
[Flaky("interactive demo")]
CrossVendorFunctionsStreamingInteractive,
DisableParallelTools,
Last
}

Expand Down Expand Up @@ -190,6 +191,7 @@ public static async Task<bool> SetupApi()
Demos.AnthropicFunctionsStreamingInteractive => ChatDemo.AnthropicFunctionsStreamingInteractive,
Demos.CohereFunctionsStreamingInteractive => ChatDemo.CohereFunctionsStreamingInteractive,
Demos.CrossVendorFunctionsStreamingInteractive => ChatDemo.CrossVendorFunctionsStreamingInteractive,
Demos.DisableParallelTools => ChatDemo.OpenAiDisableParallelFunctions,
_ => null
};

Expand Down
17 changes: 14 additions & 3 deletions LlmTornado/Chat/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,22 @@ public ChatRequest(ChatRequest? basedOn)
Temperature = basedOn.Temperature;
TopP = basedOn.TopP;
NumChoicesPerMessage = basedOn.NumChoicesPerMessage;
StopSequence = basedOn.StopSequence;
MultipleStopSequences = basedOn.MultipleStopSequences;
MaxTokens = basedOn.MaxTokens;
FrequencyPenalty = basedOn.FrequencyPenalty;
PresencePenalty = basedOn.PresencePenalty;
LogitBias = basedOn.LogitBias;
Tools = basedOn.Tools;
ToolChoice = basedOn.ToolChoice;
OuboundFunctionsContent = basedOn.OuboundFunctionsContent;
OutboundFunctionsContent = basedOn.OutboundFunctionsContent;
Adapter = basedOn.Adapter;
VendorExtensions = basedOn.VendorExtensions;
StreamOptions = basedOn.StreamOptions;
TrimResponseStart = basedOn.TrimResponseStart;
ParallelToolCalls = basedOn.ParallelToolCalls;
Seed = basedOn.Seed;
User = basedOn.User;
}

/// <summary>
Expand Down Expand Up @@ -169,7 +173,7 @@ public string? StopSequence
set
{
if (value != null)
MultipleStopSequences = new[] { value };
MultipleStopSequences = [value];
}
}

Expand Down Expand Up @@ -216,6 +220,13 @@ public string? StopSequence
[JsonProperty("tools")]
public List<Tool>? Tools { get; set; }

/// <summary>
/// Parallel function calling can be disabled / enabled for vendors supporting the feature.
/// As of 6/24, the only vendor supporting the feature is OpenAI.
/// </summary>
[JsonProperty("parallel_tool_calls")]
public bool ParallelToolCalls { get; set; } = true;

/// <summary>
/// Represents an optional field when sending tools calling prompt.
/// This field determines which function to call.
Expand All @@ -237,7 +248,7 @@ public string? StopSequence
/// somewhat lower.
/// </summary>
[JsonIgnore]
public Ref<string>? OuboundFunctionsContent { get; internal set; }
public Ref<string>? OutboundFunctionsContent { get; internal set; }

/// <summary>
/// This can be any API provider specific data. Currently used in KoboldCpp.
Expand Down

0 comments on commit af6b83a

Please sign in to comment.