Skip to content

Commit

Permalink
feat: Hook for chain execution tracing (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
TesAnti authored Jan 9, 2024
1 parent 3e12067 commit 67dedc9
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 18 deletions.
13 changes: 3 additions & 10 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,13 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Ollama.IntegrationTests", "src\tests\LangChain.Providers.Ollama.IntegrationTests\LangChain.Providers.Ollama.IntegrationTests.csproj", "{72B1E2CC-1A34-470E-A579-034CB0972BB7}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Ollama.IntegrationTests", "src\tests\LangChain.Providers.Ollama.IntegrationTests\LangChain.Providers.Ollama.IntegrationTests.csproj", "{72B1E2CC-1A34-470E-A579-034CB0972BB7}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Ollama", "src\libs\Providers\LangChain.Providers.Ollama\LangChain.Providers.Ollama.csproj", "{4913844F-74EC-4E74-AE8A-EA825569E6BA}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Automatic1111", "src\libs\Providers\LangChain.Providers.Automatic1111\LangChain.Providers.Automatic1111.csproj", "{BF4C7B87-0997-4208-84EF-D368DF7B9861}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111", "src\libs\Providers\LangChain.Providers.Automatic1111\LangChain.Providers.Automatic1111.csproj", "{BF4C7B87-0997-4208-84EF-D368DF7B9861}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{738984A2-7D3F-42E7-9B4D-3528E2539197}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down Expand Up @@ -404,10 +402,6 @@ Global
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Debug|Any CPU.Build.0 = Debug|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Release|Any CPU.ActiveCfg = Release|Any CPU
{738984A2-7D3F-42E7-9B4D-3528E2539197}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -475,7 +469,6 @@ Global
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{A6CF79BC-8365-46E8-9230-1A4AD615D40B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{738984A2-7D3F-42E7-9B4D-3528E2539197} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.HelperChains.Exceptions;
using LangChain.Chains.StackableChains.Context;
using LangChain.Schema;

namespace LangChain.Chains.HelperChains;
Expand Down Expand Up @@ -86,10 +87,13 @@ string FormatInputValues(IChainValues values)
public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null, IReadOnlyDictionary<string, object>? metadata = null)
{


if (values == null)
{
throw new ArgumentNullException(nameof(values));
}

try
{
return InternalCall(values);
Expand All @@ -108,8 +112,9 @@ public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks =

throw new StackableChainException(message, ex);
}

}

/// <summary>
///
/// </summary>
Expand Down Expand Up @@ -143,9 +148,11 @@ public static StackChain BitwiseOr(BaseStackableChain left, BaseStackableChain r
///
/// </summary>
/// <returns></returns>
public async Task<IChainValues> Run()
public async Task<IChainValues> Run(StackableChainHook? hook=null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var values = new StackableChainValues() {Hook = hook};
hook?.ChainStart(values);
var res = await CallAsync(values).ConfigureAwait(false);
return res;
}

Expand All @@ -154,9 +161,9 @@ public async Task<IChainValues> Run()
/// </summary>
/// <param name="resultKey"></param>
/// <returns></returns>
public async Task<string?> Run(string resultKey)
public async Task<string?> Run(string resultKey, StackableChainHook? hook = null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var res = await CallAsync(new StackableChainValues() { Hook = hook }).ConfigureAwait(false);
return res.Value[resultKey].ToString();
}

Expand All @@ -166,12 +173,17 @@ public async Task<IChainValues> Run()
/// <param name="resultKey"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public async Task<T> Run<T>(string resultKey)
public async Task<T> Run<T>(string resultKey, StackableChainHook? hook = null)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
var res = await CallAsync(new StackableChainValues() { Hook = hook }).ConfigureAwait(false);
return (T)res.Value[resultKey];
}

public Task<string?> Run(string resultKey)
{
return Run(resultKey, null);
}

/// <summary>
///
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

using LangChain.Chains.HelperChains;

namespace LangChain.Chains.StackableChains.Context;

public class ConsoleTraceHook: StackableChainHook

Check warning on line 6 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook'
{
public bool UseColors { get; set; }=true;

Check warning on line 8 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.UseColors'
public int ValuesLength { get; set; } = 40;

Check warning on line 9 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.ValuesLength'
public override void ChainStart(StackableChainValues values)

Check warning on line 10 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.ChainStart(StackableChainValues)'
{
Console.WriteLine();
}
public override void LinkEnter(BaseStackableChain chain, StackableChainValues values)

Check warning on line 14 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.LinkEnter(BaseStackableChain, StackableChainValues)'
{

Console.Write("|");
Console.Write(chain.GetType().Name);
Console.WriteLine();
if (chain.InputKeys.Count > 0)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
Console.Write("Input:");
Console.WriteLine();
foreach (string inputKey in chain.InputKeys)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
var value = values.Value[inputKey];
var oldColor = Console.ForegroundColor;
Console.ForegroundColor = GetColorForKey(inputKey);
Console.Write($" {inputKey}={ShortenString(value.ToString() ?? "", ValuesLength)}");
Console.ForegroundColor = oldColor;
Console.WriteLine();
}
}


}

