From c0329acde8a7d2b03add7e7c8f5e5341b48746ff Mon Sep 17 00:00:00 2001 From: Ryan Hitchman Date: Thu, 18 Jan 2024 13:58:42 -0700 Subject: [PATCH] server : implement "verbose_json" format with token details (#1781) * examples/server: implement "verbose_json" format with token details. This is intended to mirror the format of openai's Python whisper.transcribe() return values. * server: don't write WAV to a temporary file if not converting * server: use std::lock_guard instead of manual lock/unlock --- examples/common.cpp | 6 +++ examples/common.h | 1 + examples/server/server.cpp | 97 ++++++++++++++++++++++++++------------ 3 files changed, 74 insertions(+), 30 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 603c655a1..8404e00e0 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -639,6 +639,12 @@ bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); } + else if (fname.size() > 256 || fname.size() > 40 && fname.substr(0, 4) == "RIFF" && fname.substr(8, 4) == "WAVE") { + if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from fname buffer\n"); + return false; + } + } else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); return false; diff --git a/examples/common.h b/examples/common.h index 54f0b00d0..aebeb0cd4 100644 --- a/examples/common.h +++ b/examples/common.h @@ -136,6 +136,7 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat( // // Read WAV audio file and store the PCM data into pcmf32 +// fname can be a buffer of WAV data instead of a filename // The sample rate of the audio must be equal to COMMON_SAMPLE_RATE // If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM bool read_wav( diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8b6e46952..7de318596 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -18,7 +18,7 @@ #endif using namespace httplib; -using json = nlohmann::json; +using json = nlohmann::ordered_json; namespace { @@ -556,7 +556,7 @@ int main(int argc, char ** argv) { svr.Post(sparams.request_path + "/inference", [&](const Request &req, Response &res){ // acquire whisper model mutex lock - whisper_mutex.lock(); + std::lock_guard lock(whisper_mutex); // first check user requested fields of the request if (!req.has_file("file")) @@ -564,7 +564,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "error: no 'file' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}"; res.set_content(error_resp, "application/json"); - whisper_mutex.unlock(); return; } auto audio_file = req.get_file_value("file"); @@ -579,35 +578,42 @@ int main(int argc, char ** argv) { std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM - // write to temporary file - const std::string temp_filename = "whisper_server_temp_file.wav"; - std::ofstream temp_file{temp_filename, std::ios::binary}; - temp_file << audio_file.content; - temp_file.close(); - - // if file is not wav, convert to wav - if (sparams.ffmpeg_converter) { + // if file is not wav, convert to wav + // write to temporary file + const std::string temp_filename = "whisper_server_temp_file.wav"; + std::ofstream temp_file{temp_filename, std::ios::binary}; + temp_file << audio_file.content; + temp_file.close(); + std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; const bool is_converted = convert_to_wav(temp_filename, error_resp); if (!is_converted) { res.set_content(error_resp, "application/json"); - whisper_mutex.unlock(); return; } - } - // read wav content into pcmf32 - if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) { - fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); - const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; - res.set_content(error_resp, "application/json"); + // read wav content into pcmf32 + if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) + { + fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); + const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; + res.set_content(error_resp, "application/json"); + std::remove(temp_filename.c_str()); + return; + } + // remove temp file std::remove(temp_filename.c_str()); - whisper_mutex.unlock(); - return; + } else { + if (!::read_wav(audio_file.content, pcmf32, pcmf32s, params.diarize)) + { + fprintf(stderr, "error: failed to read WAV file\n"); + const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; + res.set_content(error_resp, "application/json"); + return; + } } - // remove temp file - std::remove(temp_filename.c_str()); + printf("Successfully loaded %s\n", filename.c_str()); @@ -681,6 +687,7 @@ int main(int argc, char ** argv) { wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; + wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; @@ -724,7 +731,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); const std::string error_resp = "{\"error\":\"failed to process audio\"}"; res.set_content(error_resp, "application/json"); - whisper_mutex.unlock(); return; } } @@ -778,6 +784,43 @@ int main(int argc, char ** argv) { ss << speaker << text << "\n\n"; } res.set_content(ss.str(), "text/vtt"); + } else if (params.response_format == vjson_format) { + /* try to match openai/whisper's Python format */ + std::string results = output_str(ctx, params, pcmf32s); + json jres = json{{"text", results}}; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) + { + json segment = json{ + {"id", i}, + {"text", whisper_full_get_segment_text(ctx, i)}, + }; + + if (!params.no_timestamps) { + segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; + segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; + } + + const int n_tokens = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n_tokens; ++j) { + whisper_token_data token = whisper_full_get_token_data(ctx, i, j); + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + segment["tokens"].push_back(token.id); + json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + if (!params.no_timestamps) { + word["start"] = token.t0 * 0.01; + word["end"] = token.t1 * 0.01; + } + word["probability"] = token.p; + segment["words"].push_back(word); + } + jres["segments"].push_back(segment); + } + res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); } // TODO add more output formats else @@ -792,18 +835,14 @@ int main(int argc, char ** argv) { // reset params to thier defaults params = default_params; - - // return whisper model mutex lock - whisper_mutex.unlock(); }); svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ - whisper_mutex.lock(); + std::lock_guard lock(whisper_mutex); if (!req.has_file("model")) { fprintf(stderr, "error: no 'model' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}"; res.set_content(error_resp, "application/json"); - whisper_mutex.unlock(); return; } std::string model = req.get_file_value("model").content; @@ -812,7 +851,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "error: 'model': %s not found!\n", model.c_str()); const std::string error_resp = "{\"error\":\"model not found!\"}"; res.set_content(error_resp, "application/json"); - whisper_mutex.unlock(); return; } @@ -835,7 +873,6 @@ int main(int argc, char ** argv) { res.set_content(success, "application/text"); // check if the model is in the file system - whisper_mutex.unlock(); }); svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {