From b746a4462f303a9e88fc02614353d01e7ac5c397 Mon Sep 17 00:00:00 2001
From: TesAnti <8780022+TesAnti@users.noreply.github.com>
Date: Sun, 14 Jan 2024 23:24:22 +0100
Subject: [PATCH] feat: Small fixes and agents support (#99)
* Ollama bug fix
* standartized key names to data type. this allows to use chains in small scenarious without using key names
* agents support
---
src/libs/LangChain.Core/Chains/Chain.cs | 48 +++--
.../StackableChains/Agents/Crew/AgentTask.cs | 22 +++
.../StackableChains/Agents/Crew/Crew.cs | 26 +++
.../StackableChains/Agents/Crew/CrewAgent.cs | 172 ++++++++++++++++++
.../StackableChains/Agents/Crew/CrewChain.cs | 33 ++++
.../StackableChains/Agents/Crew/Prompts.cs | 76 ++++++++
.../Agents/Crew/Tools/AskQuestionTool.cs | 35 ++++
.../Agents/Crew/Tools/CrewAgentTool.cs | 18 ++
.../Agents/Crew/Tools/CrewAgentToolLambda.cs | 16 ++
.../Agents/Crew/Tools/DelegateWorkTool.cs | 34 ++++
.../Chains/StackableChains/DoChain.cs | 19 ++
.../GenerateCompletionRequest.cs | 9 +-
.../OllamaLanguageModelInstruction.cs | 4 +-
13 files changed, 493 insertions(+), 19 deletions(-)
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewChain.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Prompts.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Tools/AskQuestionTool.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Tools/CrewAgentTool.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Tools/CrewAgentToolLambda.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Tools/DelegateWorkTool.cs
create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/DoChain.cs
diff --git a/src/libs/LangChain.Core/Chains/Chain.cs b/src/libs/LangChain.Core/Chains/Chain.cs
index 767c37b3..148a1d15 100644
--- a/src/libs/LangChain.Core/Chains/Chain.cs
+++ b/src/libs/LangChain.Core/Chains/Chain.cs
@@ -1,6 +1,7 @@
using LangChain.Chains.HelperChains;
using LangChain.Chains.StackableChains;
using LangChain.Chains.StackableChains.Agents;
+using LangChain.Chains.StackableChains.Agents.Crew;
using LangChain.Chains.StackableChains.Files;
using LangChain.Chains.StackableChains.ImageGeneration;
using LangChain.Chains.StackableChains.ReAct;
@@ -25,7 +26,7 @@ public static class Chain
///
public static PromptChain Template(
string template,
- string outputKey = "prompt")
+ string outputKey = "text")
{
return new PromptChain(template, outputKey);
}
@@ -38,7 +39,7 @@ public static PromptChain Template(
///
public static SetChain Set(
object value,
- string outputKey = "value")
+ string outputKey = "text")
{
return new SetChain(value, outputKey);
}
@@ -51,11 +52,19 @@ public static SetChain Set(
///
public static SetLambdaChain Set(
Func valueGetter,
- string outputKey = "value")
+ string outputKey = "text")
{
return new SetLambdaChain(valueGetter, outputKey);
}
+ public static DoChain Do(
+ Action> func)
+ {
+ return new DoChain(func);
+ }
+
+
+
///
///
///
@@ -65,7 +74,7 @@ public static SetLambdaChain Set(
///
public static LLMChain LLM(
IChatModel llm,
- string inputKey = "prompt",
+ string inputKey = "text",
string outputKey = "text")
{
return new LLMChain(llm, inputKey, outputKey);
@@ -80,8 +89,8 @@ public static LLMChain LLM(
///
public static RetrieveDocumentsChain RetrieveDocuments(
VectorStoreIndexWrapper index,
- string inputKey = "query",
- string outputKey = "documents")
+ string inputKey = "text",
+ string outputKey = "text")
{
return new RetrieveDocumentsChain(index, inputKey, outputKey);
}
@@ -93,8 +102,8 @@ public static RetrieveDocumentsChain RetrieveDocuments(
///
///
public static StuffDocumentsChain StuffDocuments(
- string inputKey = "documents",
- string outputKey = "combined")
+ string inputKey = "docs",
+ string outputKey = "c")
{
return new StuffDocumentsChain(inputKey, outputKey);
}
@@ -108,7 +117,7 @@ public static StuffDocumentsChain StuffDocuments(
///
public static UpdateMemoryChain UpdateMemory(
BaseChatMemory memory,
- string requestKey = "query",
+ string requestKey = "text",
string responseKey = "text")
{
return new UpdateMemoryChain(memory, requestKey, responseKey);
@@ -163,8 +172,8 @@ public static ReActAgentExecutorChain ReActAgentExecutor(
IChatModel model,
string? reActPrompt = null,
int maxActions = 5,
- string inputKey = "input",
- string outputKey = "final_answer")
+ string inputKey = "text",
+ string outputKey = "text")
{
return new ReActAgentExecutorChain(model, reActPrompt, maxActions, inputKey, outputKey);
}
@@ -177,7 +186,7 @@ public static ReActAgentExecutorChain ReActAgentExecutor(
///
public static ReActParserChain ReActParser(
string inputKey = "text",
- string outputKey = "answer")
+ string outputKey = "text")
{
return new ReActParserChain(inputKey, outputKey);
}
@@ -195,8 +204,8 @@ public static GroupChat GroupChat(
IList agents,
string? stopPhrase = null,
int messagesLimit = 10,
- string inputKey = "input",
- string outputKey = "output")
+ string inputKey = "text",
+ string outputKey = "text")
{
return new GroupChat(agents, stopPhrase, messagesLimit, inputKey, outputKey);
}
@@ -210,7 +219,7 @@ public static GroupChat GroupChat(
///
public static ImageGenerationChain GenerateImage(
IGenerateImageModel model,
- string inputKey = "prompt",
+ string inputKey = "text",
string outputKey = "image")
{
return new ImageGenerationChain(model, inputKey, outputKey);
@@ -228,4 +237,13 @@ public static SaveIntoFileChain SaveIntoFile(
{
return new SaveIntoFileChain(path, inputKey);
}
+
+
+ public static CrewChain Crew(
+ IEnumerable allAgents, CrewAgent manager,
+ string inputKey = "text",
+ string outputKey = "text")
+ {
+ return new CrewChain(allAgents, manager, inputKey, outputKey);
+ }
}
diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs
new file mode 100644
index 00000000..891902b4
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs
@@ -0,0 +1,22 @@
+using LangChain.Abstractions.Chains.Base;
+using LangChain.Chains.StackableChains.Agents.Crew.Tools;
+using LangChain.Chains.StackableChains.ReAct;
+
+namespace LangChain.Chains.StackableChains.Agents.Crew;
+
+public class AgentTask(CrewAgent agent, string description, List? tools=null)
+{
+ public CrewAgent Agent { get; set; } = agent;
+ public List Tools { get; set; } = tools??new List();
+ public string Description { get; set; } = description;
+
+ public string Execute(string context=null)
+ {
+ Agent.AddTools(Tools);
+ Agent.Context = context;
+ var chain = Chain.Set(Description, "task")
+ | Agent;
+ var res = chain.Run("result").Result;
+ return res;
+ }
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs
new file mode 100644
index 00000000..2b6cb1bd
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs
@@ -0,0 +1,26 @@
+using LangChain.Chains.StackableChains.Agents.Crew.Tools;
+
+namespace LangChain.Chains.StackableChains.Agents.Crew;
+
+public class Crew(IEnumerable agents, IEnumerable tasks)
+{
+
+ public string Run()
+ {
+ string? context = null;
+
+ foreach (var task in tasks)
+ {
+ task.Tools.Add(new AskQuestionTool(agents.Except(new []{task.Agent})));
+ task.Tools.Add(new DelegateWorkTool(agents.Except(new[] { task.Agent })));
+ var res = task.Execute(context);
+ context = res;
+ }
+
+ return context;
+
+ }
+
+
+
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs
new file mode 100644
index 00000000..63ecbdd2
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs
@@ -0,0 +1,172 @@
+using LangChain.Abstractions.Schema;
+using LangChain.Chains.HelperChains;
+using LangChain.Chains.LLM;
+using LangChain.Chains.StackableChains.Agents.Crew.Tools;
+using LangChain.Chains.StackableChains.ReAct;
+using LangChain.Memory;
+using LangChain.Providers;
+using static LangChain.Chains.Chain;
+
+namespace LangChain.Chains.StackableChains.Agents.Crew;
+
+public class CrewAgent : BaseStackableChain
+{
+ public event Action ReceivedTask=delegate{};
+ public event Action CalledAction = delegate { };
+ public event Action ActionResult = delegate { };
+ public event Action Answered = delegate { };
+
+ public string Role { get; }
+ public string Goal { get; }
+ public string? Backstory { get; }
+ private readonly IChatModel _model;
+ private readonly List _actionsHistory;
+ public bool UseMemory { get; set; }=false;
+ public bool UseCache { get; set; }
+ private IChainValues _currentValues;
+ private Dictionary _tools=new Dictionary();
+
+ private StackChain? _chain=null;
+ private readonly List _memory;
+ private int _maxActions=5;
+
+ public CrewAgent(IChatModel model, string role, string goal, string? backstory = "")
+ {
+ Role = role;
+ Goal = goal;
+ Backstory = backstory;
+ _model = model;
+
+ InputKeys = new[] {"task"};
+ OutputKeys = new[] {"result"};
+
+ _actionsHistory = new List();
+ _memory = new List();
+ }
+
+ public void AddTools(IEnumerable tools)
+ {
+ _tools = tools
+ .Where(x => !_tools.ContainsKey(x.Name))
+ .ToDictionary(x => x.Name, x => x);
+ InitializeChain();
+ }
+
+ public string? Context { get; set; } = null;
+
+ public int MaxActions
+ {
+ get => _maxActions;
+ set => _maxActions = value;
+ }
+
+ private string GenerateToolsDescriptions()
+ {
+ if (_tools.Count==0) return "";
+ return string.Join("\n", _tools.Select(x => $"- {x.Value.Name}, {x.Value.Description}\n"));
+ }
+
+ private string GenerateToolsNamesList()
+ {
+ if (_tools.Count == 0) return "";
+ return string.Join(", ", _tools.Select(x => x.Key));
+ }
+
+ private void InitializeChain()
+ {
+ string prompt;
+ if (UseMemory)
+ {
+ prompt = Prompts.TaskExecutionWithMemory;
+ }
+ else
+ {
+ prompt = Prompts.TaskExecutionWithoutMemory;
+ }
+
+
+ var chain = Set(GenerateToolsDescriptions, "tools")
+ | Set(GenerateToolsNamesList, "tool_names")
+ | Set(Role, "role")
+ | Set(Goal, "goal")
+ | Set(Backstory, "backstory")
+ | Set(() => string.Join("\n", _memory), "memory")
+ | Set(() => string.Join("\n", _actionsHistory), "actions_history")
+ | Template(prompt)
+ | Chain.LLM(_model).UseCache(UseCache)
+ | Do(x => _actionsHistory.Add((x["text"] as string)))
+ | ReActParser(inputKey: "text", outputKey: OutputKeys[0])
+ | Do(AddToMemory);
+
+
+ _chain = chain;
+ }
+
+ private void AddToMemory(Dictionary obj)
+ {
+ if (!UseMemory) return;
+ var res = obj[OutputKeys[0]];
+ if (res is AgentFinish a)
+ {
+ _memory.Add(a.Output);
+ }
+ }
+
+
+ protected override async Task InternalCall(IChainValues values)
+ {
+ var task = values.Value[InputKeys[0]] as string;
+ _actionsHistory.Clear();
+
+ ReceivedTask(task);
+
+ if (Context!=null)
+ {
+ task += "\n" + "This is the context you are working with:\n"+Context;
+ }
+
+ if (_chain == null)
+ {
+ InitializeChain();
+ }
+ var chain =
+ Set(task, "task")
+ | _chain!;
+ for (int i = 0; i < _maxActions; i++)
+ {
+
+ var res = await chain!.Run