-
-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Chains.HelperChains; | ||
using LangChain.Indexes; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains; | ||
|
||
public static class Chain | ||
{ | ||
public static BaseStackableChain Template(string template, | ||
string outputKey = "prompt") | ||
{ | ||
return new PromptChain(template, outputKey); | ||
} | ||
|
||
public static BaseStackableChain Set(string value, string outputKey = "value") | ||
{ | ||
return new SetChain(value, outputKey); | ||
} | ||
|
||
public static BaseStackableChain LLM(IChatModel llm, | ||
string inputKey = "prompt", string outputKey = "text") | ||
{ | ||
return new LLMChain(llm, inputKey, outputKey); | ||
} | ||
|
||
public static BaseStackableChain RetreiveDocuments(VectorStoreIndexWrapper index, | ||
string inputKey = "query", string outputKey = "documents") | ||
{ | ||
return new RetreiveDocumentsChain(index, inputKey, outputKey); | ||
} | ||
|
||
public static BaseStackableChain StuffDocuments( | ||
string inputKey = "documents", string outputKey = "combined") | ||
{ | ||
return new StuffDocumentsChain(inputKey, outputKey); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Chains.HelperChains.Exceptions; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public abstract class BaseStackableChain:IChain | ||
{ | ||
public string Name { get; set; } | ||
public virtual string[] InputKeys { get; protected set; } | ||
public virtual string[] OutputKeys { get; protected set; } | ||
|
||
protected string GenerateName() | ||
{ | ||
return GetType().Name; | ||
} | ||
|
||
private string GetInputs() | ||
{ | ||
return string.Join(",", InputKeys); | ||
} | ||
|
||
private string GetOutputs() | ||
{ | ||
return string.Join(",", OutputKeys); | ||
} | ||
|
||
string FormatInputValues(IChainValues values) | ||
{ | ||
List<string> res = new(); | ||
foreach (var key in InputKeys) | ||
{ | ||
if (!values.Value.ContainsKey(key)) | ||
{ | ||
res.Add($"{key} is expected but missing"); | ||
continue; | ||
}; | ||
res.Add($"{key}={values.Value[key]}"); | ||
} | ||
return string.Join(",\n", res); | ||
} | ||
|
||
public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null, | ||
List<string>? tags = null, Dictionary<string, object>? metadata = null) | ||
{ | ||
try | ||
{ | ||
return InternallCall(values); | ||
} | ||
catch (StackableChainException) | ||
{ | ||
throw; | ||
} | ||
catch (Exception ex) | ||
{ | ||
var name=Name??GenerateName(); | ||
var inputValues= FormatInputValues(values); | ||
var message = $"Error occured in {name} with inputs \n{inputValues}\n."; | ||
|
||
throw new StackableChainException(message,ex); | ||
} | ||
|
||
} | ||
|
||
protected abstract Task<IChainValues> InternallCall(IChainValues values); | ||
|
||
public static StackChain operator |(BaseStackableChain a, BaseStackableChain b) | ||
Check warning on line 68 in src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs GitHub Actions / Build abd test / Build, test and publish
|
||
{ | ||
return new StackChain(a, b); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
namespace LangChain.Chains.HelperChains.Exceptions; | ||
|
||
public class StackableChainException:Exception | ||
Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs GitHub Actions / Build abd test / Build, test and publish
Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs GitHub Actions / Build abd test / Build, test and publish
|
||
{ | ||
public StackableChainException(string message,Exception inner) : base(message, inner) | ||
{ | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class LLMChain:BaseStackableChain | ||
{ | ||
private readonly IChatModel _llm; | ||
|
||
public LLMChain(IChatModel llm, | ||
string inputKey="prompt", | ||
string outputKey="text" | ||
) | ||
{ | ||
InputKeys = new[] { inputKey }; | ||
OutputKeys = new[] { outputKey }; | ||
_llm = llm; | ||
} | ||
|
||
protected override async Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
var prompt = values.Value[InputKeys[0]].ToString(); | ||
var response=await _llm.GenerateAsync(new ChatRequest(new List<Message>() { prompt.AsSystemMessage() })); | ||
values.Value[OutputKeys[0]] = response.Messages.Last().Content; | ||
return values; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
using System.Text.RegularExpressions; | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Chains.LLM; | ||
using LangChain.Prompts; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class PromptChain: BaseStackableChain | ||
{ | ||
private readonly string _template; | ||
|
||
public PromptChain(string template,string outputKey="prompt") | ||
{ | ||
OutputKeys = new[] { outputKey }; | ||
_template = template; | ||
InputKeys = GetVariables().ToArray(); | ||
} | ||
|
||
List<string> GetVariables() | ||
{ | ||
string pattern = @"\{([^\{\}]+)\}"; | ||
var variables = new List<string>(); | ||
var matches = Regex.Matches(_template, pattern); | ||
foreach (Match match in matches) | ||
{ | ||
variables.Add(match.Groups[1].Value); | ||
} | ||
return variables; | ||
} | ||
|
||
|
||
|
||
|
||
protected override Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
// validate that input keys containing all variables | ||
var valueKeys = values.Value.Keys; | ||
var missing = InputKeys.Except(valueKeys); | ||
if (missing.Any()) | ||
{ | ||
throw new Exception($"Input keys must contain all variables in template. Missing: {string.Join(",",missing)}"); | ||
} | ||
|
||
var formattedPrompt = PromptTemplate.InterpolateFString(_template,values.Value); | ||
|
||
values.Value[OutputKeys[0]]= formattedPrompt; | ||
|
||
return Task.FromResult(values); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using System.Numerics; | ||
using LangChain.Indexes; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class RetreiveDocumentsChain:BaseStackableChain | ||
{ | ||
private readonly VectorStoreIndexWrapper _index; | ||
private readonly int _amount; | ||
|
||
public RetreiveDocumentsChain(VectorStoreIndexWrapper index, string inputKey="query", string outputKey="documents", int amount=4) | ||
{ | ||
_index = index; | ||
_amount = amount; | ||
InputKeys = new[] { inputKey }; | ||
OutputKeys = new[] { outputKey }; | ||
} | ||
|
||
protected override async Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
var retreiver = _index.Store.AsRetreiver(); | ||
retreiver.K = _amount; | ||
|
||
var query = values.Value[InputKeys[0]].ToString(); | ||
var results = await retreiver.GetRelevantDocumentsAsync(query); | ||
values.Value[OutputKeys[0]] = results.ToList(); | ||
return values; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class SetChain: BaseStackableChain | ||
{ | ||
private readonly string _query; | ||
public SetChain(string query, string outputKey="query") | ||
{ | ||
OutputKeys = new[] { outputKey }; | ||
_query = query; | ||
} | ||
|
||
protected override Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
values.Value[OutputKeys[0]] = _query; | ||
return Task.FromResult(values); | ||
} | ||
|
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Schema; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class StackChain:BaseStackableChain | ||
{ | ||
private readonly BaseStackableChain _a; | ||
private readonly BaseStackableChain _b; | ||
|
||
public string[] IsolatedInputKeys { get; set; }=new string[0]; | ||
Check warning on line 12 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs GitHub Actions / Build abd test / Build, test and publish
|
||
public string[] IsolatedOutputKeys { get; set; }=new string[0]; | ||
Check warning on line 13 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs GitHub Actions / Build abd test / Build, test and publish
|
||
|
||
public StackChain(BaseStackableChain a, BaseStackableChain b) | ||
{ | ||
_a = a; | ||
_b = b; | ||
|
||
} | ||
|
||
public StackChain AsIsolated(string[] inputKeys = null, string[] outputKeys = null) | ||
{ | ||
IsolatedInputKeys = inputKeys ?? IsolatedInputKeys; | ||
IsolatedOutputKeys = outputKeys ?? IsolatedOutputKeys; | ||
return this; | ||
} | ||
|
||
public StackChain AsIsolated(string inputKey = null, string outputKey = null) | ||
{ | ||
if (inputKey != null) IsolatedInputKeys = new[] { inputKey }; | ||
if (outputKey != null) IsolatedOutputKeys = new[] { outputKey }; | ||
return this; | ||
} | ||
|
||
protected override async Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
// since it is reference type, the values would be changed anyhow | ||
var originalValues = values; | ||
|
||
if (IsolatedInputKeys.Length>0) | ||
{ | ||
var res = new ChainValues(); | ||
foreach (var key in IsolatedInputKeys) | ||
{ | ||
res.Value[key] = values.Value[key]; | ||
} | ||
values = res; | ||
} | ||
await _a.CallAsync(values); | ||
await _b.CallAsync(values); | ||
if (IsolatedOutputKeys.Length > 0) | ||
{ | ||
|
||
foreach (var key in IsolatedOutputKeys) | ||
{ | ||
originalValues.Value[key] = values.Value[key]; | ||
} | ||
|
||
} | ||
return originalValues; | ||
} | ||
|
||
|
||
|
||
public async Task<IChainValues> Run() | ||
{ | ||
|
||
var res = await CallAsync(new ChainValues()); | ||
return res; | ||
} | ||
|
||
public async Task<string> Run(string resultKey) | ||
{ | ||
var res = await CallAsync(new ChainValues()); | ||
return res.Value[resultKey].ToString(); | ||
} | ||
} |