Skip to content

Commit

Permalink
feat: Small fixes and agents support (#99)
Browse files Browse the repository at this point in the history
* Ollama bug fix

* standartized key names to data type. this allows to use chains in small scenarious without using key names

* agents support
  • Loading branch information
TesAnti authored Jan 14, 2024
1 parent 991aa87 commit b746a44
Show file tree
Hide file tree
Showing 13 changed files with 493 additions and 19 deletions.
48 changes: 33 additions & 15 deletions src/libs/LangChain.Core/Chains/Chain.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LangChain.Chains.HelperChains;
using LangChain.Chains.StackableChains;
using LangChain.Chains.StackableChains.Agents;
using LangChain.Chains.StackableChains.Agents.Crew;
using LangChain.Chains.StackableChains.Files;
using LangChain.Chains.StackableChains.ImageGeneration;
using LangChain.Chains.StackableChains.ReAct;
Expand All @@ -25,7 +26,7 @@ public static class Chain
/// <returns></returns>
public static PromptChain Template(
string template,
string outputKey = "prompt")
string outputKey = "text")
{
return new PromptChain(template, outputKey);
}
Expand All @@ -38,7 +39,7 @@ public static PromptChain Template(
/// <returns></returns>
public static SetChain Set(
object value,
string outputKey = "value")
string outputKey = "text")
{
return new SetChain(value, outputKey);
}
Expand All @@ -51,11 +52,19 @@ public static SetChain Set(
/// <returns></returns>
public static SetLambdaChain Set(
Func<string> valueGetter,
string outputKey = "value")
string outputKey = "text")
{
return new SetLambdaChain(valueGetter, outputKey);
}

public static DoChain Do(
Action<Dictionary<string, object>> func)
{
return new DoChain(func);
}