Dictionary<string, ConsoleColor> _colorMap = new Dictionary<string, ConsoleColor>();

ConsoleColor GetColorForKey(string key)
{
if(!UseColors)
return Console.ForegroundColor;
// if key is not in map, get unique color(except black and white)
// if there no unique colors left, return white
if (!_colorMap.ContainsKey(key))
{
var color = ConsoleColor.White;
var colors = Enum.GetValues(typeof(ConsoleColor));
foreach (ConsoleColor c in colors)
{
if (c == ConsoleColor.Black || c == ConsoleColor.White)
continue;
if (!_colorMap.ContainsValue(c))
{
color = c;
break;
}
}
_colorMap.Add(key, color);
}
return _colorMap[key];
}

string ShortenString(string str, int length)
{
if (str.Length <= length)
return str;
return str.Substring(0, length - 3) + "...";
}

public override void LinkExit(BaseStackableChain chain, StackableChainValues values)

Check warning on line 78 in src/libs/LangChain.Core/Chains/StackableChains/Hooks/ConsoleTraceHook.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'ConsoleTraceHook.LinkExit(BaseStackableChain, StackableChainValues)'
{
if (chain.OutputKeys.Count > 0)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
Console.Write("Output:");
Console.WriteLine();
foreach (string outputKey in chain.OutputKeys)
{
Console.Write("|");
Console.Write(" ");
Console.Write("\u2514");
var value = values.Value[outputKey];
var oldColor = Console.ForegroundColor;
Console.ForegroundColor = GetColorForKey(outputKey);
Console.Write($" {outputKey}={ShortenString(value.ToString() ?? "", ValuesLength)}");
Console.ForegroundColor = oldColor;
Console.WriteLine();
}
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using LangChain.Chains.HelperChains;

namespace LangChain.Chains.StackableChains.Context;

public class StackableChainHook
{
public virtual void ChainStart(StackableChainValues values)
{

}

public virtual void LinkEnter(BaseStackableChain chain, StackableChainValues values)
{

}

public virtual void LinkExit(BaseStackableChain chain, StackableChainValues values)
{

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using LangChain.Schema;

namespace LangChain.Chains.StackableChains.Context;

public class StackableChainValues : ChainValues
{
public StackableChainHook? Hook { get; set; }
}
13 changes: 12 additions & 1 deletion src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.StackableChains.Context;
using LangChain.Schema;

namespace LangChain.Chains.HelperChains;
Expand Down Expand Up @@ -60,15 +61,25 @@ protected override async Task<IChainValues> InternalCall(IChainValues values)

if (IsolatedInputKeys.Count > 0)
{
var res = new ChainValues();
var res = new StackableChainValues(){Hook = (values as StackableChainValues)?.Hook};
foreach (var key in IsolatedInputKeys)
{
res.Value[key] = values.Value[key];
}
values = res;
}
if(a is not StackChain)
(values as StackableChainValues)?.Hook?.LinkEnter(a, (values as StackableChainValues));

Check warning on line 72 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkEnter(BaseStackableChain chain, StackableChainValues values)'.
await a.CallAsync(values).ConfigureAwait(false);
if (a is not StackChain)
(values as StackableChainValues)?.Hook?.LinkExit(a, (values as StackableChainValues));

Check warning on line 75 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkExit(BaseStackableChain chain, StackableChainValues values)'.

if (b is not StackChain)
(values as StackableChainValues)?.Hook?.LinkEnter(b, (values as StackableChainValues));

Check warning on line 78 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkEnter(BaseStackableChain chain, StackableChainValues values)'.
await b.CallAsync(values).ConfigureAwait(false);
if (b is not StackChain)
(values as StackableChainValues)?.Hook?.LinkExit(b, (values as StackableChainValues));

Check warning on line 81 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference argument for parameter 'values' in 'void StackableChainHook.LinkExit(BaseStackableChain chain, StackableChainValues values)'.

if (IsolatedOutputKeys.Count > 0)
{
foreach (var key in IsolatedOutputKeys)
Expand Down

0 comments on commit 67dedc9

Please sign in to comment.