Skip to content

Commit

Permalink
[Cpp]: pass schema options as reference (#109)
Browse files Browse the repository at this point in the history
Signed-off-by: sunby <[email protected]>
  • Loading branch information
sunby committed Jan 4, 2024
1 parent 4440a90 commit 6fe0748
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 67 deletions.
8 changes: 4 additions & 4 deletions cpp/include/milvus-storage/reader/common/delete_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DeleteMergeReader : public arrow::RecordBatchReader {
class DeleteFilterVisitor;

static std::unique_ptr<DeleteMergeReader> Make(std::unique_ptr<arrow::RecordBatchReader> reader,
std::shared_ptr<SchemaOptions> schema_options,
const SchemaOptions& schema_options,
const DeleteFragmentVector& delete_fragments,
const ReadOptions& options);
std::shared_ptr<arrow::Schema> schema() const override;
Expand All @@ -40,18 +40,18 @@ class DeleteMergeReader : public arrow::RecordBatchReader {

DeleteMergeReader(std::unique_ptr<arrow::RecordBatchReader> reader,
DeleteFragmentVector delete_fragments,
std::shared_ptr<SchemaOptions> schema_options,
const SchemaOptions& schema_options,
const ReadOptions& options)
: reader_(std::move(reader)),
delete_fragments_(std::move(delete_fragments)),
schema_options_(std::move(schema_options)),
schema_options_(schema_options),
options_(options) {}

private:
std::unique_ptr<arrow::RecordBatchReader> reader_;
std::shared_ptr<RecordBatchWithDeltedOffsets> filtered_batch_reader_;
DeleteFragmentVector delete_fragments_;
std::shared_ptr<SchemaOptions> schema_options_;
const SchemaOptions schema_options_;
const ReadOptions options_;
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/milvus-storage/storage/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ struct ReadOptions {
};

struct SchemaOptions {
Status Validate(const arrow::Schema* schema);
Status Validate(const arrow::Schema* schema) const;

bool has_version_column() const { return !version_column.empty(); }

std::unique_ptr<schema_proto::SchemaOptions> ToProtobuf();
std::unique_ptr<schema_proto::SchemaOptions> ToProtobuf() const;

void FromProtobuf(const schema_proto::SchemaOptions& options);

Expand Down
8 changes: 4 additions & 4 deletions cpp/include/milvus-storage/storage/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace milvus_storage {
class Schema {
public:
Schema() = default;
Schema(std::shared_ptr<arrow::Schema> schema, std::shared_ptr<SchemaOptions> options);
Schema(std::shared_ptr<arrow::Schema> schema, SchemaOptions& options);

Status Validate();

std::shared_ptr<arrow::Schema> schema();
std::shared_ptr<arrow::Schema> schema() const;

std::shared_ptr<SchemaOptions> options();
const SchemaOptions& options() const;

std::shared_ptr<arrow::Schema> scalar_schema();

Expand All @@ -50,6 +50,6 @@ class Schema {
std::shared_ptr<arrow::Schema> vector_schema_;
std::shared_ptr<arrow::Schema> delete_schema_;

std::shared_ptr<SchemaOptions> options_;
SchemaOptions options_;
};
} // namespace milvus_storage
6 changes: 3 additions & 3 deletions cpp/src/file/delete_fragment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ DeleteFragment::DeleteFragment(arrow::fs::FileSystem& fs, std::shared_ptr<Schema

Status DeleteFragment::Add(std::shared_ptr<arrow::RecordBatch> batch) {
auto schema_options = schema_->options();
auto pk_col = batch->GetColumnByName(schema_options->primary_column);
auto pk_col = batch->GetColumnByName(schema_options.primary_column);
std::shared_ptr<arrow::Int64Array> version_col = nullptr;
if (schema_->options()->has_version_column()) {
auto tmp = batch->GetColumnByName(schema_options->version_column);
if (schema_->options().has_version_column()) {
auto tmp = batch->GetColumnByName(schema_options.version_column);
version_col = std::static_pointer_cast<arrow::Int64Array>(tmp);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/reader/common/combine_offset_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ arrow::Status CombineOffsetReader::ReadNext(std::shared_ptr<arrow::RecordBatch>*

std::vector<std::shared_ptr<arrow::Array>> columns(scalar_batch->columns().begin(), scalar_batch->columns().end());

auto vector_col = table_batch.ValueOrDie()->GetColumnByName(schema_->options()->vector_column);
auto vector_col = table_batch.ValueOrDie()->GetColumnByName(schema_->options().vector_column);
if (!vector_col) {
return arrow::Status::UnknownError("vector column not found");
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/reader/common/combine_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ arrow::Status CombineReader::ReadNext(std::shared_ptr<arrow::RecordBatch>* batch

assert(scalar_batch->num_rows() == vector_batch->num_rows());

auto vec_column = vector_batch->GetColumnByName(schema_->options()->vector_column);
auto vec_column = vector_batch->GetColumnByName(schema_->options().vector_column);
std::vector<std::shared_ptr<arrow::Array>> columns(scalar_batch->columns().begin(), scalar_batch->columns().end());

auto vec_column_idx = schema_->schema()->GetFieldIndex(schema_->options()->vector_column);
auto vec_column_idx = schema_->schema()->GetFieldIndex(schema_->options().vector_column);
columns.insert(columns.begin() + vec_column_idx, vec_column);

*batch = arrow::RecordBatch::Make(schema(), scalar_batch->num_rows(), std::move(columns));
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/reader/common/delete_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace milvus_storage {
std::unique_ptr<DeleteMergeReader> DeleteMergeReader::Make(std::unique_ptr<arrow::RecordBatchReader> reader,
std::shared_ptr<SchemaOptions> schema_options,
const SchemaOptions& schema_options,
const DeleteFragmentVector& delete_fragments,
const ReadOptions& options) {
// DeleteFragmentVector filtered_delete_fragments;
Expand Down Expand Up @@ -50,15 +50,15 @@ arrow::Status DeleteMergeReader::ReadNext(std::shared_ptr<arrow::RecordBatch>* b
return arrow::Status::OK();
}

if (schema_options_->has_version_column()) {
auto version_col = record_batch->GetColumnByName(schema_options_->version_column);
if (schema_options_.has_version_column()) {
auto version_col = record_batch->GetColumnByName(schema_options_.version_column);
if (version_col == nullptr) {
return arrow::Status::Invalid("Version column not found");
}
auto visitor = DeleteFilterVisitor(delete_fragments_, std::static_pointer_cast<arrow::Int64Array>(version_col),
options_.version);

auto pk_col = record_batch->GetColumnByName(schema_options_->primary_column);
auto pk_col = record_batch->GetColumnByName(schema_options_.primary_column);
if (pk_col == nullptr) {
return arrow::Status::Invalid("Primary column not found");
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/reader/record_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ std::unique_ptr<arrow::RecordBatchReader> MakeRecordReader(std::shared_ptr<Manif

bool only_contain_scalar_columns(const std::shared_ptr<Schema> schema, const std::set<std::string>& related_columns) {
for (auto& column : related_columns) {
if (schema->options()->vector_column == column) {
if (schema->options().vector_column == column) {
return false;
}
}
Expand All @@ -82,8 +82,8 @@ bool only_contain_scalar_columns(const std::shared_ptr<Schema> schema, const std

bool only_contain_vector_columns(const std::shared_ptr<Schema> schema, const std::set<std::string>& related_columns) {
for (auto& column : related_columns) {
if (schema->options()->vector_column != column && schema->options()->primary_column != column &&
schema->options()->version_column != column) {
if (schema->options().vector_column != column && schema->options().primary_column != column &&
schema->options().version_column != column) {
return false;
}
}
Expand All @@ -92,8 +92,8 @@ bool only_contain_vector_columns(const std::shared_ptr<Schema> schema, const std

bool filters_only_contain_pk_and_version(std::shared_ptr<Schema> schema, const Filter::FilterSet& filters) {
for (auto& filter : filters) {
if (filter->get_column_name() != schema->options()->primary_column &&
filter->get_column_name() != schema->options()->version_column) {
if (filter->get_column_name() != schema->options().primary_column &&
filter->get_column_name() != schema->options().version_column) {
return false;
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/storage/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace milvus_storage {

Status SchemaOptions::Validate(const arrow::Schema* schema) {
Status SchemaOptions::Validate(const arrow::Schema* schema) const {
if (!primary_column.empty()) {
auto primary_field = schema->GetFieldByName(primary_column);
if (!primary_field) {
Expand Down Expand Up @@ -56,7 +56,7 @@ Status SchemaOptions::Validate(const arrow::Schema* schema) {
return Status::OK();
}

std::unique_ptr<schema_proto::SchemaOptions> SchemaOptions::ToProtobuf() {
std::unique_ptr<schema_proto::SchemaOptions> SchemaOptions::ToProtobuf() const{
auto options = std::make_unique<schema_proto::SchemaOptions>();
options->set_primary_column(primary_column);
options->set_version_column(version_column);
Expand Down
25 changes: 12 additions & 13 deletions cpp/src/storage/schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@
#include "common/log.h"
namespace milvus_storage {

Schema::Schema(std::shared_ptr<arrow::Schema> schema, std::shared_ptr<SchemaOptions> options)
Schema::Schema(std::shared_ptr<arrow::Schema> schema, SchemaOptions& options)
: schema_(std::move(schema)), options_(options) {}

Status Schema::Validate() {
RETURN_NOT_OK(options_->Validate(schema_.get()));
RETURN_NOT_OK(options_.Validate(schema_.get()));
RETURN_NOT_OK(BuildScalarSchema());
RETURN_NOT_OK(BuildVectorSchema());
RETURN_NOT_OK(BuildDeleteSchema());
LOG_STORAGE_DEBUG_ << "Schema validate success";
return Status::OK();
}

std::shared_ptr<arrow::Schema> Schema::schema() { return schema_; }
std::shared_ptr<arrow::Schema> Schema::schema() const{ return schema_; }

std::shared_ptr<SchemaOptions> Schema::options() { return options_; }
const SchemaOptions& Schema::options() const{ return options_; }

std::shared_ptr<arrow::Schema> Schema::scalar_schema() { return scalar_schema_; }

Expand All @@ -45,16 +45,15 @@ Result<std::unique_ptr<schema_proto::Schema>> Schema::ToProtobuf() {
auto schema = std::make_unique<schema_proto::Schema>();
ASSIGN_OR_RETURN_NOT_OK(auto arrow_schema, ToProtobufSchema(schema_.get()));

auto options = options_->ToProtobuf();
auto options = options_.ToProtobuf();
schema->set_allocated_arrow_schema(arrow_schema.release());
schema->set_allocated_schema_options(options.release());
return schema;
}

Status Schema::FromProtobuf(const schema_proto::Schema& schema) {
ASSIGN_OR_RETURN_NOT_OK(schema_, FromProtobufSchema(schema.arrow_schema()));
options_ = std::make_shared<SchemaOptions>();
options_->FromProtobuf(schema.schema_options());
options_.FromProtobuf(schema.schema_options());
RETURN_NOT_OK(BuildScalarSchema());
RETURN_NOT_OK(BuildVectorSchema());
RETURN_NOT_OK(BuildDeleteSchema());
Expand All @@ -64,7 +63,7 @@ Status Schema::FromProtobuf(const schema_proto::Schema& schema) {
Status Schema::BuildScalarSchema() {
arrow::SchemaBuilder scalar_schema_builder;
for (const auto& field : schema_->fields()) {
if (field->name() == options_->vector_column) {
if (field->name() == options_.vector_column) {
continue;
}
RETURN_ARROW_NOT_OK(scalar_schema_builder.AddField(field));
Expand All @@ -78,8 +77,8 @@ Status Schema::BuildScalarSchema() {
Status Schema::BuildVectorSchema() {
arrow::SchemaBuilder vector_schema_builder;
for (const auto& field : schema_->fields()) {
if (field->name() == options_->primary_column || field->name() == options_->version_column ||
field->name() == options_->vector_column) {
if (field->name() == options_.primary_column || field->name() == options_.version_column ||
field->name() == options_.vector_column) {
RETURN_ARROW_NOT_OK(vector_schema_builder.AddField(field));
}
}
Expand All @@ -89,10 +88,10 @@ Status Schema::BuildVectorSchema() {

Status Schema::BuildDeleteSchema() {
arrow::SchemaBuilder delete_schema_builder;
auto pk_field = schema_->GetFieldByName(options_->primary_column);
auto version_field = schema_->GetFieldByName(options_->version_column);
auto pk_field = schema_->GetFieldByName(options_.primary_column);
auto version_field = schema_->GetFieldByName(options_.version_column);
RETURN_ARROW_NOT_OK(delete_schema_builder.AddField(pk_field));
if (options_->has_version_column()) {
if (options_.has_version_column()) {
RETURN_ARROW_NOT_OK(delete_schema_builder.AddField(version_field));
}
ASSIGN_OR_RETURN_ARROW_NOT_OK(delete_schema_, delete_schema_builder.Finish());
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/manifest_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ TEST(ManifestTest, ManifestProtoTest) {
ASSERT_TRUE(metadata_status.ok());
auto arrow_schema = schema_builder.Finish().ValueOrDie();

auto schema_options = std::make_shared<SchemaOptions>();
schema_options->primary_column = "pk_field";
schema_options->version_column = "ts_field";
schema_options->vector_column = "vec_field";
SchemaOptions schema_options;
schema_options.primary_column = "pk_field";
schema_options.version_column = "ts_field";
schema_options.vector_column = "vec_field";

// Create Schema
auto space_schema1 = std::make_shared<Schema>(arrow_schema, schema_options);
Expand Down
37 changes: 18 additions & 19 deletions cpp/test/schema_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ TEST(SchemaValidateTest, SchemaValidateNoVersionColTest) {
auto arrow_schema = schema_builder.Finish().ValueOrDie();

// Create Options
auto schema_options = std::make_shared<SchemaOptions>();
schema_options->primary_column = "pk_field";

schema_options->vector_column = "vec_field";
SchemaOptions schema_options;
schema_options.primary_column = "pk_field";
schema_options.vector_column = "vec_field";

// Create Schema
auto space_schema1 = std::make_shared<Schema>(arrow_schema, schema_options);
Expand All @@ -55,18 +54,18 @@ TEST(SchemaValidateTest, SchemaValidateNoVersionColTest) {
auto scalar_schema = space_schema1->scalar_schema();
/// scalar schema has no version column but has offset column
ASSERT_EQ(scalar_schema->num_fields(), 2);
ASSERT_EQ(scalar_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(scalar_schema->field(0)->name(), schema_options.primary_column);
ASSERT_EQ(scalar_schema->field(1)->name(), "off_set");

auto vector_schema = space_schema1->vector_schema();
ASSERT_EQ(vector_schema->num_fields(), 2);
ASSERT_EQ(vector_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(vector_schema->field(0)->name(), schema_options.primary_column);

ASSERT_EQ(vector_schema->field(1)->name(), schema_options->vector_column);
ASSERT_EQ(vector_schema->field(1)->name(), schema_options.vector_column);

auto delete_schema = space_schema1->delete_schema();
ASSERT_EQ(delete_schema->num_fields(), 1);
ASSERT_EQ(delete_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(delete_schema->field(0)->name(), schema_options.primary_column);
}

TEST(SchemaValidateTest, SchemaValidateVersionColTest) {
Expand All @@ -92,10 +91,10 @@ TEST(SchemaValidateTest, SchemaValidateVersionColTest) {
auto arrow_schema = schema_builder.Finish().ValueOrDie();

// Create Options
auto schema_options = std::make_shared<SchemaOptions>();
schema_options->primary_column = "pk_field";
schema_options->version_column = "ts_field";
schema_options->vector_column = "vec_field";
SchemaOptions schema_options;
schema_options.primary_column = "pk_field";
schema_options.version_column = "ts_field";
schema_options.vector_column = "vec_field";

// Create Schema
auto space_schema1 = std::make_shared<Schema>(arrow_schema, schema_options);
Expand All @@ -105,20 +104,20 @@ TEST(SchemaValidateTest, SchemaValidateVersionColTest) {

auto scalar_schema = space_schema1->scalar_schema();
ASSERT_EQ(scalar_schema->num_fields(), 3);
ASSERT_EQ(scalar_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(scalar_schema->field(1)->name(), schema_options->version_column);
ASSERT_EQ(scalar_schema->field(0)->name(), schema_options.primary_column);
ASSERT_EQ(scalar_schema->field(1)->name(), schema_options.version_column);
ASSERT_EQ(scalar_schema->field(2)->name(), "off_set");

auto vector_schema = space_schema1->vector_schema();
ASSERT_EQ(vector_schema->num_fields(), 3);
ASSERT_EQ(vector_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(vector_schema->field(1)->name(), schema_options->version_column);
ASSERT_EQ(vector_schema->field(2)->name(), schema_options->vector_column);
ASSERT_EQ(vector_schema->field(0)->name(), schema_options.primary_column);
ASSERT_EQ(vector_schema->field(1)->name(), schema_options.version_column);
ASSERT_EQ(vector_schema->field(2)->name(), schema_options.vector_column);

auto delete_schema = space_schema1->delete_schema();
ASSERT_EQ(delete_schema->num_fields(), 2);
ASSERT_EQ(delete_schema->field(0)->name(), schema_options->primary_column);
ASSERT_EQ(delete_schema->field(1)->name(), schema_options->version_column);
ASSERT_EQ(delete_schema->field(0)->name(), schema_options.primary_column);
ASSERT_EQ(delete_schema->field(1)->name(), schema_options.version_column);
}

} // namespace milvus_storage
8 changes: 4 additions & 4 deletions cpp/test/space_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ TEST(SpaceTest, SpaceWriteReadTest) {
auto arrow_schema = CreateArrowSchema({"pk_field", "ts_field", "vec_field"},
{arrow::int64(), arrow::int64(), arrow::fixed_size_binary(10)});

auto schema_options = std::make_shared<SchemaOptions>();
schema_options->primary_column = "pk_field";
schema_options->version_column = "ts_field";
schema_options->vector_column = "vec_field";
SchemaOptions schema_options;
schema_options.primary_column = "pk_field";
schema_options.version_column = "ts_field";
schema_options.vector_column = "vec_field";

auto schema = std::make_shared<Schema>(arrow_schema, schema_options);
ASSERT_STATUS_OK(schema->Validate());
Expand Down

0 comments on commit 6fe0748

Please sign in to comment.