Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

retrievalqa implementation #48

Merged
merged 2 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ private string SerializeMetadata(Dictionary<string, object> metadata)

private Dictionary<string, object> DeserializeMetadata(MemoryRecordMetadata metadata)
{
// TODO: issue with this method is it returns values as JsonElements instead of primitive types
return JsonSerializer.Deserialize<Dictionary<string, object>>(metadata.AdditionalMetadata, _jsonSerializerOptions)
?? new Dictionary<string, object>();
}
Expand Down
26 changes: 24 additions & 2 deletions src/libs/LangChain.Core/Base/BaseChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public abstract class BaseChain : IChain
/// <param name="input">The string input to use to execute the chain.</param>
/// <returns>A text value containing the result of the chain.</returns>
/// <exception cref="ArgumentException">If the type of chain used expects multiple inputs, this method will throw an ArgumentException.</exception>
public async Task<string?> Run(string input)
public virtual async Task<string?> Run(string input)
{
var isKeylessInput = InputKeys.Length <= 1;

Expand All @@ -51,11 +51,33 @@ public abstract class BaseChain : IChain
if (keys.Count(p => p != RunKey) == 1)
{
var returnValue = returnValues.Value.FirstOrDefault(p => p.Key != RunKey).Value;
return returnValue == null ? null : returnValue.ToString();

return returnValue?.ToString();
}

throw new Exception("Return values have multiple keys, 'run' only supported when one key currently");
}

/// <summary>
/// Run the chain using a simple input/output.
/// </summary>
/// <param name="input">The dict input to use to execute the chain.</param>
/// <returns>A text value containing the result of the chain.</returns>
public virtual async Task<string> Run(Dictionary<string, object> input)
{
var keysLengthDifferent = InputKeys.Length != input.Count;

if (!keysLengthDifferent)
{
throw new ArgumentException($"Chain {ChainType()} expects {InputKeys.Length} but, received {input.Count}");
}

var returnValues = await CallAsync(new ChainValues(input));

var returnValue = returnValues.Value.FirstOrDefault(kv => kv.Key == OutputKeys[0]).Value;

return returnValue?.ToString();
}

