Skip to content

Commit

Permalink
Add support for GEOSEARCH and GEOSEARCHSTORE (#1533)
Browse files Browse the repository at this point in the history
  • Loading branch information
uds5501 committed Jul 24, 2023
1 parent 8cd6f59 commit 99cb709
Show file tree
Hide file tree
Showing 6 changed files with 478 additions and 74 deletions.
258 changes: 254 additions & 4 deletions src/commands/cmd_geo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
*
*/

#include "command_parser.h"
#include "commander.h"
#include "error_constants.h"
#include "server/server.h"
#include "types/geohash.h"
#include "types/redis_geo.h"

namespace redis {
Expand Down Expand Up @@ -52,10 +54,14 @@ class CommandGeoBase : public Commander {
*longitude = *long_stat;
*latitude = *lat_stat;

return ValidateLongLat(longitude, latitude);
}

static Status ValidateLongLat(double *longitude, double *latitude) {
if (*longitude < GEO_LONG_MIN || *longitude > GEO_LONG_MAX || *latitude < GEO_LAT_MIN || *latitude > GEO_LAT_MAX) {
return {Status::RedisParseErr, "invalid longitude,latitude pair " + longitude_para + "," + latitude_para};
return {Status::RedisParseErr,
"invalid longitude,latitude pair " + std::to_string(*longitude) + "," + std::to_string(*latitude)};
}

return Status::OK();
}

Expand Down Expand Up @@ -355,6 +361,249 @@ class CommandGeoRadius : public CommandGeoBase {
double latitude_ = 0;
};

class CommandGeoSearch : public CommandGeoBase {
public:
CommandGeoSearch() : CommandGeoBase() {}

Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 1);
key_ = GET_OR_RET(parser.TakeStr());

while (parser.Good()) {
if (parser.EatEqICase("frommember")) {
auto s = setOriginType(kMember);
if (!s.IsOK()) return s;

member_ = GET_OR_RET(parser.TakeStr());
} else if (parser.EatEqICase("fromlonlat")) {
auto s = setOriginType(kLongLat);
if (!s.IsOK()) return s;

longitude_ = GET_OR_RET(parser.TakeFloat());
latitude_ = GET_OR_RET(parser.TakeFloat());
s = ValidateLongLat(&longitude_, &latitude_);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("byradius")) {
auto s = setShapeType(kGeoShapeTypeCircular);
if (!s.IsOK()) return s;
radius_ = GET_OR_RET(parser.TakeFloat());
std::string distance_raw = GET_OR_RET(parser.TakeStr());
s = ParseDistanceUnit(distance_raw);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("bybox")) {
auto s = setShapeType(kGeoShapeTypeRectangular);
if (!s.IsOK()) return s;
width_ = GET_OR_RET(parser.TakeFloat());
height_ = GET_OR_RET(parser.TakeFloat());
std::string distance_raw = GET_OR_RET(parser.TakeStr());
s = ParseDistanceUnit(distance_raw);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("asc") && sort_ == kSortNone) {
sort_ = kSortASC;
} else if (parser.EatEqICase("desc") && sort_ == kSortNone) {
sort_ = kSortDESC;
} else if (parser.EatEqICase("count")) {
count_ = GET_OR_RET(parser.TakeInt<int>(NumericRange<int>{1, std::numeric_limits<int>::max()}));
} else if (parser.EatEqICase("withcoord")) {
with_coord_ = true;
} else if (parser.EatEqICase("withdist")) {
with_dist_ = true;
} else if (parser.EatEqICase("withhash")) {
with_hash_ = true;
} else {
return {Status::RedisParseErr, "Invalid argument given"};
}
}

if (member_ != "" && longitude_ != 0 && latitude_ != 0) {
return {Status::RedisParseErr, "please use only one of FROMMEMBER or FROMLONLAT"};
}

auto s = createGeoShape();
if (!s.IsOK()) {
return s;
}
return Commander::Parse(args);
}

Status Execute(Server *svr, Connection *conn, std::string *output) override {
std::vector<GeoPoint> geo_points;
redis::Geo geo_db(svr->storage, conn->GetNamespace());

auto s = geo_db.Search(args_[1], geo_shape_, origin_point_type_, member_, count_, sort_, false, GetUnitConversion(),
&geo_points);

if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
*output = generateOutput(geo_points);

return Status::OK();
}

protected:
double radius_ = 0;
double height_ = 0;
double width_ = 0;
int count_ = 0;
double longitude_ = 0;
double latitude_ = 0;
std::string member_;
std::string key_;
DistanceSort sort_ = kSortNone;
GeoShapeType shape_type_ = kGeoShapeTypeNone;
OriginPointType origin_point_type_ = kNone;
GeoShape geo_shape_;

Status setShapeType(GeoShapeType shape_type) {
if (shape_type_ != kGeoShapeTypeNone) {
return {Status::RedisParseErr, "please use only one of BYBOX or BYRADIUS"};
}
shape_type_ = shape_type;
return Status::OK();
}

Status setOriginType(OriginPointType origin_point_type) {
if (origin_point_type_ != kNone) {
return {Status::RedisParseErr, "please use only one of FROMMEMBER or FROMLONLAT"};
}
origin_point_type_ = origin_point_type;
return Status::OK();
}

Status createGeoShape() {
if (shape_type_ == kGeoShapeTypeNone) {
return {Status::RedisParseErr, "please use BYBOX or BYRADIUS"};
}
geo_shape_.type = shape_type_;
geo_shape_.conversion = GetUnitConversion();

if (shape_type_ == kGeoShapeTypeCircular) {
geo_shape_.radius = radius_;
} else {
geo_shape_.width = width_;
geo_shape_.height = height_;
}

if (origin_point_type_ == kLongLat) {
geo_shape_.xy[0] = longitude_;
geo_shape_.xy[1] = latitude_;
}
return Status::OK();
}

std::string generateOutput(const std::vector<GeoPoint> &geo_points) {
int result_length = static_cast<int>(geo_points.size());
int returned_items_count = (count_ == 0 || result_length < count_) ? result_length : count_;
std::vector<std::string> output;
output.reserve(returned_items_count);
for (int i = 0; i < returned_items_count; i++) {
auto geo_point = geo_points[i];
if (!with_coord_ && !with_hash_ && !with_dist_) {
output.emplace_back(redis::BulkString(geo_point.member));
} else {
std::vector<std::string> one;
one.emplace_back(redis::BulkString(geo_point.member));
if (with_dist_) {
one.emplace_back(redis::BulkString(util::Float2String(GetDistanceByUnit(geo_point.dist))));
}
if (with_hash_) {
one.emplace_back(redis::BulkString(util::Float2String(geo_point.score)));
}
if (with_coord_) {
one.emplace_back(redis::MultiBulkString(
{util::Float2String(geo_point.longitude), util::Float2String(geo_point.latitude)}));
}
output.emplace_back(redis::Array(one));
}
}
return redis::Array(output);
}

private:
bool with_coord_ = false;
bool with_dist_ = false;
bool with_hash_ = false;
};

