Skip to content

Commit

Permalink
Added something similar to LCEL
Browse files Browse the repository at this point in the history
  • Loading branch information
TesAnti committed Nov 8, 2023
1 parent 6dceb8a commit 2b0bf52
Show file tree
Hide file tree
Showing 19 changed files with 576 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@
using System.Threading.Tasks;
using LangChain.Abstractions.Embeddings.Base;
using LangChain.Docstore;
using LangChain.Indexes;
using LangChain.TextSplitters;
using LangChain.VectorStores;

namespace LangChain.Databases.InMemory
{
public class InMemoryVectorStore:VectorStore
{
public static async Task<VectorStoreIndexWrapper> CreateIndexFromDocuments(IEmbeddings embeddings,List<Document> documents)
{
InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings);
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index = await indexCreator.FromDocumentsAsync(documents);
return index;
}

private readonly Func<float[], float[], float> _distanceFunction;
List<(float[] vec, string id, Document doc)> _storage = new List<(float[] vec, string id, Document doc)>();
Expand Down
38 changes: 38 additions & 0 deletions src/libs/LangChain.Core/Chains/Chain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Chains.HelperChains;
using LangChain.Indexes;
using LangChain.Providers;

namespace LangChain.Chains;

public static class Chain
{
public static BaseStackableChain Template(string template,
string outputKey = "prompt")
{
return new PromptChain(template, outputKey);
}

public static BaseStackableChain Set(string value, string outputKey = "value")
{
return new SetChain(value, outputKey);
}

public static BaseStackableChain LLM(IChatModel llm,
string inputKey = "prompt", string outputKey = "text")
{
return new LLMChain(llm, inputKey, outputKey);
}

public static BaseStackableChain RetreiveDocuments(VectorStoreIndexWrapper index,
string inputKey = "query", string outputKey = "documents")
{
return new RetreiveDocumentsChain(index, inputKey, outputKey);
}

public static BaseStackableChain StuffDocuments(
string inputKey = "documents", string outputKey = "combined")
{
return new StuffDocumentsChain(inputKey, outputKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class StuffDocumentsChainInput(ILlmChain llmChain) : BaseCombineDocuments
public ILlmChain LlmChain { get; } = llmChain;

/// <summary>
/// Prompt to use to format each document, gets passed to `format_document`.
/// Template to use to format each document, gets passed to `format_document`.
/// </summary>
public BasePromptTemplate DocumentPrompt { get; set; } = new PromptTemplate(
new PromptTemplateInput(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.HelperChains.Exceptions;

namespace LangChain.Chains.HelperChains;

public abstract class BaseStackableChain:IChain
{
public string Name { get; set; }
public virtual string[] InputKeys { get; protected set; }
public virtual string[] OutputKeys { get; protected set; }

protected string GenerateName()
{
return GetType().Name;
}

private string GetInputs()
{
return string.Join(",", InputKeys);
}

private string GetOutputs()
{
return string.Join(",", OutputKeys);
}

string FormatInputValues(IChainValues values)
{
List<string> res = new();
foreach (var key in InputKeys)
{
if (!values.Value.ContainsKey(key))
{
res.Add($"{key} is expected but missing");
continue;
};
res.Add($"{key}={values.Value[key]}");
}
return string.Join(",\n", res);
}

public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null,
List<string>? tags = null, Dictionary<string, object>? metadata = null)
{
try
{
return InternallCall(values);
}
catch (StackableChainException)
{
throw;
}
catch (Exception ex)
{
var name=Name??GenerateName();
var inputValues= FormatInputValues(values);
var message = $"Error occured in {name} with inputs \n{inputValues}\n.";

throw new StackableChainException(message,ex);
}

}

protected abstract Task<IChainValues> InternallCall(IChainValues values);

public static StackChain operator |(BaseStackableChain a, BaseStackableChain b)

Check warning on line 68 in src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs

View workflow job for this annotation

GitHub Actions / Build abd test / Build, test and publish

Provide a method named 'BitwiseOr' as a friendly alternate for operator op_BitwiseOr (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca2225)
{
return new StackChain(a, b);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace LangChain.Chains.HelperChains.Exceptions;

public class StackableChainException:Exception

Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs

View workflow job for this annotation

GitHub Actions / Build abd test / Build, test and publish

Add the following constructor to StackableChainException: public StackableChainException() (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1032)

Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs

View workflow job for this annotation

GitHub Actions / Build abd test / Build, test and publish

Add the following constructor to StackableChainException: public StackableChainException(string message) (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1032)
{
public StackableChainException(string message,Exception inner) : base(message, inner)
{
}
}
28 changes: 28 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Providers;

namespace LangChain.Chains.HelperChains;

public class LLMChain:BaseStackableChain
{
private readonly IChatModel _llm;

public LLMChain(IChatModel llm,
string inputKey="prompt",
string outputKey="text"
)
{
InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };
_llm = llm;
}

protected override async Task<IChainValues> InternallCall(IChainValues values)
{
var prompt = values.Value[InputKeys[0]].ToString();
var response=await _llm.GenerateAsync(new ChatRequest(new List<Message>() { prompt.AsSystemMessage() }));
values.Value[OutputKeys[0]] = response.Messages.Last().Content;
return values;
}
}
52 changes: 52 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System.Text.RegularExpressions;
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.LLM;
using LangChain.Prompts;

namespace LangChain.Chains.HelperChains;

public class PromptChain: BaseStackableChain
{
private readonly string _template;

public PromptChain(string template,string outputKey="prompt")
{
OutputKeys = new[] { outputKey };
_template = template;
InputKeys = GetVariables().ToArray();
}

List<string> GetVariables()
{
string pattern = @"\{([^\{\}]+)\}";
var variables = new List<string>();
var matches = Regex.Matches(_template, pattern);
foreach (Match match in matches)
{
variables.Add(match.Groups[1].Value);
}
return variables;
}




protected override Task<IChainValues> InternallCall(IChainValues values)
{
// validate that input keys containing all variables
var valueKeys = values.Value.Keys;
var missing = InputKeys.Except(valueKeys);
if (missing.Any())
{
throw new Exception($"Input keys must contain all variables in template. Missing: {string.Join(",",missing)}");
}

var formattedPrompt = PromptTemplate.InterpolateFString(_template,values.Value);

values.Value[OutputKeys[0]]= formattedPrompt;

return Task.FromResult(values);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using System.Numerics;
using LangChain.Indexes;

namespace LangChain.Chains.HelperChains;

public class RetreiveDocumentsChain:BaseStackableChain
{
private readonly VectorStoreIndexWrapper _index;
private readonly int _amount;

public RetreiveDocumentsChain(VectorStoreIndexWrapper index, string inputKey="query", string outputKey="documents", int amount=4)
{
_index = index;
_amount = amount;
InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };
}

protected override async Task<IChainValues> InternallCall(IChainValues values)
{
var retreiver = _index.Store.AsRetreiver();
retreiver.K = _amount;

var query = values.Value[InputKeys[0]].ToString();
var results = await retreiver.GetRelevantDocumentsAsync(query);
values.Value[OutputKeys[0]] = results.ToList();
return values;
}
}
23 changes: 23 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;

namespace LangChain.Chains.HelperChains;

public class SetChain: BaseStackableChain
{
private readonly string _query;
public SetChain(string query, string outputKey="query")
{
OutputKeys = new[] { outputKey };
_query = query;
}

protected override Task<IChainValues> InternallCall(IChainValues values)
{
values.Value[OutputKeys[0]] = _query;
return Task.FromResult(values);
}


}
78 changes: 78 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Schema;

namespace LangChain.Chains.HelperChains;

public class StackChain:BaseStackableChain
{
private readonly BaseStackableChain _a;
private readonly BaseStackableChain _b;

public string[] IsolatedInputKeys { get; set; }=new string[0];

Check warning on line 12 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build abd test / Build, test and publish

Avoid unnecessary zero-length array allocations. Use Array.Empty<string>() instead. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1825)
public string[] IsolatedOutputKeys { get; set; }=new string[0];

Check warning on line 13 in src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs

View workflow job for this annotation

GitHub Actions / Build abd test / Build, test and publish

Avoid unnecessary zero-length array allocations. Use Array.Empty<string>() instead. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1825)

public StackChain(BaseStackableChain a, BaseStackableChain b)
{
_a = a;
_b = b;

}

public StackChain AsIsolated(string[] inputKeys = null, string[] outputKeys = null)
{
IsolatedInputKeys = inputKeys ?? IsolatedInputKeys;
IsolatedOutputKeys = outputKeys ?? IsolatedOutputKeys;
return this;
}

public StackChain AsIsolated(string inputKey = null, string outputKey = null)
{
if (inputKey != null) IsolatedInputKeys = new[] { inputKey };
if (outputKey != null) IsolatedOutputKeys = new[] { outputKey };
return this;
}

protected override async Task<IChainValues> InternallCall(IChainValues values)
{
// since it is reference type, the values would be changed anyhow
var originalValues = values;

if (IsolatedInputKeys.Length>0)
{
var res = new ChainValues();
foreach (var key in IsolatedInputKeys)
{
res.Value[key] = values.Value[key];
}
values = res;
}
await _a.CallAsync(values);
await _b.CallAsync(values);
if (IsolatedOutputKeys.Length > 0)
{

foreach (var key in IsolatedOutputKeys)
{
originalValues.Value[key] = values.Value[key];
}

}
return originalValues;
}



public async Task<IChainValues> Run()
{

var res = await CallAsync(new ChainValues());
return res;
}

public async Task<string> Run(string resultKey)
{
var res = await CallAsync(new ChainValues());
return res.Value[resultKey].ToString();
}
}
Loading

0 comments on commit 2b0bf52

Please sign in to comment.