Skip to content

Commit

Permalink
Add multiprocessing benchmark
Browse files Browse the repository at this point in the history
Disable C++ threading.
  • Loading branch information
syoyo committed Jan 22, 2024
1 parent 7a19c2b commit b630ec0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 30 deletions.
88 changes: 88 additions & 0 deletions benchmark/run-multiprocess-jdepp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import time
import sys
import signal

from tqdm import tqdm
import concurrent.futures

import jdepp

input_filename = "output-wiki-postagged.txt"
#input_filename = "test-postagged.txt"

interrupted = False

def handler(signum, frame):
# Gracefull shutfown
print('Signal handler called with signal', signum)

global interrupted
interrupted = True

print("reading test data:", input_filename)
lines = open(input_filename, 'r', encoding='utf8').readlines()


ninput_sentences = 0
input_sents = []
sents = []
for line in tqdm(lines, desc="[prepare]"):
if line == '\n':
# newline only line is not allowed.
continue

sents.append(line)

if line == "EOS\n":
# List[str]
text = ''.join(sents)

input_sents.append(text)

ninput_sentences += 1
sents = []

s = time.time()

parser = jdepp.Jdepp()

#model_path = "model/knbc"
#model_path = "model_2ndpoly/model/knbc"
model_path = "model_3rdpoly/model/knbc"
parser.load_model(model_path)

def run_task(in_sents):

sents = parser.parse_from_postagged_batch(in_sents)
print(sents)

n = len(sents)
del sents

return n

signal.signal(signal.SIGINT, handler)

nbatches = 1000
total_ticks = len(input_sents) // nbatches

nprocs = os.cpu_count() // 2
with tqdm(total=total_ticks) as pbar:
with concurrent.futures.ProcessPoolExecutor(max_workers=nprocs) as executor:
futures = {executor.submit(run_task, input_sents[i:i+nbatches]): i for i in range(0, len(input_sents), nbatches)}

for future in concurrent.futures.as_completed(futures):
arg = futures[future]
result = future.result()

del result

pbar.update(1)

del future

e = time.time()
proc_sec = e - s
ms_per_sentence = 1000.0 * proc_sec / float(ninput_sentences)
sys.stderr.write("J.DepP: Total {} secs({} sentences. {} ms per sentence))\n".format(proc_sec, ninput_sentences, ms_per_sentence))
3 changes: 2 additions & 1 deletion jdepp/jdepp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def dottyfy (binfo, graph_name: str = "jdepp", label_name = "# S-ID; 1", prob: b

# define nodes
for b in binfo:
s += " bunsetsu{} [label=\"{}\"];\n".format(b.id, b.morph)
# escale dquote
s += " bunsetsu{} [label=\"{}\"];\n".format(b.id, b.morph.replace('"', '\"'))

s += '\n'

Expand Down
34 changes: 5 additions & 29 deletions jdepp/python-binding-jdepp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -758,36 +758,12 @@ class PyJdepp {
std::vector<PySentence> parse_from_postagged_batch(const std::vector<std::string> &input_postagged_array) const {
std::vector<PySentence> sents;

uint32_t num_threads = (_nthreads == 0)
? uint32_t(std::thread::hardware_concurrency())
: _nthreads;
num_threads = (std::max)(
1u, (std::min)(static_cast<uint32_t>(num_threads), kMaxThreads));

size_t num_inputs = input_postagged_array.size();

if (num_inputs < 128) {
// Assume input is too small
num_threads = 1;
}

std::vector<std::thread> workers;
std::atomic<uint32_t> i{0};

// NOTE: threading is not supported in J.DepP(internal state is get corrupted in theaded execution)
// Use serial execution for a while.
sents.resize(input_postagged_array.size());

for (size_t t = 0; t < static_cast<size_t>(num_threads); t++) {
workers.push_back(std::thread([&, t]() {
size_t k = 0;

while ((k = i++) < num_inputs) {
sents[k] = std::move(parse_from_postagged(input_postagged_array[k]));
}
}));
}

for (auto &t : workers) {
t.join();
for (size_t k = 0; k < input_postagged_array.size(); k++) {
sents[k] = parse_from_postagged(input_postagged_array[k]);
}

return sents;
Expand All @@ -809,7 +785,7 @@ class PyJdepp {
for (auto &v : _argv_str) {
_argv.push_back(const_cast<char *>(v.c_str()));
}
_argv.push_back(nullptr); // must add 'nullptr' at the end, otherwise out-of-bounds access will happen in optparse
_argv.push_back(nullptr); // must add 'nullptr' at the end, otherwise out-of-bounds access will happen in optparse

for (auto &v : _learner_argv_str) {
_learner_argv.push_back(const_cast<char *>(v.c_str()));
Expand Down

0 comments on commit b630ec0

Please sign in to comment.