Skip to content

Commit

Permalink
[fix](inverted index) special characters cause buffer overflow in Uni…
Browse files Browse the repository at this point in the history
…code tokenization. (#211)
  • Loading branch information
zzzxl1993 authored May 1, 2024
1 parent 847f460 commit f10bc3f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 61 deletions.
10 changes: 10 additions & 0 deletions src/core/CLucene/analysis/AnalysisHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "CLucene/util/VoidList.h"
#include "CLucene/LuceneThreads.h"

#include <unordered_set>

CL_CLASS_DEF(util,Reader)
CL_CLASS_DEF(util,IReader)

Expand Down Expand Up @@ -297,6 +299,11 @@ class CLUCENE_EXPORT Analyzer{
virtual void set_lowercase(bool lowercase) {
_lowercase = lowercase;
}

virtual void set_stopwords(std::unordered_set<std::string_view>* stopwords) {
_stopwords = stopwords;
}

private:

DEFINE_MUTEX(THIS_LOCK)
Expand All @@ -313,7 +320,9 @@ class CLUCENE_EXPORT Analyzer{
* to save a TokenStream for later re-use by the same
* thread. */
virtual void setPreviousTokenStream(TokenStream* obj);

bool _lowercase = false;
std::unordered_set<std::string_view>* _stopwords = nullptr;

public:
/**
Expand Down Expand Up @@ -350,6 +359,7 @@ class CLUCENE_EXPORT Tokenizer:public TokenStream {
/** The text source for this Tokenizer. */
CL_NS(util)::Reader* input;
bool lowercase = false;
std::unordered_set<std::string_view>* stopwords = nullptr;

public:
/** Construct a tokenizer with null input. */
Expand Down
16 changes: 7 additions & 9 deletions src/core/CLucene/analysis/standard95/StandardAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ namespace lucene::analysis::standard95 {

class StandardAnalyzer : public Analyzer {
public:
StandardAnalyzer() : Analyzer() { _lowercase = true; }
StandardAnalyzer() : Analyzer() {
_lowercase = true;
_stopwords = nullptr;
}

bool isSDocOpt() override { return true; }

TokenStream* tokenStream(const TCHAR* fieldName,
lucene::util::Reader* reader) override {
return _CLNEW StandardTokenizer(reader, useStopWords_, _lowercase);
return _CLNEW StandardTokenizer(reader, _lowercase, _stopwords);
}

TokenStream* reusableTokenStream(const TCHAR* fieldName,
lucene::util::Reader* reader) override {
if (tokenizer_ == nullptr) {
tokenizer_ = new StandardTokenizer(reader, useStopWords_, _lowercase);
tokenizer_ = new StandardTokenizer(reader, _lowercase, _stopwords);
} else {
tokenizer_->reset(reader);
}
Expand All @@ -31,13 +35,7 @@ class StandardAnalyzer : public Analyzer {
}
}

void useStopWords(bool useStopWords) {
useStopWords_ = useStopWords;
}

private:
bool useStopWords_ = true;

StandardTokenizer* tokenizer_ = nullptr;
};

Expand Down
14 changes: 7 additions & 7 deletions src/core/CLucene/analysis/standard95/StandardTokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ static std::unordered_set<std::string_view> stop_words = {

class StandardTokenizer : public Tokenizer {
public:
StandardTokenizer(lucene::util::Reader* in, bool useStopWords)
: Tokenizer(in), useStopWords_(useStopWords) {
StandardTokenizer(lucene::util::Reader* in)
: Tokenizer(in) {
scanner_ = std::make_unique<StandardTokenizerImpl>(in);
Tokenizer::lowercase = true;
Tokenizer::stopwords = nullptr;
}
StandardTokenizer(lucene::util::Reader* in, bool useStopWords, bool lowercase)
: Tokenizer(in), useStopWords_(useStopWords) {
StandardTokenizer(lucene::util::Reader* in, bool lowercase, std::unordered_set<std::string_view>* stopwords)
: Tokenizer(in) {
scanner_ = std::make_unique<StandardTokenizerImpl>(in);
Tokenizer::lowercase = lowercase;
Tokenizer::stopwords = stopwords;
}

Token* next(Token* token) override {
Expand All @@ -47,7 +49,7 @@ class StandardTokenizer : public Tokenizer {
std::transform(term.begin(), term.end(), const_cast<char*>(term.data()),
[](char c) { return to_lower(c); });
}
if (useStopWords_ && stop_words.count(term)) {
if (stopwords && stopwords->count(term)) {
skippedPositions++;
continue;
}
Expand All @@ -70,8 +72,6 @@ class StandardTokenizer : public Tokenizer {
};

private:
bool useStopWords_ = true;

std::unique_ptr<StandardTokenizerImpl> scanner_;

int32_t skippedPositions = 0;
Expand Down
52 changes: 10 additions & 42 deletions src/core/CLucene/analysis/standard95/StandardTokenizerImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ const std::vector<std::string> StandardTokenizerImpl::ZZ_ERROR_MSG = {
"Error: pushback value was too large"};

StandardTokenizerImpl::StandardTokenizerImpl(lucene::util::Reader* reader)
: zzBuffer(ZZ_BUFFERSIZE), zzReader(reader) {}
: zzReader(reader), zzBuffer((reader == nullptr) ? 0 : reader->size()) {}

std::string_view StandardTokenizerImpl::getText() {
return std::string_view(zzBuffer.data() + zzStartRead,
Expand All @@ -67,53 +67,20 @@ std::string_view StandardTokenizerImpl::getText() {

bool StandardTokenizerImpl::zzRefill() {
if (zzStartRead > 0) {
zzEndRead += zzFinalHighSurrogate;
zzFinalHighSurrogate = 0;
std::copy_n(zzBuffer.begin() + zzStartRead, zzEndRead - zzStartRead,
zzBuffer.begin());

zzEndRead -= zzStartRead;
zzCurrentPos -= zzStartRead;
zzMarkedPos -= zzStartRead;
zzStartRead = 0;
}

int32_t requested = zzBuffer.size() - zzEndRead - zzFinalHighSurrogate;
if (requested == 0) {
return true;
return true;
}

int32_t numRead = zzReader->readCopy(zzBuffer.data(), zzEndRead, requested);
if (numRead == 0) {
_CLTHROWA(CL_ERR_Runtime,
"Reader returned 0 characters. See JFlex examples/zero-reader "
"for a workaround.");
}
int32_t numRead = zzReader->readCopy(zzBuffer.data(), 0, zzBuffer.size());
if (numRead > 0) {
zzEndRead += numRead;

int32_t n =
StringUtil::validate_utf8(std::string_view(zzBuffer.data(), zzEndRead));
if (n == -1) {
yyResetPosition();
return true;
}
assert(zzBuffer.size() == numRead);
zzEndRead += numRead;

if (n != 0) {
if (numRead == requested) {
zzEndRead -= n;
zzFinalHighSurrogate = n;
} else {
int32_t c = zzReader->read();
if (c == -1) {
int32_t n = StringUtil::validate_utf8(std::string_view(zzBuffer.data(), zzBuffer.size()));
if (n != 0) {
return true;
} else {
_CLTHROWA(CL_ERR_Runtime, "Why did you come here");
}
}
}

return false;
return false;
}

return true;
Expand All @@ -126,6 +93,7 @@ void StandardTokenizerImpl::yyclose() {

void StandardTokenizerImpl::yyreset(lucene::util::Reader* reader) {
zzReader = reader;
zzBuffer.resize(reader->size());
yyResetPosition();
zzLexicalState = YYINITIAL;
}
Expand Down Expand Up @@ -181,7 +149,7 @@ int32_t StandardTokenizerImpl::getNextToken() {

{
while (true) {
if (zzCurrentPosL < zzEndReadL) {
if (zzCurrentPosL < zzEndReadL && (zzCurrentPosL - zzStartRead) < ZZ_BUFFERSIZE) {
size_t len = 0;
zzInput = decodeUtf8ToCodepoint(
std::string_view(zzBufferL.data() + zzCurrentPosL, zzEndReadL),
Expand Down
4 changes: 2 additions & 2 deletions src/core/CLucene/util/stringUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ class StringUtil {
} else {
if ((c & 0xC0) != 0x80) return -1;
codepoint = (codepoint << 6) | (c & 0x3F);
if (!is_valid_codepoint(codepoint)) {
bytes_in_char--;
if (bytes_in_char == 0 && !is_valid_codepoint(codepoint)) {
return -1;
}
bytes_in_char--;
surplus_bytes++;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/test/analysis/TestStandard95.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

#include "CLucene/_ApiHeader.h"
#include "CLucene/analysis/standard95/StandardAnalyzer.h"
#include "CLucene/analysis/standard95/StandardTokenizer.h"
#include "test.h"

void testCut(const std::string &str, std::vector<std::string> &tokens) {
auto standard =
std::make_unique<lucene::analysis::standard95::StandardAnalyzer>();
standard->set_stopwords(&lucene::analysis::standard95::stop_words);
auto tokenizer =
static_cast<lucene::analysis::standard95::StandardTokenizer *>(
standard->tokenStream(L"name", nullptr));
Expand All @@ -28,7 +30,7 @@ void testCut(const std::string &str, std::vector<std::string> &tokens) {
void testCutLines(std::vector<std::string>& datas, std::vector<std::string> &tokens) {
auto standard =
std::make_unique<lucene::analysis::standard95::StandardAnalyzer>();
standard->useStopWords(false);
standard->set_stopwords(nullptr);
auto tokenizer =
static_cast<lucene::analysis::standard95::StandardTokenizer *>(
standard->tokenStream(L"name", nullptr));
Expand Down

0 comments on commit f10bc3f

Please sign in to comment.