Skip to content

Commit

Permalink
ByteSequenceMatcher + other review changes
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Debreceni <[email protected]>
  • Loading branch information
martinzink and adamdebreceni committed Oct 16, 2024
1 parent fd59183 commit 6643503
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 70 deletions.
134 changes: 81 additions & 53 deletions extensions/standard-processors/processors/SplitContent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
#include "utils/gsl.h"

namespace org::apache::nifi::minifi::processors {

constexpr size_t BUFFER_TARGET_SIZE = 1024;

void SplitContent::initialize() {
setSupportedProperties(Properties);
setSupportedRelationships(Relationships);
Expand All @@ -39,60 +36,59 @@ void SplitContent::initialize() {
void SplitContent::onSchedule(core::ProcessContext& context, core::ProcessSessionFactory&) {
auto byte_sequence_str = utils::getRequiredPropertyOrThrow<std::string>(context, ByteSequence.name);
const auto byte_sequence_format = utils::parseEnumProperty<ByteSequenceFormat>(context, ByteSequenceFormatProperty);
std::vector<std::byte> byte_sequence{};
if (byte_sequence_format == ByteSequenceFormat::Hexadecimal) {
byte_sequence_ = utils::string::from_hex(byte_sequence_str);
byte_sequence = utils::string::from_hex(byte_sequence_str);
} else {
byte_sequence_.resize(byte_sequence_str.size());
std::ranges::transform(byte_sequence_str, byte_sequence_.begin(), [](char c) { return static_cast<std::byte>(c); });
}
if (byte_sequence_.empty()) {
throw Exception(PROCESS_SCHEDULE_EXCEPTION, "Cannot operate without byte sequence");
byte_sequence.resize(byte_sequence_str.size());
std::ranges::transform(byte_sequence_str, byte_sequence.begin(), [](char c) { return static_cast<std::byte>(c); });
}
if (byte_sequence.empty()) { throw Exception(PROCESS_SCHEDULE_EXCEPTION, "Cannot operate without byte sequence"); }
byte_sequence_matcher_.emplace(ByteSequenceMatcher(std::move(byte_sequence)));
byte_sequence_location_ = utils::parseEnumProperty<ByteSequenceLocation>(context, ByteSequenceLocationProperty);
keep_byte_sequence = utils::getRequiredPropertyOrThrow<bool>(context, KeepByteSequence.name);
}

std::shared_ptr<core::FlowFile> SplitContent::createNewSplit(core::ProcessSession& session) const {
auto next_split = session.create();
if (!next_split) {
throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile");
}
if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Leading) {
session.appendBuffer(next_split, byte_sequence_);
}
if (!next_split) { throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile"); }
if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Leading) { session.appendBuffer(next_split, byte_sequence_matcher_->getByteSequence()); }
return next_split;
}

void SplitContent::finalizeLatestSplitContent(core::ProcessSession& session, const std::shared_ptr<core::FlowFile>& latest_split, const std::vector<std::byte>& buffer) const {
const std::span<const std::byte> data_without_byte_sequence{buffer.data(), buffer.size() - byte_sequence_.size()};
const std::span<const std::byte> data_without_byte_sequence{buffer.data(), buffer.size() - getByteSequenceSize()};
session.appendBuffer(latest_split, data_without_byte_sequence);
if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Trailing) {
session.appendBuffer(latest_split, byte_sequence_);
}
if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Trailing) { session.appendBuffer(latest_split, byte_sequence_matcher_->getByteSequence()); }
}

void SplitContent::finalizeLastSplitContent(core::ProcessSession& session, std::vector<std::shared_ptr<core::FlowFile>>& splits, const std::vector<std::byte>& buffer,
const bool ended_with_byte_sequence) const {
void SplitContent::finalizeLastSplitContent(
core::ProcessSession& session, std::vector<std::shared_ptr<core::FlowFile>>& splits, const std::vector<std::byte>& buffer, const bool ended_with_byte_sequence) const {
if (ended_with_byte_sequence && splits.back()->getSize() != 0) {
if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Leading) {
const auto last_split = session.create();
if (!last_split) {
throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile");
}
if (!last_split) { throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile"); }
splits.push_back(last_split);
session.appendBuffer(splits.back(), byte_sequence_);
session.appendBuffer(splits.back(), byte_sequence_matcher_->getByteSequence());
}
} else {
session.appendBuffer(splits.back(), buffer);
}
}

std::span<const std::byte> SplitContent::getByteSequence() const {
gsl_Assert(byte_sequence_matcher_);
return byte_sequence_matcher_->getByteSequence();
}

SplitContent::size_type SplitContent::getByteSequenceSize() const {
return getByteSequence().size();
}

namespace {
std::shared_ptr<core::FlowFile> createFirstSplit(core::ProcessSession& session) {
auto first_split = session.create();
if (!first_split) {
throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile");
}
if (!first_split) { throw Exception(PROCESSOR_EXCEPTION, "Couldn't create FlowFile"); }
return first_split;
}

Expand All @@ -113,21 +109,63 @@ bool lastSplitIsEmpty(const std::vector<std::shared_ptr<core::FlowFile>>& splits
}
} // namespace

void SplitContent::endedWithByteSequenceWithMoreDataToCome(core::ProcessSession& session, std::vector<std::shared_ptr<core::FlowFile>>& splits) const {
if (!lastSplitIsEmpty(splits)) {
splits.push_back(createNewSplit(session));
} else if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Leading) {
session.appendBuffer(splits.back(), byte_sequence_matcher_->getByteSequence());
}
}


