Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): add vector type to kqir::Value #2371

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions src/search/indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,73 @@ StatusOr<FieldValueRetriever> FieldValueRetriever::Create(IndexOnDataType type,
}
}

// placeholders, remove them after vector indexing is implemented
static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; }
static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; }

StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json &val,
const redis::IndexFieldMetadata *type) {
if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
if (!val.is_number() || val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"};
return kqir::MakeValue<kqir::Numeric>(val.as_double());
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
if (val.is_string()) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(val.as_string(), delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (val.is_array()) {
std::vector<std::string> strs;
for (size_t i = 0; i < val.size(); ++i) {
if (!val[i].is_string())
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
strs.push_back(val[i].as_string());
}
return kqir::MakeValue<kqir::StringArray>(strs);
} else {
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
}
} else if (IsVectorType(type)) {
size_t dim = GetVectorDim(type);
if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"};
if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"};
std::vector<double> nums;
for (size_t i = 0; i < dim; ++i) {
if (!val[i].is_number() || val[i].is_string())
return {Status::NotOK, "json value should be array of numbers for vector fields"};
nums.push_back(val[i].as_double());
}
return kqir::MakeValue<kqir::NumericArray>(nums);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}
}

StatusOr<kqir::Value> FieldValueRetriever::ParseFromHash(const std::string &value,
const redis::IndexFieldMetadata *type) {
if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
auto num = GET_OR_RET(ParseFloat(value));
return kqir::MakeValue<kqir::Numeric>(num);
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(value, delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (IsVectorType(type)) {
const size_t dim = GetVectorDim(type);
if (value.size() != dim * sizeof(double)) {
return {Status::NotOK, "field value is too short or too long to be parsed as a vector"};
}
std::vector<double> vec;
for (size_t i = 0; i < dim; ++i) {
// TODO: care about endian later
// TODO: currently only support 64bit floating point
vec.push_back(*(reinterpret_cast<const double *>(value.data()) + i));
}
return kqir::MakeValue<kqir::NumericArray>(vec);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}
}

StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, const redis::IndexFieldMetadata *type) {
if (std::holds_alternative<HashData>(db)) {
auto &[hash, metadata, key] = std::get<HashData>(db);
Expand All @@ -71,17 +138,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons
if (s.IsNotFound()) return {Status::NotFound, s.ToString()};
if (!s.ok()) return {Status::NotOK, s.ToString()};

if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
auto num = GET_OR_RET(ParseFloat(value));
return kqir::MakeValue<kqir::Numeric>(num);
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(value, delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}

return ParseFromHash(value, type);
} else if (std::holds_alternative<JsonData>(db)) {
auto &value = std::get<JsonData>(db);

Expand All @@ -91,31 +148,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons
return {Status::NotFound, "json value specified by the field (json path) should exist and be unique"};
auto val = s->value[0];

if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
if (val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"};
return kqir::MakeValue<kqir::Numeric>(val.as_double());
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
if (val.is_string()) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(val.as_string(), delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (val.is_array()) {
std::vector<std::string> strs;
for (size_t i = 0; i < val.size(); ++i) {
if (!val[i].is_string())
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
strs.push_back(val[i].as_string());
}
return kqir::MakeValue<kqir::StringArray>(strs);
} else {
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
}
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}

return Status::OK();

return ParseFromJson(val, type);
} else {
return {Status::NotOK, "unknown redis data type to retrieve"};
}
Expand Down
3 changes: 3 additions & 0 deletions src/search/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct FieldValueRetriever {
explicit FieldValueRetriever(JsonValue json) : db(std::in_place_type<JsonData>, std::move(json)) {}

StatusOr<kqir::Value> Retrieve(std::string_view field, const redis::IndexFieldMetadata *type);

static StatusOr<kqir::Value> ParseFromJson(const jsoncons::json &value, const redis::IndexFieldMetadata *type);
static StatusOr<kqir::Value> ParseFromHash(const std::string &value, const redis::IndexFieldMetadata *type);
};

struct IndexUpdater {
Expand Down
9 changes: 7 additions & 2 deletions src/search/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using String = std::string; // e.g. a single tag
using NumericArray = std::vector<Numeric>; // used for vector fields
using StringArray = std::vector<String>; // used for tag fields, e.g. a list for tags

struct Value : std::variant<Null, Numeric, StringArray> {
using Base = std::variant<Null, Numeric, StringArray>;
struct Value : std::variant<Null, Numeric, StringArray, NumericArray> {
using Base = std::variant<Null, Numeric, StringArray, NumericArray>;

using Base::Base;

Expand Down Expand Up @@ -72,6 +72,9 @@ struct Value : std::variant<Null, Numeric, StringArray> {
} else if (Is<StringArray>()) {
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, sep);
} else if (Is<NumericArray>()) {
return util::StringJoin(
Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }, sep);
}

__builtin_unreachable();
Expand All @@ -87,6 +90,8 @@ struct Value : std::variant<Null, Numeric, StringArray> {
char sep = tag ? tag->separator : ',';
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, std::string(1, sep));
} else if (Is<NumericArray>()) {
return util::StringJoin(Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); });
}

__builtin_unreachable();
Expand Down
Loading