Skip to content

Commit

Permalink
LLamaSharp embeddings and test
Browse files Browse the repository at this point in the history
  • Loading branch information
TesAnti committed Nov 4, 2023
1 parent 1d7b83c commit 4d80e76
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using LangChain.Abstractions.Embeddings.Base;
using LLama.Common;
using LLama;

namespace LangChain.Providers.LLamaSharp;

public class LLamaSharpEmbeddings:IEmbeddings
{
protected readonly LLamaSharpConfiguration _configuration;
protected readonly LLamaWeights _model;
protected readonly ModelParams _parameters;
private readonly LLamaEmbedder _embedder;

public LLamaSharpEmbeddings(LLamaSharpConfiguration configuration)
{
_parameters = new ModelParams(configuration.PathToModelFile)
{
ContextSize = (uint)configuration.ContextSize,
Seed = (uint)configuration.Seed,

};
_model = LLamaWeights.LoadFromFile(_parameters);
_configuration = configuration;
_embedder = new LLamaEmbedder(_model, _parameters);
}

public Task<float[][]> EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default)
{
float[][] result = new float[texts.Length][];
for (int i = 0; i < texts.Length; i++)
{
result[i] = _embedder.GetEmbeddings(texts[i]);
}
return Task.FromResult(result);
}

public Task<float[]> EmbedQueryAsync(string text, CancellationToken cancellationToken = default)
{
return Task.FromResult(_embedder.GetEmbeddings(text));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,49 @@ public void InstructionTest()
Assert.AreEqual("4",response.Messages.Last().Content.Trim());

}

float VectorDistance(float[] a, float[] b)
{
float result = 0;
for (int i = 0; i < a.Length; i++)
{
result += (a[i] - b[i]) * (a[i] - b[i]);
}

return result;

}
[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void EmbeddingsTest()
{
var model = new LLamaSharpEmbeddings(new LLamaSharpConfiguration
{
PathToModelFile = ModelPath,
Temperature = 0
});

string[] texts = new string[]
{
"I spent entire day watching TV",
"My dog name is Bob",
"This icecream is delicious",
"It is cold in space"
};

var database = model.EmbedDocumentsAsync(texts).Result;


var query = model.EmbedQueryAsync("How do you call your pet?").Result;

var zipped = database.Zip(texts);

var ordered= zipped.Select(x=>new {text=x.Second,dist=VectorDistance(x.First,query)});

var closest = ordered.OrderBy(x => x.dist).First();

Assert.AreEqual("My dog name is Bob", closest.text);
}
}

0 comments on commit 4d80e76

Please sign in to comment.