Skip to content

Commit

Permalink
server : implement "verbose_json" format with token details (openai#1781
Browse files Browse the repository at this point in the history
)

* examples/server: implement "verbose_json" format with token details.

This is intended to mirror the format of openai's Python
whisper.transcribe() return values.

* server: don't write WAV to a temporary file if not converting

* server: use std::lock_guard instead of manual lock/unlock
  • Loading branch information
rmmh authored Jan 18, 2024
1 parent fb466b3 commit c0329ac
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 30 deletions.
6 changes: 6 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector

fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (fname.size() > 256 || fname.size() > 40 && fname.substr(0, 4) == "RIFF" && fname.substr(8, 4) == "WAVE") {
if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from fname buffer\n");
return false;
}
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
//

// Read WAV audio file and store the PCM data into pcmf32
// fname can be a buffer of WAV data instead of a filename
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
Expand Down
97 changes: 67 additions & 30 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#endif

using namespace httplib;
using json = nlohmann::json;
using json = nlohmann::ordered_json;

namespace {

Expand Down Expand Up @@ -556,15 +556,14 @@ int main(int argc, char ** argv) {

svr.Post(sparams.request_path + "/inference", [&](const Request &req, Response &res){
// acquire whisper model mutex lock
whisper_mutex.lock();
std::lock_guard<std::mutex> lock(whisper_mutex);

// first check user requested fields of the request
if (!req.has_file("file"))
{
fprintf(stderr, "error: no 'file' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
auto audio_file = req.get_file_value("file");
Expand All @@ -579,35 +578,42 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM

// write to temporary file
const std::string temp_filename = "whisper_server_temp_file.wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content;
temp_file.close();

// if file is not wav, convert to wav

if (sparams.ffmpeg_converter) {
// if file is not wav, convert to wav
// write to temporary file
const std::string temp_filename = "whisper_server_temp_file.wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content;
temp_file.close();

std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
const bool is_converted = convert_to_wav(temp_filename, error_resp);
if (!is_converted) {
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}

// read wav content into pcmf32
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
// read wav content into pcmf32
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize))
{
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
std::remove(temp_filename.c_str());
return;
}
// remove temp file
std::remove(temp_filename.c_str());
whisper_mutex.unlock();
return;
} else {
if (!::read_wav(audio_file.content, pcmf32, pcmf32s, params.diarize))
{
fprintf(stderr, "error: failed to read WAV file\n");
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
return;
}
}
// remove temp file
std::remove(temp_filename.c_str());


printf("Successfully loaded %s\n", filename.c_str());

Expand Down Expand Up @@ -681,6 +687,7 @@ int main(int argc, char ** argv) {
wparams.logprob_thold = params.logprob_thold;

wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;

whisper_print_user_data user_data = { &params, &pcmf32s, 0 };

Expand Down Expand Up @@ -724,7 +731,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
Expand Down Expand Up @@ -778,6 +784,43 @@ int main(int argc, char ** argv) {
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "text/vtt");
} else if (params.response_format == vjson_format) {
/* try to match openai/whisper's Python format */
std::string results = output_str(ctx, params, pcmf32s);
json jres = json{{"text", results}};
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i)
{
json segment = json{
{"id", i},
{"text", whisper_full_get_segment_text(ctx, i)},
};

if (!params.no_timestamps) {
segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01;
segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01;
}

const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
if (token.id >= whisper_token_eot(ctx)) {
continue;
}

segment["tokens"].push_back(token.id);
json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}};
if (!params.no_timestamps) {
word["start"] = token.t0 * 0.01;
word["end"] = token.t1 * 0.01;
}
word["probability"] = token.p;
segment["words"].push_back(word);
}
jres["segments"].push_back(segment);
}
res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
}
// TODO add more output formats
else
Expand All @@ -792,18 +835,14 @@ int main(int argc, char ** argv) {

// reset params to thier defaults
params = default_params;

// return whisper model mutex lock
whisper_mutex.unlock();
});
svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
whisper_mutex.lock();
std::lock_guard<std::mutex> lock(whisper_mutex);
if (!req.has_file("model"))
{
fprintf(stderr, "error: no 'model' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
std::string model = req.get_file_value("model").content;
Expand All @@ -812,7 +851,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
const std::string error_resp = "{\"error\":\"model not found!\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}

Expand All @@ -835,7 +873,6 @@ int main(int argc, char ** argv) {
res.set_content(success, "application/text");

// check if the model is in the file system
whisper_mutex.unlock();
});

svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
Expand Down

0 comments on commit c0329ac

Please sign in to comment.