SplitContent::ByteSequenceMatcher::ByteSequenceMatcher(std::vector<std::byte> byte_sequence) : byte_sequence_(std::move(byte_sequence)) {
byte_sequence_nodes_.push_back(node{.byte = {}, .cache = {}, .previous_max_match = {}});
for (const auto& byte: byte_sequence_) { byte_sequence_nodes_.push_back(node{.byte = byte, .cache = {}, .previous_max_match = {}}); }
}

SplitContent::size_type SplitContent::ByteSequenceMatcher::getNumberOfMatchingBytes(const size_type number_of_currently_matching_bytes, const std::byte next_byte) {
gsl_Assert(number_of_currently_matching_bytes <= byte_sequence_nodes_.size());
auto& curr_go = byte_sequence_nodes_[number_of_currently_matching_bytes].cache;
if (curr_go.contains(next_byte)) { return curr_go.at(next_byte); }
if (next_byte == byte_sequence_nodes_[number_of_currently_matching_bytes + 1].byte) {
curr_go[next_byte] = number_of_currently_matching_bytes + 1;
return number_of_currently_matching_bytes + 1;
}
if (number_of_currently_matching_bytes == 0) {
curr_go[next_byte] = 0;
return 0;
}

curr_go[next_byte] = getNumberOfMatchingBytes(getPreviousMaxMatch(number_of_currently_matching_bytes), next_byte);
return curr_go.at(next_byte);
}

SplitContent::size_type SplitContent::ByteSequenceMatcher::getPreviousMaxMatch(const size_type number_of_currently_matching_bytes) {
gsl_Assert(number_of_currently_matching_bytes <= byte_sequence_nodes_.size());
auto& prev_max_match = byte_sequence_nodes_[number_of_currently_matching_bytes].previous_max_match;
if (prev_max_match) { return *prev_max_match; }
if (number_of_currently_matching_bytes <= 1) {
prev_max_match = 0;
return 0;
}
prev_max_match = getNumberOfMatchingBytes(getPreviousMaxMatch(number_of_currently_matching_bytes - 1), byte_sequence_nodes_[number_of_currently_matching_bytes].byte);
return *prev_max_match;
}

