Skip to content

Commit

Permalink
Merge pull request #99 from evan-cao-wb/master
Browse files Browse the repository at this point in the history
modify cosine function
  • Loading branch information
Oceania2018 committed Aug 11, 2023
2 parents 1d2704f + bb81651 commit 3f7b2c3
Showing 1 changed file with 53 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using BotSharp.Abstraction.VectorStorage;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace BotSharp.Core.Plugins.MemVecDb;

Expand Down Expand Up @@ -60,18 +62,62 @@ private float[] CalEuclideanDistance(float[] vec, List<VecRecord> records)
return c.ToArray<float>();
}

public float[] CalCosineSimilarity(float[] vec, List<VecRecord> records)
public NDArray CalCosineSimilarity(float[] vec, List<VecRecord> records)
{
var similarities = new float[records.Count];
var a = vec;
var normA = np.linalg.norm(a);
var recordsArray = np.zeros((records.Count, records[0].Vector.Length), dtype: np.float32);

for (int i = 0; i < records.Count; i++)
{
var b = records[i].Vector;
similarities[i] = np.dot(a, b) / (normA * np.linalg.norm(b));
recordsArray[i] = records[i].Vector;
}

return similarities;
var vecArray = np.expand_dims(np.array(vec, dtype: np.float32), axis: 0); // [1. 300]

(var normVecArray, var _) = SafeNormalize(vecArray);
(var normRecordsArray, var _) = SafeNormalize(recordsArray);

var simiMatix = tf.matmul(tf.cast(normVecArray, tf.float32), tf.transpose(tf.cast(normRecordsArray, tf.float32))).numpy(); // [1, num_records]

simiMatix = np.squeeze(simiMatix, axis: 0);

return simiMatix;
}

public int[] CalCosineSimilarityTopK(float[] vec, List<VecRecord> records, int topK = 10, float filterProb = 0.75f)
{
var simiMatix = CalCosineSimilarity(vec, records);

var topIndex = np.argsort(simiMatix)["::-1"][$":{topK}"];

var resIndex = new List<int>();

for (int i = 0; i < topK; i++)
{
var index = topIndex[i];
var value = simiMatix[index];

if (value > filterProb)
{
resIndex.Add(topIndex[i]);
}
}

return resIndex.ToArray();
}

public (NDArray, NDArray) SafeNormalize(NDArray x, double eps = 2.223E-15)
{
var squaredX = np.sum(np.multiply(x, x), axis: 1);
var normX = np.sqrt(squaredX);

var epsTensor = tf.cast(tf.convert_to_tensor(eps), dtype: tf.float32);
var normXTensor = tf.cast(normX, tf.float32);
var contantMask = (normXTensor < epsTensor);
var divideTensor = tf.ones_like(normXTensor, dtype: tf.float32);

normX = tf.where(contantMask, divideTensor, normXTensor).numpy();
normX = np.expand_dims(normX, axis: 1);

return (x / normX, normX);
}
}

0 comments on commit 3f7b2c3

Please sign in to comment.