/// <summary>
///
/// </summary>
Expand All @@ -65,7 +74,7 @@ public static SetLambdaChain Set(
/// <returns></returns>
public static LLMChain LLM(
IChatModel llm,
string inputKey = "prompt",
string inputKey = "text",
string outputKey = "text")
{
return new LLMChain(llm, inputKey, outputKey);
Expand All @@ -80,8 +89,8 @@ public static LLMChain LLM(
/// <returns></returns>
public static RetrieveDocumentsChain RetrieveDocuments(
VectorStoreIndexWrapper index,
string inputKey = "query",
string outputKey = "documents")
string inputKey = "text",
string outputKey = "text")
{
return new RetrieveDocumentsChain(index, inputKey, outputKey);
}
Expand All @@ -93,8 +102,8 @@ public static RetrieveDocumentsChain RetrieveDocuments(
/// <param name="outputKey"></param>
/// <returns></returns>
public static StuffDocumentsChain StuffDocuments(
string inputKey = "documents",
string outputKey = "combined")
string inputKey = "docs",
string outputKey = "c")
{
return new StuffDocumentsChain(inputKey, outputKey);
}
Expand All @@ -108,7 +117,7 @@ public static StuffDocumentsChain StuffDocuments(
/// <returns></returns>
public static UpdateMemoryChain UpdateMemory(
BaseChatMemory memory,
string requestKey = "query",
string requestKey = "text",
string responseKey = "text")
{
return new UpdateMemoryChain(memory, requestKey, responseKey);
Expand Down Expand Up @@ -163,8 +172,8 @@ public static ReActAgentExecutorChain ReActAgentExecutor(
IChatModel model,
string? reActPrompt = null,
int maxActions = 5,
string inputKey = "input",
string outputKey = "final_answer")
string inputKey = "text",
string outputKey = "text")
{
return new ReActAgentExecutorChain(model, reActPrompt, maxActions, inputKey, outputKey);
}
Expand All @@ -177,7 +186,7 @@ public static ReActAgentExecutorChain ReActAgentExecutor(
/// <returns></returns>
public static ReActParserChain ReActParser(
string inputKey = "text",
string outputKey = "answer")
string outputKey = "text")
{
return new ReActParserChain(inputKey, outputKey);
}
Expand All @@ -195,8 +204,8 @@ public static GroupChat GroupChat(
IList<AgentExecutorChain> agents,
string? stopPhrase = null,
int messagesLimit = 10,
string inputKey = "input",
string outputKey = "output")
string inputKey = "text",
string outputKey = "text")
{
return new GroupChat(agents, stopPhrase, messagesLimit, inputKey, outputKey);
}
Expand All @@ -210,7 +219,7 @@ public static GroupChat GroupChat(
/// <returns></returns>
public static ImageGenerationChain GenerateImage(
IGenerateImageModel model,
string inputKey = "prompt",
string inputKey = "text",
string outputKey = "image")
{
return new ImageGenerationChain(model, inputKey, outputKey);
Expand All @@ -228,4 +237,13 @@ public static SaveIntoFileChain SaveIntoFile(
{
return new SaveIntoFileChain(path, inputKey);
}


public static CrewChain Crew(
IEnumerable<CrewAgent> allAgents, CrewAgent manager,
string inputKey = "text",
string outputKey = "text")
{
return new CrewChain(allAgents, manager, inputKey, outputKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Chains.StackableChains.Agents.Crew.Tools;
using LangChain.Chains.StackableChains.ReAct;

namespace LangChain.Chains.StackableChains.Agents.Crew;

public class AgentTask(CrewAgent agent, string description, List<CrewAgentTool>? tools=null)
{
public CrewAgent Agent { get; set; } = agent;
public List<CrewAgentTool> Tools { get; set; } = tools??new List<CrewAgentTool>();
public string Description { get; set; } = description;

public string Execute(string context=null)

Check warning on line 13 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Cannot convert null literal to non-nullable reference type.
{
Agent.AddTools(Tools);
Agent.Context = context;
var chain = Chain.Set(Description, "task")
| Agent;
var res = chain.Run("result").Result;
return res;

Check warning on line 20 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/AgentTask.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference return.
}
}
26 changes: 26 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using LangChain.Chains.StackableChains.Agents.Crew.Tools;

namespace LangChain.Chains.StackableChains.Agents.Crew;

public class Crew(IEnumerable<CrewAgent> agents, IEnumerable<AgentTask> tasks)
{

public string Run()
{
string? context = null;

foreach (var task in tasks)
{
task.Tools.Add(new AskQuestionTool(agents.Except(new []{task.Agent})));
task.Tools.Add(new DelegateWorkTool(agents.Except(new[] { task.Agent })));
var res = task.Execute(context);

Check warning on line 16 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'context' in 'string AgentTask.Execute(string context = null)'.
context = res;
}

return context;

Check warning on line 20 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/Crew.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference return.

}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Chains.LLM;
using LangChain.Chains.StackableChains.Agents.Crew.Tools;
using LangChain.Chains.StackableChains.ReAct;
using LangChain.Memory;
using LangChain.Providers;
using static LangChain.Chains.Chain;

namespace LangChain.Chains.StackableChains.Agents.Crew;

public class CrewAgent : BaseStackableChain
{
public event Action<string> ReceivedTask=delegate{};
public event Action<string,string> CalledAction = delegate { };
public event Action<string> ActionResult = delegate { };
public event Action<string> Answered = delegate { };

public string Role { get; }
public string Goal { get; }
public string? Backstory { get; }
private readonly IChatModel _model;
private readonly List<string> _actionsHistory;
public bool UseMemory { get; set; }=false;
public bool UseCache { get; set; }
private IChainValues _currentValues;
private Dictionary<string, CrewAgentTool> _tools=new Dictionary<string, CrewAgentTool>();

private StackChain? _chain=null;
private readonly List<string> _memory;
private int _maxActions=5;

public CrewAgent(IChatModel model, string role, string goal, string? backstory = "")

Check warning on line 33 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable field '_currentValues' must contain a non-null value when exiting constructor. Consider declaring the field as nullable.
{
Role = role;
Goal = goal;
Backstory = backstory;
_model = model;

InputKeys = new[] {"task"};
OutputKeys = new[] {"result"};

_actionsHistory = new List<string>();
_memory = new List<string>();
}

public void AddTools(IEnumerable<CrewAgentTool> tools)
{
_tools = tools
.Where(x => !_tools.ContainsKey(x.Name))
.ToDictionary(x => x.Name, x => x);
InitializeChain();
}

public string? Context { get; set; } = null;

public int MaxActions
{
get => _maxActions;
set => _maxActions = value;
}

private string GenerateToolsDescriptions()
{
if (_tools.Count==0) return "";
return string.Join("\n", _tools.Select(x => $"- {x.Value.Name}, {x.Value.Description}\n"));
}

private string GenerateToolsNamesList()
{
if (_tools.Count == 0) return "";
return string.Join(", ", _tools.Select(x => x.Key));
}

private void InitializeChain()
{
string prompt;
if (UseMemory)
{
prompt = Prompts.TaskExecutionWithMemory;
}
else
{
prompt = Prompts.TaskExecutionWithoutMemory;
}


var chain = Set(GenerateToolsDescriptions, "tools")
| Set(GenerateToolsNamesList, "tool_names")
| Set(Role, "role")
| Set(Goal, "goal")
| Set(Backstory, "backstory")

Check warning on line 92 in src/libs/LangChain.Core/Chains/StackableChains/Agents/Crew/CrewAgent.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'value' in 'SetChain Chain.Set(object value, string outputKey = "text")'.
| Set(() => string.Join("\n", _memory), "memory")
| Set(() => string.Join("\n", _actionsHistory), "actions_history")
| Template(prompt)
| Chain.LLM(_model).UseCache(UseCache)
| Do(x => _actionsHistory.Add((x["text"] as string)))
| ReActParser(inputKey: "text", outputKey: OutputKeys[0])
| Do(AddToMemory);


_chain = chain;
}

private void AddToMemory(Dictionary<string, object> obj)
{
if (!UseMemory) return;
var res = obj[OutputKeys[0]];
if (res is AgentFinish a)
{
_memory.Add(a.Output);
}
}


protected override async Task<IChainValues> InternalCall(IChainValues values)
{
var task = values.Value[InputKeys[0]] as string;
_actionsHistory.Clear();

ReceivedTask(task);

if (Context!=null)
{
task += "\n" + "This is the context you are working with:\n"+Context;
}

if (_chain == null)
{
InitializeChain();
}
var chain =
Set(task, "task")
| _chain!;
for (int i = 0; i < _maxActions; i++)
{

var res = await chain!.Run<object>(OutputKeys[0]);
if (res is AgentAction action)
{
CalledAction(action.Action, action.ActionInput);

if (!_tools.ContainsKey(action.Action))
{
ActionResult("You don't have this tool");
_actionsHistory.Add("Observation: You don't have this tool");
_actionsHistory.Add("Thought:");
continue;
}

var tool = _tools[action.Action];
var toolRes = tool.ToolAction(action.ActionInput);
ActionResult(toolRes);
_actionsHistory.Add("Observation: " + toolRes);
_actionsHistory.Add("Thought:");

continue;
}
else if (res is AgentFinish finish)
{
values.Value.Add(OutputKeys[0], finish.Output);
if(UseMemory)
_memory.Add(finish.Output);

Answered(finish.Output);
return values;
}
}

throw new Exception($"Max actions exceeded({_maxActions})");
}
}
Loading

0 comments on commit b746a44

Please sign in to comment.