diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f0e8440a..1a6a7e25f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,10 @@ if (MSVC) endif() message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}") + # DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler + # See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856 + # Remove this definition once the conda msvcp140.dll dll is updated. + add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) endif() if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON) @@ -442,7 +446,9 @@ endif() if(OCOS_ENABLE_BERT_TOKENIZER) # Bert set(_HAS_TOKENIZER ON) - file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*") + file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" + "operators/tokenizer/bert_tokenizer.*" + "operators/tokenizer/bert_tokenizer_decoder.*") list(APPEND TARGET_SRC ${bert_TARGET_SRC}) endif() @@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE) endif() target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS}) +target_include_directories(ortcustomops PUBLIC "$") target_include_directories(ortcustomops PUBLIC "$") + target_link_libraries(ortcustomops PUBLIC ocos_operators) if(_BUILD_SHARED_LIBRARY) @@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY) standardize_output_folder(extensions_shared) if(LINUX OR ANDROID) - set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") + set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS + " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") # strip if not a debug build if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s") diff --git a/docs/c_api.md b/docs/c_api.md new file mode 100644 index 000000000..1a3d4613b --- /dev/null +++ b/docs/c_api.md @@ -0,0 +1,19 @@ +# ONNXRuntime Extensions C ABI + +ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs. + +The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available: + +- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models. +- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels. +- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model. + +## ABI QuickStart + +Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization. + +**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c). + +**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75). + +**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16). diff --git a/include/ortx_c_helper.h b/include/ortx_cpp_helper.h similarity index 100% rename from include/ortx_c_helper.h rename to include/ortx_cpp_helper.h diff --git a/include/ortx_extractor.h b/include/ortx_extractor.h new file mode 100644 index 000000000..13901666b --- /dev/null +++ b/include/ortx_extractor.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// C ABI header file for the onnxruntime-extensions tokenization module + +#pragma once + +#include "ortx_utils.h" + +typedef OrtxObject OrtxFeatureExtractor; +typedef OrtxObject OrtxRawAudios; +typedef OrtxObject OrtxTensorResult; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Creates a feature extractor object. + * + * This function creates a feature extractor object based on the provided feature definition. + * + * @param[out] extractor Pointer to a pointer to the created feature extractor object. + * @param[in] fe_def The feature definition used to create the feature extractor. + * + * @return An error code indicating the result of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def); + +/** + * Loads a collection of audio files into memory. + * + * This function loads a collection of audio files specified by the `audio_paths` array + * into memory and returns a pointer to the loaded audio data in the `audios` parameter. + * + * @param audios A pointer to a pointer that will be updated with the loaded audio data. + * The caller is responsible for freeing the memory allocated for the audio data. + * @param audio_paths An array of strings representing the paths to the audio files to be loaded. + * @param num_audios The number of audio files to be loaded. + * + * @return An `extError_t` value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios); + +/** + * @brief Creates an array of raw audio objects. + * + * This function creates an array of raw audio objects based on the provided data and sizes. + * + * @param audios Pointer to the variable that will hold the created raw audio objects. + * @param data Array of pointers to the audio data. + * @param sizes Array of pointers to the sizes of the audio data. + * @param num_audios Number of audio objects to create. + * + * @return extError_t Error code indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios); + +/** + * @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor. + * + * This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct, + * and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using + * the specified feature extractor and stores the result in the provided log_mel pointer. + * + * @param extractor The feature extractor to use for calculating the log mel spectrogram. + * @param audio The raw audio data to process. + * @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored. + * @return An extError_t value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel); + +#ifdef __cplusplus +} +#endif diff --git a/include/ortx_processor.h b/include/ortx_processor.h index 6dcc5a84e..b42a6c4f2 100644 --- a/include/ortx_processor.h +++ b/include/ortx_processor.h @@ -10,7 +10,6 @@ // typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting typedef OrtxObject OrtxProcessor; typedef OrtxObject OrtxRawImages; -typedef OrtxObject OrtxImageProcessorResult; #ifdef __cplusplus extern "C" { @@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images, size_t* num_images_loaded); + /** - * @brief Preprocesses the given raw images using the specified processor. + * @brief Creates raw images from the provided data. + * + * This function creates raw images from the provided data. The raw images are stored in the `images` parameter. + * + * @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images. + * @param data Array of pointers to the data for each image. + * @param sizes Array of pointers to the sizes of each image. + * @param num_images Number of images to create. + * @return An `extError_t` value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images); + +/** + * @brief Pre-processes the given raw images using the specified processor. * * This function applies preprocessing operations on the raw images using the provided processor. * The result of the preprocessing is stored in the `OrtxImageProcessorResult` object. @@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima * @return An `extError_t` value indicating the success or failure of the preprocessing operation. */ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images, - OrtxImageProcessorResult** result); - -/** - * @brief Retrieves the image processor result at the specified index. - * - * @param result Pointer to the OrtxImageProcessorResult structure to store the result. - * @param index The index of the result to retrieve. - * @return extError_t The error code indicating the success or failure of the operation. - */ -extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor); - -/** \brief Clear the outputs of the processor - * - * \param processor The processor object - * \param result The result object to clear - * \return Error code indicating the success or failure of the operation - */ -extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result); + OrtxTensorResult** result); #ifdef __cplusplus } diff --git a/include/ortx_utils.h b/include/ortx_utils.h index e6c0af9aa..ed1ca2ec5 100644 --- a/include/ortx_utils.h +++ b/include/ortx_utils.h @@ -17,19 +17,22 @@ typedef enum { kOrtxKindDetokenizerCache = 0x778B, kOrtxKindProcessor = 0x778C, kOrtxKindRawImages = 0x778D, - kOrtxKindImageProcessorResult = 0x778E, + kOrtxKindTensorResult = 0x778E, kOrtxKindProcessorResult = 0x778F, kOrtxKindTensor = 0x7790, + kOrtxKindFeatureExtractor = 0x7791, + kOrtxKindRawAudios = 0x7792, kOrtxKindEnd = 0x7999 } extObjectKind_t; // all object managed by the library should be 'derived' from this struct // which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C typedef struct { - int ext_kind_; + extObjectKind_t ext_kind_; } OrtxObject; typedef OrtxObject OrtxTensor; +typedef OrtxObject OrtxTensorResult; // C, instead of C++ doesn't cast automatically, // so we need to use a macro to cast the object to the correct type @@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object); */ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object); +/** + * @brief Retrieves the tensor at the specified index from the given tensor result. + * + * This function allows you to access a specific tensor from a tensor result object. + * + * @param result The tensor result object from which to retrieve the tensor. + * @param index The index of the tensor to retrieve. + * @param tensor A pointer to a variable that will hold the retrieved tensor. + * @return An error code indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor); + /** \brief Get the data from the tensor * * \param tensor The tensor object diff --git a/onnxruntime_extensions/_torch_cvt.py b/onnxruntime_extensions/_torch_cvt.py index a17b1bb2d..10b85c1a6 100644 --- a/onnxruntime_extensions/_torch_cvt.py +++ b/onnxruntime_extensions/_torch_cvt.py @@ -17,7 +17,7 @@ from ._ortapi2 import make_onnx_model from ._cuops import SingleOpGraph from ._hf_cvt import HFTokenizerConverter -from .util import remove_unused_initializers +from .util import remove_unused_initializers, mel_filterbank class _WhisperHParams: @@ -30,53 +30,15 @@ class _WhisperHParams: N_FRAMES = N_SAMPLES // HOP_LENGTH -def _mel_filterbank( - n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32): - """ - Compute a Mel-filterbank. The filters are stored in the rows, the columns, - and it is Slaney normalized mel-scale filterbank. - """ - fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype) - - # the centers of the frequency bins for the DFT - freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) - - mel = np.linspace(min_mel, max_mel, n_mels + 2) - # Fill in the linear scale - f_min = 0.0 - f_sp = 200.0 / 3 - freqs = f_min + f_sp * mel - - # And now the nonlinear scale - min_log_hz = 1000.0 # beginning of log region (Hz) - min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region - - log_t = mel >= min_log_mel - freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel)) - mel_bins = freqs - - mel_spacing = np.diff(mel_bins) - - ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1) - for i in range(n_mels): - left = -ramps[i] / mel_spacing[i] - right = ramps[i + 2] / mel_spacing[i + 1] - - # intersect them with each other and zero - fbank[i] = np.maximum(0, np.minimum(left, right)) - - energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels]) - fbank *= energy_norm[:, np.newaxis] - return fbank - - class CustomOpStftNorm(torch.autograd.Function): @staticmethod def symbolic(g, self, n_fft, hop_length, window): - t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) - t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64)) - t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) + t_n_fft = g.op('Constant', value_t=torch.tensor( + n_fft, dtype=torch.int64)) + t_hop_length = g.op('Constant', value_t=torch.tensor( + hop_length, dtype=torch.int64)) + t_frame_size = g.op( + 'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size) @staticmethod @@ -97,7 +59,7 @@ def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT, self.n_fft = n_fft self.window = torch.hann_window(n_fft) self.mel_filters = torch.from_numpy( - _mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels)) + mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels)) def forward(self, audio_pcm: torch.Tensor): stft_norm = CustomOpStftNorm.apply(audio_pcm, @@ -112,7 +74,8 @@ def forward(self, audio_pcm: torch.Tensor): spec_shape = log_spec.shape padding_spec = torch.ones(spec_shape[0], spec_shape[1], - self.n_samples // self.hop_length - spec_shape[2], + self.n_samples // self.hop_length - + spec_shape[2], dtype=torch.float) padding_spec *= spec_min log_spec = torch.cat((log_spec, padding_spec), dim=2) @@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft): make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0', 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'], name='slice_1'), - make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0), - make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1), + make_node('Constant', inputs=[], outputs=[ + 'const0_output_0'], name='const0', value_int=0), + make_node('Constant', inputs=[], outputs=[ + 'const1_output_0'], name='const1', value_int=1), make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'], name='gather_4', axis=3), make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'], name='gather_5', axis=3), - make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'), - make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'), - make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'), + make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[ + 'mul_output_0'], name='mul0'), + make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[ + 'mul_1_output_0'], name='mul1'), + make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[ + stft_norm_node.output[0]], name='add0'), ] new_stft_nodes.extend(onnx_model.graph.node[:node_idx]) new_stft_nodes.extend(replaced_nodes) @@ -253,9 +221,11 @@ def post_processing(self, **kwargs): del g.node[:] g.node.extend(nodes) - inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] + inputs = [onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] del g.input[:] g.input.extend(inputs) - g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text'])) + g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto( + onnx.TensorProto.STRING, ['N', 'text'])) return make_onnx_model(g, opset_version=self.opset_version) diff --git a/operators/audio/audio.cc b/operators/audio/audio.cc index ecd7d4cf3..2a0d3eb0f 100644 --- a/operators/audio/audio.cc +++ b/operators/audio/audio.cc @@ -3,17 +3,15 @@ #include "ocos.h" #ifdef ENABLE_DR_LIBS -#include "audio_decoder.hpp" +#include "audio_decoder.h" #endif // ENABLE_DR_LIBS -FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& { +FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& { static OrtOpLoader op_loader( - []() { return nullptr; } #ifdef ENABLE_DR_LIBS - , - CustomCpuStructV2("AudioDecoder", AudioDecoder) + CustomCpuStructV2("AudioDecoder", AudioDecoder), #endif - ); + []() { return nullptr; }); return op_loader.GetCustomOps(); }; diff --git a/operators/audio/audio_decoder.cc b/operators/audio/audio_decoder.cc new file mode 100644 index 000000000..b9e92dcd9 --- /dev/null +++ b/operators/audio/audio_decoder.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include +#include + +#include "audio_decoder.h" + +#define DR_FLAC_IMPLEMENTATION +#include "dr_flac.h" +#define DR_MP3_IMPLEMENTATION 1 +#define DR_MP3_FLOAT_OUTPUT 1 +#include "dr_mp3.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include "narrow.h" +#include "string_utils.h" +#include "string_tensor.h" +#include "sampling.h" + +OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_); + if (!status) { + status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_); + } + + return status; +} + +AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, + OrtxStatus& status) const { + const std::map format_mapping = {{"default", AudioStreamType::kDefault}, + {"wav", AudioStreamType::kWAV}, + {"mp3", AudioStreamType::kMP3}, + {"flac", AudioStreamType::kFLAC}}; + + AudioStreamType stream_format = AudioStreamType::kDefault; + if (str_format.length() > 0) { + auto pos = format_mapping.find(str_format); + if (pos == format_mapping.end()) { + status = {kOrtxErrorInvalidArgument, + MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()}; + return stream_format; + } + stream_format = pos->second; + } + + if (stream_format == AudioStreamType::kDefault) { + auto p_stream = reinterpret_cast(p_data); + std::string_view marker(p_stream, 4); + if (marker == "fLaC") { + stream_format = AudioStreamType::kFLAC; + } else if (marker == "RIFF") { + stream_format = AudioStreamType::kWAV; + } else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) { + // http://www.mp3-tech.org/programmer/frame_header.html + // only detect the 8 + 3 bits sync word + stream_format = AudioStreamType::kMP3; + } else { + status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"}; + } + } + + return stream_format; +} + +template +static size_t DrReadFrames(std::list>& frames, FX_DECODER fx, TY_AUDIO& obj) { + const size_t default_chunk_size = 1024 * 256; + int64_t total_buf_size = 0; + + for (;;) { + std::vector buf; + buf.resize(default_chunk_size * obj.channels); + auto n_frames = fx(&obj, default_chunk_size, buf.data()); + if (n_frames <= 0) { + break; + } + auto data_size = n_frames * obj.channels; + total_buf_size += data_size; + buf.resize(data_size); + frames.emplace_back(std::move(buf)); + } + + return total_buf_size; +} + +OrtxStatus AudioDecoder::Compute(const ortc::Tensor& input, const std::optional format, + ortc::Tensor& output0) const { + const uint8_t* p_data = input.Data(); + auto input_dim = input.Shape(); + OrtxStatus status; + if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) { + return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."}; + } + + std::string str_format; + if (format) { + str_format = *format; + } + auto stream_format = ReadStreamFormat(p_data, str_format, status); + if (status) { + return status; + } + + int64_t total_buf_size = 0; + std::list> lst_frames; + int64_t orig_sample_rate = 0; + int64_t orig_channels = 0; + + if (stream_format == AudioStreamType::kMP3) { + auto mp3_obj_ptr = std::make_unique(); + if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."}; + return status; + } + orig_sample_rate = mp3_obj_ptr->sampleRate; + orig_channels = mp3_obj_ptr->channels; + total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr); + + } else if (stream_format == AudioStreamType::kFLAC) { + drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr); + auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); }); + if (flac_obj == nullptr) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."}; + return status; + } + orig_sample_rate = flac_obj->sampleRate; + orig_channels = flac_obj->channels; + total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj); + + } else { + drwav wav_obj; + if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."}; + return status; + } + orig_sample_rate = wav_obj.sampleRate; + orig_channels = wav_obj.channels; + total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj); + } + + if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."}; + return status; + } + + // join all frames + std::vector buf; + buf.resize(total_buf_size); + int64_t offset = 0; + for (auto& _b : lst_frames) { + std::copy(_b.begin(), _b.end(), buf.begin() + offset); + offset += _b.size(); + } + + // mix the stereo channels into mono channel + if (stereo_mixer_ && orig_channels > 1) { + if (buf.size() > 1) { + for (size_t i = 0; i < buf.size() / 2; ++i) { + buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2; + } + buf.resize(buf.size() / 2); + } + } + + if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) { + // A lowpass filter on buf audio data to remove high frequency noise + ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate); + std::vector filtered_buf = filter.Process(buf); + // downsample the audio data + KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_); + } + + std::vector dim_out = {1, ort_extensions::narrow(buf.size())}; + float* p_output = output0.Allocate(dim_out); + std::copy(buf.begin(), buf.end(), p_output); + return status; +} diff --git a/operators/audio/audio_decoder.h b/operators/audio/audio_decoder.h new file mode 100644 index 000000000..cecfcffdb --- /dev/null +++ b/operators/audio/audio_decoder.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "ocos.h" + +#include +#include + +struct AudioDecoder { + public: + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info); + + template + OrtxStatus Init(const DictT& attrs) { + // in API mode, the default value is 1 + downsample_rate_ = 16000; + stereo_mixer_ = 1; + for (const auto& [key, value] : attrs) { + if (key == "target_sample_rate") { + downsample_rate_ = std::get(value); + } else if (key == "stereo_to_mono") { + stereo_mixer_ = std::get(value); + } else { + return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Invalid argument"}; + } + } + + return {}; + } + + enum class AudioStreamType { kDefault = 0, kWAV, kMP3, kFLAC }; + + AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtxStatus& status) const; + OrtxStatus Compute(const ortc::Tensor& input, const std::optional format, + ortc::Tensor& output0) const; + OrtxStatus ComputeNoOpt(const ortc::Tensor& input, ortc::Tensor& output0) { + return Compute(input, std::nullopt, output0); + } + + private: + int64_t downsample_rate_{}; + int64_t stereo_mixer_{}; +}; diff --git a/operators/audio/audio_decoder.hpp b/operators/audio/audio_decoder.hpp deleted file mode 100644 index 06e61172e..000000000 --- a/operators/audio/audio_decoder.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "ocos.h" - -#include -#include -#include -#include -#define DR_FLAC_IMPLEMENTATION -#include "dr_flac.h" -#define DR_MP3_IMPLEMENTATION 1 -#define DR_MP3_FLOAT_OUTPUT 1 -#include "dr_mp3.h" -#define DR_WAV_IMPLEMENTATION -#include "dr_wav.h" - -#include -#include "narrow.h" -#include "string_utils.h" -#include "string_tensor.h" -#include "sampling.h" - -struct AudioDecoder{ - public: - - OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { - auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_); - if (!status) { - status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_); - } - - return status; - } - - enum class AudioStreamType { - kDefault = 0, - kWAV, - kMP3, - kFLAC - }; - - AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const { - static const std::map format_mapping = { - {"default", AudioStreamType::kDefault}, - {"wav", AudioStreamType::kWAV}, - {"mp3", AudioStreamType::kMP3}, - {"flac", AudioStreamType::kFLAC}}; - - AudioStreamType stream_format = AudioStreamType::kDefault; - if (str_format.length() > 0) { - auto pos = format_mapping.find(str_format); - if (pos == format_mapping.end()) { - status = OrtW::CreateStatus(MakeString( - "[AudioDecoder]: Unknown audio stream format: ", str_format) - .c_str(), - ORT_INVALID_ARGUMENT); - return stream_format; - } - stream_format = pos->second; - } - - if (stream_format == AudioStreamType::kDefault) { - auto p_stream = reinterpret_cast(p_data); - std::string_view marker(p_stream, 4); - if (marker == "fLaC") { - stream_format = AudioStreamType::kFLAC; - } else if (marker == "RIFF") { - stream_format = AudioStreamType::kWAV; - } else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) { - // http://www.mp3-tech.org/programmer/frame_header.html - // only detect the 8 + 3 bits sync word - stream_format = AudioStreamType::kMP3; - } else { - status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT); - } - } - - return stream_format; - } - - template - static size_t DrReadFrames(std::list>& frames, FX_DECODER fx, TY_AUDIO& obj) { - const size_t default_chunk_size = 1024 * 256; - int64_t total_buf_size = 0; - - for (;;) { - std::vector buf; - buf.resize(default_chunk_size * obj.channels); - auto n_frames = fx(&obj, default_chunk_size, buf.data()); - if (n_frames <= 0) { - break; - } - auto data_size = n_frames * obj.channels; - total_buf_size += data_size; - buf.resize(data_size); - frames.emplace_back(std::move(buf)); - } - - return total_buf_size; - } - - OrtStatusPtr Compute(const ortc::Tensor& input, - const std::optional format, - ortc::Tensor& output0) const { - const uint8_t* p_data = input.Data(); - auto input_dim = input.Shape(); - OrtStatusPtr status = nullptr; - if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) { - status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT); - return status; - } - - std::string str_format; - if (format) { - str_format = *format; - } - auto stream_format = ReadStreamFormat(p_data, str_format, status); - if (status) { - return status; - } - - int64_t total_buf_size = 0; - std::list> lst_frames; - int64_t orig_sample_rate = 0; - int64_t orig_channels = 0; - - if (stream_format == AudioStreamType::kMP3) { - auto mp3_obj_ptr = std::make_unique(); - if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = mp3_obj_ptr->sampleRate; - orig_channels = mp3_obj_ptr->channels; - total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr); - - } else if (stream_format == AudioStreamType::kFLAC) { - drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr); - auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); }); - if (flac_obj == nullptr) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = flac_obj->sampleRate; - orig_channels = flac_obj->channels; - total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj); - - } else { - drwav wav_obj; - if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = wav_obj.sampleRate; - orig_channels = wav_obj.channels; - total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj); - } - - if (downsample_rate_ != 0 && - orig_sample_rate < downsample_rate_) { - status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT); - return status; - } - - // join all frames - std::vector buf; - buf.resize(total_buf_size); - int64_t offset = 0; - for (auto& _b : lst_frames) { - std::copy(_b.begin(), _b.end(), buf.begin() + offset); - offset += _b.size(); - } - - // mix the stereo channels into mono channel - if (stereo_mixer_ && orig_channels > 1) { - if (buf.size() > 1) { - for (size_t i = 0; i < buf.size() / 2; ++i) { - buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2; - } - buf.resize(buf.size() / 2); - } - } - - if (downsample_rate_ != 0 && - downsample_rate_ != orig_sample_rate) { - // A lowpass filter on buf audio data to remove high frequency noise - ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate); - std::vector filtered_buf = filter.Process(buf); - // downsample the audio data - KaiserWindowInterpolation::Process(filtered_buf, buf, - 1.0f * orig_sample_rate, 1.0f * downsample_rate_); - } - - std::vector dim_out = {1, ort_extensions::narrow(buf.size())}; - float* p_output = output0.Allocate(dim_out); - std::copy(buf.begin(), buf.end(), p_output); - return status; - } - - private: - int64_t downsample_rate_{}; - int64_t stereo_mixer_{}; -}; diff --git a/operators/math/dlib/stft_norm.hpp b/operators/math/dlib/stft_norm.hpp index 4a16b1d9c..9e8d59051 100644 --- a/operators/math/dlib/stft_norm.hpp +++ b/operators/math/dlib/stft_norm.hpp @@ -6,37 +6,32 @@ #include "ocos.h" #include -struct StftNormal{ +struct StftNormal { StftNormal() = default; OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { return OrtW::GetOpAttribute(info, "onesided", onesided_); } - OrtStatusPtr Compute(const ortc::Tensor& input0, - int64_t n_fft, - int64_t hop_length, - const ortc::Span& input3, - int64_t frame_length, - ortc::Tensor& output0) const { + OrtxStatus Compute(const ortc::Tensor& input0, int64_t n_fft, int64_t hop_length, + const ortc::Span& input3, int64_t frame_length, ortc::Tensor& output0) const { auto X = input0.Data(); auto window = input3.data_; auto dimensions = input0.Shape(); auto win_length = input3.size(); if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) { - return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT); + return {kOrtxErrorInvalidArgument, "[Stft] Only batch == 1 tensor supported."}; } if (frame_length != n_fft) { - return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT); + return {kOrtxErrorInvalidArgument, "[Stft] Only support size of FFT equals the frame length."}; } dlib::matrix dm_x = dlib::mat(X, 1, dimensions[1]); dlib::matrix hann_win = dlib::mat(window, 1, win_length); - auto m_stft = dlib::stft( - dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, - n_fft, win_length, hop_length); + auto m_stft = + dlib::stft(dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, n_fft, win_length, hop_length); if (onesided_) { m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1); @@ -49,7 +44,7 @@ struct StftNormal{ auto out0 = output0.Allocate(outdim); memcpy(out0, result.steal_memory().get(), result_size * sizeof(float)); - return nullptr; + return {}; } private: diff --git a/shared/api/c_api_feature_extraction.cc b/shared/api/c_api_feature_extraction.cc new file mode 100644 index 000000000..8ffde2455 --- /dev/null +++ b/shared/api/c_api_feature_extraction.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "speech_extractor.h" + +#include "c_api_utils.hpp" + +using namespace ort_extensions; + +class RawAudiosObject : public OrtxObjectImpl { + public: + RawAudiosObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindRawAudios) {} + ~RawAudiosObject() override = default; + + std::unique_ptr audios_; + size_t num_audios_; +}; + +extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** raw_audios, const char* const* audio_paths, size_t num_audios) { + if (raw_audios == nullptr || audio_paths == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto audios_obj = std::make_unique(); + auto [audios, num] = + ort_extensions::LoadRawData(audio_paths, audio_paths + num_audios); + audios_obj->audios_ = std::move(audios); + audios_obj->num_audios_ = num; + + *raw_audios = static_cast(audios_obj.release()); + return extError_t(); +} + +extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* def) { + if (extractor == nullptr || def == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto extractor_ptr = std::make_unique(); + ReturnableStatus status = extractor_ptr->Init(def); + if (status.IsOk()) { + *extractor = static_cast(extractor_ptr.release()); + } else { + *extractor = nullptr; + } + + return status.Code(); +} + +extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios, + OrtxTensorResult** result) { + if (extractor == nullptr || raw_audios == nullptr || result == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto extractor_ptr = static_cast(extractor); + auto audios_obj = static_cast(raw_audios); + + auto ts_result = std::make_unique(); + std::unique_ptr> log_mel[1]; + ReturnableStatus status = + extractor_ptr->DoCall(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), log_mel[0]); + if (status.IsOk()) { + std::vector> tensors; + std::transform(log_mel, log_mel + 1, std::back_inserter(tensors), + [](auto& ts) { return std::unique_ptr(ts.release()); }); + ts_result->SetTensors(std::move(tensors)); + *result = ts_result.release(); + } else { + *result = nullptr; + } + + return status.Code(); +} diff --git a/shared/api/c_api_processor.cc b/shared/api/c_api_processor.cc index 2beb90a13..8e2e12598 100644 --- a/shared/api/c_api_processor.cc +++ b/shared/api/c_api_processor.cc @@ -4,6 +4,8 @@ #include "ortx_processor.h" #include "image_processor.h" +#include "c_api_utils.hpp" + using namespace ort_extensions; extError_t OrtxCreateProcessor(OrtxProcessor** processor, const char* def) { @@ -37,19 +39,19 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima } auto images_obj = std::make_unique(); - auto [img, num] = LoadRawImages(image_paths, image_paths + num_images); + auto [img, num] = LoadRawData(image_paths, image_paths + num_images); images_obj->images = std::move(img); images_obj->num_images = num; if (num_images_loaded != nullptr) { *num_images_loaded = num; } - + *images = static_cast(images_obj.release()); return extError_t(); } extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images, - OrtxImageProcessorResult** result) { + OrtxTensorResult** result) { if (processor == nullptr || images == nullptr || result == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -67,59 +69,14 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm return status.Code(); } - auto result_ptr = std::make_unique(); + auto result_ptr = std::make_unique(); status = processor_ptr->PreProcess(ort_extensions::span(images_ptr->images.get(), images_ptr->num_images), *result_ptr); if (status.IsOk()) { - *result = static_cast(result_ptr.release()); + *result = static_cast(result_ptr.release()); } else { *result = nullptr; } return {}; } - -extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor) { - if (result == nullptr || tensor == nullptr) { - ReturnableStatus::last_error_message_ = "Invalid argument"; - return kOrtxErrorInvalidArgument; - } - - auto result_ptr = static_cast(result); - ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult)); - if (!status.IsOk()) { - return status.Code(); - } - - if (index >= result_ptr->results.size()) { - ReturnableStatus::last_error_message_ = "Index out of range"; - return kOrtxErrorInvalidArgument; - } - - auto tensor_ptr = std::make_unique>(); - tensor_ptr->SetObject(result_ptr->results[index].get()); - *tensor = static_cast(tensor_ptr.release()); - return extError_t(); -} - -extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result) { - if (processor == nullptr || result == nullptr) { - ReturnableStatus::last_error_message_ = "Invalid argument"; - return kOrtxErrorInvalidArgument; - } - - const auto processor_ptr = static_cast(processor); - ReturnableStatus status(processor_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindProcessor)); - if (!status.IsOk()) { - return status.Code(); - } - - auto result_ptr = static_cast(result); - status = result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult); - if (!status.IsOk()) { - return status.Code(); - } - - ImageProcessor::ClearOutputs(result_ptr); - return extError_t(); -} diff --git a/shared/api/c_api_tokenizer.cc b/shared/api/c_api_tokenizer.cc index 22c24defc..e3d2fd8de 100644 --- a/shared/api/c_api_tokenizer.cc +++ b/shared/api/c_api_tokenizer.cc @@ -6,7 +6,7 @@ #include "c_api_utils.hpp" #include "tokenizer_impl.h" -namespace ort_extensions { +using namespace ort_extensions; class DetokenizerCache : public OrtxObjectImpl { public: @@ -17,29 +17,20 @@ class DetokenizerCache : public OrtxObjectImpl { std::string last_text_{}; // last detokenized text }; -template<> -OrtxObject* OrtxObjectFactory::CreateForward() { - return std::make_unique().release(); -} - -template<> -void OrtxObjectFactory::DisposeForward(OrtxObject* obj) { - Dispose(obj); +template <> +OrtxObject* OrtxObjectFactory::CreateForward() { + return Create(); } -} // namespace ort_extensions - -using namespace ort_extensions; -extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, - const char* input[], size_t batch_size, OrtxTokenId2DArray** output) { +extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, + OrtxTokenId2DArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; } auto token_ptr = static_cast(tokenizer); - ReturnableStatus status = - token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer); + ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer); if (!status.IsOk()) { return status.Code(); } @@ -61,8 +52,8 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, return extError_t(); } -extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, - const OrtxTokenId2DArray* input, OrtxStringArray** output) { +extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input, + OrtxStringArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -81,11 +72,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, } std::vector> t_ids; - std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), - std::back_inserter(t_ids), - [](const std::vector& vec) { - return span(vec.data(), vec.size()); - }); + std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), std::back_inserter(t_ids), + [](const std::vector& vec) { return span(vec.data(), vec.size()); }); std::vector output_text; status = token_ptr->Detokenize(t_ids, output_text); @@ -101,9 +89,7 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, ; } -extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, - const extTokenId_t* input, - size_t len, +extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len, OrtxStringArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; @@ -186,8 +172,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to return extError_t(); } -extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, - size_t index, const extTokenId_t** item, size_t* length) { +extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index, + const extTokenId_t** item, size_t* length) { if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -210,9 +196,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok return extError_t(); } -extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, - OrtxDetokenizerCache* cache, - extTokenId_t next_id, const char** text_out) { +extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id, + const char** text_out) { if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; diff --git a/shared/api/c_api_utils.cc b/shared/api/c_api_utils.cc index 0345fdb23..3fc376efe 100644 --- a/shared/api/c_api_utils.cc +++ b/shared/api/c_api_utils.cc @@ -10,6 +10,8 @@ using namespace ort_extensions; +class DetokenizerCache; // forward definition in tokenizer_impl.cc + thread_local std::string ReturnableStatus::last_error_message_; OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const { @@ -37,7 +39,7 @@ extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, . va_start(args, object); if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) { - *object = OrtxObjectFactory::CreateForward(); + *object = OrtxObjectFactory::CreateForward(); } else if (kind == extObjectKind_t::kOrtxKindTokenizer) { return OrtxCreateTokenizer(static_cast(object), va_arg(args, const char*)); } @@ -80,8 +82,8 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) { return kOrtxErrorInvalidArgument; } - if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) { - OrtxObjectFactory::Dispose(object); + /* if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) { + OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) { OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) { @@ -94,6 +96,11 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) { OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindProcessor) { OrtxObjectFactory::Dispose(object); + } */ + if (Ortx_object->ortx_kind() >= kOrtxKindBegin && Ortx_object->ortx_kind() < kOrtxKindEnd) { + OrtxObjectFactory::Dispose(object); + } else { + return kOrtxErrorInvalidArgument; } return extError_t(); @@ -113,6 +120,30 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) { return err; } +extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor) { + if (result == nullptr || tensor == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto result_ptr = static_cast(result); + ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTensorResult)); + if (!status.IsOk()) { + return status.Code(); + } + + ortc::TensorBase* ts = result_ptr->GetAt(index); + if (ts == nullptr) { + ReturnableStatus::last_error_message_ = "Cannot get the tensor at the specified index from the result"; + return kOrtxErrorInvalidArgument; + } + + auto tensor_ptr = std::make_unique>(); + tensor_ptr->SetObject(ts); + *tensor = static_cast(tensor_ptr.release()); + return extError_t(); +} + extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape, size_t* num_dims) { if (tensor == nullptr) { @@ -120,7 +151,7 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data return kOrtxErrorInvalidArgument; } - auto tensor_impl = static_cast*>(tensor); + auto tensor_impl = static_cast*>(tensor); if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; diff --git a/shared/api/c_api_utils.hpp b/shared/api/c_api_utils.hpp index d7794b610..c1fed4727 100644 --- a/shared/api/c_api_utils.hpp +++ b/shared/api/c_api_utils.hpp @@ -3,8 +3,10 @@ #pragma once #include +#include #include "ortx_utils.h" +#include "file_sys.h" #include "ext_status.h" #include "op_def_struct.h" @@ -12,7 +14,7 @@ namespace ort_extensions { class OrtxObjectImpl : public OrtxObject { public: explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() { - ext_kind_ = static_cast(kind); + ext_kind_ = kind; }; virtual ~OrtxObjectImpl() = default; @@ -24,30 +26,21 @@ class OrtxObjectImpl : public OrtxObject { } return static_cast(ext_kind_); } - - template - struct Type2Kind { - static const extObjectKind_t value = kOrtxKindUnknown; - }; -}; - -template <> -struct OrtxObjectImpl::Type2Kind { - static const extObjectKind_t value = kOrtxKindTensor; }; -template +// A wrapper class to store a object pointer which is readonly. i.e. unowned. +template class OrtxObjectWrapper : public OrtxObjectImpl { public: - OrtxObjectWrapper() : OrtxObjectImpl(OrtxObjectImpl::Type2Kind::value) {} + OrtxObjectWrapper() : OrtxObjectImpl(kind) {} ~OrtxObjectWrapper() override = default; - void SetObject(T* t) { stored_object_ = t; } + void SetObject(const T* t) { stored_object_ = t; } - [[nodiscard]] T* GetObject() const { return stored_object_; } + [[nodiscard]] const T* GetObject() const { return stored_object_; } private: - T* stored_object_{}; + const T* stored_object_{}; }; template @@ -100,6 +93,35 @@ class StringArray : public OrtxObjectImpl { std::vector strings_; }; +class TensorResult : public OrtxObjectImpl { + public: + TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {} + ~TensorResult() override = default; + + void SetTensors(std::vector>&& tensors) { tensors_ = std::move(tensors); } + + [[nodiscard]] const std::vector>& tensors() const { return tensors_; } + + [[nodiscard]] std::vector GetTensors() const { + std::vector ts; + ts.reserve(tensors_.size()); + for (auto& t : tensors_) { + ts.push_back(t.get()); + } + return ts; + } + + ortc::TensorBase* GetAt(size_t i) const { + if (i < tensors_.size()) { + return tensors_[i].get(); + } + return nullptr; + } + + private: + std::vector> tensors_; +}; + struct ReturnableStatus { public: thread_local static std::string last_error_message_; @@ -123,24 +145,25 @@ struct ReturnableStatus { OrtxStatus status_; }; -template class OrtxObjectFactory { public: - static std::unique_ptr Create() { return std::make_unique(); } - - static OrtxObject* CreateForward(); - static void DisposeForward(OrtxObject* object); + template + static OrtxObject* Create() { + return std::make_unique().release(); + } + template static void Dispose(OrtxObject* object) { auto obj_ptr = static_cast(object); std::unique_ptr ptr(obj_ptr); ptr.reset(); } -}; - -class DetokenizerCache; // forward definition in tokenizer_impl.cc -class ProcessorResult; // forward definition in image_processor.h + // Forward declaration for creating an object which isn't visible to c_api_utils.cc + // and the definition is in the corresponding .cc file. + template + static OrtxObject* CreateForward(); +}; class CppAllocator : public ortc::IAllocator { public: @@ -157,4 +180,25 @@ class CppAllocator : public ortc::IAllocator { } }; +template +std::tuple, size_t> LoadRawData(It begin, It end) { + auto raw_data = std::make_unique(end - begin); + size_t n = 0; + for (auto it = begin; it != end; ++it) { + std::ifstream ifs = path(*it).open(std::ios::binary | std::ios::in); + if (!ifs.is_open()) { + break; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + + T& datum = raw_data[n++]; + datum.resize(size); + ifs.read(reinterpret_cast(datum.data()), size); + } + + return std::make_tuple(std::move(raw_data), n); +} } // namespace ort_extensions diff --git a/shared/api/image_processor.cc b/shared/api/image_processor.cc index 1cbab6e10..5383c10fe 100644 --- a/shared/api/image_processor.cc +++ b/shared/api/image_processor.cc @@ -7,6 +7,7 @@ #include "file_sys.h" #include "image_processor.h" +#include "c_api_utils.hpp" #include "cv2/imgcodecs/imdecode.hpp" #include "image_transforms.hpp" #include "image_transforms_phi_3.hpp" @@ -14,38 +15,11 @@ using namespace ort_extensions; using json = nlohmann::json; -namespace ort_extensions { -template -std::tuple, size_t> LoadRawImages(It begin, It end) { - auto raw_images = std::make_unique(end - begin); - size_t n = 0; - for (auto it = begin; it != end; ++it) { - std::ifstream ifs = path(*it).open(std::ios::binary); - if (!ifs.is_open()) { - break; - } - - ifs.seekg(0, std::ios::end); - size_t size = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - - ImageRawData& raw_image = raw_images[n++]; - raw_image.resize(size); - ifs.read(reinterpret_cast(raw_image.data()), size); - } - - return std::make_tuple(std::move(raw_images), n); -} - -std::tuple, size_t> LoadRawImages( - const std::initializer_list& image_paths) { - return LoadRawImages(image_paths.begin(), image_paths.end()); +std::tuple, size_t> +ort_extensions::LoadRawImages(const std::initializer_list& image_paths) { + return ort_extensions::LoadRawData(image_paths.begin(), image_paths.end()); } -template std::tuple, size_t> LoadRawImages(char const**, char const**); - -} // namespace ort_extensions - Operation::KernelRegistry ImageProcessor::kernel_registry_ = { {"DecodeImage", []() { return CreateKernelInstance(image_decoder); }}, {"Resize", []() { return CreateKernelInstance(&Resize::Compute); }}, @@ -97,9 +71,7 @@ OrtxStatus ImageProcessor::Init(std::string_view processor_def) { return {}; } -ImageProcessor::ImageProcessor() - : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) { -} +ImageProcessor::ImageProcessor() : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {} template static ortc::Tensor* StackTensor(const std::vector& arg_lists, int axis, ortc::IAllocator* allocator) { @@ -136,39 +108,6 @@ static ortc::Tensor* StackTensor(const std::vector& arg_lists, in return output.release(); } -static OrtxStatus StackTensors(const std::vector& arg_lists, std::vector& outputs, - ortc::IAllocator* allocator) { - if (arg_lists.empty()) { - return {}; - } - - size_t batch_size = arg_lists.size(); - size_t num_outputs = arg_lists[0].size(); - for (size_t axis = 0; axis < num_outputs; ++axis) { - std::vector ts_ptrs; - ts_ptrs.reserve(arg_lists.size()); - std::vector shape = arg_lists[0][axis]->Shape(); - for (auto& ts : arg_lists) { - if (shape != ts[axis]->Shape()) { - return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."}; - } - ts_ptrs.push_back(ts[axis]); - } - - std::vector output_shape = shape; - output_shape.insert(output_shape.begin(), batch_size); - std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape); - for (size_t i = 0; i < batch_size; ++i) { - auto ts = ts_ptrs[i]; - const std::byte* ts_buff = reinterpret_cast(ts->DataRaw()); - auto ts_size = ts->SizeInBytes(); - std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size); - } - } - - return {}; -} - std::tuple ImageProcessor::PreProcess(ort_extensions::span image_data, ortc::Tensor** pixel_values, ortc::Tensor** image_sizes, @@ -209,7 +148,7 @@ std::tuple ImageProcessor::PreProcess(ort_extension return {status, std::move(r)}; } -OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_data, ImageProcessorResult& r) const { +OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_data, TensorResult& r) const { std::vector inputs; inputs.resize(image_data.size()); for (size_t i = 0; i < image_data.size(); ++i) { @@ -235,9 +174,13 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_d } } - r.results = operations_.back()->AllocateOutputs(allocator_); - status = StackTensors(outputs, r.results, allocator_); + auto img_result = operations_.back()->AllocateOutputs(allocator_); + status = OrtxRunner::StackTensors(outputs, img_result, allocator_); operations_.back()->ResetTensors(allocator_); + if (status.IsOk()) { + r.SetTensors(std::move(img_result)); + } + return status; } @@ -257,14 +200,3 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) { r->num_img_takens = nullptr; } } - -void ort_extensions::ImageProcessor::ClearOutputs(ImageProcessorResult* r) { - if (r == nullptr) { - return; - } - - for (auto& ts : r->results) { - ts.reset(); - } - r->results.clear(); // clear the vector -} diff --git a/shared/api/image_processor.h b/shared/api/image_processor.h index 534e811d6..02eee4e32 100644 --- a/shared/api/image_processor.h +++ b/shared/api/image_processor.h @@ -16,9 +16,6 @@ namespace ort_extensions { using ImageRawData = std::vector; -template -std::tuple, size_t> LoadRawImages(It begin, It end); - std::tuple, size_t> LoadRawImages( const std::initializer_list& image_paths); @@ -29,13 +26,6 @@ class ProcessorResult : public OrtxObjectImpl { ortc::Tensor* image_sizes{}; ortc::Tensor* num_img_takens{}; }; - -class ImageProcessorResult : public OrtxObjectImpl { - public: - ImageProcessorResult() : OrtxObjectImpl(kOrtxKindImageProcessorResult) {} - std::vector results; -}; - class ImageProcessor : public OrtxObjectImpl { public: ImageProcessor(); @@ -43,15 +33,16 @@ class ImageProcessor : public OrtxObjectImpl { OrtxStatus Init(std::string_view processor_def); + // Deprecated, using the next function instead std::tuple PreProcess(ort_extensions::span image_data, ortc::Tensor** pixel_values, ortc::Tensor** image_sizes, ortc::Tensor** num_img_takens) const; - OrtxStatus PreProcess(ort_extensions::span image_data, ImageProcessorResult& r) const; + OrtxStatus PreProcess(ort_extensions::span image_data, TensorResult& r) const; + // Deprecated, using the next function instead static void ClearOutputs(ProcessorResult* r); - static void ClearOutputs(ImageProcessorResult* r); static Operation::KernelRegistry kernel_registry_; diff --git a/shared/api/runner.hpp b/shared/api/runner.hpp index 3590190bb..1b6a01ddf 100644 --- a/shared/api/runner.hpp +++ b/shared/api/runner.hpp @@ -28,7 +28,8 @@ class KernelDef { virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0; virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0; - using AttrType = std::variant>; + using AttrType = + std::variant, std::vector, std::vector>; using AttrDict = std::unordered_map; template @@ -98,7 +99,7 @@ class KernelDef { template class KernelFunction : public KernelDef { public: - KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){}; + KernelFunction(OrtxStatus (*body)(Args...)) : body_(body) {}; virtual ~KernelFunction() = default; TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { @@ -132,7 +133,7 @@ class KernelFunction : public KernelDef { template class KernelStruct : public KernelDef { public: - KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){}; + KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body) {}; virtual ~KernelStruct() = default; TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { @@ -167,8 +168,18 @@ class KernelStruct : public KernelDef { attr_dict[key] = value.template get(); } else if (value.is_number_float()) { attr_dict[key] = value.template get(); - } else if (value.is_array()) { - attr_dict[key] = value.template get>(); + } else if (value.is_array() && value.size() > 0) { + auto& elem_0 = value.at(0); + if (elem_0.is_number_float()) { + attr_dict[key] = value.template get>(); + } else if (elem_0.is_string()) { + attr_dict[key] = value.template get>(); + } else if (elem_0.is_number_integer() || elem_0.is_number_unsigned()) { + attr_dict[key] = value.template get>(); + } else { + return {kOrtxErrorCorruptData, "Unsupported mix types in attribute value."}; + } + } else { return {kOrtxErrorCorruptData, "Invalid attribute type."}; } @@ -309,6 +320,39 @@ class OrtxRunner { return {}; } + static OrtxStatus StackTensors(const std::vector& arg_lists, std::vector& outputs, + ortc::IAllocator* allocator) { + if (arg_lists.empty()) { + return {}; + } + + size_t batch_size = arg_lists.size(); + size_t num_outputs = arg_lists[0].size(); + for (size_t axis = 0; axis < num_outputs; ++axis) { + std::vector ts_ptrs; + ts_ptrs.reserve(arg_lists.size()); + std::vector shape = arg_lists[0][axis]->Shape(); + for (auto& ts : arg_lists) { + if (shape != ts[axis]->Shape()) { + return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."}; + } + ts_ptrs.push_back(ts[axis]); + } + + std::vector output_shape = shape; + output_shape.insert(output_shape.begin(), batch_size); + std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape); + for (size_t i = 0; i < batch_size; ++i) { + auto ts = ts_ptrs[i]; + const std::byte* ts_buff = reinterpret_cast(ts->DataRaw()); + auto ts_size = ts->SizeInBytes(); + std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size); + } + } + + return {}; + } + private: ortc::IAllocator* allocator_; std::vector ops_; diff --git a/shared/api/speech_extractor.cc b/shared/api/speech_extractor.cc new file mode 100644 index 000000000..5cd005f8a --- /dev/null +++ b/shared/api/speech_extractor.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "speech_extractor.h" + +#include "audio/audio_decoder.h" +#include "speech_features.hpp" + +using namespace ort_extensions; + +Operation::KernelRegistry SpeechFeatureExtractor::kernel_registry_ = { + {"AudioDecoder", []() { return CreateKernelInstance(&AudioDecoder::ComputeNoOpt); }}, + {"STFTNorm", []() { return CreateKernelInstance(&SpeechFeatures::STFTNorm); }}, + {"LogMelSpectrum", []() { return CreateKernelInstance(&LogMel::Compute); }}, +}; + +SpeechFeatureExtractor::SpeechFeatureExtractor() + : OrtxObjectImpl(extObjectKind_t::kOrtxKindFeatureExtractor), allocator_(&CppAllocator::Instance()) {} + +OrtxStatus SpeechFeatureExtractor::Init(std::string_view extractor_def) { + std::string fe_def_str; + if (extractor_def.size() >= 5 && extractor_def.substr(extractor_def.size() - 5) == ".json") { + std::ifstream ifs = path({extractor_def.data(), extractor_def.size()}).open(); + if (!ifs.is_open()) { + return {kOrtxErrorInvalidArgument, std::string("[ImageProcessor]: failed to open ") + std::string(extractor_def)}; + } + fe_def_str = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator()); + extractor_def = fe_def_str.c_str(); + } + + // pase the extraction_def by json + auto fe_json = json::parse(extractor_def, nullptr, false); + if (fe_json.is_discarded()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: failed to parse extractor json configuration."}; + } + + auto fe_root = fe_json.at("feature_extraction"); + if (!fe_root.is_object()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: feature_extraction field is missing."}; + } + + auto op_sequence = fe_root.at("sequence"); + if (!op_sequence.is_array() || op_sequence.empty()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: sequence field is missing."}; + } + + operations_.reserve(op_sequence.size()); + for (auto mod_iter = op_sequence.begin(); mod_iter != op_sequence.end(); ++mod_iter) { + auto op = std::make_unique(kernel_registry_); + auto status = op->Init(mod_iter->dump()); + if (!status.IsOk()) { + return status; + } + + operations_.push_back(std::move(op)); + } + + return {}; +} + +OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span raw_speech, + std::unique_ptr>& log_mel) const { + // setup the input tensors + std::vector inputs; + inputs.resize(raw_speech.size()); + for (size_t i = 0; i < raw_speech.size(); ++i) { + auto& ts_input = inputs[i]; + AudioRawData& speech = raw_speech[i]; + std::vector shape = {static_cast(speech.size())}; + ts_input.push_back(std::make_unique>(shape, speech.data()).release()); + } + + std::vector outputs; + std::vector ops(operations_.size()); + std::transform(operations_.begin(), operations_.end(), ops.begin(), [](auto& op) { return op.get(); }); + OrtxRunner runner(allocator_, ops.data(), ops.size()); + auto status = runner.Run(inputs, outputs); + if (!status.IsOk()) { + return status; + } + + // clear the input tensors + for (auto& input : inputs) { + for (auto& ts : input) { + std::unique_ptr(ts).reset(); + } + } + + auto results = operations_.back()->AllocateOutputs(allocator_); + status = OrtxRunner::StackTensors(outputs, results, allocator_); + if (status.IsOk()) { + log_mel.reset(static_cast*>(results[0].release())); + operations_.back()->ResetTensors(allocator_); + } + + return status; +} diff --git a/shared/api/speech_extractor.h b/shared/api/speech_extractor.h new file mode 100644 index 000000000..3219da6eb --- /dev/null +++ b/shared/api/speech_extractor.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ortx_extractor.h" +#include "c_api_utils.hpp" +#include "runner.hpp" + + +namespace ort_extensions { + +typedef std::vector AudioRawData; + +class SpeechFeatureExtractor : public OrtxObjectImpl { + public: + SpeechFeatureExtractor(); + + virtual ~SpeechFeatureExtractor() = default; + + public: + OrtxStatus Init(std::string_view extractor_def); + + OrtxStatus DoCall(ort_extensions::span raw_speech, std::unique_ptr>& log_mel) const; + + static Operation::KernelRegistry kernel_registry_; + + private: + std::vector> operations_; + ortc::IAllocator* allocator_; +}; + +} // namespace ort_extensions diff --git a/shared/api/speech_features.hpp b/shared/api/speech_features.hpp new file mode 100644 index 000000000..acc368a12 --- /dev/null +++ b/shared/api/speech_features.hpp @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace ort_extensions { + +class SpeechFeatures { + public: + template + OrtxStatus Init(const DictT& attrs) { + for (const auto& [key, value] : attrs) { + if (key == "n_fft") { + n_fft_ = std::get(value); + } else if (key == "hop_length") { + hop_length_ = std::get(value); + } else if (key == "frame_length") { + frame_length_ = std::get(value); + } else if (key == "hann_win") { + auto& win = std::get>(value); + hann_win_.resize(win.size()); + std::transform(win.begin(), win.end(), hann_win_.begin(), [](double x) { return static_cast(x); }); + } else if (key != "_comment") { + return {kOrtxErrorInvalidArgument, "[AudioFeatures]: Invalid key in the JSON configuration."}; + } + } + + if (hann_win_.empty()) { + hann_win_ = hann_window(frame_length_); + } + return {}; + } + + OrtxStatus STFTNorm(const ortc::Tensor& pcm, ortc::Tensor& stft_norm) { + return stft_norm_.Compute(pcm, n_fft_, hop_length_, {hann_win_.data(), hann_win_.size()}, frame_length_, stft_norm); + } + + static std::vector hann_window(int N) { + std::vector window(N); + + for (int n = 0; n < N; ++n) { + // this formula leads to more rounding errors than the one below + // window[n] = static_cast(0.5 * (1 - std::cos(2 * M_PI * n / (N - 1)))); + double n_sin = std::sin(M_PI * n / N); + window[n] = static_cast(n_sin * n_sin); + } + + return window; + } + + private: + StftNormal stft_norm_; + int64_t n_fft_{}; + int64_t hop_length_{}; + int64_t frame_length_{}; + std::vector hann_win_; +}; + +class LogMel { + public: + template + OrtxStatus Init(const DictT& attrs) { + int n_fft = 0; + int n_mel = 0; + int chunk_size = 0; + for (const auto& [key, value] : attrs) { + if (key == "hop_length") { + hop_length_ = std::get(value); + } else if (key == "n_fft") { + n_fft = std::get(value); + } else if (key == "n_mel") { + n_mel = std::get(value); + } else if (key == "chunk_size") { + chunk_size = std::get(value); + } else { + return {kOrtxErrorInvalidArgument, "[LogMel]: Invalid key in the JSON configuration."}; + } + } + + n_samples_ = n_sr_ * chunk_size; + mel_filters_ = MelFilterBank(n_fft, n_mel, n_sr_); + return {}; + } + + OrtxStatus Compute(const ortc::Tensor& stft_norm, ortc::Tensor& logmel) { + // Compute the Mel spectrogram by following Python code + /* + magnitudes = stft_norm[:, :, :-1] + mel_spec = self.mel_filters @ magnitudes + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + spec_min = log_spec.max() - 8.0 + log_spec = torch.maximum(log_spec, spec_min) + spec_shape = log_spec.shape + padding_spec = torch.ones(spec_shape[0], + spec_shape[1], + self.n_samples // self.hop_length - spec_shape[2], + dtype=torch.float) + padding_spec *= spec_min + log_spec = torch.cat((log_spec, padding_spec), dim=2) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + */ + assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1); + std::vector stft_shape = stft_norm.Shape(); + dlib::matrix magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1); + for (int i = 0; i < magnitudes.nr(); ++i) { + std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1, + magnitudes.begin() + i * magnitudes.nc()); + } + + dlib::matrix mel_spec = mel_filters_ * magnitudes; + for (int i = 0; i < mel_spec.nr(); ++i) { + for (int j = 0; j < mel_spec.nc(); ++j) { + mel_spec(i, j) = std::max(1e-10f, mel_spec(i, j)); + } + } + + dlib::matrix log_spec = dlib::log10(mel_spec); + float log_spec_min = dlib::max(log_spec) - 8.0f; + for (int i = 0; i < log_spec.nr(); ++i) { + for (int j = 0; j < log_spec.nc(); ++j) { + float v = std::max(log_spec(i, j), log_spec_min); + v = (v + 4.0f) / 4.0f; + log_spec(i, j) = v; + } + } + + std::vector shape = {mel_filters_.nr(), n_samples_ / hop_length_}; + float* buff = logmel.Allocate(shape); + std::fill(buff, buff + logmel.NumberOfElement(), (log_spec_min + 4.0f) / 4.0f); + for (int i = 0; i < log_spec.nr(); ++i) { + auto row_len = log_spec.nc() * i; + std::copy(log_spec.begin() + i * log_spec.nc(), log_spec.begin() + (i + 1) * log_spec.nc(), buff + i * shape[1]); + } + + return {}; + } + + // Function to compute the Mel filterbank + static dlib::matrix MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0, + float max_mel = 45.245640471924965) { + // Initialize the filterbank matrix + dlib::matrix fbank(n_mels, n_fft / 2 + 1); + memset(fbank.begin(), 0, fbank.size() * sizeof(float)); + + // Compute the frequency bins for the DFT + std::vector freq_bins(n_fft / 2 + 1); + for (int i = 0; i <= n_fft / 2; ++i) { + freq_bins[i] = i * sr / static_cast(n_fft); + } + + // Compute the Mel scale frequencies + std::vector mel(n_mels + 2); + for (int i = 0; i < n_mels + 2; ++i) { + mel[i] = min_mel + i * (max_mel - min_mel) / (n_mels + 1); + } + + // Fill in the linear scale + float f_min = 0.0f; + float f_sp = 200.0f / 3.0f; + std::vector freqs(n_mels + 2); + for (int i = 0; i < n_mels + 2; ++i) { + freqs[i] = f_min + f_sp * mel[i]; + } + + // Nonlinear scale + float min_log_hz = 1000.0f; + float min_log_mel = (min_log_hz - f_min) / f_sp; + float logstep = log(6.4) / 27.0; + + for (int i = 0; i < n_mels + 2; ++i) { + if (mel[i] >= min_log_mel) { + freqs[i] = min_log_hz * exp(logstep * (mel[i] - min_log_mel)); + } + } + + std::vector mel_bins = freqs; + std::vector mel_spacing(n_mels + 1); + for (int i = 0; i < n_mels + 1; ++i) { + mel_spacing[i] = mel_bins[i + 1] - mel_bins[i]; + } + + // Compute the ramps + std::vector> ramps(n_mels + 2, std::vector(n_fft / 2 + 1)); + for (int i = 0; i < n_mels + 2; ++i) { + for (int j = 0; j <= n_fft / 2; ++j) { + ramps[i][j] = mel_bins[i] - freq_bins[j]; + } + } + + for (int i = 0; i < n_mels; ++i) { + for (int j = 0; j <= n_fft / 2; ++j) { + float left = -ramps[i][j] / mel_spacing[i]; + float right = ramps[i + 2][j] / mel_spacing[i + 1]; + fbank(i, j) = std::max(0.0f, std::min(left, right)); + } + } + + // Energy normalization + for (int i = 0; i < n_mels; ++i) { + float energy_norm = 2.0f / (mel_bins[i + 2] - mel_bins[i]); + for (int j = 0; j <= n_fft / 2; ++j) { + fbank(i, j) *= energy_norm; + } + } + + return fbank; + } + + private: + int64_t n_samples_ = {}; // sr * chunk_size + int64_t hop_length_{}; + const int64_t n_sr_{16000}; + dlib::matrix mel_filters_; +}; + +} // namespace ort_extensions diff --git a/test/data/whisper/feature_extraction.json b/test/data/whisper/feature_extraction.json new file mode 100644 index 000000000..f7deaac0f --- /dev/null +++ b/test/data/whisper/feature_extraction.json @@ -0,0 +1,437 @@ +{ + "feature_extraction": { + "sequence": [ + { + "operation": { + "name": "audio_decoder", + "type": "AudioDecoder" + } + }, + { + "operation": { + "name": "STFT", + "type": "STFTNorm", + "attrs": { + "n_fft": 400, + "frame_length": 400, + "hop_length": 160, + "_comment": [ + 0.0, + 0.0000616908073425293, + 0.0002467334270477295, + 0.0005550682544708252, + 0.000986635684967041, + 0.0015413463115692139, + 0.0022190213203430176, + 0.0030195116996765137, + 0.003942638635635376, + 0.004988163709640503, + 0.006155818700790405, + 0.007445335388183594, + 0.008856385946273804, + 0.010388582944869995, + 0.012041628360748291, + 0.013815045356750488, + 0.01570841670036316, + 0.01772129535675049, + 0.019853144884109497, + 0.022103488445281982, + 0.02447172999382019, + 0.026957333087921143, + 0.029559612274169922, + 0.03227800130844116, + 0.03511175513267517, + 0.03806024789810181, + 0.0411226749420166, + 0.044298380613327026, + 0.04758647084236145, + 0.05098623037338257, + 0.05449673533439636, + 0.058117181062698364, + 0.06184667348861694, + 0.0656842589378357, + 0.06962898373603821, + 0.07367992401123047, + 0.0778360664844513, + 0.08209633827209473, + 0.08645972609519958, + 0.09092515707015991, + 0.09549149870872498, + 0.10015767812728882, + 0.10492250323295593, + 0.1097848117351532, + 0.11474338173866272, + 0.11979702115058899, + 0.12494447827339172, + 0.13018447160720825, + 0.1355157196521759, + 0.14093685150146484, + 0.1464466154575348, + 0.15204361081123352, + 0.1577264666557312, + 0.16349375247955322, + 0.16934409737586975, + 0.1752760112285614, + 0.18128803372383118, + 0.18737870454788208, + 0.19354650378227234, + 0.1997898817062378, + 0.20610737800598145, + 0.21249738335609436, + 0.21895831823349, + 0.2254886031150818, + 0.23208662867546082, + 0.23875075578689575, + 0.24547931551933289, + 0.2522706985473633, + 0.25912320613861084, + 0.26603513956069946, + 0.27300477027893066, + 0.2800304591655731, + 0.2871103882789612, + 0.29424285888671875, + 0.30142611265182495, + 0.30865830183029175, + 0.31593772768974304, + 0.3232625722885132, + 0.3306310474872589, + 0.3380413055419922, + 0.34549152851104736, + 0.352979838848114, + 0.3605044484138489, + 0.3680635094642639, + 0.37565508484840393, + 0.38327735662460327, + 0.3909284174442291, + 0.39860638976097107, + 0.4063093662261963, + 0.41403549909591675, + 0.42178282141685486, + 0.4295494258403778, + 0.43733343482017517, + 0.44513291120529175, + 0.45294591784477234, + 0.46077051758766174, + 0.46860480308532715, + 0.4764467775821686, + 0.4842946231365204, + 0.492146372795105, + 0.5, + 0.5078536868095398, + 0.515705406665802, + 0.5235532522201538, + 0.5313953161239624, + 0.5392295718193054, + 0.5470541715621948, + 0.5548672080039978, + 0.562666654586792, + 0.5704506635665894, + 0.5782172679901123, + 0.5859646201133728, + 0.5936906933784485, + 0.6013936996459961, + 0.609071671962738, + 0.6167227625846863, + 0.6243450045585632, + 0.6319366097450256, + 0.6394955515861511, + 0.6470202207565308, + 0.6545085310935974, + 0.6619587540626526, + 0.6693689823150635, + 0.6767374277114868, + 0.6840623021125793, + 0.691341757774353, + 0.6985740065574646, + 0.7057572603225708, + 0.7128896713256836, + 0.719969630241394, + 0.7269952893257141, + 0.7339649796485901, + 0.7408769130706787, + 0.7477294206619263, + 0.7545207738876343, + 0.761249303817749, + 0.7679134607315063, + 0.774511456489563, + 0.7810417413711548, + 0.7875027060508728, + 0.7938927412033081, + 0.800210177898407, + 0.8064535856246948, + 0.8126214146614075, + 0.8187121152877808, + 0.8247240781784058, + 0.8306560516357422, + 0.8365063667297363, + 0.8422735929489136, + 0.8479564785957336, + 0.8535534143447876, + 0.8590631484985352, + 0.8644843101501465, + 0.8698155879974365, + 0.8750555515289307, + 0.8802030086517334, + 0.8852566480636597, + 0.8902152180671692, + 0.8950775265693665, + 0.899842381477356, + 0.9045084714889526, + 0.9090749025344849, + 0.9135403037071228, + 0.9179036617279053, + 0.9221639633178711, + 0.9263200759887695, + 0.9303710460662842, + 0.9343158006668091, + 0.9381533861160278, + 0.941882848739624, + 0.945503294467926, + 0.9490138292312622, + 0.9524135589599609, + 0.9557017087936401, + 0.9588773250579834, + 0.961939811706543, + 0.9648882746696472, + 0.9677220582962036, + 0.9704403877258301, + 0.9730427265167236, + 0.9755282998085022, + 0.9778965711593628, + 0.9801468849182129, + 0.9822787046432495, + 0.9842916131019592, + 0.9861849546432495, + 0.9879584312438965, + 0.9896113872528076, + 0.9911436438560486, + 0.9925546646118164, + 0.9938441514968872, + 0.9950118064880371, + 0.996057391166687, + 0.9969804883003235, + 0.997780978679657, + 0.9984586238861084, + 0.999013364315033, + 0.9994449615478516, + 0.9997532367706299, + 0.9999383091926575, + 1, + 0.9999383091926575, + 0.9997532367706299, + 0.9994449615478516, + 0.999013364315033, + 0.9984586238861084, + 0.997780978679657, + 0.9969804286956787, + 0.9960573315620422, + 0.9950118064880371, + 0.9938441514968872, + 0.9925546646118164, + 0.9911435842514038, + 0.9896113872528076, + 0.9879583716392517, + 0.9861849546432495, + 0.9842915534973145, + 0.9822787046432495, + 0.9801468253135681, + 0.9778964519500732, + 0.9755282402038574, + 0.9730426073074341, + 0.9704403877258301, + 0.9677219390869141, + 0.9648882150650024, + 0.9619396924972534, + 0.9588772654533386, + 0.9557015895843506, + 0.9524134397506714, + 0.9490137100219727, + 0.9455032348632812, + 0.9418827295303345, + 0.9381532669067383, + 0.9343156814575195, + 0.9303709268569946, + 0.9263200759887695, + 0.9221639633178711, + 0.9179036617279053, + 0.913540244102478, + 0.9090747833251953, + 0.9045084714889526, + 0.8998422622680664, + 0.8950774669647217, + 0.8902151584625244, + 0.8852565884590149, + 0.8802029490470886, + 0.8750554919242859, + 0.869815468788147, + 0.8644842505455017, + 0.8590630888938904, + 0.853553295135498, + 0.8479562997817993, + 0.842273473739624, + 0.836506187915802, + 0.8306558728218079, + 0.8247239589691162, + 0.8187118768692017, + 0.8126212358474731, + 0.8064534664154053, + 0.8002099990844727, + 0.793892502784729, + 0.7875025272369385, + 0.7810416221618652, + 0.7745113372802734, + 0.767913281917572, + 0.7612491846084595, + 0.7545205950737, + 0.7477291822433472, + 0.7408767342567444, + 0.7339648008346558, + 0.7269951105117798, + 0.7199694514274597, + 0.7128894925117493, + 0.7057570219039917, + 0.6985738277435303, + 0.6913415789604187, + 0.684062123298645, + 0.6767372488975525, + 0.6693688035011292, + 0.6619585752487183, + 0.6545083522796631, + 0.6470199823379517, + 0.6394953727722168, + 0.6319363117218018, + 0.6243447661399841, + 0.6167224645614624, + 0.6090714335441589, + 0.601393461227417, + 0.5936904549598694, + 0.5859643220901489, + 0.5782170295715332, + 0.5704504251480103, + 0.5626664161682129, + 0.5548669099807739, + 0.5470539331436157, + 0.5392293334007263, + 0.5313950181007385, + 0.5235530138015747, + 0.5157051682472229, + 0.507853627204895, + 0.5, + 0.4921463429927826, + 0.484294593334198, + 0.4764467477798462, + 0.46860471367836, + 0.4607704281806946, + 0.4529458284378052, + 0.4451328217983246, + 0.437333345413208, + 0.42954933643341064, + 0.4217827320098877, + 0.4140354096889496, + 0.4063093066215515, + 0.3986063003540039, + 0.39092832803726196, + 0.3832772672176361, + 0.37565499544143677, + 0.36806342005729675, + 0.3605043888092041, + 0.35297977924346924, + 0.3454914391040802, + 0.338041216135025, + 0.33063095808029175, + 0.3232625126838684, + 0.3159376382827759, + 0.3086581826210022, + 0.3014259934425354, + 0.2942427396774292, + 0.28711026906967163, + 0.2800303101539612, + 0.2730046510696411, + 0.2660350203514099, + 0.2591230869293213, + 0.25227057933807373, + 0.24547919631004333, + 0.2387506067752838, + 0.23208650946617126, + 0.22548848390579224, + 0.21895819902420044, + 0.2124972641468048, + 0.2061072587966919, + 0.19978976249694824, + 0.1935463547706604, + 0.18737855553627014, + 0.18128788471221924, + 0.17527586221694946, + 0.1693439483642578, + 0.16349363327026367, + 0.15772631764411926, + 0.15204349160194397, + 0.14644649624824524, + 0.1409367322921753, + 0.13551557064056396, + 0.1301843225955963, + 0.12494435906410217, + 0.11979690194129944, + 0.11474326252937317, + 0.10978469252586365, + 0.10492238402366638, + 0.10015755891799927, + 0.09549137949943542, + 0.09092503786087036, + 0.08645960688591003, + 0.08209621906280518, + 0.07783591747283936, + 0.07367980480194092, + 0.06962886452674866, + 0.06568413972854614, + 0.06184655427932739, + 0.0581170916557312, + 0.0544966459274292, + 0.05098611116409302, + 0.04758638143539429, + 0.044298261404037476, + 0.04112258553504944, + 0.038060128688812256, + 0.03511166572570801, + 0.03227788209915161, + 0.02955952286720276, + 0.02695724368095398, + 0.024471670389175415, + 0.02210339903831482, + 0.01985308527946472, + 0.017721205949783325, + 0.015708357095718384, + 0.0138150155544281, + 0.012041598558425903, + 0.010388582944869995, + 0.008856356143951416, + 0.007445335388183594, + 0.006155818700790405, + 0.004988163709640503, + 0.003942638635635376, + 0.0030195116996765137, + 0.0022190213203430176, + 0.0015413165092468262, + 0.000986635684967041, + 0.0005550682544708252, + 0.0002467334270477295, + 0.0000616908073425293 + ] + } + } + }, + { + "operation": { + "name": "log_mel_spectrogram", + "type": "LogMelSpectrum", + "attrs": { + "chunk_size": 30, + "hop_length": 160, + "n_fft": 400, + "n_mel": 80 + } + } + } + ] + } +} \ No newline at end of file diff --git a/test/pp_api_test/c_only_test.h b/test/pp_api_test/c_only_test.h index ed414ace6..1a20b8612 100644 --- a/test/pp_api_test/c_only_test.h +++ b/test/pp_api_test/c_only_test.h @@ -4,7 +4,9 @@ #pragma once #include "ortx_tokenizer.h" +// make sure the C only compiler compatibility only. #include "ortx_processor.h" +#include "ortx_extractor.h" #ifdef __cplusplus diff --git a/test/pp_api_test/test_feature_extraction.cc b/test/pp_api_test/test_feature_extraction.cc new file mode 100644 index 000000000..9c3076daf --- /dev/null +++ b/test/pp_api_test/test_feature_extraction.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ortx_cpp_helper.h" +#include "shared/api/speech_extractor.h" + +using namespace ort_extensions; + +TEST(ExtractorTest, TestWhisperFeatureExtraction) { + const char* audio_path[] = {"data/jfk.flac", "data/1272-141231-0002.wav", "data/1272-141231-0002.mp3"}; + OrtxObjectPtr raw_audios; + extError_t err = OrtxLoadAudios(ort_extensions::ptr(raw_audios), audio_path, 3); + ASSERT_EQ(err, kOrtxOK); + + OrtxObjectPtr feature_extractor(OrtxCreateSpeechFeatureExtractor, "data/whisper/feature_extraction.json"); + OrtxObjectPtr result; + err = OrtxSpeechLogMel(feature_extractor.get(), raw_audios.get(), ort_extensions::ptr(result)); + ASSERT_EQ(err, kOrtxOK); + + OrtxObjectPtr tensor; + err = OrtxTensorResultGetAt(result.get(), 0, ort_extensions::ptr(tensor)); + ASSERT_EQ(err, kOrtxOK); + + const float* data{}; + const int64_t* shape{}; + size_t num_dims; + err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims); + ASSERT_EQ(err, kOrtxOK); + ASSERT_EQ(num_dims, 3); + ASSERT_EQ(shape[0], 3); + ASSERT_EQ(shape[1], 80); + ASSERT_EQ(shape[2], 3000); +} diff --git a/test/pp_api_test/test_processor.cc b/test/pp_api_test/test_processor.cc index df06e54e8..1fcf132db 100644 --- a/test/pp_api_test/test_processor.cc +++ b/test/pp_api_test/test_processor.cc @@ -7,7 +7,7 @@ #include #include "gtest/gtest.h" -#include "ortx_c_helper.h" +#include "ortx_cpp_helper.h" #include "shared/api/image_processor.h" using namespace ort_extensions; @@ -85,18 +85,18 @@ TEST(ProcessorTest, TestClipImageProcessing) { } ASSERT_EQ(err, kOrtxOK); - OrtxObjectPtr result; + OrtxObjectPtr result; err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result)); ASSERT_EQ(err, kOrtxOK); - OrtxObjectPtr tensor; - err = OrtxImageGetTensorResult(result.get(), 0, ort_extensions::ptr(tensor)); + OrtxTensor* tensor; + err = OrtxTensorResultGetAt(result.get(), 0, &tensor); ASSERT_EQ(err, kOrtxOK); const float* data{}; const int64_t* shape{}; size_t num_dims; - err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims); + err = OrtxGetTensorDataFloat(tensor, &data, &shape, &num_dims); ASSERT_EQ(err, kOrtxOK); ASSERT_EQ(num_dims, 4); } diff --git a/test/static_test/test_tenor_api.cc b/test/static_test/test_tensor_api.cc similarity index 100% rename from test/static_test/test_tenor_api.cc rename to test/static_test/test_tensor_api.cc