void SplitContent::onTrigger(core::ProcessContext& context, core::ProcessSession& session) {
gsl_Assert(!byte_sequence_.empty());
gsl_Assert(byte_sequence_matcher_);
const auto original = session.get();
if (!original) {
context.yield();
return;
}

const auto ff_content_stream = session.getFlowFileContentStream(*original);
if (!ff_content_stream) {
throw Exception(PROCESSOR_EXCEPTION, fmt::format("Couldn't access the ContentStream of {}", original->getUUID().to_string()));
}
if (!ff_content_stream) { throw Exception(PROCESSOR_EXCEPTION, fmt::format("Couldn't access the ContentStream of {}", original->getUUID().to_string())); }
std::vector<std::byte> buffer{};
buffer.reserve(BUFFER_TARGET_SIZE + byte_sequence_.size());
size_t matching_bytes = 0;
buffer.reserve(BUFFER_TARGET_SIZE + getByteSequenceSize());
size_type matching_bytes = 0;

bool ended_with_byte_sequence = false;
std::vector<std::shared_ptr<core::FlowFile>> splits{};
splits.push_back(createFirstSplit(session));
Expand All @@ -136,27 +174,17 @@ void SplitContent::onTrigger(core::ProcessContext& context, core::ProcessSession
buffer.push_back(*latest_byte);
if (ended_with_byte_sequence) {
ended_with_byte_sequence = false;
if (!lastSplitIsEmpty(splits)) {
splits.push_back(createNewSplit(session));
} else if (keep_byte_sequence && byte_sequence_location_ == ByteSequenceLocation::Leading) {
session.appendBuffer(splits.back(), byte_sequence_);
}
endedWithByteSequenceWithMoreDataToCome(session, splits);
}
if (latest_byte == byte_sequence_[matching_bytes]) {
matching_bytes++;
if (matching_bytes == byte_sequence_.size()) {
// Found the Byte Sequence
finalizeLatestSplitContent(session, splits.back(), buffer);
ended_with_byte_sequence = true;
matching_bytes = 0;
buffer.clear();
}
} else {
matching_bytes = byte_sequence_matcher_->getNumberOfMatchingBytes(matching_bytes, *latest_byte);
if (matching_bytes == getByteSequenceSize()) {
finalizeLatestSplitContent(session, splits.back(), buffer);
ended_with_byte_sequence = true;
matching_bytes = 0;
if (buffer.size() >= BUFFER_TARGET_SIZE) {
session.appendBuffer(splits.back(), buffer);
buffer.clear();
}
buffer.clear();
} else if (buffer.size() >= BUFFER_TARGET_SIZE) {
session.appendBuffer(splits.back(), std::span<const std::byte>(buffer.data(), buffer.size() - matching_bytes));
buffer.assign(getByteSequence().begin(), getByteSequence().begin() + gsl::narrow<std::vector<std::byte>::difference_type>(matching_bytes));
}
}

Expand Down
30 changes: 27 additions & 3 deletions extensions/standard-processors/processors/SplitContent.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class SplitContent final : public core::Processor {
public:
explicit SplitContent(const std::string_view name, const utils::Identifier& uuid = {}) : Processor(name, uuid) {}

using size_type = std::vector<std::byte>::size_type;
enum class ByteSequenceFormat { Hexadecimal, Text };

enum class ByteSequenceLocation { Trailing, Leading };

EXTENSIONAPI static constexpr auto Description = "Splits incoming FlowFiles by a specified byte sequence";
Expand Down Expand Up @@ -96,16 +96,40 @@ class SplitContent final : public core::Processor {
EXTENSIONAPI static constexpr bool IsSingleThreaded = false;
ADD_COMMON_VIRTUAL_FUNCTIONS_FOR_PROCESSORS

static constexpr size_type BUFFER_TARGET_SIZE = 4096;

void onSchedule(core::ProcessContext& context, core::ProcessSessionFactory& session_factory) override;
void onTrigger(core::ProcessContext& context, core::ProcessSession& session) override;
void initialize() override;


private:
std::shared_ptr<core::FlowFile> createNewSplit(core::ProcessSession& session) const;
void finalizeLatestSplitContent(core::ProcessSession& session, const std::shared_ptr<core::FlowFile>& latest_split, const std::vector<std::byte>& buffer) const;
void finalizeLastSplitContent(core::ProcessSession& session, std::vector<std::shared_ptr<core::FlowFile>>& splits, const std::vector<std::byte>& buffer, bool ended_with_byte_sequence) const;

std::vector<std::byte> byte_sequence_{};
void endedWithByteSequenceWithMoreDataToCome(core::ProcessSession& session, std::vector<std::shared_ptr<core::FlowFile>>& splits) const;

class ByteSequenceMatcher {
public:
explicit ByteSequenceMatcher(std::vector<std::byte> byte_sequence);
size_type getNumberOfMatchingBytes(size_type number_of_currently_matching_bytes, std::byte next_byte);
size_type getPreviousMaxMatch(size_type number_of_currently_matching_bytes);
[[nodiscard]] std::span<const std::byte> getByteSequence() const { return byte_sequence_; }

private:
struct node {
std::byte byte;
std::unordered_map<std::byte, size_type> cache;
std::optional<size_type> previous_max_match;
};
std::vector<node> byte_sequence_nodes_;
const std::vector<std::byte> byte_sequence_;
};

std::span<const std::byte> getByteSequence() const;
size_type getByteSequenceSize() const;

std::optional<ByteSequenceMatcher> byte_sequence_matcher_;
bool keep_byte_sequence = false;
ByteSequenceLocation byte_sequence_location_ = ByteSequenceLocation::Trailing;
std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<SplitContent>::getLogger(uuid_);
Expand Down
62 changes: 53 additions & 9 deletions extensions/standard-processors/tests/unit/SplitContentTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,17 @@ TEST_CASE("ByteSequenceAtBufferTargetSize") {
const auto split_content = std::make_shared<SplitContent>("SplitContent");
minifi::test::SingleProcessorTestController controller{split_content};

auto [pre_fix_size, separator_size, post_fix_size] = GENERATE(
std::make_tuple(1020, 1020, 1020),
std::make_tuple(10, 10, 1020),
std::make_tuple(10, 1020, 10),
std::make_tuple(10, 10, 1020),
std::make_tuple(10, 1020, 1020),
std::make_tuple(1020, 10, 1020),
std::make_tuple(1020, 1020, 10),
std::make_tuple(2000, 1020, 10));
auto x = SplitContent::BUFFER_TARGET_SIZE-10;

auto [pre_fix_size, separator_size, post_fix_size] = GENERATE_COPY(
std::make_tuple(x, x, x),
std::make_tuple(10, 10, x),
std::make_tuple(10, x, 10),
std::make_tuple(10, 10, x),
std::make_tuple(10, x, x),
std::make_tuple(x, 10, x),
std::make_tuple(x, x, 10),
std::make_tuple(2*x, x, 10));


const std::string pre_fix = utils::string::repeat("a", pre_fix_size);
Expand All @@ -435,4 +437,46 @@ TEST_CASE("ByteSequenceAtBufferTargetSize") {
CHECK(controller.plan->getContent(original[0]) == input);
}

TEST_CASE("TrickyWithLeading", "[NiFi]") {
const auto split_content = std::make_shared<SplitContent>("SplitContent");
minifi::test::SingleProcessorTestController controller{split_content};
split_content->setProperty(SplitContent::ByteSequenceFormatProperty, magic_enum::enum_name(SplitContent::ByteSequenceFormat::Text));
split_content->setProperty(SplitContent::ByteSequence, "aab");
split_content->setProperty(SplitContent::KeepByteSequence, "true");
split_content->setProperty(SplitContent::ByteSequenceLocationProperty, magic_enum::enum_name(SplitContent::ByteSequenceLocation::Leading));

auto trigger_results = controller.trigger("aaabc");
auto original = trigger_results.at(processors::SplitContent::Original);
auto splits = trigger_results.at(processors::SplitContent::Splits);

REQUIRE(original.size() == 1);
REQUIRE(splits.size() == 2);

CHECK(controller.plan->getContent(original[0]) == "aaabc");

CHECK(controller.plan->getContent(splits[0]) == "a");
CHECK(controller.plan->getContent(splits[1]) == "aabc");
}

TEST_CASE("TrickyWithTrailing", "[NiFi]") {
const auto split_content = std::make_shared<SplitContent>("SplitContent");
minifi::test::SingleProcessorTestController controller{split_content};
split_content->setProperty(SplitContent::ByteSequenceFormatProperty, magic_enum::enum_name(SplitContent::ByteSequenceFormat::Text));
split_content->setProperty(SplitContent::ByteSequence, "aab");
split_content->setProperty(SplitContent::KeepByteSequence, "true");
split_content->setProperty(SplitContent::ByteSequenceLocationProperty, magic_enum::enum_name(SplitContent::ByteSequenceLocation::Trailing));

auto trigger_results = controller.trigger("aaabc");
auto original = trigger_results.at(processors::SplitContent::Original);
auto splits = trigger_results.at(processors::SplitContent::Splits);

REQUIRE(original.size() == 1);
REQUIRE(splits.size() == 2);

CHECK(controller.plan->getContent(original[0]) == "aaabc");

CHECK(controller.plan->getContent(splits[0]) == "aaab");
CHECK(controller.plan->getContent(splits[1]) == "c");
}

} // namespace org::apache::nifi::minifi::processors::test
9 changes: 4 additions & 5 deletions libminifi/test/libtest/unit/TestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,17 +676,16 @@ std::vector<std::byte> TestPlan::getContentAsBytes(const core::FlowFile& flow_fi
const auto content_claim = flow_file.getResourceClaim();
const auto content_stream = content_repo_->read(*content_claim);
const auto output_stream = std::make_shared<minifi::io::BufferStream>();
std::ignore = minifi::InputStreamPipe{*output_stream}(content_stream);
auto content = output_stream->getBuffer().subspan(flow_file.getOffset(), flow_file.getSize());
return ranges::to<std::vector>(content);
minifi::internal::pipe(*content_stream, *output_stream);
return ranges::to<std::vector>(output_stream->getBuffer());
}

std::string TestPlan::getContent(const minifi::core::FlowFile& file) const {
const auto content_claim = file.getResourceClaim();
const auto content_stream = content_repo_->read(*content_claim);
const auto output_stream = std::make_shared<minifi::io::BufferStream>();
std::ignore = minifi::InputStreamPipe{*output_stream}(content_stream);
return utils::span_to<std::string>(minifi::utils::as_span<const char>(output_stream->getBuffer()).subspan(file.getOffset(), file.getSize()));
minifi::internal::pipe(*content_stream, *output_stream);
return utils::span_to<std::string>(minifi::utils::as_span<const char>(output_stream->getBuffer()));
}

TestController::TestController()
Expand Down

0 comments on commit 6643503

Please sign in to comment.