From 6b7c5dcbc89c98ebc8390996d95975d9b04c5459 Mon Sep 17 00:00:00 2001 From: Khoroshev Evgeniy Date: Mon, 6 Nov 2023 23:02:45 +0700 Subject: [PATCH] feat: callbacks refactor (#49) * callbacks refactor * naming * build fix * implement base and console trace * console logger * remove commented --------- Co-authored-by: Evgenii Khoroshev --- examples/LangChain.Samples.Prompts/Program.cs | 5 +- .../Program.cs | 17 +- .../Base/BaseCallbackHandler.cs | 128 ++-- src/libs/LangChain.Core/Base/BaseChain.cs | 48 +- .../LangChain.Core/Base/BaseChainInput.cs | 37 ++ src/libs/LangChain.Core/Base/BaseLangChain.cs | 6 +- src/libs/LangChain.Core/Base/ChainInputs.cs | 9 +- src/libs/LangChain.Core/Base/Handler.cs | 4 +- .../Base/IBaseCallbackHandler.cs | 69 +- .../Base/IBaseCallbackHandlerInput.cs | 21 +- .../Base/IBaseLangChainParams.cs | 5 +- src/libs/LangChain.Core/Base/IChainInputs.cs | 24 +- .../Base/Tracers/BaseCallbackHandlerInput.cs | 11 + .../LangChain.Core/Base/Tracers/BaseTracer.cs | 595 ++++++++++++++++++ .../Base/Tracers/ConsoleCallbackHandler.cs | 223 +++++++ .../LangChain.Core/Base/Tracers/RunBase.cs | 114 ++++ .../Base/Tracers/StringExtensions.cs | 20 + .../Base/Tracers/TracerException.cs | 11 + .../LangChain.Core/Callback/BaseRunManager.cs | 73 ++- .../Callback/CallbackManager.cs | 348 +++++----- .../Callback/CallbackManagerForChainRun.cs | 88 +-- .../Callback/CallbackManagerForLlmRun.cs | 58 +- .../CallbackManagerForRetrieverRun.cs | 31 +- .../Callback/CallbackManagerForToolRun.cs | 71 +-- .../LangChain.Core/Callback/ICallbacks.cs | 9 + .../Callback/ParentRunManager.cs | 47 ++ src/libs/LangChain.Core/Chains/Base/IChain.cs | 8 +- .../CombineDocuments/AnalyzeDocumentChain.cs | 5 +- .../BaseCombineDocumentsChain.cs | 6 +- .../BaseCombineDocumentsChainInput.cs | 5 +- .../LangChain.Core/Chains/LLM/LLMChain.cs | 12 +- .../Chains/LLM/LLMChainInput.cs | 6 +- .../RetrievalQA/BaseRetrievalQaChain.cs | 18 +- .../RetrievalQA/BaseRetrievalQaChainInput.cs | 5 +- .../Chains/RetrievalQA/RetrievalQaChain.cs | 9 +- .../Chains/Sequentials/SequentialChain.cs | 11 +- .../Sequentials/SequentialChainInput.cs | 8 +- .../LangChain.Core/Memory/MemoryExtensions.cs | 7 +- .../Retrievers/BaseRetriever.cs | 56 +- .../VectorStores/VectorStoreRetriever.cs | 18 +- .../RetrievalQa/RetrievalQaChainTests.cs | 12 +- .../SequentialChainTests.cs | 9 +- 42 files changed, 1699 insertions(+), 568 deletions(-) create mode 100644 src/libs/LangChain.Core/Base/BaseChainInput.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/BaseCallbackHandlerInput.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/BaseTracer.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/ConsoleCallbackHandler.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/RunBase.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/StringExtensions.cs create mode 100644 src/libs/LangChain.Core/Base/Tracers/TracerException.cs create mode 100644 src/libs/LangChain.Core/Callback/ICallbacks.cs create mode 100644 src/libs/LangChain.Core/Callback/ParentRunManager.cs diff --git a/examples/LangChain.Samples.Prompts/Program.cs b/examples/LangChain.Samples.Prompts/Program.cs index 304cb2fe..944962e2 100644 --- a/examples/LangChain.Samples.Prompts/Program.cs +++ b/examples/LangChain.Samples.Prompts/Program.cs @@ -35,7 +35,10 @@ HumanMessagePromptTemplate.FromTemplate("{text}") }); -var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt)); +var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt) +{ + Verbose = true +}); var resultB = await chainB.CallAsync(new ChainValues(new Dictionary(3) { diff --git a/examples/LangChain.Samples.SequentialChain/Program.cs b/examples/LangChain.Samples.SequentialChain/Program.cs index 6150dead..2158f4d3 100644 --- a/examples/LangChain.Samples.SequentialChain/Program.cs +++ b/examples/LangChain.Samples.SequentialChain/Program.cs @@ -12,6 +12,7 @@ var chainOne = new LlmChain(new LlmChainInput(llm, firstPrompt) { + Verbose = true, OutputKey = "company_name" }); @@ -20,11 +21,15 @@ var chainTwo = new LlmChain(new LlmChainInput(llm, secondPrompt)); -var overallChain = new SequentialChain(new SequentialChainInput(new [] -{ - chainOne, - chainTwo -}, new []{"product"})); +var overallChain = new SequentialChain(new SequentialChainInput( + new[] + { + chainOne, + chainTwo + }, + new[] { "product" }, + new[] { "company_name", "text" } +)); var result = await overallChain.CallAsync(new ChainValues(new Dictionary(1) { @@ -32,4 +37,4 @@ })); Console.WriteLine(result.Value["text"]); -Console.WriteLine("Test"); \ No newline at end of file +Console.WriteLine("SequentialChain sample finished."); \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/BaseCallbackHandler.cs b/src/libs/LangChain.Core/Base/BaseCallbackHandler.cs index b8b3468e..c01cc541 100644 --- a/src/libs/LangChain.Core/Base/BaseCallbackHandler.cs +++ b/src/libs/LangChain.Core/Base/BaseCallbackHandler.cs @@ -1,4 +1,8 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Docstore; using LangChain.LLMS; +using LangChain.Providers; +using LangChain.Retrievers; using LangChain.Schema; namespace LangChain.Base; @@ -7,11 +11,36 @@ namespace LangChain.Base; public abstract class BaseCallbackHandler : IBaseCallbackHandler { /// - public string Name { get; protected set; } + public abstract string Name { get; } + + public bool IgnoreLlm { get; set; } + public bool IgnoreRetry { get; set; } + public bool IgnoreChain { get; set; } + public bool IgnoreAgent { get; set; } + public bool IgnoreRetriever { get; set; } + public bool IgnoreChatModel { get; set; } + + /// + /// + /// + /// + protected BaseCallbackHandler(IBaseCallbackHandlerInput input) + { + input = input ?? throw new ArgumentNullException(nameof(input)); + + IgnoreLlm = input.IgnoreLlm; + IgnoreRetry = input.IgnoreRetry; + IgnoreChain = input.IgnoreChain; + IgnoreAgent = input.IgnoreAgent; + IgnoreRetriever = input.IgnoreRetriever; + IgnoreChatModel = input.IgnoreChatModel; + } /// - public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string runId, string? parentRunId = null, - Dictionary? extraParams = null); + public abstract Task HandleLlmStartAsync( + BaseLlm llm, string[] prompts, string runId, string? parentRunId = null, + List? tags = null, Dictionary? metadata = null, + string name = null, Dictionary? extraParams = null); /// public abstract Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null); @@ -23,20 +52,42 @@ public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string r public abstract Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null); /// - public abstract Task HandleChatModelStartAsync(Dictionary llm, List> messages, string runId, string? parentRunId = null, + public abstract Task HandleChatModelStartAsync(BaseLlm llm, List> messages, string runId, + string? parentRunId = null, Dictionary? extraParams = null); /// - public abstract Task HandleChainStartAsync(Dictionary chain, Dictionary inputs, string runId, string? parentRunId = null); + public abstract Task HandleChainStartAsync(IChain chain, Dictionary inputs, + string runId, string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string runType = null, + string name = null, + Dictionary? extraParams = null); /// - public abstract Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null); + public abstract Task HandleChainErrorAsync( + Exception err, string runId, + Dictionary? inputs = null, + string? parentRunId = null); /// - public abstract Task HandleChainEndAsync(Dictionary outputs, string runId, string? parentRunId = null); + public abstract Task HandleChainEndAsync( + Dictionary? inputs, + Dictionary outputs, + string runId, + string? parentRunId = null); /// - public abstract Task HandleToolStartAsync(Dictionary tool, string input, string runId, string? parentRunId = null); + public abstract Task HandleToolStartAsync( + Dictionary tool, + string input, string runId, + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string runType = null, + string name = null, + Dictionary? extraParams = null); /// public abstract Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null); @@ -54,55 +105,24 @@ public abstract Task HandleChatModelStartAsync(Dictionary llm, L public abstract Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null); /// - public abstract Task HandleRetrieverStartAsync(string query, string runId, string? parentRunId); + public abstract Task HandleRetrieverStartAsync( + BaseRetriever retriever, + string query, + string runId, + string? parentRunId, + List? tags = null, + Dictionary? metadata = null, + string? runType = null, + string? name = null, + Dictionary? extraParams = null); /// - public abstract Task HandleRetrieverEndAsync(string query, string runId, string? parentRunId); + public abstract Task HandleRetrieverEndAsync( + string query, + List documents, + string runId, + string? parentRunId); /// public abstract Task HandleRetrieverErrorAsync(Exception error, string query, string runId, string? parentRunId); - - /// - /// - /// - public bool IgnoreLlm { get; set; } - - /// - /// - /// - public bool IgnoreChain { get; set; } - - /// - /// - /// - public bool IgnoreAgent { get; set; } - - public bool IgnoreRetriever { get; set; } - - /// - /// - /// - protected BaseCallbackHandler() - { - Name = Guid.NewGuid().ToString(); - } - - /// - /// - /// - /// - protected BaseCallbackHandler(IBaseCallbackHandlerInput input) : this() - { - input = input ?? throw new ArgumentNullException(nameof(input)); - - IgnoreLlm = input.IgnoreLlm; - IgnoreChain = input.IgnoreChain; - IgnoreAgent = input.IgnoreAgent; - } - - /// - /// - /// - /// - public abstract IBaseCallbackHandler Copy(); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/BaseChain.cs b/src/libs/LangChain.Core/Base/BaseChain.cs index 0d4bf2f7..7347bc6e 100644 --- a/src/libs/LangChain.Core/Base/BaseChain.cs +++ b/src/libs/LangChain.Core/Base/BaseChain.cs @@ -1,5 +1,6 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; +using LangChain.Callback; using LangChain.Chains; using LangChain.Schema; @@ -9,7 +10,7 @@ namespace LangChain.Base; using LoadValues = Dictionary; /// -public abstract class BaseChain : IChain +public abstract class BaseChain(IChainInputs fields) : IChain { const string RunKey = "__run"; @@ -57,7 +58,7 @@ public abstract class BaseChain : IChain throw new Exception("Return values have multiple keys, 'run' only supported when one key currently"); } - + /// /// Run the chain using a simple input/output. /// @@ -83,8 +84,49 @@ public virtual async Task Run(Dictionary input) /// Execute the chain, using the values provided. /// /// The to use. + /// + /// + /// + /// + public async Task CallAsync( + IChainValues values, + ICallbacks? callbacks = null, + List? tags = null, + Dictionary? metadata = null) + { + var callbackManager = await CallbackManager.Configure( + callbacks, + fields.Callbacks, + fields.Verbose, + tags, + fields.Tags, + metadata, + fields.Metadata); + + var runManager = await callbackManager.HandleChainStart(this, values); + + try + { + var result = await CallAsync(values, runManager); + + await runManager.HandleChainEndAsync(values, result); + + return result; + } + catch (Exception e) + { + await runManager.HandleChainErrorAsync(e, values); + throw; + } + } + + /// + /// Execute the chain, using the values provided. + /// + /// The to use. + /// /// - public abstract Task CallAsync(IChainValues values); + protected abstract Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager); /// /// diff --git a/src/libs/LangChain.Core/Base/BaseChainInput.cs b/src/libs/LangChain.Core/Base/BaseChainInput.cs new file mode 100644 index 00000000..cc10ea76 --- /dev/null +++ b/src/libs/LangChain.Core/Base/BaseChainInput.cs @@ -0,0 +1,37 @@ +using LangChain.Callback; + +namespace LangChain.Base; + +public interface IBaseChainInput +{ + /// + /// Optional list of callback handlers (or callback manager). Defaults to None. + /// Callback handlers are called throughout the lifecycle of a call to a chain, + /// starting with on_chain_start, ending with on_chain_end or on_chain_error. + /// Each custom chain can optionally call additional callback methods, see Callback docs + /// for full details. + /// + public ICallbacks? Callbacks { get; set; } + + /// + /// Whether or not run in verbose mode. In verbose mode, some intermediate logs + /// will be printed to the console. + /// + public bool Verbose { get; set; } + + /// + /// Optional list of tags associated with the chain. Defaults to None. + /// These tags will be associated with each call to this chain, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a chain with its use case. + /// + public List Tags { get; set; } + + /// + /// Optional metadata associated with the chain. Defaults to None. + /// This metadata will be associated with each call to this chain, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a chain with its use case. + /// + public Dictionary Metadata { get; set; } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/BaseLangChain.cs b/src/libs/LangChain.Core/Base/BaseLangChain.cs index ac4f0260..f047b6f4 100644 --- a/src/libs/LangChain.Core/Base/BaseLangChain.cs +++ b/src/libs/LangChain.Core/Base/BaseLangChain.cs @@ -3,12 +3,10 @@ namespace LangChain.Base; /// public abstract class BaseLangChain : IBaseLangChainParams { - private const bool DefaultVerbosity = false; - /// /// /// - public bool? Verbose { get; set; } + public bool Verbose { get; set; } /// /// @@ -18,6 +16,6 @@ protected BaseLangChain(IBaseLangChainParams parameters) { parameters = parameters ?? throw new ArgumentNullException(nameof(parameters)); - Verbose = parameters.Verbose ?? DefaultVerbosity; + Verbose = parameters.Verbose; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/ChainInputs.cs b/src/libs/LangChain.Core/Base/ChainInputs.cs index 86bb88c3..5de78a45 100644 --- a/src/libs/LangChain.Core/Base/ChainInputs.cs +++ b/src/libs/LangChain.Core/Base/ChainInputs.cs @@ -5,9 +5,8 @@ namespace LangChain.Base; /// public class ChainInputs : IChainInputs { - /// - public CallbackManager? CallbackManager { get; set; } - - /// - public bool? Verbose { get; set; } + public ICallbacks? Callbacks { get; set; } + public List Tags { get; set; } + public Dictionary Metadata { get; set; } + public bool Verbose { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Handler.cs b/src/libs/LangChain.Core/Base/Handler.cs index c3f60a5c..9fb0ee97 100644 --- a/src/libs/LangChain.Core/Base/Handler.cs +++ b/src/libs/LangChain.Core/Base/Handler.cs @@ -3,9 +3,7 @@ namespace LangChain.Base; /// public abstract class Handler : BaseCallbackHandler { - /// - public override IBaseCallbackHandler Copy() + protected Handler(IBaseCallbackHandlerInput input) : base(input) { - throw new NotImplementedException(); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/IBaseCallbackHandler.cs b/src/libs/LangChain.Core/Base/IBaseCallbackHandler.cs index 430a5abc..05e2bc94 100644 --- a/src/libs/LangChain.Core/Base/IBaseCallbackHandler.cs +++ b/src/libs/LangChain.Core/Base/IBaseCallbackHandler.cs @@ -1,4 +1,8 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Docstore; using LangChain.LLMS; +using LangChain.Providers; +using LangChain.Retrievers; using LangChain.Schema; namespace LangChain.Base; @@ -16,21 +20,13 @@ public interface IBaseCallbackHandler /// /// /// - /// - /// - /// - /// - /// - /// - public Task HandleLlmStartAsync( - BaseLlm llm, - string[] prompts, - string runId, - string? parentRunId = null, - Dictionary? extraParams = null); + public abstract Task HandleLlmStartAsync( + BaseLlm llm, string[] prompts, string runId, string? parentRunId = null, + List? tags = null, Dictionary? metadata = null, + string name = null, Dictionary? extraParams = null); /// - /// + /// Run on new LLM token. Only available when streaming is enabled. /// /// /// @@ -41,49 +37,40 @@ public Task HandleLlmNewTokenAsync( string runId, string? parentRunId = null); - /// - /// - /// - /// - /// - /// - /// public Task HandleLlmErrorAsync( Exception err, string runId, string? parentRunId = null); - /// - /// - /// - /// - /// - /// - /// public Task HandleLlmEndAsync( LlmResult output, string runId, string? parentRunId = null); - public Task HandleChatModelStartAsync( - Dictionary llm, - List> messages, + public Task HandleChatModelStartAsync(BaseLlm llm, + List> messages, string runId, string? parentRunId = null, Dictionary? extraParams = null); - public Task HandleChainStartAsync( - Dictionary chain, + public Task HandleChainStartAsync(IChain chain, Dictionary inputs, string runId, - string? parentRunId = null); + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string runType = null, + string name = null, + Dictionary? extraParams = null); public Task HandleChainErrorAsync( Exception err, string runId, + Dictionary inputs, string? parentRunId = null); public Task HandleChainEndAsync( + Dictionary? inputs, Dictionary outputs, string runId, string? parentRunId = null); @@ -92,7 +79,12 @@ public Task HandleToolStartAsync( Dictionary tool, string input, string runId, - string? parentRunId = null); + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string runType = null, + string name = null, + Dictionary? extraParams = null); public Task HandleToolErrorAsync( Exception err, @@ -120,12 +112,19 @@ public Task HandleAgentEndAsync( string? parentRunId = null); public Task HandleRetrieverStartAsync( + BaseRetriever retriever, string query, string runId, - string? parentRunId); + string? parentRunId, + List? tags = null, + Dictionary? metadata = null, + string? runType = null, + string? name = null, + Dictionary? extraParams = null); public Task HandleRetrieverEndAsync( string query, + List documents, string runId, string? parentRunId); diff --git a/src/libs/LangChain.Core/Base/IBaseCallbackHandlerInput.cs b/src/libs/LangChain.Core/Base/IBaseCallbackHandlerInput.cs index 4934ce38..f97246dc 100644 --- a/src/libs/LangChain.Core/Base/IBaseCallbackHandlerInput.cs +++ b/src/libs/LangChain.Core/Base/IBaseCallbackHandlerInput.cs @@ -5,18 +5,21 @@ namespace LangChain.Base; /// public interface IBaseCallbackHandlerInput { - /// - /// - /// + /// Whether to ignore LLM callbacks. bool IgnoreLlm { get; set; } - /// - /// - /// + /// Whether to ignore retry callbacks. + bool IgnoreRetry { get; set; } + + /// Whether to ignore chain callbacks. bool IgnoreChain { get; set; } - /// - /// - /// + /// Whether to ignore agent callbacks. bool IgnoreAgent { get; set; } + + /// Whether to ignore retriever callbacks. + bool IgnoreRetriever { get; set; } + + /// Whether to ignore chat model callbacks. + bool IgnoreChatModel { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/IBaseLangChainParams.cs b/src/libs/LangChain.Core/Base/IBaseLangChainParams.cs index 5030ebd9..c49cf463 100644 --- a/src/libs/LangChain.Core/Base/IBaseLangChainParams.cs +++ b/src/libs/LangChain.Core/Base/IBaseLangChainParams.cs @@ -6,7 +6,8 @@ namespace LangChain.Base; public interface IBaseLangChainParams { /// - /// + /// Whether or not run in verbose mode. In verbose mode, some intermediate logs + /// will be printed to the console. /// - bool? Verbose { get; set; } + bool Verbose { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/IChainInputs.cs b/src/libs/LangChain.Core/Base/IChainInputs.cs index d6420aed..77cac6c7 100644 --- a/src/libs/LangChain.Core/Base/IChainInputs.cs +++ b/src/libs/LangChain.Core/Base/IChainInputs.cs @@ -6,7 +6,27 @@ namespace LangChain.Base; public interface IChainInputs : IBaseLangChainParams { /// - /// + /// Optional list of callback handlers (or callback manager). Defaults to None. + /// Callback handlers are called throughout the lifecycle of a call to a chain, + /// starting with on_chain_start, ending with on_chain_end or on_chain_error. + /// Each custom chain can optionally call additional callback methods, see Callback docs + /// for full details. /// - CallbackManager? CallbackManager { get; set; } + public ICallbacks? Callbacks { get; set; } + + /// + /// Optional list of tags associated with the chain. Defaults to None. + /// These tags will be associated with each call to this chain, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a chain with its use case. + /// + public List Tags { get; set; } + + /// + /// Optional metadata associated with the chain. Defaults to None. + /// This metadata will be associated with each call to this chain, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a chain with its use case. + /// + public Dictionary Metadata { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Tracers/BaseCallbackHandlerInput.cs b/src/libs/LangChain.Core/Base/Tracers/BaseCallbackHandlerInput.cs new file mode 100644 index 00000000..e9980f91 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/BaseCallbackHandlerInput.cs @@ -0,0 +1,11 @@ +namespace LangChain.Base.Tracers; + +public class BaseCallbackHandlerInput : IBaseCallbackHandlerInput +{ + public bool IgnoreLlm { get; set; } + public bool IgnoreRetry { get; set; } + public bool IgnoreChain { get; set; } + public bool IgnoreAgent { get; set; } + public bool IgnoreRetriever { get; set; } + public bool IgnoreChatModel { get; set; } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Tracers/BaseTracer.cs b/src/libs/LangChain.Core/Base/Tracers/BaseTracer.cs new file mode 100644 index 00000000..ecb79db8 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/BaseTracer.cs @@ -0,0 +1,595 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Docstore; +using LangChain.LLMS; +using LangChain.Providers; +using LangChain.Retrievers; +using LangChain.Schema; + +namespace LangChain.Base.Tracers; + +/// +/// Base class for tracers. +/// +public abstract class BaseTracer(IBaseCallbackHandlerInput input) : BaseCallbackHandler(input) +{ + protected Dictionary RunMap { get; } = new(); + + protected abstract Task PersistRun(Run run); + + public override async Task HandleLlmStartAsync( + BaseLlm llm, + string[] prompts, + string runId, + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string name = null, + Dictionary? extraParams = null) + { + var executionOrder = GetExecutionOrder(parentRunId); + var startTime = DateTime.UtcNow; + if (metadata != null) + { + extraParams.Add("metadata", metadata); + } + + var run = new Run + { + Id = runId, + ParentRunId = parentRunId, + //todo: pass llm or dumpd(llm) + // serialized = serialized, + Inputs = new Dictionary { ["prompts"] = prompts }, + ExtraData = extraParams, + Events = new List> + { + new() + { + ["name"] = "start", + ["time"] = startTime + } + }, + StartTime = startTime, + ExecutionOrder = executionOrder, + ChildExecutionOrder = executionOrder, + RunType = "llm", + Tags = tags ?? new List(), + Name = name + }; + + StartTrace(run); + await HandleLlmStartAsync(run); + } + + public override async Task HandleLlmErrorAsync(Exception err, string runId, string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_llm_error callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "llm") + { + throw new TracerException($"No LLM Run found to be traced for {runId}"); + } + + run.Error = err.ToString(); + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { ["name"] = "error", ["time"] = run.EndTime }); + + EndTrace(run); + await HandleLlmErrorAsync(run); + } + + public override async Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_llm_end callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "llm") + { + throw new TracerException($"No LLM Run found to be traced for {runId}"); + } + + run.Outputs = output.LlmOutput; + for (int i = 0; i < output.Generations.Length; i++) + { + var generation = output.Generations[i]; + var outputGeneration = (run.Outputs["generations"] as List>)[i]; + if (outputGeneration.ContainsKey("message")) + { + outputGeneration["message"] = (generation as ChatGeneration)?.Message; + } + } + + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { { "name", "end" }, { "time", run.EndTime } }); + + EndTrace(run); + await HandleLlmEndAsync(run); + } + + public override async Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_llm_new_token callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "llm") + { + throw new TracerException($"No LLM Run found to be traced for {runId}"); + } + + var eventData = new Dictionary { ["token"] = token }; + + run.Events.Add( + new() + { + ["name"] = "new_token", + ["time"] = DateTime.UtcNow, + ["kwargs"] = eventData, + }); + + await HandleLlmNewTokenAsync(run, token); + } + + public override async Task HandleChatModelStartAsync(BaseLlm llm, List> messages, string runId, + string? parentRunId = null, + Dictionary? extraParams = null) + { + throw new NotImplementedException(); + } + + public override async Task HandleChainStartAsync( + IChain chain, + Dictionary inputs, + string runId, + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string? runType = null, + string? name = null, + Dictionary? extraParams = null) + { + var executionOrder = GetExecutionOrder(parentRunId); + var startTime = DateTime.UtcNow; + + if (metadata != null) + { + extraParams.Add("metadata", metadata); + } + + var chainRun = new Run + { + Id = runId, + ParentRunId = parentRunId, + // serialized=serialized, + Inputs = inputs, + ExtraData = extraParams, + Events = new List> { new() { ["name"] = "start", ["time"] = startTime } }, + StartTime = startTime, + ExecutionOrder = executionOrder, + ChildExecutionOrder = executionOrder, + ChildRuns = new(), + RunType = runType ?? "chain", + Name = name, + Tags = tags ?? new() + }; + + StartTrace(chainRun); + await HandleChainStartAsync(chainRun); + } + + /// + /// Handle an error for a chain run. + /// + public override async Task HandleChainErrorAsync( + Exception err, + string runId, + Dictionary? inputs = null, + string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_chain_error callback."); + } + + if (!RunMap.TryGetValue(runId, out var run)) + { + throw new TracerException($"No chain Run found to be traced for {runId}"); + } + + run.Error = err.ToString(); + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { ["name"] = "error", ["time"] = run.EndTime }); + + run.Inputs = inputs; + EndTrace(run); + await HandleChainErrorAsync(run); + } + + /// + /// End a trace for a chain run. + /// + public override async Task HandleChainEndAsync( + Dictionary? inputs, + Dictionary outputs, + string runId, + string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_chain_end callback."); + } + + if (!RunMap.TryGetValue(runId, out var run)) + { + throw new TracerException($"No chain Run found to be traced for {runId}"); + } + + run.Outputs = outputs; + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { ["name"] = "end", ["time"] = run.EndTime }); + + run.Inputs = inputs; + + EndTrace(run); + await HandleChainEndAsync(run); + } + + public override async Task HandleToolStartAsync( + Dictionary tool, + string input, + string runId, + string? parentRunId = null, + List? tags = null, + Dictionary? metadata = null, + string runType = null, + string name = null, + Dictionary? extraParams = null) + { + var executionOrder = GetExecutionOrder(parentRunId); + var startTime = DateTime.UtcNow; + + if (metadata != null) + { extraParams.Add("metadata", metadata);} + + var run = new Run + { + Id = runId, + ParentRunId = parentRunId, + Serialized = tool, + Inputs = new Dictionary { ["input"] = input }, + ExtraData = extraParams, + Events = new List> { new() { ["name"] = "start", ["time"] = startTime } }, + StartTime = startTime, + ExecutionOrder = executionOrder, + ChildExecutionOrder = executionOrder, + ChildRuns = new(), + RunType = "tool", + Tags = tags ?? new(), + Name = name, + }; + + StartTrace(run); + await HandleToolStartAsync(run); + } + + /// + /// Handle an error for a tool run. + /// + public override async Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_tool_error callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "tool") + { + throw new TracerException($"No retriever Run found to be traced for {runId}"); + } + + run.Error = err.ToString(); + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { ["name"] = "error", ["time"] = run.EndTime }); + EndTrace(run); + await HandleToolErrorAsync(run); + } + + /// + /// + /// + /// + /// + /// + /// + public override async Task HandleToolEndAsync(string output, string runId, string? parentRunId = null) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_tool_end callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "tool") + { + throw new TracerException($"No retriever Run found to be traced for {runId}"); + } + + run.Outputs = new Dictionary() + { + ["output"] = output + }; + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary { ["name"] = "end", ["time"] = run.EndTime }); + EndTrace(run); + await HandleToolEndAsync(run); + } + + /// + /// Run when Retriever starts running. + /// + public override async Task HandleRetrieverStartAsync( + BaseRetriever retriever, + string query, + string runId, + string? parentRunId, + List? tags = null, + Dictionary? metadata = null, + string? runType = null, + string? name = null, + Dictionary? extraParams = null) + { + var executionOrder = GetExecutionOrder(parentRunId); + var startTime = DateTime.UtcNow; + + if (metadata != null) + { + extraParams.Add("metadata", metadata); + } + + var run = new Run + { + Id = runId, + Name = name ?? "Retriever", + ParentRunId = parentRunId, + // TODO: pass retriever or dumpd(retriever)? + // serialized=serialized, + Inputs = new Dictionary { ["query"] = query }, + ExtraData = extraParams, + Events = new List> { new() { ["name"] = "start", ["time"] = startTime } }, + StartTime = startTime, + ExecutionOrder = executionOrder, + ChildExecutionOrder = executionOrder, + Tags = tags, + ChildRuns = new(), + RunType = "retriever", + }; + + StartTrace(run); + await HandleRetrieverStartAsync(run); + } + + /// + /// Run when Retriever ends running. + /// + public override async Task HandleRetrieverEndAsync( + string query, + List documents, + string runId, + string? parentRunId) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_retriever_end callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "retriever") + { + throw new TracerException($"No retriever Run found to be traced for {runId}"); + } + + run.Outputs = new Dictionary { ["documents"] = documents }; + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary + { ["name"] = "end", ["time"] = run.EndTime }); + + EndTrace(run); + await HandleRetrieverEndAsync(run); + } + + /// + /// Run when Retriever errors. + /// + public override async Task HandleRetrieverErrorAsync(Exception error, string query, string runId, + string? parentRunId) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_retriever_end callback."); + } + + if (!RunMap.TryGetValue(runId, out var run) || run.RunType != "retriever") + { + throw new TracerException($"No retriever Run found to be traced for {runId}"); + } + + run.Error = error.ToString(); + run.EndTime = DateTime.UtcNow; + run.Events.Add(new Dictionary + { + ["name"] = "error", + ["time"] = run.EndTime + }); + + EndTrace(run); + await HandleRetrieverErrorAsync(run); + } + + public override async Task HandleTextAsync(string text, string runId, string? parentRunId = null) + { + } + + public override async Task HandleAgentActionAsync(Dictionary action, string runId, string? parentRunId = null) + { + } + + public override async Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null) + { + } + + /*public Run OnRetry(RetryCallState retryState, string runId) + { + if (runId == null) + { + throw new TracerException("No run_id provided for on_retry callback."); + } + + if (!_runMap.TryGetValue(runId, out var run) || run == null) + { + throw new TracerException("No Run found to be traced for on_retry"); + } + + var kwargs = new Dictionary + { + { "slept", retryState.IdleFor }, + { "attempt", retryState.AttemptNumber } + }; + + if (retryState.Outcome == null) + { + kwargs["outcome"] = "N/A"; + } + else if (retryState.Outcome.Failed) + { + kwargs["outcome"] = "failed"; + Exception exception = retryState.Outcome.Exception(); + kwargs["exception"] = exception.ToString(); + kwargs["exception_type"] = exception.GetType().Name; + } + else + { + kwargs["outcome"] = "success"; + kwargs["result"] = retryState.Outcome.Result().ToString(); + } + + run.events.Add(new Dictionary + { + { "name", "retry" }, + { "time", DateTime.UtcNow }, + { "kwargs", kwargs } + }); + + return run; + }*/ + + /// + /// Process a run upon creation. + /// + protected abstract void OnRunCreate(Run run); + + /// + /// Process a run upon update. + /// + protected abstract void OnRunUpdate(Run run); + + protected abstract Task HandleLlmStartAsync(Run run); + protected abstract Task HandleLlmNewTokenAsync(Run run, string token); + protected abstract Task HandleLlmErrorAsync(Run run); + protected abstract Task HandleLlmEndAsync(Run run); + protected abstract Task HandleChatModelStartAsync(Run run); + protected abstract Task HandleChainStartAsync(Run run); + protected abstract Task HandleChainErrorAsync(Run run); + protected abstract Task HandleChainEndAsync(Run run); + protected abstract Task HandleToolStartAsync(Run run); + protected abstract Task HandleToolErrorAsync(Run run); + protected abstract Task HandleToolEndAsync(Run run); + protected abstract Task HandleTextAsync(Run run); + protected abstract Task HandleAgentActionAsync(Run run); + protected abstract Task HandleAgentEndAsync(Run run); + protected abstract Task HandleRetrieverStartAsync(Run run); + protected abstract Task HandleRetrieverEndAsync(Run run); + protected abstract Task HandleRetrieverErrorAsync(Run run); + + /// + /// Add child run to a chain run or tool run. + /// + /// + /// + private static void AddChildRun(Run parentRun, Run childRun) => parentRun.ChildRuns.Add(childRun); + + //Start a trace for a run. + private void StartTrace(Run run) + { + if (run.ParentRunId != null) + { + if (RunMap.TryGetValue(run.ParentRunId, out var parentRun)) + { + AddChildRun(parentRun, run); + + parentRun.ChildExecutionOrder = + Math.Max(parentRun.ChildExecutionOrder ?? 0, run.ChildExecutionOrder ?? 0); + } + else + { + Console.WriteLine($"Parent run with id {run.ParentRunId} not found."); + } + } + + RunMap[run.Id] = run; + OnRunCreate(run); + } + + //End a trace for a run. + private void EndTrace(Run run) + { + if (run.ParentRunId == null) + { + PersistRun(run); + } + else + { + if (RunMap.TryGetValue(run.ParentRunId, out var parentRun)) + { + if (run.ChildExecutionOrder != null && parentRun.ChildExecutionOrder != null && + run.ChildExecutionOrder > parentRun.ChildExecutionOrder) + { + parentRun.ChildExecutionOrder = run.ChildExecutionOrder; + } + } + else + { + Console.WriteLine($"Parent run with id {run.ParentRunId} not found."); + } + } + + RunMap.Remove(run.Id); + OnRunUpdate(run); + } + + //Get the execution order for a run. + private int GetExecutionOrder(string? parentRunId = null) + { + if (parentRunId == null) + { + return 1; + } + + if (RunMap.TryGetValue(parentRunId, out var parentRun)) + { + if (parentRun.ChildExecutionOrder == null) + { + throw new TracerException($"Parent run with id {parentRunId} has no child execution order."); + } + } + else + { + Console.WriteLine($"Parent run with id {parentRunId} not found."); + } + + return 1; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Tracers/ConsoleCallbackHandler.cs b/src/libs/LangChain.Core/Base/Tracers/ConsoleCallbackHandler.cs new file mode 100644 index 00000000..1ed3aa42 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/ConsoleCallbackHandler.cs @@ -0,0 +1,223 @@ +namespace LangChain.Base.Tracers; + +public class ConsoleCallbackHandlerInput : BaseCallbackHandlerInput +{ +} + +public class ConsoleCallbackHandler(ConsoleCallbackHandlerInput fields) : BaseTracer(fields) +{ + public ConsoleCallbackHandler() : this(new ConsoleCallbackHandlerInput()) + { + + } + + public override string Name => "console_callback_handler"; + protected override Task PersistRun(Run run) => Task.CompletedTask; + + protected override async Task HandleLlmStartAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + object inputs = run.Inputs.TryGetValue("prompts", out var input) + ? new Dictionary> { { "prompts", (input as List)?.Select(p => p.Trim()).ToList() } } + : run.Inputs; + + Print( + $"{GetColoredText("[llm/start]", ConsoleFormats.Green)} {GetColoredText($"[{crumbs}] Entering LLM run with input:", ConsoleFormats.Bold)}\n" + + $"{JsonSerializeOrDefault(inputs, "[inputs]")}" + ); + } + + protected override async Task HandleLlmNewTokenAsync(Run run, string token) { } + + protected override async Task HandleLlmErrorAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + + Print($"{GetColoredText("[llm/error]", ConsoleFormats.Red)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] LLM run errored with error:", ConsoleFormats.Bold)}\n" + + $"{JsonSerializeOrDefault(run.Error, "[error]")}" + ); + } + + protected override async Task HandleLlmEndAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + + Print($"{GetColoredText("[llm/end]", ConsoleFormats.Blue)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] Exiting LLM run with output:", ConsoleFormats.Bold)}\n" + + $"{JsonSerializeOrDefault(run.Outputs, "[response]")}" + ); + } + + protected override async Task HandleChatModelStartAsync(Run run) { } + + protected override async Task HandleChainStartAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + var runType = run.RunType.Capitalize(); + var input = JsonSerializeOrDefault(run.Inputs, "[inputs]"); + + Print( + $"{GetColoredText("[chain/start]", ConsoleFormats.Green)} {GetColoredText($"[{crumbs}] Entering {runType} run with input:", ConsoleFormats.Bold)}\n" + + $"{input}" + ); + } + + + protected override async Task HandleChainErrorAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + var runType = run.RunType.Capitalize(); + var error = JsonSerializeOrDefault(run.Error, "[error]"); + Print( + $"{GetColoredText("[chain/error]", ConsoleFormats.Red)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] {runType} run errored with error:", ConsoleFormats.Bold)}\n" + + $"{error}" + ); + } + + protected override async Task HandleChainEndAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + var runType = run.RunType.Capitalize(); + var outputs = JsonSerializeOrDefault(run.Outputs, "[outputs]"); + + Print( + $"{GetColoredText("[chain/end]", ConsoleFormats.Blue)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] Exiting {runType} run with output:", ConsoleFormats.Bold)}\n" + + $"{outputs}" + ); + } + + protected override async Task HandleToolStartAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + Print( + $"{GetColoredText("[chain/start]", ConsoleFormats.Green)} {GetColoredText($"[{crumbs}] Entering Tool run with input:", ConsoleFormats.Bold)}\n" + + $"{run.Inputs["input"].ToString().Trim()}" + ); + } + + protected override async Task HandleToolErrorAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + Print( + $"{GetColoredText("[chain/error]", ConsoleFormats.Red)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] Tool run errored with error:", ConsoleFormats.Bold)}\n" + + $"{run.Error}" + ); + } + + protected override async Task HandleToolEndAsync(Run run) + { + var crumbs = GetBreadcrumbs(run); + if (run.Outputs.Count != 0) + Print( + $"{GetColoredText("[chain/end]", ConsoleFormats.Blue)} {GetColoredText($"[{crumbs}] [{Elapsed(run)}] Exiting Tool run with output:", ConsoleFormats.Bold)}\n" + + $"{run.Outputs["output"].ToString().Trim()}" + ); + } + + protected override async Task HandleTextAsync(Run run) + { + } + + protected override async Task HandleAgentActionAsync(Run run) + { + } + + protected override async Task HandleAgentEndAsync(Run run) + { + } + + protected override async Task HandleRetrieverStartAsync(Run run) + { + } + + protected override async Task HandleRetrieverEndAsync(Run run) + { + } + + protected override async Task HandleRetrieverErrorAsync(Run run) + { + } + + protected override void OnRunCreate(Run run) + { + } + + protected override void OnRunUpdate(Run run) + { + } + + private List GetParents(Run run) + { + var parents = new List(); + var currentRun = run; + while (currentRun.ParentRunId != null) + { + if (RunMap.TryGetValue(currentRun.ParentRunId, out var parent) && parent != null) + { + parents.Add(parent); + currentRun = parent; + } + else break; + } + + return parents; + } + + private string GetBreadcrumbs(Run run) + { + var parents = GetParents(run); + parents.Reverse(); + parents.Add(run); + + var breadcrumbs = parents.Select((parent, i) => $"{parent.ExecutionOrder}:{parent.RunType}:{parent.Name}"); + var result = string.Join(" > ", breadcrumbs); + + return result; + } + + private void Print(string text) => Console.WriteLine(text); + + private string GetColoredText(string text, string format) + { + return $"{format}{text}{ConsoleFormats.Normal}"; + } + + private string JsonSerializeOrDefault(object obj, string @default) + { + try + { + return System.Text.Json.JsonSerializer.Serialize(obj); + } + catch (Exception _) + { + return @default; + } + } + + /// + /// Get the elapsed time of a run. + /// + /// A string with the elapsed time in seconds or milliseconds if time is less than a second. + private string Elapsed(Run run) + { + if (!run.EndTime.HasValue) + return "N/A"; + + var elapsedTime = run.EndTime.Value - run.StartTime; + var milliseconds = elapsedTime.TotalMilliseconds; + + return elapsedTime.TotalMilliseconds < 1000 + ? $"{milliseconds}ms" + : $"{elapsedTime.TotalSeconds:F1}s"; + } + + private static class ConsoleFormats + { + public static string Normal = Console.IsOutputRedirected ? "" : "\x1b[39m"; + public static string Red = Console.IsOutputRedirected ? "" : "\x1b[91m"; + public static string Green = Console.IsOutputRedirected ? "" : "\x1b[92m"; + public static string Yellow = Console.IsOutputRedirected ? "" : "\x1b[93m"; + public static string Blue = Console.IsOutputRedirected ? "" : "\x1b[94m"; + public static string Bold = Console.IsOutputRedirected ? "" : "\x1b[1m"; + public static string Underline = Console.IsOutputRedirected ? "" : "\x1b[4m"; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Tracers/RunBase.cs b/src/libs/LangChain.Core/Base/Tracers/RunBase.cs new file mode 100644 index 00000000..24b9bbd0 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/RunBase.cs @@ -0,0 +1,114 @@ +namespace LangChain.Base.Tracers; + +/// +/// Base Run schema. +/// Contains the fundamental fields to define a run in a system. +/// +public abstract class RunBase +{ + /// + /// Unique identifier for the run. + /// + public string Id { get; set; } + + /// + /// Human-readable name for the run. + /// + public string Name { get; set; } + + /// + /// Start time of the run. + /// + public DateTime StartTime { get; set; } + + /// + /// The type of run, such as tool, chain, llm, retriever, + /// embedding, prompt, parser. + /// + public string RunType { get; set; } + + /// + /// End time of the run, if applicable. + /// + public DateTime? EndTime { get; set; } + + /// + /// Additional metadata or settings related to the run. + /// + public Dictionary ExtraData { get; set; } + + /// + /// Error message, if the run encountered any issues. + /// + public string Error { get; set; } + + /// + /// Serialized object that executed the run for potential reuse. + /// + public Dictionary Serialized { get; set; } + + /// + /// List of events associated with the run, like start and end events. + /// + public List> Events { get; set; } + + /// + /// Inputs used for the run. + /// + public Dictionary Inputs { get; set; } + + /// + /// Outputs generated by the run, if any. + /// + public Dictionary Outputs { get; set; } + + /// + /// Reference to an example that this run may be based on. + /// + public Guid? ReferenceExampleId { get; set; } + + /// + /// Identifier for a parent run, if this run is a sub-run. + /// + public string? ParentRunId { get; set; } + + /// + /// Tags for categorizing or annotating the run. + /// + public List Tags { get; set; } +} + +/// +/// Run schema in the Tracer +/// +public class Run : RunBase +{ + /// + /// The execution order of the run within a run trace. + /// + public int ExecutionOrder { get; set; } + public int? ChildExecutionOrder { get; set; } + + /// + /// The child runs of this run + /// + public List ChildRuns { get; set; } = new(); + + // TODO: name init; + // @root_validator(pre=True) + // def assign_name(cls, values: dict) -> dict: + // """Assign name to the run.""" + // if values.get("name") is None: + // if "name" in values["serialized"]: + // values["name"] = values["serialized"]["name"] + // elif "id" in values["serialized"]: + // values["name"] = values["serialized"]["id"][-1] + // if values.get("events") is None: + // values["events"] = [] + // return values +} + +public static class RunExtensions +{ + +} diff --git a/src/libs/LangChain.Core/Base/Tracers/StringExtensions.cs b/src/libs/LangChain.Core/Base/Tracers/StringExtensions.cs new file mode 100644 index 00000000..9dac6246 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/StringExtensions.cs @@ -0,0 +1,20 @@ +namespace LangChain.Base.Tracers; + +public static class StringExtensions +{ + public static string Capitalize(this string? word) + { + if (word == null) + { + return word; + } + + if (word.Length == 1) + { + return word.ToUpper(System.Globalization.CultureInfo.CurrentCulture); + } + + return word.Substring(0, 1).ToUpper(System.Globalization.CultureInfo.CurrentCulture) + + word.Substring(1).ToLower(System.Globalization.CultureInfo.CurrentCulture); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Base/Tracers/TracerException.cs b/src/libs/LangChain.Core/Base/Tracers/TracerException.cs new file mode 100644 index 00000000..bdffe5c2 --- /dev/null +++ b/src/libs/LangChain.Core/Base/Tracers/TracerException.cs @@ -0,0 +1,11 @@ +namespace LangChain.Base.Tracers; + +/// +/// Base class for exceptions in tracers module. +/// +public class TracerException : Exception +{ + public TracerException(string message) : base(message) + { + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/BaseRunManager.cs b/src/libs/LangChain.Core/Callback/BaseRunManager.cs index 8b48f0e8..f28fb415 100644 --- a/src/libs/LangChain.Core/Callback/BaseRunManager.cs +++ b/src/libs/LangChain.Core/Callback/BaseRunManager.cs @@ -5,28 +5,32 @@ namespace LangChain.Callback; /// /// /// -public class BaseRunManager +public abstract class BaseRunManager { - /// - /// - /// + /// public string RunId { get; } - /// - /// - /// + /// protected List Handlers { get; } - /// - /// - /// + /// protected List InheritableHandlers { get; } - /// - /// - /// + /// protected string? ParentRunId { get; } + /// + protected List Tags { get; } + + /// + protected List InheritableTags { get; } + + /// + protected Dictionary Metadata { get; } + + /// + protected Dictionary InheritableMetadata { get; } + /// /// /// @@ -34,18 +38,42 @@ public class BaseRunManager /// /// /// - public BaseRunManager(string runId, List handlers, List inheritableHandlers, string? parentRunId = null) + /// + /// + /// + /// + public BaseRunManager( + string runId, + List handlers, + List inheritableHandlers, + string? parentRunId = null, + List? tags = null, + List? inheritableTags = null, + Dictionary? metadata = null, + Dictionary? inheritableMetadata = null) { RunId = runId; Handlers = handlers; InheritableHandlers = inheritableHandlers; + Tags = tags; + InheritableTags = inheritableTags ?? new(); + Metadata = metadata ?? new(); + InheritableMetadata = inheritableMetadata ?? new(); ParentRunId = parentRunId; } + protected BaseRunManager() + : this( + runId: Guid.NewGuid().ToString("N"), + handlers: new(), + inheritableHandlers: new()) + { + } + /// - /// + /// Run when text is received. /// - /// + /// The received text. public async Task HandleText(string text) { foreach (var handler in Handlers) @@ -60,4 +88,15 @@ public async Task HandleText(string text) } } } -} \ No newline at end of file + + /// + /// Return a manager that doesn't perform any operations. + /// TODO: (static abstract not supported by some target runtimes) + /// + public static T GetNoopManager() where T : IRunManagerImplementation, new() + { + return new T(); + } +} + +public interface IRunManagerImplementation where TThis : IRunManagerImplementation, new(); diff --git a/src/libs/LangChain.Core/Callback/CallbackManager.cs b/src/libs/LangChain.Core/Callback/CallbackManager.cs index c7f9bf05..abf829bd 100644 --- a/src/libs/LangChain.Core/Callback/CallbackManager.cs +++ b/src/libs/LangChain.Core/Callback/CallbackManager.cs @@ -1,17 +1,14 @@ +using LangChain.Abstractions.Schema; using LangChain.Base; +using LangChain.Base.Tracers; using LangChain.LLMS; using LangChain.Providers; using LangChain.Retrievers; -using LangChain.Schema; namespace LangChain.Callback; -using System; -using System.Collections.Generic; -using System.Threading.Tasks; - /// -/// +/// Base callback manager that handles callbacks from LangChain. /// public class CallbackManager { @@ -25,17 +22,81 @@ public class CallbackManager /// public List InheritableHandlers { get; private set; } public string Name { get; } = "callback_manager"; - private readonly string? _parentRunId; + public readonly string? ParentRunId; + + protected List Tags { get; } + protected List InheritableTags { get; } + protected Dictionary Metadata { get; } + protected Dictionary InheritableMetadata { get; } /// /// /// + /// /// - public CallbackManager(string? parentRunId = null) + /// + /// + /// + /// + /// + public CallbackManager( + List? handlers = null, + List? inheritableHandlers = null, + List? tags = null, + List? inheritableTags = null, + Dictionary? metadata = null, + Dictionary? inheritableMetadata = null, + string? parentRunId = null) + { + Handlers = handlers ?? new List(); + InheritableHandlers = inheritableHandlers ?? new List(); + ParentRunId = parentRunId; + + Tags = tags ?? new(); + InheritableTags = inheritableTags ?? new(); + Metadata = metadata ?? new(); + InheritableMetadata = inheritableMetadata ?? new(); + } + + public void AddTags(List tags, bool inherit = true) + { + Tags.RemoveAll(tag => tags.Contains(tag)); + Tags.AddRange(tags); + + if (inherit) + { + InheritableTags.AddRange(tags); + } + } + + public void RemoveTags(List tags) { - Handlers = new List(); - InheritableHandlers = new List(); - _parentRunId = parentRunId; + foreach (var tag in tags) + { + Tags.Remove(tag); + InheritableTags.Remove(tag); + } + } + + public void AddMetadata(Dictionary metadata, bool inherit = true) + { + foreach (var kv in metadata) + { + Metadata[kv.Key] = kv.Value; + if (inherit) + { + InheritableMetadata[kv.Key] = kv.Value; + } + } + } + + public void RemoveMetadata(List keys) + { + foreach (var key in keys) + { + Metadata.Remove(key); + InheritableMetadata.Remove(key); + } } public async Task HandleLlmStart( @@ -45,13 +106,15 @@ public async Task HandleLlmStart( string? parentRunId = null, Dictionary? extraParams = null) { + runId ??= Guid.NewGuid().ToString(); + foreach (var handler in Handlers) { if (!handler.IgnoreLlm) { try { - await handler.HandleLlmStartAsync(llm, prompts.ToArray(), runId ?? Guid.NewGuid().ToString(), _parentRunId, extraParams); + await handler.HandleLlmStartAsync(llm, prompts.ToArray(), runId, ParentRunId, extraParams: extraParams); } catch (Exception ex) { @@ -60,7 +123,7 @@ public async Task HandleLlmStart( } } - return new CallbackManagerForLlmRun(runId, Handlers, InheritableHandlers, _parentRunId); + return new CallbackManagerForLlmRun(runId, Handlers, InheritableHandlers, ParentRunId); } public async Task HandleChatModelStart( @@ -70,55 +133,49 @@ public async Task HandleChatModelStart( string? parentRunId = null, Dictionary? extraParams = null) { - List messageStrings = null; + runId ??= Guid.NewGuid().ToString(); + foreach (var handler in Handlers) { if (!handler.IgnoreLlm) { - /*try + try { - if (handler is IHandleChatModelStart handleChatModelStartHandler) - { - await handleChatModelStartHandler.HandleChatModelStart(llm, messages, runId ?? Guid.NewGuid().ToString(), _parentRunId, extraParams); - } - else if (handler is IHandleLLMStart handleLLMStartHandler) - { - messageStrings = messages.Select(x => GetBufferString(x)).ToList(); - await handleLLMStartHandler.HandleLLMStart(llm, messageStrings, runId ?? Guid.NewGuid().ToString(), _parentRunId, extraParams); - } + await handler.HandleChatModelStartAsync(llm, messages, runId, ParentRunId, extraParams); } catch (Exception ex) { Console.Error.WriteLine($"Error in handler {handler.GetType().Name}, HandleLLMStart: {ex}"); - }*/ + } } } - return new CallbackManagerForLlmRun(runId, Handlers, InheritableHandlers, _parentRunId); + return new CallbackManagerForLlmRun(runId, Handlers, InheritableHandlers, ParentRunId); } public async Task HandleChainStart( BaseChain chain, - ChainValues inputs, + IChainValues inputs, string? runId = null) { + runId ??= Guid.NewGuid().ToString(); + foreach (var handler in Handlers) { - //TODO: Implement methods - // if (!handler.IgnoreChain) - // { - // try - // { - // await handler.HandleChainStart(chain, inputs, runId ?? Guid.NewGuid().ToString(), _parentRunId); - // } - // catch (Exception ex) - // { - // Console.Error.WriteLine($"Error in handler {handler.GetType().Name}, HandleChainStart: {ex}"); - // } - // } + if (!handler.IgnoreChain) + { + try + { + await handler.HandleChainStartAsync(chain, inputs.Value, runId, ParentRunId); + } + catch (Exception ex) + { + Console.Error.WriteLine($"Error in handler {handler.GetType().Name}, HandleChainStart: {ex}"); + } + } } - return new CallbackManagerForChainRun(runId, Handlers, InheritableHandlers, _parentRunId); + return new CallbackManagerForChainRun(runId, Handlers, InheritableHandlers, ParentRunId); } public async Task HandleRetrieverStart( @@ -128,13 +185,16 @@ public async Task HandleRetrieverStart( string? parentRunId = null, Dictionary? extraParams = null) { + runId ??= Guid.NewGuid().ToString(); + foreach (var handler in Handlers) { if (!handler.IgnoreLlm) { try { - await handler.HandleRetrieverStartAsync(query, runId ?? Guid.NewGuid().ToString(), _parentRunId); + // TODO: pass extraParams ? + await handler.HandleRetrieverStartAsync(retriever, query, runId, ParentRunId, extraParams: extraParams); } catch (Exception ex) { @@ -143,7 +203,17 @@ public async Task HandleRetrieverStart( } } - return new CallbackManagerForRetrieverRun(runId, Handlers, InheritableHandlers, _parentRunId); + var manager = new CallbackManagerForRetrieverRun( + runId, + Handlers, + InheritableHandlers, + ParentRunId, + Tags, + InheritableTags, + Metadata, + InheritableMetadata); + + return manager; } public void AddHandler(BaseCallbackHandler handler, bool inherit = true) @@ -155,11 +225,6 @@ public void AddHandler(BaseCallbackHandler handler, bool inherit = true) } } - public void AddHandler(BaseCallbackHandler handler) - { - throw new NotImplementedException(); - } - public void RemoveHandler(BaseCallbackHandler handler) { Handlers.Remove(handler); @@ -183,12 +248,13 @@ public void SetHandlers(List handlers, bool inherit = true) public CallbackManager Copy(List? additionalHandlers = null, bool inherit = true) { - var manager = new CallbackManager(_parentRunId); + var manager = new CallbackManager(parentRunId: ParentRunId); foreach (var handler in Handlers) { var inheritable = InheritableHandlers.Contains(handler); manager.AddHandler(handler, inheritable); } + if (additionalHandlers != null) { foreach (var handler in additionalHandlers) @@ -215,130 +281,98 @@ public static CallbackManager FromHandlers(List handlers) return manager; } + // TODO: review! motivation? + // ICallbackManagerOptions? options = null, public static async Task Configure( - List? inheritableHandlers = null, - List? localHandlers = null, - ICallbackManagerOptions? options = null) + ICallbacks? inheritableCallbacks = null, + ICallbacks? localCallbacks = null, + bool verbose = false, + List? localTags = null, + List? inheritableTags = null, + Dictionary? localMetadata = null, + Dictionary? inheritableMetadata = null) { - CallbackManager callbackManager = null; - if (inheritableHandlers != null || localHandlers != null) + // TODO: parentRunId using AsyncLocal + // python version using `contextvars` lib + // run_tree = get_run_tree_context() + // parent_run_id = None if run_tree is None else getattr(run_tree, "id") + string parentId = null; + + CallbackManager callbackManager; + + if (inheritableCallbacks != null || localCallbacks != null) { - if (inheritableHandlers is List || inheritableHandlers == null) + switch (inheritableCallbacks) { - callbackManager = new CallbackManager(); - callbackManager.SetHandlers(inheritableHandlers?.Cast().ToList() ?? new List(), true); + case HandlersCallbacks inheritableHandlers: + callbackManager = new CallbackManager(parentRunId: parentId); + callbackManager.SetHandlers(inheritableHandlers.Value, true); + break; + + case ManagerCallbacks managerCallbacks: + // ToList() and ToDictionary() used to create copy + callbackManager = new CallbackManager( + managerCallbacks.Value.Handlers.ToList(), + managerCallbacks.Value.InheritableHandlers.ToList(), + managerCallbacks.Value.Tags.ToList(), + managerCallbacks.Value.InheritableTags.ToList(), + managerCallbacks.Value.Metadata.ToDictionary(kv => kv.Key, kv => kv.Value), + managerCallbacks.Value.InheritableMetadata.ToDictionary(kv => kv.Key, kv => kv.Value), + parentRunId: managerCallbacks.Value.ParentRunId); + break; + + default: + callbackManager = new CallbackManager(parentRunId: parentId); + break; } - callbackManager = callbackManager.Copy( - localHandlers, - false); + var localHandlers = localCallbacks switch + { + HandlersCallbacks localHandlersCallbacks => localHandlersCallbacks.Value, + ManagerCallbacks managerCallbacks => managerCallbacks.Value.Handlers, + _ => new List() + }; + + callbackManager = callbackManager.Copy(localHandlers, false); } - var verboseEnabled = (Environment.GetEnvironmentVariable("LANGCHAIN_VERBOSE") != null || options?.Verbose == true); + else + { + callbackManager = new CallbackManager(parentRunId: parentId); + } + + if (inheritableTags != null) callbackManager.AddTags(inheritableTags); + if (localTags != null) callbackManager.AddTags(localTags, inherit: false); + + if (inheritableMetadata != null) callbackManager.AddMetadata(inheritableMetadata); + if (localMetadata != null) callbackManager.AddMetadata(localMetadata, inherit: false); + + var verboseEnabled = (Environment.GetEnvironmentVariable("LANGCHAIN_VERBOSE") != null || verbose); var tracingV2Enabled = (Environment.GetEnvironmentVariable("LANGCHAIN_TRACING_V2") != null); var tracingEnabled = tracingV2Enabled || (Environment.GetEnvironmentVariable("LANGCHAIN_TRACING") != null); if (verboseEnabled || tracingEnabled) { - if (callbackManager == null) - { - callbackManager = new CallbackManager(); - } - //TODO: Implement handlers - /*if (!callbackManager.Handlers.Any(h => h.Name == ConsoleCallbackHandler.Name)) + // TODO: replace inlined name "console_callback_handler" with const + if (callbackManager.Handlers.All(h => h.Name != "console_callback_handler")) { var consoleHandler = new ConsoleCallbackHandler(); - callbackManager.AddHandler(consoleHandler, true); + callbackManager.AddHandler(consoleHandler, inherit: true); } - if (!callbackManager.Handlers.Any(h => h.Name == "langchain_tracer")) - { - if (tracingV2Enabled) - { - callbackManager.AddHandler(await GetTracingV2CallbackHandler(), true); - } - else - { - var session = Environment.GetEnvironmentVariable("LANGCHAIN_SESSION"); - callbackManager.AddHandler(await GetTracingCallbackHandler(session), true); - } - }*/ - } - return callbackManager; - } - - private static string GetBufferString(List messages) - { - // Implement your logic here to convert messages to a string - throw new NotImplementedException(); - } - - public Task HandleLlmStartAsync(Dictionary llm, string[] prompts, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChatModelStartAsync(Dictionary llm, List> messages, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainStartAsync(Dictionary chain, Dictionary inputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainEndAsync(Dictionary outputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - public Task HandleToolStartAsync(Dictionary tool, string input, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolEndAsync(string output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleTextAsync(string text, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleAgentActionAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } + // TODO: implement handlers + // if (callbackManager.Handlers.All(h => h.Name != "langchain_tracer")) + // { + // if (tracingV2Enabled) + // { + // callbackManager.AddHandler(await GetTracingV2CallbackHandler(), true); + // } + // else + // { + // var session = Environment.GetEnvironmentVariable("LANGCHAIN_SESSION"); + // callbackManager.AddHandler(await GetTracingCallbackHandler(session), true); + // } + // } + } - public Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); + return callbackManager; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/CallbackManagerForChainRun.cs b/src/libs/LangChain.Core/Callback/CallbackManagerForChainRun.cs index 50b4f04c..fa404304 100644 --- a/src/libs/LangChain.Core/Callback/CallbackManagerForChainRun.cs +++ b/src/libs/LangChain.Core/Callback/CallbackManagerForChainRun.cs @@ -1,23 +1,25 @@ +using LangChain.Abstractions.Schema; using LangChain.Base; -using LangChain.Schema; namespace LangChain.Callback; -public class CallbackManagerForChainRun : BaseRunManager +public class CallbackManagerForChainRun : ParentRunManager, IRunManagerImplementation { - public CallbackManagerForChainRun(string runId, List handlers, List inheritableHandlers, string? parentRunId = null) - : base(runId, handlers, inheritableHandlers, parentRunId) + public CallbackManagerForChainRun() { + } - public CallbackManager GetChild() + public CallbackManagerForChainRun( + string runId, + List handlers, + List inheritableHandlers, + string? parentRunId = null) + : base(runId, handlers, inheritableHandlers, parentRunId) { - var manager = new CallbackManager(RunId); - manager.SetHandlers(InheritableHandlers); - return manager; } - public async Task HandleChainEndAsync(ChainValues output) + public async Task HandleChainEndAsync(IChainValues input, IChainValues output) { foreach (var handler in Handlers) { @@ -25,7 +27,7 @@ public async Task HandleChainEndAsync(ChainValues output) { try { - await handler.HandleChainEndAsync(output.Value, RunId, ParentRunId); + await handler.HandleChainEndAsync(input.Value, output.Value, RunId, ParentRunId); } catch (Exception ex) { @@ -35,39 +37,7 @@ public async Task HandleChainEndAsync(ChainValues output) } } - public Task HandleLlmStartAsync(Dictionary llm, string[] prompts, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChatModelStartAsync(Dictionary llm, List> messages, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainStartAsync(Dictionary chain, Dictionary inputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public async Task HandleChainErrorAsync(Exception error, string runId, string? parentRunId = null) + public async Task HandleChainErrorAsync(Exception error, IChainValues input) { foreach (var handler in Handlers) { @@ -75,7 +45,7 @@ public async Task HandleChainErrorAsync(Exception error, string runId, string? p { try { - await handler.HandleChainErrorAsync(error, RunId, ParentRunId); + await handler.HandleChainErrorAsync(error, RunId, input.Value, ParentRunId); } catch (Exception ex) { @@ -85,38 +55,8 @@ public async Task HandleChainErrorAsync(Exception error, string runId, string? p } } - public Task HandleChainEndAsync(Dictionary outputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolStartAsync(Dictionary tool, string input, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolEndAsync(string output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - public Task HandleTextAsync(string text, string runId, string? parentRunId = null) { throw new NotImplementedException(); } - - public Task HandleAgentActionAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/CallbackManagerForLlmRun.cs b/src/libs/LangChain.Core/Callback/CallbackManagerForLlmRun.cs index 51669909..5fd96f16 100644 --- a/src/libs/LangChain.Core/Callback/CallbackManagerForLlmRun.cs +++ b/src/libs/LangChain.Core/Callback/CallbackManagerForLlmRun.cs @@ -10,32 +10,7 @@ public CallbackManagerForLlmRun(string runId, List handlers { } - public Task HandleToolEndAsync(string output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleTextAsync(string text, string runId, string parentRunId) - { - throw new NotImplementedException(); - } - - public Task HandleAgentActionAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmStartAsync(Dictionary llm, string[] prompts, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - + // TODO: remove? public async Task HandleLlmNewTokenAsync(string token, string runId, string parentRunId) { foreach (var handler in Handlers) @@ -89,35 +64,4 @@ public async Task HandleLlmEndAsync(LlmResult output, string runId, string paren } } } - - public Task HandleChatModelStartAsync(Dictionary llm, List> messages, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainStartAsync(Dictionary chain, Dictionary inputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainEndAsync(Dictionary outputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolStartAsync(Dictionary tool, string input, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/CallbackManagerForRetrieverRun.cs b/src/libs/LangChain.Core/Callback/CallbackManagerForRetrieverRun.cs index d35345bd..e9007b14 100644 --- a/src/libs/LangChain.Core/Callback/CallbackManagerForRetrieverRun.cs +++ b/src/libs/LangChain.Core/Callback/CallbackManagerForRetrieverRun.cs @@ -1,19 +1,37 @@ using LangChain.Base; +using LangChain.Docstore; namespace LangChain.Callback; -public class CallbackManagerForRetrieverRun : BaseRunManager + + +/// +/// Callback manager for retriever run. +/// +public class CallbackManagerForRetrieverRun : ParentRunManager, IRunManagerImplementation { + public CallbackManagerForRetrieverRun() + { + + } + public CallbackManagerForRetrieverRun( string runId, List handlers, List inheritableHandlers, - string? parentRunId = null) - : base(runId, handlers, inheritableHandlers, parentRunId) + string? parentRunId = null, + List? tags = null, + List? inheritableTags = null, + Dictionary? metadata = null, + Dictionary? inheritableMetadata = null) + : base(runId, handlers, inheritableHandlers, parentRunId, tags, inheritableTags, metadata, inheritableMetadata) { } - public async Task HandleRetrieverEndAsync(string query) + /// + /// Run when retriever ends running. + /// + public async Task HandleRetrieverEndAsync(string query, IEnumerable docs) { foreach (var handler in Handlers) { @@ -21,7 +39,7 @@ public async Task HandleRetrieverEndAsync(string query) { try { - await handler.HandleRetrieverEndAsync(query, RunId, ParentRunId); + await handler.HandleRetrieverEndAsync(query, docs.ToList(), RunId, ParentRunId); } catch (Exception ex) { @@ -31,6 +49,9 @@ public async Task HandleRetrieverEndAsync(string query) } } + /// + /// Run when retriever errors. + /// public async Task HandleRetrieverErrorAsync(Exception error, string query) { foreach (var handler in Handlers) diff --git a/src/libs/LangChain.Core/Callback/CallbackManagerForToolRun.cs b/src/libs/LangChain.Core/Callback/CallbackManagerForToolRun.cs index f2271f59..9120d0c5 100644 --- a/src/libs/LangChain.Core/Callback/CallbackManagerForToolRun.cs +++ b/src/libs/LangChain.Core/Callback/CallbackManagerForToolRun.cs @@ -3,67 +3,13 @@ namespace LangChain.Callback; -public class CallbackManagerForToolRun : BaseRunManager +public class CallbackManagerForToolRun : ParentRunManager { public CallbackManagerForToolRun(string runId, List handlers, List inheritableHandlers, string? parentRunId = null) : base(runId, handlers, inheritableHandlers, parentRunId) { } - public CallbackManager GetChild() - { - var manager = new CallbackManager(RunId); - manager.SetHandlers(InheritableHandlers); - return manager; - } - - public Task HandleLlmStartAsync(Dictionary llm, string[] prompts, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChatModelStartAsync(Dictionary llm, List> messages, string runId, string? parentRunId = null, - Dictionary? extraParams = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainStartAsync(Dictionary chain, Dictionary inputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleChainEndAsync(Dictionary outputs, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleToolStartAsync(Dictionary tool, string input, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - public Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null) { throw new NotImplementedException(); @@ -73,19 +19,4 @@ public Task HandleToolEndAsync(string output, string runId, string? parentRunId { throw new NotImplementedException(); } - - public Task HandleTextAsync(string text, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleAgentActionAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } - - public Task HandleAgentEndAsync(Dictionary action, string runId, string? parentRunId = null) - { - throw new NotImplementedException(); - } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/ICallbacks.cs b/src/libs/LangChain.Core/Callback/ICallbacks.cs new file mode 100644 index 00000000..46e0c174 --- /dev/null +++ b/src/libs/LangChain.Core/Callback/ICallbacks.cs @@ -0,0 +1,9 @@ +using LangChain.Base; + +namespace LangChain.Callback; + +public interface ICallbacks; + +public record ManagerCallbacks(CallbackManager Value) : ICallbacks; + +public record HandlersCallbacks(List Value) : ICallbacks; \ No newline at end of file diff --git a/src/libs/LangChain.Core/Callback/ParentRunManager.cs b/src/libs/LangChain.Core/Callback/ParentRunManager.cs new file mode 100644 index 00000000..e65c99a0 --- /dev/null +++ b/src/libs/LangChain.Core/Callback/ParentRunManager.cs @@ -0,0 +1,47 @@ +using LangChain.Base; + +namespace LangChain.Callback; + +/// +/// Sync Parent Run Manager. +/// +public class ParentRunManager : BaseRunManager +{ + public ParentRunManager() + { + + } + + public ParentRunManager( + string runId, + List handlers, + List inheritableHandlers, + string? parentRunId = null, + List? tags = null, + List? inheritableTags = null, + Dictionary? metadata = null, + Dictionary? inheritableMetadata = null) + : base(runId, handlers, inheritableHandlers, parentRunId, tags, inheritableTags, metadata, inheritableMetadata) + { + } + + /// + /// Get a child callback manager. + /// + /// The tag for the child callback manager. + /// The child callback manager. + public CallbackManager GetChild(string? tag = null) + { + var manager = new CallbackManager(parentRunId: RunId); + + manager.SetHandlers(InheritableHandlers); + + manager.AddTags(InheritableTags); + manager.AddMetadata(InheritableMetadata); + + if (tag != null) + manager.AddTags(new List { tag }, inherit: false); + + return manager; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/Base/IChain.cs b/src/libs/LangChain.Core/Chains/Base/IChain.cs index 29b5ff78..dcae992b 100644 --- a/src/libs/LangChain.Core/Chains/Base/IChain.cs +++ b/src/libs/LangChain.Core/Chains/Base/IChain.cs @@ -1,4 +1,5 @@ using LangChain.Abstractions.Schema; +using LangChain.Callback; namespace LangChain.Abstractions.Chains.Base; @@ -6,5 +7,10 @@ public interface IChain { string[] InputKeys { get; } string[] OutputKeys { get; } - Task CallAsync(IChainValues values); + + Task CallAsync( + IChainValues values, + ICallbacks? callbacks = null, + List? tags = null, + Dictionary? metadata = null); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/AnalyzeDocumentChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/AnalyzeDocumentChain.cs index ba11992e..d1d76974 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/AnalyzeDocumentChain.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/AnalyzeDocumentChain.cs @@ -1,13 +1,14 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; using LangChain.Base; +using LangChain.Callback; using LangChain.Docstore; using LangChain.Schema; using LangChain.TextSplitters; namespace LangChain.Chains.CombineDocuments; -public class AnalyzeDocumentChain(AnalyzeDocumentsChainInput fields) : BaseChain, IChain +public class AnalyzeDocumentChain(AnalyzeDocumentsChainInput fields) : BaseChain(fields), IChain { private readonly string _inputKey = fields.InputKey; private readonly string _outputKey = fields.OutputKey; @@ -20,7 +21,7 @@ public class AnalyzeDocumentChain(AnalyzeDocumentsChainInput fields) : BaseChain public override string[] InputKeys => new [] { _inputKey }; public override string[] OutputKeys => new [] { _outputKey }; - public override async Task CallAsync(IChainValues values) + protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) { var documents = values.Value[_inputKey]; var docs = _textSplitter.SplitDocuments(documents as List); diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs index 3f25caa5..e047a8ad 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs @@ -1,6 +1,7 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; using LangChain.Base; +using LangChain.Callback; using LangChain.Docstore; using LangChain.Schema; @@ -17,7 +18,7 @@ namespace LangChain.Chains.CombineDocuments; /// determine whether it's safe to pass a list of documents into this chain or whether /// that will longer than the context length). /// -public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput fields) : BaseChain, IChain +public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput fields) : BaseChain(fields), IChain { public readonly string InputKey = fields.InputKey; public readonly string OutputKey = fields.OutputKey; @@ -29,8 +30,9 @@ public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput f /// Prepare inputs, call combine docs, prepare outputs. /// /// + /// /// - public override async Task CallAsync(IChainValues values) + protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) { var docs = values.Value["input_documents"]; diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChainInput.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChainInput.cs index 39cebab8..1d50964e 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChainInput.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChainInput.cs @@ -1,13 +1,10 @@ using LangChain.Base; -using LangChain.Callback; namespace LangChain.Chains.CombineDocuments; /// -public abstract class BaseCombineDocumentsChainInput : IChainInputs +public abstract class BaseCombineDocumentsChainInput : ChainInputs { public string InputKey { get; set; } = "input_documents"; public string OutputKey { get; set; } = "output_text"; - public bool? Verbose { get; set; } - public CallbackManager? CallbackManager { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs b/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs index 670cce2d..a0365c88 100644 --- a/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs +++ b/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs @@ -13,7 +13,7 @@ namespace LangChain.Chains.LLM; using System.Collections.Generic; using System.Threading.Tasks; -public class LlmChain(LlmChainInput fields) : BaseChain, ILlmChain +public class LlmChain(LlmChainInput fields) : BaseChain(fields), ILlmChain { public BasePromptTemplate Prompt { get; } = fields.Prompt; public IChatModel Llm { get; } = fields.Llm; @@ -22,9 +22,13 @@ public class LlmChain(LlmChainInput fields) : BaseChain, ILlmChain public override string ChainType() => "llm_chain"; - public bool? Verbose { get; set; } public CallbackManager? CallbackManager { get; set; } + public bool Verbose { get; set; } + public ICallbacks? Callbacks { get; set; } + public List Tags { get; set; } + public Dictionary Metadata { get; set; } + public override string[] InputKeys => Prompt.InputVariables.ToArray(); public override string[] OutputKeys => new[] { OutputKey }; @@ -40,8 +44,9 @@ public class LlmChain(LlmChainInput fields) : BaseChain, ILlmChain /// Execute the chain. /// /// The values to use when executing the chain. + /// /// The resulting output . - public override async Task CallAsync(IChainValues values) + protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) { List? stop = new List(); @@ -79,5 +84,4 @@ public async Task Predict(ChainValues values) var output = await CallAsync(values); return output.Value[OutputKey]; } - } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs b/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs index 918b7f8f..a5b82eda 100644 --- a/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs +++ b/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs @@ -1,4 +1,4 @@ -using LangChain.Callback; +using LangChain.Base; using LangChain.Memory; using LangChain.Prompts.Base; using LangChain.Providers; @@ -9,12 +9,10 @@ public class LlmChainInput( IChatModel llm, BasePromptTemplate prompt, BaseMemory? memory = null) - : ILlmChainInput + : ChainInputs, ILlmChainInput { public BasePromptTemplate Prompt { get; set; } = prompt; public IChatModel Llm { get; set; } = llm; public string OutputKey { get; set; } = "text"; - public bool? Verbose { get; set; } - public CallbackManager CallbackManager { get; set; } public BaseMemory? Memory { get; set; } = memory; } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs index 6b8465df..056fc3c0 100644 --- a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs +++ b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs @@ -1,6 +1,7 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; using LangChain.Base; +using LangChain.Callback; using LangChain.Chains.CombineDocuments; using LangChain.Docstore; using LangChain.Schema; @@ -11,7 +12,7 @@ namespace LangChain.Chains.RetrievalQA; /// Base class for question-answering chains. /// /// -public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : BaseChain, IChain +public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : BaseChain(fields), IChain { private readonly string _inputKey = fields.InputKey; private readonly string _outputKey = fields.OutputKey; @@ -20,6 +21,8 @@ public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : B private const string SourceDocuments = "source_documents"; + public CallbackManager? CallbackManager { get; set; } + public override string[] InputKeys => new [] { _inputKey }; public override string[] OutputKeys => fields.ReturnSourceDocuments ? new [] { _outputKey, SourceDocuments } @@ -32,14 +35,18 @@ public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : B /// the retrieved documents as well under the key 'source_documents'. /// /// + /// /// /// - public override async Task CallAsync(IChainValues values) + protected override async Task CallAsync( + IChainValues values, + CallbackManagerForChainRun? runManager) { - + runManager ??= BaseRunManager.GetNoopManager(); + var question = values.Value[_inputKey].ToString(); - var docs = (await GetDocsAsync(question)).ToList(); + var docs = (await GetDocsAsync(question, runManager)).ToList(); var input = new Dictionary { @@ -66,5 +73,6 @@ public override async Task CallAsync(IChainValues values) /// Get documents to do question answering over. /// /// - public abstract Task> GetDocsAsync(string question); + /// + public abstract Task> GetDocsAsync(string question, CallbackManagerForChainRun runManager); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs index 3c3dc7bb..cf01466e 100644 --- a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs +++ b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs @@ -1,10 +1,9 @@ using LangChain.Base; -using LangChain.Callback; using LangChain.Chains.CombineDocuments; namespace LangChain.Chains.RetrievalQA; -public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocumentsChain) : IChainInputs +public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocumentsChain) : ChainInputs { /// Chain to use to combine the documents. public BaseCombineDocumentsChain CombineDocumentsChain { get; } = combineDocumentsChain; @@ -14,6 +13,4 @@ public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocument public string InputKey { get; set; } = "question"; public string OutputKey { get; set; } = "output_text"; - public bool? Verbose { get; set; } - public CallbackManager? CallbackManager { get; set; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/RetrievalQA/RetrievalQaChain.cs b/src/libs/LangChain.Core/Chains/RetrievalQA/RetrievalQaChain.cs index c1a5a60f..5c6e6abb 100644 --- a/src/libs/LangChain.Core/Chains/RetrievalQA/RetrievalQaChain.cs +++ b/src/libs/LangChain.Core/Chains/RetrievalQA/RetrievalQaChain.cs @@ -1,3 +1,4 @@ +using LangChain.Callback; using LangChain.Docstore; using LangChain.Retrievers; @@ -13,10 +14,10 @@ public class RetrievalQaChain(RetrievalQaChainInput fields) : BaseRetrievalQaCha public override string ChainType() => "retrieval_qa"; - public override async Task> GetDocsAsync(string question) + public override async Task> GetDocsAsync(string question, CallbackManagerForChainRun runManager) { - // todo: runid - var runId = "???"; - return await _retriever.GetRelevantDocumentsAsync(question, runId, fields.CallbackManager); + return await _retriever.GetRelevantDocumentsAsync( + question, + callbacks: new ManagerCallbacks(runManager.GetChild())); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/Sequentials/SequentialChain.cs b/src/libs/LangChain.Core/Chains/Sequentials/SequentialChain.cs index 6bac131b..f98450b3 100644 --- a/src/libs/LangChain.Core/Chains/Sequentials/SequentialChain.cs +++ b/src/libs/LangChain.Core/Chains/Sequentials/SequentialChain.cs @@ -1,6 +1,7 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; using LangChain.Base; +using LangChain.Callback; using LangChain.Schema; namespace LangChain.Chains.Sequentials; @@ -30,7 +31,7 @@ public class SequentialChain : BaseChain /// /// /// - public SequentialChain(SequentialChainInput input) + public SequentialChain(SequentialChainInput input) : base(input) { Chains = input.Chains; InputKeys = input.InputVariables; @@ -45,11 +46,9 @@ public SequentialChain(SequentialChainInput input) } } - public override string ChainType() - { - return "sequential_chain"; - } - public override async Task CallAsync(IChainValues values) + public override string ChainType() => "sequential_chain"; + + protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) { var allChainValues = new ChainValues(new Dictionary(_allOutputKeys.Count)); foreach (var input in InputKeys) diff --git a/src/libs/LangChain.Core/Chains/Sequentials/SequentialChainInput.cs b/src/libs/LangChain.Core/Chains/Sequentials/SequentialChainInput.cs index d737811d..d0ef2a3c 100644 --- a/src/libs/LangChain.Core/Chains/Sequentials/SequentialChainInput.cs +++ b/src/libs/LangChain.Core/Chains/Sequentials/SequentialChainInput.cs @@ -1,12 +1,13 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Base; +using LangChain.Callback; namespace LangChain.Chains.Sequentials; /// /// /// -public class SequentialChainInput +public class SequentialChainInput : IChainInputs { /// /// @@ -28,6 +29,11 @@ public class SequentialChainInput /// public bool ReturnAll { get; } + public bool Verbose { get; set; } + public ICallbacks? Callbacks { get; set; } + public List Tags { get; set; } + public Dictionary Metadata { get; set; } + /// /// /// diff --git a/src/libs/LangChain.Core/Memory/MemoryExtensions.cs b/src/libs/LangChain.Core/Memory/MemoryExtensions.cs index 37662bf0..8c947575 100644 --- a/src/libs/LangChain.Core/Memory/MemoryExtensions.cs +++ b/src/libs/LangChain.Core/Memory/MemoryExtensions.cs @@ -6,7 +6,11 @@ public static class MemoryExtensions { public static IReadOnlyCollection WithHistory(this IReadOnlyCollection messages, BaseMemory? memory) { - if(memory == null) return messages; + if (memory == null) + { + return messages; + } + var history = "These are our previous conversations:\n"; var previousMessages = memory.LoadMemoryVariables(null); if (previousMessages.Value is { } messageDict && @@ -23,5 +27,4 @@ public static IReadOnlyCollection WithHistory(this IReadOnlyCollection< history.AsHumanMessage(), }.Concat(messages).ToArray(); } - } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs index ca4bd91f..7e13107a 100644 --- a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs +++ b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs @@ -12,7 +12,25 @@ namespace LangChain.Retrievers; /// public abstract class BaseRetriever { - protected abstract Task> GetRelevantDocumentsAsync(string query, int k = 4); + /// + /// Optional list of tags associated with the retriever. Defaults to None + /// These tags will be associated with each call to this retriever, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a retriever with its + /// use case. + /// + public List Tags { get; set; } + + /// + /// Optional metadata associated with the retriever. Defaults to None + /// This metadata will be associated with each call to this retriever, + /// and passed as arguments to the handlers defined in `callbacks`. + /// You can use these to eg identify a specific instance of a retriever with its + /// use case. + /// + public Dictionary Metadata { get; set; } + + protected abstract Task> GetRelevantDocumentsCoreAsync(string query, CallbackManagerForRetrieverRun runManager = null); /// /// Retrieve documents relevant to a query. @@ -20,20 +38,32 @@ public abstract class BaseRetriever /// string to find relevant documents for /// /// - /// - public virtual async Task> GetRelevantDocumentsAsync(string query, string runId, CallbackManager? callbacks = null) + /// + /// + /// + /// Relevant documents + public virtual async Task> GetRelevantDocumentsAsync( + string query, + string? runId = null, + ICallbacks? callbacks = null, + bool verbose = false, + List? tags = null, + Dictionary? metadata = null) { - CallbackManagerForRetrieverRun runManager=null; - if (callbacks != null) - { - runManager = await callbacks.HandleRetrieverStart(this, query, runId); - } - - try + var callbackManager = await CallbackManager.Configure( + callbacks, + localCallbacks: null, + verbose: verbose, + localTags: Tags, + inheritableTags: tags, + localMetadata: Metadata, + inheritableMetadata: metadata); + + var runManager = await callbackManager.HandleRetrieverStart(this, query, runId); + try { - var docs = await GetRelevantDocumentsAsync(query); - if(runManager!=null) - await runManager.HandleRetrieverEndAsync(query); + var docs = await GetRelevantDocumentsCoreAsync(query, runManager); + await runManager.HandleRetrieverEndAsync(query, docs.ToList()); return docs; } diff --git a/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs b/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs index 38a8a29c..99e6bb29 100644 --- a/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs +++ b/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs @@ -1,3 +1,4 @@ +using LangChain.Callback; using LangChain.Docstore; using LangChain.Retrievers; @@ -10,12 +11,15 @@ namespace LangChain.VectorStores; public class VectorStoreRetriever : BaseRetriever { public VectorStore Vectorstore { get; init; } - - + private ESearchType SearchType { get; init; } + private int K { get; init; } = 4; + private float? ScoreThreshold { get; init; } - public VectorStoreRetriever(VectorStore vectorstore, ESearchType searchType = ESearchType.Similarity, + public VectorStoreRetriever( + VectorStore vectorstore, + ESearchType searchType = ESearchType.Similarity, float? scoreThreshold = null) { SearchType = searchType; @@ -28,19 +32,19 @@ public VectorStoreRetriever(VectorStore vectorstore, ESearchType searchType = ES ScoreThreshold = scoreThreshold; } - protected override async Task> GetRelevantDocumentsAsync(string query, int k = 4) + protected override async Task> GetRelevantDocumentsCoreAsync(string query, CallbackManagerForRetrieverRun runManager = null) { switch (SearchType) { case ESearchType.Similarity: - return await Vectorstore.SimilaritySearchAsync(query, k); + return await Vectorstore.SimilaritySearchAsync(query, K); case ESearchType.SimilarityScoreThreshold: - var docsAndSimilarities = await Vectorstore.SimilaritySearchWithRelevanceScores(query, k); + var docsAndSimilarities = await Vectorstore.SimilaritySearchWithRelevanceScores(query, K); return docsAndSimilarities.Select(dws => dws.Item1); case ESearchType.MMR: - return await Vectorstore.MaxMarginalRelevanceSearch(query, k); + return await Vectorstore.MaxMarginalRelevanceSearch(query, K); default: throw new ArgumentException($"{SearchType} not supported"); diff --git a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs index af902027..1a05679c 100644 --- a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs +++ b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs @@ -28,7 +28,10 @@ public async Task Retrieval_Ok() m => m.GetRelevantDocumentsAsync( It.Is(x => x == "question?"), It.IsAny(), - It.IsAny()), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny>()), Times.Once()); combineDocumentsMock @@ -49,8 +52,11 @@ private Mock CreateRetrieverMock() .GetRelevantDocumentsAsync( It.IsAny(), It.IsAny(), - It.IsAny())) - .Returns((query, _, _) => + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny>())) + .Returns, Dictionary>((query, _, _, _, _, _) => { var docs = new List { diff --git a/src/tests/LangChain.UnitTest/SequentialChainTests.cs b/src/tests/LangChain.UnitTest/SequentialChainTests.cs index c9683a32..d3a20194 100644 --- a/src/tests/LangChain.UnitTest/SequentialChainTests.cs +++ b/src/tests/LangChain.UnitTest/SequentialChainTests.cs @@ -1,5 +1,6 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Abstractions.Schema; +using LangChain.Callback; using LangChain.Chains.Sequentials; using LangChain.Schema; using Moq; @@ -92,8 +93,12 @@ private Mock CreateFakeChainMock(string[] inputVariables, string[] outpu fakeChainMock.Setup(_ => _.InputKeys).Returns(inputVariables); fakeChainMock.Setup(_ => _.OutputKeys).Returns(outputVariables); - fakeChainMock.Setup(x => x.CallAsync(It.IsAny())) - .Returns(chainValues => + fakeChainMock.Setup(x => x.CallAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny>())) + .Returns, Dictionary>((chainValues, _, _, _) => { var output = new ChainValues();