From bb81651ec013b7d2654f1cac8ebd38c5d5c92eae Mon Sep 17 00:00:00 2001 From: Wenbo Cao <104199@smsassist.com> Date: Fri, 11 Aug 2023 14:16:08 -0500 Subject: [PATCH] modify cosine function --- .../Plugins/MemVecDb/MemVectorDatabase.cs | 60 ++++++++++++++++--- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/src/Infrastructure/BotSharp.Core/Plugins/MemVecDb/MemVectorDatabase.cs b/src/Infrastructure/BotSharp.Core/Plugins/MemVecDb/MemVectorDatabase.cs index e724e2eec..5aafbacf4 100644 --- a/src/Infrastructure/BotSharp.Core/Plugins/MemVecDb/MemVectorDatabase.cs +++ b/src/Infrastructure/BotSharp.Core/Plugins/MemVecDb/MemVectorDatabase.cs @@ -1,5 +1,7 @@ using BotSharp.Abstraction.VectorStorage; +using Tensorflow; using Tensorflow.NumPy; +using static Tensorflow.Binding; namespace BotSharp.Core.Plugins.MemVecDb; @@ -60,18 +62,62 @@ private float[] CalEuclideanDistance(float[] vec, List records) return c.ToArray(); } - public float[] CalCosineSimilarity(float[] vec, List records) + public NDArray CalCosineSimilarity(float[] vec, List records) { - var similarities = new float[records.Count]; - var a = vec; - var normA = np.linalg.norm(a); + var recordsArray = np.zeros((records.Count, records[0].Vector.Length), dtype: np.float32); for (int i = 0; i < records.Count; i++) { - var b = records[i].Vector; - similarities[i] = np.dot(a, b) / (normA * np.linalg.norm(b)); + recordsArray[i] = records[i].Vector; } - return similarities; + var vecArray = np.expand_dims(np.array(vec, dtype: np.float32), axis: 0); // [1. 300] + + (var normVecArray, var _) = SafeNormalize(vecArray); + (var normRecordsArray, var _) = SafeNormalize(recordsArray); + + var simiMatix = tf.matmul(tf.cast(normVecArray, tf.float32), tf.transpose(tf.cast(normRecordsArray, tf.float32))).numpy(); // [1, num_records] + + simiMatix = np.squeeze(simiMatix, axis: 0); + + return simiMatix; + } + + public int[] CalCosineSimilarityTopK(float[] vec, List records, int topK = 10, float filterProb = 0.75f) + { + var simiMatix = CalCosineSimilarity(vec, records); + + var topIndex = np.argsort(simiMatix)["::-1"][$":{topK}"]; + + var resIndex = new List(); + + for (int i = 0; i < topK; i++) + { + var index = topIndex[i]; + var value = simiMatix[index]; + + if (value > filterProb) + { + resIndex.Add(topIndex[i]); + } + } + + return resIndex.ToArray(); + } + + public (NDArray, NDArray) SafeNormalize(NDArray x, double eps = 2.223E-15) + { + var squaredX = np.sum(np.multiply(x, x), axis: 1); + var normX = np.sqrt(squaredX); + + var epsTensor = tf.cast(tf.convert_to_tensor(eps), dtype: tf.float32); + var normXTensor = tf.cast(normX, tf.float32); + var contantMask = (normXTensor < epsTensor); + var divideTensor = tf.ones_like(normXTensor, dtype: tf.float32); + + normX = tf.where(contantMask, divideTensor, normXTensor).numpy(); + normX = np.expand_dims(normX, axis: 1); + + return (x / normX, normX); } }