Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run LSTM recognition in multiple threads #4275

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 92 additions & 41 deletions src/ccmain/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint> // for int16_t, int32_t
#include <cstdio> // for fclose, fopen, FILE
#include <ctime> // for clock
#include <future>
#include "control.h"
#ifndef DISABLED_LEGACY_ENGINE
# include "docqual.h"
Expand Down Expand Up @@ -194,36 +195,42 @@ void Tesseract::SetupWordPassN(int pass_n, WordData *word) {
}
}

// Runs word recognition on all the words.
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT *pr_it,
std::vector<WordData> *words) {
// TODO(rays) Before this loop can be parallelized (it would yield a massive
// speed-up) all remaining member globals need to be converted to local/heap
// (eg set_pass1 and set_pass2) and an intermediate adaption pass needs to be
// added. The results will be significantly different with adaption on, and
// deterioration will need investigation.
pr_it->restart_page();
for (unsigned w = 0; w < words->size(); ++w) {
WordData *word = &(*words)[w];
if (w > 0) {
word->prev_word = &(*words)[w - 1];
bool Tesseract::RecogWordsSegment(std::vector<WordData>::iterator start,
std::vector<WordData>::iterator end,
int pass_n,
ETEXT_DESC *monitor,
PAGE_RES *page_res,
LSTMRecognizer *lstm_recognizer,
std::atomic<int>& words_done,
int total_words,
std::shared_ptr<std::mutex> recog_words_mutex) {
PAGE_RES_IT pr_it(page_res, recog_words_mutex);
// Process a segment of the words vector
pr_it.restart_page();

for (auto it = start; it != end; ++it, ++words_done) {
WordData *word = &(*it);
if (it != start) {
word->prev_word = &(*(it - 1));
}
if (monitor != nullptr) {
std::lock_guard<std::mutex> lock(*recog_words_mutex);
monitor->ocr_alive = true;
if (pass_n == 1) {
monitor->progress = 70 * w / words->size();
monitor->progress = 70 * words_done / total_words;
} else {
monitor->progress = 70 + 30 * w / words->size();
monitor->progress = 70 + 30 * words_done / total_words;
}
// Only call the progress callback for the first thread.
if (monitor->progress_callback2 != nullptr) {
TBOX box = pr_it->word()->word->bounding_box();
TBOX box = pr_it.word()->word->bounding_box();
(*monitor->progress_callback2)(monitor, box.left(), box.right(), box.top(), box.bottom());
}
if (monitor->deadline_exceeded() ||
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, words->size()))) {
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, total_words))) {
// Timeout. Fake out the rest of the words.
for (; w < words->size(); ++w) {
(*words)[w].word->SetupFake(unicharset);
for (; it != end; ++it) {
it->word->SetupFake(unicharset);
}
return false;
}
Expand All @@ -238,31 +245,72 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT
}
}
// Sync pr_it with the WordData.
while (pr_it->word() != nullptr && pr_it->word() != word->word) {
pr_it->forward();
}
ASSERT_HOST(pr_it->word() != nullptr);
pr_it.forward_to_word(word->word);
ASSERT_HOST(pr_it.word() != nullptr);
bool make_next_word_fuzzy = false;
#ifndef DISABLED_LEGACY_ENGINE
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, &pr_it, &make_next_word_fuzzy)) {
// Needs to be setup again to see the new outlines in the chopped_word.
SetupWordPassN(pass_n, word);
}
#endif // ndef DISABLED_LEGACY_ENGINE

classify_word_and_language(pass_n, pr_it, word);
classify_word_and_language(pass_n, &pr_it, word, lstm_recognizer);
if (tessedit_dump_choices || debug_noise_removal) {
tprintf("Pass%d: %s [%s]\n", pass_n, word->word->best_choice->unichar_string().c_str(),
word->word->best_choice->debug_string().c_str());
}
pr_it->forward();
if (make_next_word_fuzzy && pr_it->word() != nullptr) {
pr_it->MakeCurrentWordFuzzy();
pr_it.forward();
if (make_next_word_fuzzy && pr_it.word() != nullptr) {
pr_it.MakeCurrentWordFuzzy();
}
}
return true;
}

// Runs word recognition on all the words.
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES *page_res,
std::vector<WordData> *words) {
int total_words = words->size();
int segment_size = std::max(total_words / lstm_num_threads, 1);
std::atomic<int> words_done(0);
std::shared_ptr<std::mutex> recog_words_mutex = std::make_shared<std::mutex>();
std::vector<std::future<bool>> futures;

// Launch multiple threads to recognize the words in parallel
auto segment_start = words->begin() + segment_size;
for (int i = 1; i < lstm_num_threads && segment_start != words->end(); ++i) {
auto segment_end = segment_start + segment_size;
if (i == lstm_num_threads - 1 ||
std::distance(segment_start, words->end()) < segment_size) {
segment_end = words->end();
}
futures.push_back(std::async(
std::launch::async, &Tesseract::RecogWordsSegment, this, segment_start,
segment_end, pass_n, monitor, page_res, lstm_recognizers_[i],
std::ref(words_done), total_words, recog_words_mutex));
segment_start = segment_end;
}

// Process the first segment in this thread
bool overall_result = RecogWordsSegment(words->begin(),
words->begin() + segment_size,
pass_n,
monitor,
page_res,
lstm_recognizers_[0],
std::ref(words_done),
total_words,
recog_words_mutex);

// Wait for all threads to complete and aggregate results
for (auto &f : futures) {
overall_result &= f.get();
}

return overall_result;
}