class CommandGeoSearchStore : public CommandGeoSearch {
public:
Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 1);
store_key_ = GET_OR_RET(parser.TakeStr());
key_ = GET_OR_RET(parser.TakeStr());

while (parser.Good()) {
if (parser.EatEqICase("frommember")) {
auto s = setOriginType(kMember);
if (!s.IsOK()) return s;
member_ = GET_OR_RET(parser.TakeStr());
} else if (parser.EatEqICase("fromlonlat")) {
auto s = setOriginType(kLongLat);
if (!s.IsOK()) return s;

longitude_ = GET_OR_RET(parser.TakeFloat());
latitude_ = GET_OR_RET(parser.TakeFloat());
s = ValidateLongLat(&longitude_, &latitude_);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("byradius")) {
auto s = setShapeType(kGeoShapeTypeCircular);
if (!s.IsOK()) return s;
radius_ = GET_OR_RET(parser.TakeFloat());
std::string distance_raw = GET_OR_RET(parser.TakeStr());
s = ParseDistanceUnit(distance_raw);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("bybox")) {
auto s = setShapeType(kGeoShapeTypeRectangular);
if (!s.IsOK()) return s;
width_ = GET_OR_RET(parser.TakeFloat());
height_ = GET_OR_RET(parser.TakeFloat());
std::string distance_raw = GET_OR_RET(parser.TakeStr());
s = ParseDistanceUnit(distance_raw);
if (!s.IsOK()) return s;
} else if (parser.EatEqICase("asc") && sort_ == kSortNone) {
sort_ = kSortASC;
} else if (parser.EatEqICase("desc") && sort_ == kSortNone) {
sort_ = kSortDESC;
} else if (parser.EatEqICase("count")) {
count_ = GET_OR_RET(parser.TakeInt<int>(NumericRange<int>{1, std::numeric_limits<int>::max()}));
} else if (parser.EatEqICase("storedist")) {
store_distance_ = true;
} else {
return {Status::RedisParseErr, "Invalid argument given"};
}
}

if (member_ != "" && longitude_ != 0 && latitude_ != 0) {
return {Status::RedisParseErr, "please use only one of FROMMEMBER or FROMLONLAT"};
}

auto s = createGeoShape();
if (!s.IsOK()) {
return s;
}
return Commander::Parse(args);
}

Status Execute(Server *svr, Connection *conn, std::string *output) override {
std::vector<GeoPoint> geo_points;
redis::Geo geo_db(svr->storage, conn->GetNamespace());

auto s = geo_db.SearchStore(args_[2], geo_shape_, origin_point_type_, member_, count_, sort_, store_key_,
store_distance_, GetUnitConversion(), &geo_points);

if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
*output = redis::Integer(geo_points.size());
return Status::OK();
}

private:
bool store_distance_ = false;
std::string store_key_;
};

class CommandGeoRadiusByMember : public CommandGeoRadius {
public:
CommandGeoRadiusByMember() = default;
Expand Down Expand Up @@ -406,7 +655,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandGeoAdd>("geoadd", -5, "write", 1, 1,
MakeCmdAttr<CommandGeoRadius>("georadius", -6, "write", 1, 1, 1),
MakeCmdAttr<CommandGeoRadiusByMember>("georadiusbymember", -5, "write", 1, 1, 1),
MakeCmdAttr<CommandGeoRadiusReadonly>("georadius_ro", -6, "read-only", 1, 1, 1),
MakeCmdAttr<CommandGeoRadiusByMemberReadonly>("georadiusbymember_ro", -5, "read-only", 1, 1,
1), )
MakeCmdAttr<CommandGeoRadiusByMemberReadonly>("georadiusbymember_ro", -5, "read-only", 1, 1, 1),
MakeCmdAttr<CommandGeoSearch>("geosearch", -7, "read-only", 1, 1, 1),
MakeCmdAttr<CommandGeoSearchStore>("geosearchstore", -8, "write", 1, 1, 1))

} // namespace redis
Loading

0 comments on commit 99cb709

Please sign in to comment.