/// <summary>
/// Execute the chain, using the values provided.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput f
public override async Task<IChainValues> CallAsync(IChainValues values)
{
var docs = values.Value[InputKey];

//Other keys are assumed to be needed for LLM prediction
var otherKeys = values.Value
.Where(kv => kv.Key != InputKey)
Expand Down
69 changes: 69 additions & 0 deletions src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Base;
using LangChain.Chains.CombineDocuments;
using LangChain.Docstore;
using LangChain.Schema;

namespace LangChain.Chains.RetrievalQA;

/// <summary>
/// Base class for question-answering chains.
/// </summary>
/// <param name="fields"></param>
public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : BaseChain, IChain
{
private readonly string _inputKey = fields.InputKey;
private readonly string _outputKey = fields.OutputKey;
private readonly bool _returnSourceDocuments = fields.ReturnSourceDocuments;
private readonly BaseCombineDocumentsChain _combineDocumentsChain = fields.CombineDocumentsChain;

private const string SourceDocuments = "source_documents";

public override string[] InputKeys => new [] { _inputKey };
public override string[] OutputKeys => fields.ReturnSourceDocuments
? new [] { _outputKey, SourceDocuments }
: new [] { _outputKey };

/// <summary>
/// Run get_relevant_text and llm on input query.
///
/// If chain has 'return_source_documents' as 'True', returns
/// the retrieved documents as well under the key 'source_documents'.
/// </summary>
/// <param name="values"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public override async Task<IChainValues> CallAsync(IChainValues values)
{
var question = values.Value[_inputKey].ToString();

var docs = await GetDocsAsync(question);

var input = new Dictionary<string, object>
{
["input_documents"] = docs,
["question"] = question
};

var answer = await _combineDocumentsChain.Run(input);

var output = new Dictionary<string, object>
{
[_outputKey] = answer
};

if (_returnSourceDocuments)
{
output.Add(SourceDocuments, docs);
}

return new ChainValues(output);
}

/// <summary>
/// Get documents to do question answering over.
/// </summary>
/// <param name="question"></param>
public abstract Task<IEnumerable<Document>> GetDocsAsync(string question);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using LangChain.Base;
using LangChain.Callback;
using LangChain.Chains.CombineDocuments;

namespace LangChain.Chains.RetrievalQA;

public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocumentsChain) : IChainInputs
{
/// <summary> Chain to use to combine the documents. </summary>
public BaseCombineDocumentsChain CombineDocumentsChain { get; } = combineDocumentsChain;

/// <summary> Return the source documents or not. </summary>
public bool ReturnSourceDocuments { get; set; }

public string InputKey { get; set; } = "input_documents";
public string OutputKey { get; set; } = "output_text";
public bool? Verbose { get; set; }
public CallbackManager? CallbackManager { get; set; }
}
22 changes: 22 additions & 0 deletions src/libs/LangChain.Core/Chains/RetrievalQA/RetrievalQaChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using LangChain.Docstore;
using LangChain.Retrievers;

namespace LangChain.Chains.RetrievalQA;

/// <summary>
/// Chain for question-answering against an index.
/// </summary>
/// <param name="fields"></param>
public class RetrievalQaChain(RetrievalQaChainInput fields) : BaseRetrievalQaChain(fields)
{
private readonly BaseRetriever _retriever = fields.Retriever;

public override string ChainType() => "retrieval_qa";

public override async Task<IEnumerable<Document>> GetDocsAsync(string question)
{
// todo: runid
var runId = "???";
return await _retriever.GetRelevantDocumentsAsync(question, runId, fields.CallbackManager);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using LangChain.Chains.CombineDocuments;
using LangChain.Retrievers;

namespace LangChain.Chains.RetrievalQA;

public class RetrievalQaChainInput(
BaseCombineDocumentsChain combineDocumentsChain,
BaseRetriever retriever)
: BaseRetrievalQaChainInput(combineDocumentsChain)
{
/// <summary> Documents retriever. </summary>
public BaseRetriever Retriever { get; } = retriever;
}
1 change: 1 addition & 0 deletions src/libs/LangChain.Core/LangChain.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
</ItemGroup>

<ItemGroup>
<Folder Include="Chains\QuestionAnswering\" />
<Folder Include="Docstore\" />
<Folder Include="TextSplitters\" />
</ItemGroup>
Expand Down
7 changes: 5 additions & 2 deletions src/libs/LangChain.Core/Retrievers/BaseRetriever.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
namespace LangChain.Retrievers;

/// <summary>
/// BaseRetriever
/// Abstract base class for a Document retrieval system.
///
/// A retrieval system is defined as something that can take string queries and return
/// the most 'relevant' Documents from some source.
/// <see cref="https://api.python.langchain.com/en/latest/_modules/langchain/schema/retriever.html" />
/// </summary>
public abstract class BaseRetriever
Expand All @@ -18,7 +21,7 @@ public abstract class BaseRetriever
/// <param name="runId"></param>
/// <param name="callbacks"></param>
/// <returns></returns>
public async Task<IEnumerable<Document>> GetRelevantDocumentsAsync(string query, string runId, CallbackManager? callbacks = null)
public virtual async Task<IEnumerable<Document>> GetRelevantDocumentsAsync(string query, string runId, CallbackManager? callbacks = null)
{
var runManager = await callbacks.HandleRetrieverStart(this, query, runId);
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public override List<string> SplitText(string text)
}
else
{
if (goodSplits.Any())
if (goodSplits.Count != 0)
{
List<string> mergedText = MergeSplits(goodSplits, separator);
finalChunks.AddRange(mergedText);
Expand All @@ -69,7 +69,7 @@ public override List<string> SplitText(string text)
}
}

if (goodSplits.Any())
if (goodSplits.Count != 0)
{
List<string> mergedText = MergeSplits(goodSplits, separator);
finalChunks.AddRange(mergedText);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using LangChain.Callback;
using LangChain.Chains.CombineDocuments;
using LangChain.Chains.RetrievalQA;
using LangChain.Docstore;
using LangChain.Retrievers;
using Moq;

namespace LangChain.Core.UnitTests.Chains.RetrievalQa;

[TestClass]
public class RetrievalQaChainTests
{
[TestMethod]
public async Task Retrieval_Ok()
{
var retrieverMock = CreateRetrieverMock();
var combineDocumentsMock = CreateCombineDocumentsChainMock();

var input = new RetrievalQaChainInput(combineDocumentsMock.Object, retrieverMock.Object);
var chain = new RetrievalQaChain(input);

var result = await chain.Run("question?");

result.Should().BeEquivalentTo("answer");

retrieverMock
.Verify(
m => m.GetRelevantDocumentsAsync(
It.Is<string>(x => x == "question?"),
It.IsAny<string>(),
It.IsAny<CallbackManager>()),
Times.Once());

combineDocumentsMock
.Verify(m => m.Run(
It.Is<Dictionary<string, object>>(x =>
x["input_documents"].As<List<Document>>()
.Select(doc => doc.PageContent)
.Intersect(new string[] { "first", "second", "third" })
.Count() == 3)),
Times.Once());
}

private Mock<BaseRetriever> CreateRetrieverMock()
{
var mock = new Mock<BaseRetriever>();

mock.Setup(x => x
.GetRelevantDocumentsAsync(
It.IsAny<string>(),
It.IsAny<string>(),
It.IsAny<CallbackManager>()))
.Returns<string, string, CallbackManager>((query, _, _) =>
{
var docs = new List<Document>
{
CreateDocument("first"),
CreateDocument("second"),
CreateDocument("third")
}.AsEnumerable();

return Task.FromResult(docs);
});

return mock;
}

private Mock<BaseCombineDocumentsChain> CreateCombineDocumentsChainMock()
{
var mock = new Mock<BaseCombineDocumentsChain>(new Mock<BaseCombineDocumentsChainInput>().Object);

mock.Setup(x => x
.Run(It.IsAny<Dictionary<string, object>>()))
.Returns<Dictionary<string, object>>(input => Task.FromResult("answer"));

return mock;
}

private Document CreateDocument(string content) => new(content, new ());
}
Loading