/**
* recog_all_words()
*
Expand Down Expand Up @@ -340,7 +388,7 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,

most_recently_used_ = this;
// Run pass 1 word recognition.
if (!RecogAllWordsPassN(1, monitor, &page_res_it, &words)) {
if (!RecogAllWordsPassN(1, monitor, page_res, &words)) {
return false;
}
// Pass 1 post-processing.
Expand Down Expand Up @@ -380,11 +428,10 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,
}
most_recently_used_ = this;
// Run pass 2 word recognition.
if (!RecogAllWordsPassN(2, monitor, &page_res_it, &words)) {
if (!RecogAllWordsPassN(2, monitor, page_res, &words)) {
return false;
}
}

// The next passes are only required for Tess-only.
if (AnyTessLang() && !AnyLSTMLang()) {
// ****************** Pass 3 *******************
Expand Down Expand Up @@ -871,14 +918,15 @@ static int SelectBestWords(double rating_ratio, double certainty_margin, bool de
// Returns positive if this recognizer found more new best words than the
// number kept from best_words.
int Tesseract::RetryWithLanguage(const WordData &word_data, WordRecognizer recognizer, bool debug,
WERD_RES **in_word, PointerVector<WERD_RES> *best_words) {
WERD_RES **in_word, PointerVector<WERD_RES> *best_words,
LSTMRecognizer *lstm_recognizer) {
if (debug) {
tprintf("Trying word using lang %s, oem %d\n", lang.c_str(),
static_cast<int>(tessedit_ocr_engine_mode));
}
// Run the recognizer on the word.
PointerVector<WERD_RES> new_words;
(this->*recognizer)(word_data, in_word, &new_words);
(this->*recognizer)(word_data, in_word, &new_words, lstm_recognizer);
if (new_words.empty()) {
// Transfer input word to new_words, as the classifier must have put
// the result back in the input.
Expand Down Expand Up @@ -1300,7 +1348,10 @@ float Tesseract::ClassifyBlobAsWord(int pass_n, PAGE_RES_IT *pr_it, C_BLOB *blob
// Recognizes in the current language, and if successful that is all.
// If recognition was not successful, tries all available languages until
// it gets a successful result or runs out of languages. Keeps the best result.
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data) {
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data,
LSTMRecognizer *lstm_recognizer_thread_local) {
LSTMRecognizer *lstm_recognizer = lstm_recognizer_thread_local ? lstm_recognizer_thread_local
: lstm_recognizer_;
#ifdef DISABLED_LEGACY_ENGINE
WordRecognizer recognizer = &Tesseract::classify_word_pass1;
#else
Expand Down Expand Up @@ -1333,19 +1384,19 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
}
}
most_recently_used_->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[sub],
&best_words);
&best_words, lstm_recognizer);
Tesseract *best_lang_tess = most_recently_used_;
if (!WordsAcceptable(best_words)) {
// Try all the other languages to see if they are any better.
if (most_recently_used_ != this &&
this->RetryWithLanguage(*word_data, recognizer, debug,
&word_data->lang_words[sub_langs_.size()], &best_words) > 0) {
&word_data->lang_words[sub_langs_.size()], &best_words, lstm_recognizer) > 0) {
best_lang_tess = this;
}
for (unsigned i = 0; !WordsAcceptable(best_words) && i < sub_langs_.size(); ++i) {
if (most_recently_used_ != sub_langs_[i] &&
sub_langs_[i]->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[i],
&best_words) > 0) {
&best_words, lstm_recognizer) > 0) {
best_lang_tess = sub_langs_[i];
}
}
Expand Down Expand Up @@ -1378,7 +1429,7 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
*/

void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_word,
PointerVector<WERD_RES> *out_words) {
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
ROW *row = word_data.row;
BLOCK *block = word_data.block;
prev_word_best_choice_ =
Expand All @@ -1390,14 +1441,14 @@ void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_wor
tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) {
#endif // def DISABLED_LEGACY_ENGINE
if (!(*in_word)->odd_size || tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
LSTMRecognizeWord(*block, row, *in_word, out_words);
LSTMRecognizeWord(*block, row, *in_word, out_words, lstm_recognizer);
if (!out_words->empty()) {
return; // Successful lstm recognition.
}
}
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
// No fallback allowed, so use a fake.
(*in_word)->SetupFake(lstm_recognizer_->GetUnicharset());
(*in_word)->SetupFake(lstm_recognizer->GetUnicharset());
return;
}

Expand Down Expand Up @@ -1534,7 +1585,7 @@ bool Tesseract::TestNewNormalization(int original_misfits, float baseline_shift,
*/

void Tesseract::classify_word_pass2(const WordData &word_data, WERD_RES **in_word,
PointerVector<WERD_RES> *out_words) {
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
// Return if we do not want to run Tesseract.
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
return;
Expand Down
Loading