diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..c65e7720f --- /dev/null +++ b/.clang-format @@ -0,0 +1,9 @@ +--- +BasedOnStyle: Google +--- +Language: Cpp +Cpp11BracedListStyle: true +Standard: Cpp11 +DerivePointerAlignment: false +PointerAlignment: Right +--- diff --git a/.github/scripts/install-kaldifeat.sh b/.github/scripts/install-kaldifeat.sh new file mode 100755 index 000000000..6666a5064 --- /dev/null +++ b/.github/scripts/install-kaldifeat.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# This script installs kaldifeat into the directory ~/tmp/kaldifeat +# which is cached by GitHub actions for later runs. + +mkdir -p ~/tmp +cd ~/tmp +git clone https://github.com/csukuangfj/kaldifeat +cd kaldifeat +mkdir build +cd build +cmake -DCMAKE_BUILD_TYPE=Release .. +make -j2 _kaldifeat diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml new file mode 100644 index 000000000..5194680c9 --- /dev/null +++ b/.github/workflows/publish_to_pypi.yml @@ -0,0 +1,38 @@ +name: Publish to PyPI + +on: + push: + tags: + - '*' + +jobs: + pypi: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install Python dependencies + shell: bash + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + - name: Build + shell: bash + run: | + python3 setup.py sdist + ls -l dist/* + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + twine upload dist/k2-sherpa-*.tar.gz diff --git a/.github/workflows/run-test.yaml b/.github/workflows/run-test.yaml new file mode 100644 index 000000000..497b2b5c0 --- /dev/null +++ b/.github/workflows/run-test.yaml @@ -0,0 +1,122 @@ + +# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +name: Run tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + run_tests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-18.04, macos-10.15] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + python-version: [3.7, 3.8, 3.9] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install GCC 7 + if: startsWith(matrix.os, 'ubuntu') + run: | + sudo apt-get install -y gcc-7 g++-7 + echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV + echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + if: startsWith(matrix.os, 'ubuntu') + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine typing_extensions websockets sentencepiece>=0.1.96 + python3 -m pip install -qq torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + if: startsWith(matrix.os, 'macos') + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine typing_extensions websockets sentencepiece>=0.1.96 + python3 -m pip install -qq torch==${{ matrix.torch }} torchaudio==${{ matrix.torchaudio }} numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-${{ matrix.os }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Install sherpa + shell: bash + run: | + python3 setup.py install + + - name: Download pretrained model and test-data + shell: bash + run: | + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + + - name: Start server + shell: bash + run: | + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + sherpa/bin/offline_server.py \ + --port 6006 \ + --num-device 0 \ + --max-batch-size 10 \ + --max-wait-ms 5 \ + --feature-extractor-pool-size 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/cpu_jit.pt \ + --bpe-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model & + echo "Sleep 10 seconds to wait for the server startup" + sleep 10 + + - name: Start client + shell: bash + run: | + sherpa/bin/offline_client.py \ + --server-addr localhost \ + --server-port 6006 \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1089-134686-0001.wav \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1221-135766-0001.wav \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1221-135766-0002.wav diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..31a3a989e --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +path.sh +build +dist +__pycache__ +*.egg-info diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..aa5f484d4 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.8 FATAL_ERROR) +project(sherpa) + +set(SHERPA_VERSION "0.1") + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") + +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(BUILD_RPATH_USE_ORIGIN TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +set(CMAKE_INSTALL_RPATH "$ORIGIN") +set(CMAKE_BUILD_RPATH "$ORIGIN") + +set(BUILD_SHARED_LIBS ON) +if(WIN32) + message(STATUS "Set BUILD_SHARED_LIBS to OFF for Windows") + set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) +endif() + +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) + +include(pybind11) +include(torch) + +include_directories(${CMAKE_SOURCE_DIR}) + +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +add_subdirectory(sherpa) diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..ee06cfc77 --- /dev/null +++ b/LICENSE @@ -0,0 +1,211 @@ + + Legal Notices + + NOTE (this is not from the Apache License): The copyright model is that + authors (or their employers, if noted in individual files) own their + individual contributions. The authors' contributions can be discerned + from the git history. + + ------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..09b52372a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include LICENSE +include README.md +include CMakeLists.txt +recursive-include sherpa *.* +recursive-include cmake *.* diff --git a/README.md b/README.md index 923d71b52..998708e16 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,149 @@ ## Introduction -An ASR server framework supporting both streaming and non-streaming recognition. +An ASR server framework in **Python**, aiming to support both streaming +and non-streaming recognition. -Most parts will be implemented in Python, while CPU-bound tasks are implemented -in C++, which are called by Python threads with the GIL being released. +**Note**: Only non-streaming recognition is implemented at present. We +will add streaming recognition later. -## TODOs +CPU-bound tasks, such as neural network computation, are implemented in +C++; while IO-bound tasks, such as socket communication, are implemented +in Python. -- [ ] Support non-streaming recognition -- [ ] Documentation for installation and usage -- [ ] Support streaming recognition +**Caution**: We assume the model is trained using pruned stateless RNN-T +from [icefall][icefall] and it is from a directory like +`pruned_transducer_statelessX` where `X` >=2. + +## Installation + +First, you have to install `PyTorch` and `torchaudio`. PyTorch 1.10 is known +to work. Other versions may also work. + +Second, clone this repository + +```bash +git clone https://github.com/k2-fsa/sherpa +cd sherpa +pip install -r ./requirements.txt +``` + +Third, install the C++ extension of `sherpa`. You can use one of +the following methods. + +### Option 1: Use `pip` + +```bash +pip install --verbose k2-sherpa +``` + +### Option 2: Build from source with `setup.py` + +```bash +python3 setup.py install +``` + +### Option 3: Build from source with `cmake` + +```bash +mkdir build +cd build +cmake .. +make -j +export PYTHONPATH=$PWD/../sherpa/python:$PWD/lib:$PYTHONPATH +``` + + +## Usage + +First, check that `sherpa` has been installed successfully: + +```bash +python3 -c "import sherpa; print(sherpa.__version__)" +``` + +It should print the version of `sherpa`. + +### Start the server + +To start the server, you need to first generate two files: + +- (1) The torch script model file. You can use `export.py --jit=1` in +`pruned_transducer_statelessX` from [icefall][icefall]. + +- (2) The BPE model file. You can find it in `data/lang_bpe_XXX/bpe.model` +in [icefall][icefall], where `XXX` is the number of BPE tokens used in +the training. + +With the above two files ready, you can start the server with the +following command: + +```bash +sherpa/bin/offline_server.py \ + --port 6006 \ + --num-device 0 \ + --max-batch-size 10 \ + --max-wait-ms 5 \ + --feature-extractor-pool-size 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./path/to/exp/cpu_jit.pt \ + --bpe-model-filename ./path/to/data/lang_bpe_500/bpe.model & +``` + +You can use `./sherpa/bin/offline_server.py --help` to view the help message. + +We provide a pretrained model using the LibriSpeech dataset at + + +The following shows how to use the above pretrained model to start the server. + +```bash +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + +sherpa/bin/offline_server.py \ + --port 6006 \ + --num-device 0 \ + --max-batch-size 10 \ + --max-wait-ms 5 \ + --feature-extractor-pool-size 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/cpu_jit.pt \ + --bpe-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model +``` + +### Start the client +After starting the server, you can use the following command to start the client: + +```bash +./sherpa/bin/offline_client.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav +``` + +You can use `./sherpa/bin/offline_client.py --help` to view the usage message. + +The following shows how to use the client to send some test waves to the server +for recognition. + +```bash +sherpa/bin/offline_client.py \ + --server-addr localhost \ + --server-port 6006 \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1089-134686-0001.wav \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1221-135766-0001.wav \ + icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1221-135766-0002.wav +``` + +### RTF test + +We provide a demo [./sherpa/bin/decode_mainifest.py](./sherpa/bin/decode_mainifest.py) +to decode the `test-clean` dataset from the LibriSpeech corpus. + +It creates 50 connections to the server using websockets and sends audio files +to the server for recognition. + +At the end, it will display the RTF and the WER. + +[icefall]: https://github.com/k2-fsa/icefall/ diff --git a/cmake/Modules/FetchContent.cmake b/cmake/Modules/FetchContent.cmake new file mode 100644 index 000000000..98cdf6cb9 --- /dev/null +++ b/cmake/Modules/FetchContent.cmake @@ -0,0 +1,916 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FetchContent +------------------ + +.. only:: html + + .. contents:: + +Overview +^^^^^^^^ + +This module enables populating content at configure time via any method +supported by the :module:`ExternalProject` module. Whereas +:command:`ExternalProject_Add` downloads at build time, the +``FetchContent`` module makes content available immediately, allowing the +configure step to use the content in commands like :command:`add_subdirectory`, +:command:`include` or :command:`file` operations. + +Content population details would normally be defined separately from the +command that performs the actual population. Projects should also +check whether the content has already been populated somewhere else in the +project hierarchy. Typical usage would look something like this: + +.. code-block:: cmake + + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.8.0 + ) + + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) + endif() + +When using the above pattern with a hierarchical project arrangement, +projects at higher levels in the hierarchy are able to define or override +the population details of content specified anywhere lower in the project +hierarchy. The ability to detect whether content has already been +populated ensures that even if multiple child projects want certain content +to be available, the first one to populate it wins. The other child project +can simply make use of the already available content instead of repeating +the population for itself. See the +:ref:`Examples ` section which demonstrates +this scenario. + +The ``FetchContent`` module also supports defining and populating +content in a single call, with no check for whether the content has been +populated elsewhere in the project already. This is a more low level +operation and would not normally be the way the module is used, but it is +sometimes useful as part of implementing some higher level feature or to +populate some content in CMake's script mode. + + +Declaring Content Details +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. command:: FetchContent_Declare + + .. code-block:: cmake + + FetchContent_Declare( ...) + + The ``FetchContent_Declare()`` function records the options that describe + how to populate the specified content, but if such details have already + been recorded earlier in this project (regardless of where in the project + hierarchy), this and all later calls for the same content ```` are + ignored. This "first to record, wins" approach is what allows hierarchical + projects to have parent projects override content details of child projects. + + The content ```` can be any string without spaces, but good practice + would be to use only letters, numbers and underscores. The name will be + treated case-insensitively and it should be obvious for the content it + represents, often being the name of the child project or the value given + to its top level :command:`project` command (if it is a CMake project). + For well-known public projects, the name should generally be the official + name of the project. Choosing an unusual name makes it unlikely that other + projects needing that same content will use the same name, leading to + the content being populated multiple times. + + The ```` can be any of the download or update/patch options + that the :command:`ExternalProject_Add` command understands. The configure, + build, install and test steps are explicitly disabled and therefore options + related to them will be ignored. In most cases, ```` will + just be a couple of options defining the download method and method-specific + details like a commit tag or archive hash. For example: + + .. code-block:: cmake + + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.8.0 + ) + + FetchContent_Declare( + myCompanyIcons + URL https://intranet.mycompany.com/assets/iconset_1.12.tar.gz + URL_HASH 5588a7b18261c20068beabfb4f530b87 + ) + + FetchContent_Declare( + myCompanyCertificates + SVN_REPOSITORY svn+ssh://svn.mycompany.com/srv/svn/trunk/certs + SVN_REVISION -r12345 + ) + +Populating The Content +^^^^^^^^^^^^^^^^^^^^^^ + +.. command:: FetchContent_Populate + + .. code-block:: cmake + + FetchContent_Populate( ) + + In most cases, the only argument given to ``FetchContent_Populate()`` is the + ````. When used this way, the command assumes the content details have + been recorded by an earlier call to :command:`FetchContent_Declare`. The + details are stored in a global property, so they are unaffected by things + like variable or directory scope. Therefore, it doesn't matter where in the + project the details were previously declared, as long as they have been + declared before the call to ``FetchContent_Populate()``. Those saved details + are then used to construct a call to :command:`ExternalProject_Add` in a + private sub-build to perform the content population immediately. The + implementation of ``ExternalProject_Add()`` ensures that if the content has + already been populated in a previous CMake run, that content will be reused + rather than repopulating them again. For the common case where population + involves downloading content, the cost of the download is only paid once. + + An internal global property records when a particular content population + request has been processed. If ``FetchContent_Populate()`` is called more + than once for the same content name within a configure run, the second call + will halt with an error. Projects can and should check whether content + population has already been processed with the + :command:`FetchContent_GetProperties` command before calling + ``FetchContent_Populate()``. + + ``FetchContent_Populate()`` will set three variables in the scope of the + caller; ``_POPULATED``, ``_SOURCE_DIR`` and + ``_BINARY_DIR``, where ```` is the lowercased ````. + ``_POPULATED`` will always be set to ``True`` by the call. + ``_SOURCE_DIR`` is the location where the + content can be found upon return (it will have already been populated), while + ``_BINARY_DIR`` is a directory intended for use as a corresponding + build directory. The main use case for the two directory variables is to + call :command:`add_subdirectory` immediately after population, i.e.: + + .. code-block:: cmake + + FetchContent_Populate(FooBar ...) + add_subdirectory(${foobar_SOURCE_DIR} ${foobar_BINARY_DIR}) + + The values of the three variables can also be retrieved from anywhere in the + project hierarchy using the :command:`FetchContent_GetProperties` command. + + A number of cache variables influence the behavior of all content population + performed using details saved from a :command:`FetchContent_Declare` call: + + ``FETCHCONTENT_BASE_DIR`` + In most cases, the saved details do not specify any options relating to the + directories to use for the internal sub-build, final source and build areas. + It is generally best to leave these decisions up to the ``FetchContent`` + module to handle on the project's behalf. The ``FETCHCONTENT_BASE_DIR`` + cache variable controls the point under which all content population + directories are collected, but in most cases developers would not need to + change this. The default location is ``${CMAKE_BINARY_DIR}/_deps``, but if + developers change this value, they should aim to keep the path short and + just below the top level of the build tree to avoid running into path + length problems on Windows. + + ``FETCHCONTENT_QUIET`` + The logging output during population can be quite verbose, making the + configure stage quite noisy. This cache option (``ON`` by default) hides + all population output unless an error is encountered. If experiencing + problems with hung downloads, temporarily switching this option off may + help diagnose which content population is causing the issue. + + ``FETCHCONTENT_FULLY_DISCONNECTED`` + When this option is enabled, no attempt is made to download or update + any content. It is assumed that all content has already been populated in + a previous run or the source directories have been pointed at existing + contents the developer has provided manually (using options described + further below). When the developer knows that no changes have been made to + any content details, turning this option ``ON`` can significantly speed up + the configure stage. It is ``OFF`` by default. + + ``FETCHCONTENT_UPDATES_DISCONNECTED`` + This is a less severe download/update control compared to + ``FETCHCONTENT_FULLY_DISCONNECTED``. Instead of bypassing all download and + update logic, the ``FETCHCONTENT_UPDATES_DISCONNECTED`` only disables the + update stage. Therefore, if content has not been downloaded previously, + it will still be downloaded when this option is enabled. This can speed up + the configure stage, but not as much as + ``FETCHCONTENT_FULLY_DISCONNECTED``. It is ``OFF`` by default. + + In addition to the above cache variables, the following cache variables are + also defined for each content name (```` is the uppercased value of + ````): + + ``FETCHCONTENT_SOURCE_DIR_`` + If this is set, no download or update steps are performed for the specified + content and the ``_SOURCE_DIR`` variable returned to the caller is + pointed at this location. This gives developers a way to have a separate + checkout of the content that they can modify freely without interference + from the build. The build simply uses that existing source, but it still + defines ``_BINARY_DIR`` to point inside its own build area. + Developers are strongly encouraged to use this mechanism rather than + editing the sources populated in the default location, as changes to + sources in the default location can be lost when content population details + are changed by the project. + + ``FETCHCONTENT_UPDATES_DISCONNECTED_`` + This is the per-content equivalent of + ``FETCHCONTENT_UPDATES_DISCONNECTED``. If the global option or this option + is ``ON``, then updates will be disabled for the named content. + Disabling updates for individual content can be useful for content whose + details rarely change, while still leaving other frequently changing + content with updates enabled. + + + The ``FetchContent_Populate()`` command also supports a syntax allowing the + content details to be specified directly rather than using any saved + details. This is more low-level and use of this form is generally to be + avoided in favour of using saved content details as outlined above. + Nevertheless, in certain situations it can be useful to invoke the content + population as an isolated operation (typically as part of implementing some + other higher level feature or when using CMake in script mode): + + .. code-block:: cmake + + FetchContent_Populate( + [QUIET] + [SUBBUILD_DIR ] + [SOURCE_DIR ] + [BINARY_DIR ] + ... + ) + + This form has a number of key differences to that where only ```` is + provided: + + - All required population details are assumed to have been provided directly + in the call to ``FetchContent_Populate()``. Any saved details for + ```` are ignored. + - No check is made for whether content for ```` has already been + populated. + - No global property is set to record that the population has occurred. + - No global properties record the source or binary directories used for the + populated content. + - The ``FETCHCONTENT_FULLY_DISCONNECTED`` and + ``FETCHCONTENT_UPDATES_DISCONNECTED`` cache variables are ignored. + + The ``_SOURCE_DIR`` and ``_BINARY_DIR`` variables are still + returned to the caller, but since these locations are not stored as global + properties when this form is used, they are only available to the calling + scope and below rather than the entire project hierarchy. No + ``_POPULATED`` variable is set in the caller's scope with this form. + + The supported options for ``FetchContent_Populate()`` are the same as those + for :command:`FetchContent_Declare()`. Those few options shown just + above are either specific to ``FetchContent_Populate()`` or their behavior is + slightly modified from how :command:`ExternalProject_Add` treats them. + + ``QUIET`` + The ``QUIET`` option can be given to hide the output associated with + populating the specified content. If the population fails, the output will + be shown regardless of whether this option was given or not so that the + cause of the failure can be diagnosed. The global ``FETCHCONTENT_QUIET`` + cache variable has no effect on ``FetchContent_Populate()`` calls where the + content details are provided directly. + + ``SUBBUILD_DIR`` + The ``SUBBUILD_DIR`` argument can be provided to change the location of the + sub-build created to perform the population. The default value is + ``${CMAKE_CURRENT_BINARY_DIR}/-subbuild`` and it would be unusual + to need to override this default. If a relative path is specified, it will + be interpreted as relative to :variable:`CMAKE_CURRENT_BINARY_DIR`. + + ``SOURCE_DIR``, ``BINARY_DIR`` + The ``SOURCE_DIR`` and ``BINARY_DIR`` arguments are supported by + :command:`ExternalProject_Add`, but different default values are used by + ``FetchContent_Populate()``. ``SOURCE_DIR`` defaults to + ``${CMAKE_CURRENT_BINARY_DIR}/-src`` and ``BINARY_DIR`` defaults to + ``${CMAKE_CURRENT_BINARY_DIR}/-build``. If a relative path is + specified, it will be interpreted as relative to + :variable:`CMAKE_CURRENT_BINARY_DIR`. + + In addition to the above explicit options, any other unrecognized options are + passed through unmodified to :command:`ExternalProject_Add` to perform the + download, patch and update steps. The following options are explicitly + prohibited (they are disabled by the ``FetchContent_Populate()`` command): + + - ``CONFIGURE_COMMAND`` + - ``BUILD_COMMAND`` + - ``INSTALL_COMMAND`` + - ``TEST_COMMAND`` + + If using ``FetchContent_Populate()`` within CMake's script mode, be aware + that the implementation sets up a sub-build which therefore requires a CMake + generator and build tool to be available. If these cannot be found by + default, then the :variable:`CMAKE_GENERATOR` and/or + :variable:`CMAKE_MAKE_PROGRAM` variables will need to be set appropriately + on the command line invoking the script. + + +Retrieve Population Properties +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. command:: FetchContent_GetProperties + + When using saved content details, a call to :command:`FetchContent_Populate` + records information in global properties which can be queried at any time. + This information includes the source and binary directories associated with + the content and also whether or not the content population has been processed + during the current configure run. + + .. code-block:: cmake + + FetchContent_GetProperties( + [SOURCE_DIR ] + [BINARY_DIR ] + [POPULATED ] + ) + + The ``SOURCE_DIR``, ``BINARY_DIR`` and ``POPULATED`` options can be used to + specify which properties should be retrieved. Each option accepts a value + which is the name of the variable in which to store that property. Most of + the time though, only ```` is given, in which case the call will then + set the same variables as a call to + :command:`FetchContent_Populate(name) `. This allows + the following canonical pattern to be used, which ensures that the relevant + variables will always be defined regardless of whether or not the population + has been performed elsewhere in the project already: + + .. code-block:: cmake + + FetchContent_GetProperties(foobar) + if(NOT foobar_POPULATED) + FetchContent_Populate(foobar) + + # Set any custom variables, etc. here, then + # populate the content as part of this build + + add_subdirectory(${foobar_SOURCE_DIR} ${foobar_BINARY_DIR}) + endif() + + The above pattern allows other parts of the overall project hierarchy to + re-use the same content and ensure that it is only populated once. + + +.. _`fetch-content-examples`: + +Examples +^^^^^^^^ + +Consider a project hierarchy where ``projA`` is the top level project and it +depends on projects ``projB`` and ``projC``. Both ``projB`` and ``projC`` +can be built standalone and they also both depend on another project +``projD``. For simplicity, this example will assume that all four projects +are available on a company git server. The ``CMakeLists.txt`` of each project +might have sections like the following: + +*projA*: + +.. code-block:: cmake + + include(FetchContent) + FetchContent_Declare( + projB + GIT_REPOSITORY git@mycompany.com/git/projB.git + GIT_TAG 4a89dc7e24ff212a7b5167bef7ab079d + ) + FetchContent_Declare( + projC + GIT_REPOSITORY git@mycompany.com/git/projC.git + GIT_TAG 4ad4016bd1d8d5412d135cf8ceea1bb9 + ) + FetchContent_Declare( + projD + GIT_REPOSITORY git@mycompany.com/git/projD.git + GIT_TAG origin/integrationBranch + ) + + FetchContent_GetProperties(projB) + if(NOT projb_POPULATED) + FetchContent_Populate(projB) + add_subdirectory(${projb_SOURCE_DIR} ${projb_BINARY_DIR}) + endif() + + FetchContent_GetProperties(projC) + if(NOT projc_POPULATED) + FetchContent_Populate(projC) + add_subdirectory(${projc_SOURCE_DIR} ${projc_BINARY_DIR}) + endif() + +*projB*: + +.. code-block:: cmake + + include(FetchContent) + FetchContent_Declare( + projD + GIT_REPOSITORY git@mycompany.com/git/projD.git + GIT_TAG 20b415f9034bbd2a2e8216e9a5c9e632 + ) + + FetchContent_GetProperties(projD) + if(NOT projd_POPULATED) + FetchContent_Populate(projD) + add_subdirectory(${projd_SOURCE_DIR} ${projd_BINARY_DIR}) + endif() + + +*projC*: + +.. code-block:: cmake + + include(FetchContent) + FetchContent_Declare( + projD + GIT_REPOSITORY git@mycompany.com/git/projD.git + GIT_TAG 7d9a17ad2c962aa13e2fbb8043fb6b8a + ) + + FetchContent_GetProperties(projD) + if(NOT projd_POPULATED) + FetchContent_Populate(projD) + add_subdirectory(${projd_SOURCE_DIR} ${projd_BINARY_DIR}) + endif() + +A few key points should be noted in the above: + +- ``projB`` and ``projC`` define different content details for ``projD``, + but ``projA`` also defines a set of content details for ``projD`` and + because ``projA`` will define them first, the details from ``projB`` and + ``projC`` will not be used. The override details defined by ``projA`` + are not required to match either of those from ``projB`` or ``projC``, but + it is up to the higher level project to ensure that the details it does + define still make sense for the child projects. +- While ``projA`` defined content details for ``projD``, it did not need + to explicitly call ``FetchContent_Populate(projD)`` itself. Instead, it + leaves that to a child project to do (in this case it will be ``projB`` + since it is added to the build ahead of ``projC``). If ``projA`` needed to + customize how the ``projD`` content was brought into the build as well + (e.g. define some CMake variables before calling + :command:`add_subdirectory` after populating), it would do the call to + ``FetchContent_Populate()``, etc. just as it did for the ``projB`` and + ``projC`` content. For higher level projects, it is usually enough to + just define the override content details and leave the actual population + to the child projects. This saves repeating the same thing at each level + of the project hierarchy unnecessarily. +- Even though ``projA`` is the top level project in this example, it still + checks whether ``projB`` and ``projC`` have already been populated before + going ahead to do those populations. This makes ``projA`` able to be more + easily incorporated as a child of some other higher level project in the + future if required. Always protect a call to + :command:`FetchContent_Populate` with a check to + :command:`FetchContent_GetProperties`, even in what may be considered a top + level project at the time. + + +The following example demonstrates how one might download and unpack a +firmware tarball using CMake's :manual:`script mode `. The call to +:command:`FetchContent_Populate` specifies all the content details and the +unpacked firmware will be placed in a ``firmware`` directory below the +current working directory. + +*getFirmware.cmake*: + +.. code-block:: cmake + + # NOTE: Intended to be run in script mode with cmake -P + include(FetchContent) + FetchContent_Populate( + firmware + URL https://mycompany.com/assets/firmware-1.23-arm.tar.gz + URL_HASH MD5=68247684da89b608d466253762b0ff11 + SOURCE_DIR firmware + ) + +#]=======================================================================] + + +set(__FetchContent_privateDir "${CMAKE_CURRENT_LIST_DIR}/FetchContent") + +#======================================================================= +# Recording and retrieving content details for later population +#======================================================================= + +# Internal use, projects must not call this directly. It is +# intended for use by FetchContent_Declare() only. +# +# Sets a content-specific global property (not meant for use +# outside of functions defined here in this file) which can later +# be retrieved using __FetchContent_getSavedDetails() with just the +# same content name. If there is already a value stored in the +# property, it is left unchanged and this call has no effect. +# This allows parent projects to define the content details, +# overriding anything a child project may try to set (properties +# are not cached between runs, so the first thing to set it in a +# build will be in control). +function(__FetchContent_declareDetails contentName) + + string(TOLOWER ${contentName} contentNameLower) + set(propertyName "_FetchContent_${contentNameLower}_savedDetails") + get_property(alreadyDefined GLOBAL PROPERTY ${propertyName} DEFINED) + if(NOT alreadyDefined) + define_property(GLOBAL PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} ${ARGN}) + endif() + +endfunction() + + +# Internal use, projects must not call this directly. It is +# intended for use by the FetchContent_Declare() function. +# +# Retrieves details saved for the specified content in an +# earlier call to __FetchContent_declareDetails(). +function(__FetchContent_getSavedDetails contentName outVar) + + string(TOLOWER ${contentName} contentNameLower) + set(propertyName "_FetchContent_${contentNameLower}_savedDetails") + get_property(alreadyDefined GLOBAL PROPERTY ${propertyName} DEFINED) + if(NOT alreadyDefined) + message(FATAL_ERROR "No content details recorded for ${contentName}") + endif() + get_property(propertyValue GLOBAL PROPERTY ${propertyName}) + set(${outVar} "${propertyValue}" PARENT_SCOPE) + +endfunction() + + +# Saves population details of the content, sets defaults for the +# SOURCE_DIR and BUILD_DIR. +function(FetchContent_Declare contentName) + + set(options "") + set(oneValueArgs SVN_REPOSITORY) + set(multiValueArgs "") + + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + unset(srcDirSuffix) + unset(svnRepoArgs) + if(ARG_SVN_REPOSITORY) + # Add a hash of the svn repository URL to the source dir. This works + # around the problem where if the URL changes, the download would + # fail because it tries to checkout/update rather than switch the + # old URL to the new one. We limit the hash to the first 7 characters + # so that the source path doesn't get overly long (which can be a + # problem on windows due to path length limits). + string(SHA1 urlSHA ${ARG_SVN_REPOSITORY}) + string(SUBSTRING ${urlSHA} 0 7 urlSHA) + set(srcDirSuffix "-${urlSHA}") + set(svnRepoArgs SVN_REPOSITORY ${ARG_SVN_REPOSITORY}) + endif() + + string(TOLOWER ${contentName} contentNameLower) + __FetchContent_declareDetails( + ${contentNameLower} + SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src${srcDirSuffix}" + BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build" + ${svnRepoArgs} + # List these last so they can override things we set above + ${ARG_UNPARSED_ARGUMENTS} + ) + +endfunction() + + +#======================================================================= +# Set/get whether the specified content has been populated yet. +# The setter also records the source and binary dirs used. +#======================================================================= + +# Internal use, projects must not call this directly. It is +# intended for use by the FetchContent_Populate() function to +# record when FetchContent_Populate() is called for a particular +# content name. +function(__FetchContent_setPopulated contentName sourceDir binaryDir) + + string(TOLOWER ${contentName} contentNameLower) + set(prefix "_FetchContent_${contentNameLower}") + + set(propertyName "${prefix}_sourceDir") + define_property(GLOBAL PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} ${sourceDir}) + + set(propertyName "${prefix}_binaryDir") + define_property(GLOBAL PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} ${binaryDir}) + + set(propertyName "${prefix}_populated") + define_property(GLOBAL PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} True) + +endfunction() + + +# Set variables in the calling scope for any of the retrievable +# properties. If no specific properties are requested, variables +# will be set for all retrievable properties. +# +# This function is intended to also be used by projects as the canonical +# way to detect whether they should call FetchContent_Populate() +# and pull the populated source into the build with add_subdirectory(), +# if they are using the populated content in that way. +function(FetchContent_GetProperties contentName) + + string(TOLOWER ${contentName} contentNameLower) + + set(options "") + set(oneValueArgs SOURCE_DIR BINARY_DIR POPULATED) + set(multiValueArgs "") + + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if(NOT ARG_SOURCE_DIR AND + NOT ARG_BINARY_DIR AND + NOT ARG_POPULATED) + # No specific properties requested, provide them all + set(ARG_SOURCE_DIR ${contentNameLower}_SOURCE_DIR) + set(ARG_BINARY_DIR ${contentNameLower}_BINARY_DIR) + set(ARG_POPULATED ${contentNameLower}_POPULATED) + endif() + + set(prefix "_FetchContent_${contentNameLower}") + + if(ARG_SOURCE_DIR) + set(propertyName "${prefix}_sourceDir") + get_property(value GLOBAL PROPERTY ${propertyName}) + if(value) + set(${ARG_SOURCE_DIR} ${value} PARENT_SCOPE) + endif() + endif() + + if(ARG_BINARY_DIR) + set(propertyName "${prefix}_binaryDir") + get_property(value GLOBAL PROPERTY ${propertyName}) + if(value) + set(${ARG_BINARY_DIR} ${value} PARENT_SCOPE) + endif() + endif() + + if(ARG_POPULATED) + set(propertyName "${prefix}_populated") + get_property(value GLOBAL PROPERTY ${propertyName} DEFINED) + set(${ARG_POPULATED} ${value} PARENT_SCOPE) + endif() + +endfunction() + + +#======================================================================= +# Performing the population +#======================================================================= + +# The value of contentName will always have been lowercased by the caller. +# All other arguments are assumed to be options that are understood by +# ExternalProject_Add(), except for QUIET and SUBBUILD_DIR. +function(__FetchContent_directPopulate contentName) + + set(options + QUIET + ) + set(oneValueArgs + SUBBUILD_DIR + SOURCE_DIR + BINARY_DIR + # Prevent the following from being passed through + CONFIGURE_COMMAND + BUILD_COMMAND + INSTALL_COMMAND + TEST_COMMAND + ) + set(multiValueArgs "") + + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if(NOT ARG_SUBBUILD_DIR) + message(FATAL_ERROR "Internal error: SUBBUILD_DIR not set") + elseif(NOT IS_ABSOLUTE "${ARG_SUBBUILD_DIR}") + set(ARG_SUBBUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_SUBBUILD_DIR}") + endif() + + if(NOT ARG_SOURCE_DIR) + message(FATAL_ERROR "Internal error: SOURCE_DIR not set") + elseif(NOT IS_ABSOLUTE "${ARG_SOURCE_DIR}") + set(ARG_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_SOURCE_DIR}") + endif() + + if(NOT ARG_BINARY_DIR) + message(FATAL_ERROR "Internal error: BINARY_DIR not set") + elseif(NOT IS_ABSOLUTE "${ARG_BINARY_DIR}") + set(ARG_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_BINARY_DIR}") + endif() + + # Ensure the caller can know where to find the source and build directories + # with some convenient variables. Doing this here ensures the caller sees + # the correct result in the case where the default values are overridden by + # the content details set by the project. + set(${contentName}_SOURCE_DIR "${ARG_SOURCE_DIR}" PARENT_SCOPE) + set(${contentName}_BINARY_DIR "${ARG_BINARY_DIR}" PARENT_SCOPE) + + # The unparsed arguments may contain spaces, so build up ARG_EXTRA + # in such a way that it correctly substitutes into the generated + # CMakeLists.txt file with each argument quoted. + unset(ARG_EXTRA) + foreach(arg IN LISTS ARG_UNPARSED_ARGUMENTS) + set(ARG_EXTRA "${ARG_EXTRA} \"${arg}\"") + endforeach() + + # Hide output if requested, but save it to a variable in case there's an + # error so we can show the output upon failure. When not quiet, don't + # capture the output to a variable because the user may want to see the + # output as it happens (e.g. progress during long downloads). Combine both + # stdout and stderr in the one capture variable so the output stays in order. + if (ARG_QUIET) + set(outputOptions + OUTPUT_VARIABLE capturedOutput + ERROR_VARIABLE capturedOutput + ) + else() + set(capturedOutput) + set(outputOptions) + message(STATUS "Populating ${contentName}") + endif() + + if(CMAKE_GENERATOR) + set(generatorOpts "-G${CMAKE_GENERATOR}") + if(CMAKE_GENERATOR_PLATFORM) + list(APPEND generatorOpts "-A${CMAKE_GENERATOR_PLATFORM}") + endif() + if(CMAKE_GENERATOR_TOOLSET) + list(APPEND generatorOpts "-T${CMAKE_GENERATOR_TOOLSET}") + endif() + + if(CMAKE_MAKE_PROGRAM) + list(APPEND generatorOpts "-DCMAKE_MAKE_PROGRAM:FILEPATH=${CMAKE_MAKE_PROGRAM}") + endif() + + else() + # Likely we've been invoked via CMake's script mode where no + # generator is set (and hence CMAKE_MAKE_PROGRAM could not be + # trusted even if provided). We will have to rely on being + # able to find the default generator and build tool. + unset(generatorOpts) + endif() + + # Create and build a separate CMake project to carry out the population. + # If we've already previously done these steps, they will not cause + # anything to be updated, so extra rebuilds of the project won't occur. + # Make sure to pass through CMAKE_MAKE_PROGRAM in case the main project + # has this set to something not findable on the PATH. + configure_file("${__FetchContent_privateDir}/CMakeLists.cmake.in" + "${ARG_SUBBUILD_DIR}/CMakeLists.txt") + execute_process( + COMMAND ${CMAKE_COMMAND} ${generatorOpts} . + RESULT_VARIABLE result + ${outputOptions} + WORKING_DIRECTORY "${ARG_SUBBUILD_DIR}" + ) + if(result) + if(capturedOutput) + message("${capturedOutput}") + endif() + message(FATAL_ERROR "CMake step for ${contentName} failed: ${result}") + endif() + execute_process( + COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + ${outputOptions} + WORKING_DIRECTORY "${ARG_SUBBUILD_DIR}" + ) + if(result) + if(capturedOutput) + message("${capturedOutput}") + endif() + message(FATAL_ERROR "Build step for ${contentName} failed: ${result}") + endif() + +endfunction() + + +option(FETCHCONTENT_FULLY_DISCONNECTED "Disables all attempts to download or update content and assumes source dirs already exist") +option(FETCHCONTENT_UPDATES_DISCONNECTED "Enables UPDATE_DISCONNECTED behavior for all content population") +option(FETCHCONTENT_QUIET "Enables QUIET option for all content population" ON) +set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/_deps" CACHE PATH "Directory under which to collect all populated content") + +# Populate the specified content using details stored from +# an earlier call to FetchContent_Declare(). +function(FetchContent_Populate contentName) + + if(NOT contentName) + message(FATAL_ERROR "Empty contentName not allowed for FetchContent_Populate()") + endif() + + string(TOLOWER ${contentName} contentNameLower) + + if(ARGN) + # This is the direct population form with details fully specified + # as part of the call, so we already have everything we need + __FetchContent_directPopulate( + ${contentNameLower} + SUBBUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-subbuild" + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-build" + ${ARGN} # Could override any of the above ..._DIR variables + ) + + # Pass source and binary dir variables back to the caller + set(${contentNameLower}_SOURCE_DIR "${${contentNameLower}_SOURCE_DIR}" PARENT_SCOPE) + set(${contentNameLower}_BINARY_DIR "${${contentNameLower}_BINARY_DIR}" PARENT_SCOPE) + + # Don't set global properties, or record that we did this population, since + # this was a direct call outside of the normal declared details form. + # We only want to save values in the global properties for content that + # honours the hierarchical details mechanism so that projects are not + # robbed of the ability to override details set in nested projects. + return() + endif() + + # No details provided, so assume they were saved from an earlier call + # to FetchContent_Declare(). Do a check that we haven't already + # populated this content before in case the caller forgot to check. + FetchContent_GetProperties(${contentName}) + if(${contentNameLower}_POPULATED) + message(FATAL_ERROR "Content ${contentName} already populated in ${${contentNameLower}_SOURCE_DIR}") + endif() + + string(TOUPPER ${contentName} contentNameUpper) + set(FETCHCONTENT_SOURCE_DIR_${contentNameUpper} + "${FETCHCONTENT_SOURCE_DIR_${contentNameUpper}}" + CACHE PATH "When not empty, overrides where to find pre-populated content for ${contentName}") + + if(FETCHCONTENT_SOURCE_DIR_${contentNameUpper}) + # The source directory has been explicitly provided in the cache, + # so no population is required + set(${contentNameLower}_SOURCE_DIR "${FETCHCONTENT_SOURCE_DIR_${contentNameUpper}}") + set(${contentNameLower}_BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build") + + elseif(FETCHCONTENT_FULLY_DISCONNECTED) + # Bypass population and assume source is already there from a previous run + set(${contentNameLower}_SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src") + set(${contentNameLower}_BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build") + + else() + # Support both a global "disconnect all updates" and a per-content + # update test (either one being set disables updates for this content). + option(FETCHCONTENT_UPDATES_DISCONNECTED_${contentNameUpper} + "Enables UPDATE_DISCONNECTED behavior just for population of ${contentName}") + if(FETCHCONTENT_UPDATES_DISCONNECTED OR + FETCHCONTENT_UPDATES_DISCONNECTED_${contentNameUpper}) + set(disconnectUpdates True) + else() + set(disconnectUpdates False) + endif() + + if(FETCHCONTENT_QUIET) + set(quietFlag QUIET) + else() + unset(quietFlag) + endif() + + __FetchContent_getSavedDetails(${contentName} contentDetails) + if("${contentDetails}" STREQUAL "") + message(FATAL_ERROR "No details have been set for content: ${contentName}") + endif() + + __FetchContent_directPopulate( + ${contentNameLower} + ${quietFlag} + UPDATE_DISCONNECTED ${disconnectUpdates} + SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-subbuild" + SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src" + BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build" + # Put the saved details last so they can override any of the + # the options we set above (this can include SOURCE_DIR or + # BUILD_DIR) + ${contentDetails} + ) + endif() + + __FetchContent_setPopulated( + ${contentName} + ${${contentNameLower}_SOURCE_DIR} + ${${contentNameLower}_BINARY_DIR} + ) + + # Pass variables back to the caller. The variables passed back here + # must match what FetchContent_GetProperties() sets when it is called + # with just the content name. + set(${contentNameLower}_SOURCE_DIR "${${contentNameLower}_SOURCE_DIR}" PARENT_SCOPE) + set(${contentNameLower}_BINARY_DIR "${${contentNameLower}_BINARY_DIR}" PARENT_SCOPE) + set(${contentNameLower}_POPULATED True PARENT_SCOPE) + +endfunction() diff --git a/cmake/Modules/FetchContent/CMakeLists.cmake.in b/cmake/Modules/FetchContent/CMakeLists.cmake.in new file mode 100644 index 000000000..9a7a7715a --- /dev/null +++ b/cmake/Modules/FetchContent/CMakeLists.cmake.in @@ -0,0 +1,21 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +cmake_minimum_required(VERSION ${CMAKE_VERSION}) + +# We name the project and the target for the ExternalProject_Add() call +# to something that will highlight to the user what we are working on if +# something goes wrong and an error message is produced. + +project(${contentName}-populate NONE) + +include(ExternalProject) +ExternalProject_Add(${contentName}-populate + ${ARG_EXTRA} + SOURCE_DIR "${ARG_SOURCE_DIR}" + BINARY_DIR "${ARG_BINARY_DIR}" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/Modules/README.md b/cmake/Modules/README.md new file mode 100644 index 000000000..c8d275f11 --- /dev/null +++ b/cmake/Modules/README.md @@ -0,0 +1,5 @@ + +## FetchContent + +`FetchContent.cmake` and `FetchContent/CMakeLists.cmake.in` +are copied from `cmake/3.11.0/share/cmake-3.11/Modules`. diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py new file mode 100644 index 000000000..12930f663 --- /dev/null +++ b/cmake/cmake_extension.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021-2022 Xiaomi Corporation (author: Fangjun Kuang) + +import glob +import os +import platform +import shutil +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + + +def is_for_pypi(): + ans = os.environ.get("SHERPA_IS_FOR_PYPI", None) + return ans is not None + + +def is_macos(): + return platform.system() == "Darwin" + + +def is_windows(): + return platform.system() == "Windows" + + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # sherpa-xxx-pyxx-none-any.whl + if is_for_pypi() and not is_macos(): + self.root_is_pure = True + else: + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + sherpa_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("SHERPA_CMAKE_ARGS", "") + make_args = os.environ.get("SHERPA_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + if is_windows(): + build_cmd = f""" + cmake {cmake_args} -B {self.build_temp} -S {sherpa_dir} + cmake --build {self.build_temp} --target _sherpa --config Release -- -m + """ + print(f"build command is:\n{build_cmd}") + ret = os.system(f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_dir}") + if ret != 0: + raise Exception("Failed to build sherpa") + + ret = os.system( + f"cmake --build {self.build_temp} --target _sherpa --config release -- -m" + ) + if ret != 0: + raise exception("failed to build sherpa") + else: + if make_args == "" and system_make_args == "": + print("for fast compilation, run:") + print('export SHERPA_MAKE_ARGS="-j"; python setup.py install') + + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {sherpa_dir} + + + make {make_args} _sherpa + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild sherpa failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/k2-fsa/sherpa/issues/new\n" # noqa + ) + + lib_so = glob.glob(f"{self.build_temp}/lib/*sherpa*.so") + lib_so += glob.glob(f"{self.build_temp}/lib/*sherpa*.dylib") # macOS + + # bin/Release/_sherpa.cp38-win_amd64.pyd + lib_so += glob.glob( + f"{self.build_temp}/**/*sherpa*.pyd", recursive=True + ) # windows + + # lib/Release/*.lib + lib_so += glob.glob( + f"{self.build_temp}/**/*sherpa*.lib", recursive=True + ) # windows + for so in lib_so: + print(f"Copying {so} to {self.build_lib}/") + shutil.copy(f"{so}", f"{self.build_lib}/") diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake new file mode 100644 index 000000000..76cf090e4 --- /dev/null +++ b/cmake/pybind11.cmake @@ -0,0 +1,25 @@ +function(download_pybind11) + if(CMAKE_VERSION VERSION_LESS 3.11) + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) + endif() + + include(FetchContent) + + set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.9.2.tar.gz") + set(pybind11_HASH "SHA256=6bd528c4dbe2276635dc787b6b1f2e5316cf6b49ee3e150264e455a0d68d19c1") + + FetchContent_Declare(pybind11 + URL ${pybind11_URL} + URL_HASH ${pybind11_HASH} + ) + + FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) + message(STATUS "Downloading pybind11") + FetchContent_Populate(pybind11) + endif() + message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") + add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_pybind11() diff --git a/cmake/torch.cmake b/cmake/torch.cmake new file mode 100644 index 000000000..9b55345ee --- /dev/null +++ b/cmake/torch.cmake @@ -0,0 +1,28 @@ +# PYTHON_EXECUTABLE is set by cmake/pybind11.cmake +message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") + +execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import os; import torch; print(os.path.dirname(torch.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_DIR +) + +message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") +message(STATUS "TORCH_DIR: ${TORCH_DIR}") + +list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}") +find_package(Torch REQUIRED) +message(STATUS "TORCH_LIBRARIES: ${TORCH_LIBRARIES}") + + +# set the global CMAKE_CXX_FLAGS so that +# sherpa uses the same ABI flag as PyTorch +string(APPEND CMAKE_CXX_FLAGS " ${TORCH_CXX_FLAGS} ") + +execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_VERSION +) + +message(STATUS "PyTorch version: ${TORCH_VERSION}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..130cbcf29 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +websockets +kaldifeat +sentencepiece>=0.1.96 diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..196b2ac12 --- /dev/null +++ b/setup.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +import re +import sys + +import setuptools + +from cmake.cmake_extension import BuildExtension, bdist_wheel, cmake_extension + +if sys.version_info < (3,): + print("Python 2 has reached end-of-life and is no longer supported by sherpa.") + sys.exit(-1) + +if sys.version_info < (3, 7): + print( + "Python 3.6 has reached end-of-life on December 31st, 2021 " + "and is no longer supported by sherpa." + ) + sys.exit(-1) + + +def read_long_description(): + with open("README.md", encoding="utf8") as f: + readme = f.read() + return readme + + +def get_package_version(): + with open("CMakeLists.txt") as f: + content = f.read() + + match = re.search(r"set\(SHERPA_VERSION (.*)\)", content) + latest_version = match.group(1).strip('"') + return latest_version + + +package_name = "k2-sherpa" + +with open("sherpa/python/sherpa/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + +setuptools.setup( + name=package_name, + version=get_package_version(), + author="Fangjun Kuang", + author_email="csukuangfj@gmail.com", + package_dir={ + "sherpa": "sherpa/python/sherpa", + }, + packages=["sherpa"], + url="https://github.com/k2-fsa/sherpa", + long_description=read_long_description(), + long_description_content_type="text/markdown", + ext_modules=[cmake_extension("_sherpa")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, + zip_safe=False, + classifiers=[ + "Programming Language :: C++", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + python_requires=">=3.7.0", + license="Apache licensed, as found in the LICENSE file", +) + +# remove the line __version__ from sherpa/python/sherpa/__init__.py +with open("sherpa/python/sherpa/__init__.py", "r") as f: + lines = f.readlines() + +with open("sherpa/python/sherpa/__init__.py", "w") as f: + for line in lines: + if "__version__" not in line: + f.write(line) diff --git a/sherpa/CMakeLists.txt b/sherpa/CMakeLists.txt new file mode 100644 index 000000000..c70d00c60 --- /dev/null +++ b/sherpa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(csrc) +add_subdirectory(python) diff --git a/sherpa/bin/.gitignore b/sherpa/bin/.gitignore new file mode 100644 index 000000000..5afadd3a8 --- /dev/null +++ b/sherpa/bin/.gitignore @@ -0,0 +1,4 @@ +log +errs-* +recogs-* +test_wavs diff --git a/sherpa/bin/decode_mainifest.py b/sherpa/bin/decode_mainifest.py new file mode 100755 index 000000000..f98c93bb8 --- /dev/null +++ b/sherpa/bin/decode_mainifest.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script loads a manifest in lhotse format and sends it to the server +for decoding, in parallel. + +Usage: + + ./decode_mainifest.py + +(Note: You have to first start the server before starting the client) +""" + +import asyncio +import time + +import numpy as np +import websockets +from icefall.utils import store_transcripts, write_error_stats +from lhotse import CutSet, load_manifest + + +async def send(cuts: CutSet, name: str): + total_duration = 0.0 + results = [] + async with websockets.connect("ws://localhost:6006") as websocket: + for i, c in enumerate(cuts): + if i % 5 == 0: + print(f"{name}: {i}/{len(cuts)}") + + samples = c.load_audio().reshape(-1).astype(np.float32) + num_bytes = samples.nbytes + + await websocket.send((num_bytes).to_bytes(8, "big", signed=True)) + + frame_size = (2 ** 20) // 4 # max payload is 1MB + start = 0 + while start < samples.size: + end = start + frame_size + await websocket.send(samples.data[start:end]) + start = end + decoding_results = await websocket.recv() + + total_duration += c.duration + + results.append((c.supervisions[0].text.split(), decoding_results.split())) + + return total_duration, results + + +async def main(): + filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/data/fbank/cuts_test-clean.json.gz" + cuts = load_manifest(filename) + num_tasks = 50 # we start this number of tasks to send the requests + cuts_list = cuts.split(num_tasks) + tasks = [] + + start_time = time.time() + for i in range(num_tasks): + task = asyncio.create_task(send(cuts_list[i], f"task-{i}")) + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + results = [] + total_duration = 0.0 + for ans in ans_list: + total_duration += ans[0] + results += ans[1] + + rtf = elapsed / total_duration + + print(f"RTF: {rtf:.4f}") + print( + f"total_duration: {total_duration:.2f} seconds " + f"({total_duration/3600:.2f} hours)" + ) + print(f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)") + + store_transcripts(filename="recogs-test-clean.txt", texts=results) + with open("errs-test-clean.txt", "w") as f: + wer = write_error_stats(f, "test-set", results, enable_log=True) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sherpa/bin/offline_client.py b/sherpa/bin/offline_client.py new file mode 100755 index 000000000..3c4ba2749 --- /dev/null +++ b/sherpa/bin/offline_client.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A client for offline ASR recognition. + +Usage: + ./offline_client.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(Note: You have to first start the server before starting the client) +""" +import argparse +import asyncio + +import torchaudio +import websockets + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser.parse_args() + + +async def main(): + args = get_args() + assert len(args.sound_files) > 0, f"Empty sound files" + + server_addr = args.server_addr + server_port = args.server_port + + async with websockets.connect(f"ws://{server_addr}:{server_port}") as websocket: + for test_wav in args.sound_files: + print(f"Sending {test_wav}") + wave, sample_rate = torchaudio.load(test_wav) + assert sample_rate == 16000, sample_rate + + wave = wave.squeeze(0) + num_bytes = wave.numel() * wave.element_size() + await websocket.send((num_bytes).to_bytes(8, "big", signed=True)) + + frame_size = (2 ** 20) // 4 # max payload is 1MB + start = 0 + while start < wave.numel(): + end = start + frame_size + await websocket.send(wave.numpy().data[start:end]) + start = end + decoding_results = await websocket.recv() + print(test_wav, "\n", decoding_results) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sherpa/bin/offline_server.py b/sherpa/bin/offline_server.py new file mode 100755 index 000000000..ee940bd84 --- /dev/null +++ b/sherpa/bin/offline_server.py @@ -0,0 +1,492 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A server for offline ASR recognition. Offline means you send all the content +of the audio for recognition. It supports multiple clients sending at +the same time. + +Usage: + ./offline_server.py +""" + +import asyncio +import logging +import math +import warnings +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List, Optional + +import kaldifeat +import sentencepiece as spm +import torch +import websockets +from _sherpa import RnntModel, greedy_search +from torch.nn.utils.rnn import pad_sequence +import argparse + +LOG_EPS = math.log(1e-10) + +# You can use +# icefall/egs/librispeech/ASR/pruned_transducer_statelessX/export.py --jit 1 +# to generate the following model +DEFAULT_NN_MODEL_FILENAME = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt" # noqa +DEFAULT_BPE_MODEL_FILENAME = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/data/lang_bpe_500/bpe.model" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--port", + type=int, + default=6006, + help="The server will listen on this port", + ) + + parser.add_argument( + "--num-device", + type=int, + default=1, + help="""Number of GPU devices to use. Set it to 0 to use CPU + for computation. If positive, then GPUs with ID 0, 1, ..., num_device-1 + will be used for computation. You can use the environment variable + CUDA_VISIBLE_DEVICES to map available GPU devices. + """, + ) + + parser.add_argument( + "--max-batch-size", + type=int, + default=25, + help="""Max batch size for computation. Note if there are not enough + requests in the queue, it will wait for max_wait_ms time. After that, + even if there are still not enough requests, it still sends the + available requests in the queue for computation. + """, + ) + + parser.add_argument( + "--max-wait-ms", + type=float, + default=5, + help="""Max time in millisecond to wait to build batches for inference. + If there are not enough requests in the feature queue to build a batch + of max_batch_size, it waits up to this time before fetching available + requests for computation. + """, + ) + + parser.add_argument( + "--feature-extractor-pool-size", + type=int, + default=5, + help="""Number of threads for feature extraction. By default, feature + extraction are run on CPU. + """, + ) + + parser.add_argument( + "--nn-pool-size", + type=int, + default=1, + help="""Number of threads for NN computation and decoding. + Note: It should be in general less than or equal to num_device + if num_device is positive. + """, + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + default=DEFAULT_NN_MODEL_FILENAME, + help="""The torchscript model. You can use + icefall/egs/librispeech/ASR/pruned_transducer_statelessX/export.py --jit=1 + to generate this model. + """, + ) + + parser.add_argument( + "--bpe-model-filename", + type=str, + default=DEFAULT_BPE_MODEL_FILENAME, + help="""The BPE model + You can find it in the directory egs/librispeech/ASR/data/lang_bpe_xxx + where xxx is the number of BPE tokens you used to train the model. + """, + ) + + return parser.parse_args() + + +def run_model_and_do_greedy_search( + model: torch.jit.ScriptModule, + features: List[torch.Tensor], +) -> List[List[int]]: + """Run RNN-T model with the given features and use greedy search + to decode the output of the model. + + Args: + model: + The RNN-T model. + features: + A list of 2-D tensors. Each entry is of shape (num_frames, feature_dim). + Returns: + Return a list-of-list containing the decoding token IDs. + """ + features_length = torch.tensor([f.size(0) for f in features], dtype=torch.int64) + features = pad_sequence( + features, + batch_first=True, + padding_value=LOG_EPS, + ) + + device = model.device + features = features.to(device) + features_length = features_length.to(device) + + encoder_out, encoder_out_length = model.encoder( + features=features, + features_length=features_length, + ) + + hyp_tokens = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_length=encoder_out_length.cpu(), + ) + return hyp_tokens + + +class OfflineServer: + def __init__( + self, + nn_model_filename: str, + bpe_model_filename: str, + num_device: int, + batch_size: int, + max_wait_ms: float, + feature_extractor_pool_size: int = 3, + nn_pool_size: int = 3, + ): + """ + Args: + nn_model_filename: + Path to the torch script model. + bpe_model_filename: + Path to the BPE model. + num_device: + If 0, use CPU for neural network computation and decoding. + If positive, it means the number of GPUs to use for NN computation + and decoding. For each device, there will be a corresponding + torchscript model. We assume available device IDs are + 0, 1, ... , num_device - 1. You can use the environment variable + CUDA_VISIBLE_DEVICES to achieve this. + batch_size: + Max batch size for inference. + max_wait_ms: + Max wait time in milliseconds in order to build a batch of + `batch_size`. + feature_extractor_pool_size: + Number of threads to create for the feature extractor thread pool. + nn_pool_size: + Number of threads for the thread pool that is used for NN + computation and decoding. + """ + self.feature_extractor = self._build_feature_extractor() + self.nn_models = self._build_nn_model(nn_model_filename, num_device) + + assert nn_pool_size > 0 + + self.feature_extractor_pool = ThreadPoolExecutor( + max_workers=feature_extractor_pool_size, + thread_name_prefix="feature", + ) + self.nn_pool = ThreadPoolExecutor( + max_workers=nn_pool_size, + thread_name_prefix="nn", + ) + + self.feature_queue = asyncio.Queue() + + self.sp = spm.SentencePieceProcessor() + self.sp.load(bpe_model_filename) + + self.counter = 0 + + self.max_wait_ms = max_wait_ms + self.batch_size = batch_size + + def _build_feature_extractor(self): + """Build a fbank feature extractor for extracting features. + + TODO: + Pass the options as arguments + """ + opts = kaldifeat.FbankOptions() + opts.device = "cpu" # Note: It also supports CUDA, e.g., "cuda:0" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + return fbank + + def _build_nn_model( + self, nn_model_filename: str, num_device: int + ) -> List[RnntModel]: + """Build a torch script model for each given device. + + Args: + nn_model_filename: + The path to the torch script model. + num_device: + Number of devices to use for NN computation and decoding. + If it is 0, then only use CPU and it returns a model on CPU. + If it is positive, it create a model for each device and returns + them. + Returns: + Return a list of torch script models. + """ + if num_device < 1: + model = RnntModel( + filename=nn_model_filename, + device="cpu", + optimize_for_inference=False, + ) + return [model] + + ans = [] + for i in range(num_device): + device = torch.device("cuda", i) + model = RnntModel( + filename=nn_model_filename, + device=device, + optimize_for_inference=False, + ) + ans.append(model) + + return ans + + async def loop(self, port: int): + logging.info("started") + task = asyncio.create_task(self.feature_consumer_task()) + + # If you use multiple GPUs, you can create multiple + # feature consumer tasks. + # asyncio.create_task(self.feature_consumer_task()) + # asyncio.create_task(self.feature_consumer_task()) + + async with websockets.serve(self.handle_connection, "", port): + await asyncio.Future() # run forever + await task + + async def recv_audio_samples( + self, + socket: websockets.WebSocketServerProtocol, + ) -> Optional[torch.Tensor]: + """Receives a tensor from the client. + + The message from the client has the following format: + + - a header of 8 bytes, containing the number of bytes of the tensor. + The header is in big endian format. + - a binary representation of the 1-D torch.float32 tensor. + + Args: + socket: + The socket for communicating with the client. + Returns: + Return a 1-D torch.float32 tensor. + """ + expected_num_bytes = None + received = b"" + async for message in socket: + if expected_num_bytes is None: + assert len(message) >= 8, (len(message), message) + expected_num_bytes = int.from_bytes(message[:8], "big", signed=True) + received += message[8:] + if len(received) == expected_num_bytes: + break + else: + received += message + if len(received) == expected_num_bytes: + break + if not received: + return None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # PyTorch warns that the underlying buffer is not writable. + # We ignore it here as we are not going to write it anyway. + return torch.frombuffer(received, dtype=torch.float32) + + async def feature_consumer_task(self): + """This function extracts features from the feature_queue, + batches them up, sends them to the RNN-T model for computation + and decoding. + """ + while True: + if self.feature_queue.empty(): + await asyncio.sleep(self.max_wait_ms / 1000) + continue + batch = [] + try: + while len(batch) < self.batch_size: + item = self.feature_queue.get_nowait() + batch.append(item) + except asyncio.QueueEmpty: + pass + + feature_list = [b[0] for b in batch] + + loop = asyncio.get_running_loop() + self.counter = (self.counter + 1) % len(self.nn_models) + model = self.nn_models[self.counter] + + hyp_tokens = await loop.run_in_executor( + self.nn_pool, + run_model_and_do_greedy_search, + model, + feature_list, + ) + + for i, hyp in enumerate(hyp_tokens): + self.feature_queue.task_done() + future = batch[i][1] + loop.call_soon(future.set_result, hyp) + + async def compute_features(self, samples: torch.Tensor) -> torch.Tensor: + """Compute the fbank features for the given audio samples. + + Args: + samples: + A 1-D torch.float32 tensor containing the audio samples. Its + sampling rate should be the one as expected by the feature + extractor. Also, its range should match the one used in the + training. + Returns: + Return a 2-D tensor of shape (num_frames, feature_dim) containing + the features. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self.feature_extractor_pool, + self.feature_extractor, # it releases the GIL + samples, + ) + + async def compute_and_decode( + self, + features: torch.Tensor, + ) -> List[int]: + """Run the RNN-T model on the features and do greedy search. + + Args: + features: + A 2-D tensor of shape (num_frames, feature_dim). + Returns: + Return a list of token IDs containing the decoded results. + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + await self.feature_queue.put((features, future)) + await future + return future.result() + + async def handle_connection( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and sends + deocoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + logging.info(f"Connected: {socket.remote_address}") + while True: + samples = await self.recv_audio_samples(socket) + if samples is None: + break + features = await self.compute_features(samples) + hyp = await self.compute_and_decode(features) + result = self.sp.decode(hyp) + await socket.send(result) + + logging.info(f"Disconnected: {socket.remote_address}") + + +@torch.no_grad() +def main(): + args = get_args() + + nn_model_filename = args.nn_model_filename + bpe_model_filename = args.bpe_model_filename + port = args.port + num_device = args.num_device + max_wait_ms = args.max_wait_ms + batch_size = args.max_batch_size + feature_extractor_pool_size = args.feature_extractor_pool_size + nn_pool_size = args.nn_pool_size + + offline_server = OfflineServer( + nn_model_filename=nn_model_filename, + bpe_model_filename=bpe_model_filename, + num_device=num_device, + max_wait_ms=max_wait_ms, + batch_size=batch_size, + feature_extractor_pool_size=feature_extractor_pool_size, + nn_pool_size=nn_pool_size, + ) + asyncio.run(offline_server.loop(port)) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +""" +// Use the following in C++ +torch::jit::getExecutorMode() = false; +torch::jit::getProfilingMode() = false; +torch::jit::setGraphExecutorOptimize(false); +""" + + +if __name__ == "__main__": + torch.manual_seed(20220519) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt new file mode 100644 index 000000000..81759af9e --- /dev/null +++ b/sherpa/csrc/CMakeLists.txt @@ -0,0 +1,8 @@ +# Please sort the filenames alphabetically +set(sherpa_srcs + rnnt_beam_search.cc + rnnt_model.cc +) + +add_library(sherpa_core ${sherpa_srcs}) +target_link_libraries(sherpa_core PUBLIC ${TORCH_LIBRARIES}) diff --git a/sherpa/csrc/rnnt_beam_search.cc b/sherpa/csrc/rnnt_beam_search.cc new file mode 100644 index 000000000..720308484 --- /dev/null +++ b/sherpa/csrc/rnnt_beam_search.cc @@ -0,0 +1,144 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa/csrc/rnnt_model.h" +#include "torch/all.h" + +namespace sherpa { + +/** + * Construct the decoder input from the current hypothesis. + * + * @param hyps A list-of-list of token IDs containing the current decoding + * results. Its length is `batch_size` + * @param decoder_input A 2-D tensor of shape (batch_size, context_size). + */ +static void BuildDecoderInput(const std::vector> &hyps, + torch::Tensor *decoder_input) { + int32_t batch_size = decoder_input->size(0); + int32_t context_size = decoder_input->size(1); + int64_t *p = decoder_input->data_ptr(); + for (int32_t i = 0; i != batch_size; ++i) { + auto start = hyps[i].end() - context_size; + auto end = hyps[i].end(); + std::copy(start, end, p); + p += context_size; + } +} + +std::vector> GreedySearch( + RnntModel &model, torch::Tensor encoder_out, + torch::Tensor encoder_out_length) { + TORCH_CHECK(encoder_out.dim() == 3, "encoder_out.dim() is ", + encoder_out.dim(), "Expected is 3"); + TORCH_CHECK(encoder_out.scalar_type() == torch::kFloat, + "encoder_out.scalar_type() is ", encoder_out.scalar_type()); + + TORCH_CHECK(encoder_out_length.dim() == 1, "encoder_out_length.dim() is", + encoder_out_length.dim()); + TORCH_CHECK(encoder_out_length.scalar_type() == torch::kLong, + "encoder_out_length.scalar_type() is ", + encoder_out_length.scalar_type()); + + TORCH_CHECK(encoder_out_length.is_cpu()); + + torch::Device device = model.Device(); + encoder_out = encoder_out.to(device); + + torch::nn::utils::rnn::PackedSequence packed_seq = + torch::nn::utils::rnn::pack_padded_sequence(encoder_out, + encoder_out_length, + /*batch_first*/ true, + /*enforce_sorted*/ false); + + auto projected_encoder_out = model.ForwardEncoderProj(packed_seq.data()); + + int32_t blank_id = model.BlankId(); + int32_t unk_id = model.UnkId(); + int32_t context_size = model.ContextSize(); + + int32_t batch_size = encoder_out_length.size(0); + + std::vector blanks(context_size, blank_id); + std::vector> hyps(batch_size, blanks); + + auto decoder_input = + torch::full({batch_size, context_size}, blank_id, + torch::dtype(torch::kLong) + .memory_format(torch::MemoryFormat::Contiguous)); + auto decoder_out = model.ForwardDecoder(decoder_input.to(device)); + decoder_out = model.ForwardDecoderProj(decoder_out); + // decoder_out's shape is (batch_size, 1, joiner_dim) + + using torch::indexing::Slice; + auto batch_sizes_accessor = packed_seq.batch_sizes().accessor(); + int32_t num_batches = packed_seq.batch_sizes().numel(); + int32_t offset = 0; + for (int32_t i = 0; i != num_batches; ++i) { + int32_t cur_batch_size = batch_sizes_accessor[i]; + int32_t start = offset; + int32_t end = start + cur_batch_size; + auto cur_encoder_out = projected_encoder_out.index({Slice(start, end)}); + offset = end; + + cur_encoder_out = cur_encoder_out.unsqueeze(1).unsqueeze(1); + // Now cur_encoder_out's shape is (cur_batch_size, 1, 1, joiner_dim) + if (cur_batch_size < decoder_out.size(0)) { + decoder_out = decoder_out.index({Slice(0, cur_batch_size)}); + } + + auto logits = + model.ForwardJoiner(cur_encoder_out, decoder_out.unsqueeze(1)); + // logits' shape is (cur_batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1); + auto max_indices = logits.argmax(/*dim*/ -1).cpu(); + auto max_indices_accessor = max_indices.accessor(); + bool emitted = false; + for (int32_t k = 0; k != cur_batch_size; ++k) { + auto index = max_indices_accessor[k]; + if (index != blank_id && index != unk_id) { + emitted = true; + hyps[k].push_back(index); + } + } + + if (emitted) { + if (cur_batch_size < decoder_input.size(0)) { + decoder_input = decoder_input.index({Slice(0, cur_batch_size)}); + } + BuildDecoderInput(hyps, &decoder_input); + decoder_out = model.ForwardDecoder(decoder_input.to(device)); + decoder_out = model.ForwardDecoderProj(decoder_out); + } + } + + auto unsorted_indices = packed_seq.unsorted_indices().cpu(); + auto unsorted_indices_accessor = unsorted_indices.accessor(); + + std::vector> ans(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + torch::ArrayRef arr(hyps[unsorted_indices_accessor[i]]); + ans[i] = arr.slice(context_size).vec(); + } + + return ans; +} + +} // namespace sherpa diff --git a/sherpa/csrc/rnnt_beam_search.h b/sherpa/csrc/rnnt_beam_search.h new file mode 100644 index 000000000..72df87cf6 --- /dev/null +++ b/sherpa/csrc/rnnt_beam_search.h @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_CSRC_RNNT_BEAM_SEARCH_H_ +#define SHERPA_CSRC_RNNT_BEAM_SEARCH_H_ + +#include + +#include "sherpa/csrc/rnnt_model.h" + +namespace sherpa { + +/** RNN-T Greedy search decoding by limiting the max symol per frame to one. + * + * @param model The RNN-T model. + * + * @param encoder_out Output from the encoder network. Its shape is + * (batch_size, T, encoder_out_dim) and its dtype is + * torch::kFloat. + * + * @param encoder_out_lens A 1-D tensor containing the valid frames before + * padding in `encoder_out`. Its dtype is torch.kLong + * and its shape is (batch_size,). Also, it must be + * on CPU. + * + * @return Return A list-of-list of token IDs containing the decoding results. + * The returned vector has size `batch_size` and each entry contains the + * decoding results for the corresponding input in encoder_out. + */ +std::vector> GreedySearch( + RnntModel &model, torch::Tensor encoder_out, + torch::Tensor encoder_out_length); + +} // namespace sherpa + +#endif // SHERPA_CSRC_RNNT_BEAM_SEARCH_H_ diff --git a/sherpa/csrc/rnnt_model.cc b/sherpa/csrc/rnnt_model.cc new file mode 100644 index 000000000..dd91389cb --- /dev/null +++ b/sherpa/csrc/rnnt_model.cc @@ -0,0 +1,84 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sherpa/csrc/rnnt_model.h" + +namespace sherpa { + +RnntModel::RnntModel(const std::string &filename, + torch::Device device /*=torch::kCPU*/, + bool optimize_for_inference /*=false*/) + : device_(device) { + model_ = torch::jit::load(filename, device); + model_.eval(); + if (optimize_for_inference) { + model_ = torch::jit::optimize_for_inference(model_); + } + + encoder_ = model_.attr("encoder").toModule(); + decoder_ = model_.attr("decoder").toModule(); + joiner_ = model_.attr("joiner").toModule(); + + encoder_proj_ = joiner_.attr("encoder_proj").toModule(); + decoder_proj_ = joiner_.attr("decoder_proj").toModule(); + + blank_id_ = decoder_.attr("blank_id").toInt(); + + unk_id_ = blank_id_; + if (decoder_.hasattr("unk_id")) { + unk_id_ = decoder_.attr("unk_id").toInt(); + } + + context_size_ = decoder_.attr("context_size").toInt(); +} + +std::pair RnntModel::ForwardEncoder( + const torch::Tensor &features, const torch::Tensor &features_length) { + auto outputs = model_.attr("encoder") + .toModule() + .run_method("forward", features, features_length) + .toTuple(); + + auto encoder_out = outputs->elements()[0].toTensor(); + auto encoder_out_length = outputs->elements()[1].toTensor(); + + return {encoder_out, encoder_out_length}; +} + +torch::Tensor RnntModel::ForwardDecoder(const torch::Tensor &decoder_input) { + return decoder_.run_method("forward", decoder_input, /*need_pad*/ false) + .toTensor(); +} + +torch::Tensor RnntModel::ForwardJoiner( + const torch::Tensor &projected_encoder_out, + const torch::Tensor &projected_decoder_out) { + return joiner_ + .run_method("forward", projected_encoder_out, projected_decoder_out, + /*project_input*/ false) + .toTensor(); +} + +torch::Tensor RnntModel::ForwardEncoderProj(const torch::Tensor &encoder_out) { + return encoder_proj_.run_method("forward", encoder_out).toTensor(); +} + +torch::Tensor RnntModel::ForwardDecoderProj(const torch::Tensor &decoder_out) { + return decoder_proj_.run_method("forward", decoder_out).toTensor(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/rnnt_model.h b/sherpa/csrc/rnnt_model.h new file mode 100644 index 000000000..024b82c41 --- /dev/null +++ b/sherpa/csrc/rnnt_model.h @@ -0,0 +1,132 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_CSRC_RNNT_MODEL_H_ +#define SHERPA_CSRC_RNNT_MODEL_H_ + +#include +#include + +#include "torch/script.h" + +namespace sherpa { + +/** It wraps a torch script model, which is from + * pruned_transducer_stateless2/mode.py within icefall. + */ +class RnntModel { + public: + /** + * @param filename Path name of the torch script model. + * @param device The model will be moved to this device + * @param optimize_for_inference true to invoke + * torch::jit::optimize_for_inference(). + */ + explicit RnntModel(const std::string &filename, + torch::Device device = torch::kCPU, + bool optimize_for_inference = false); + + ~RnntModel() = default; + + torch::Device Device() const { return device_; } + + int32_t BlankId() const { return blank_id_; } + int32_t UnkId() const { return unk_id_; } + int32_t ContextSize() const { return context_size_; } + + /** Run the encoder network. + * + * @param features A 3-D tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing the number of + * valid frames in `features`. + * @return Return a tuple containing two tensors: + * - encoder_out, a 3-D tensor of shape (N, T, C) + * - encoder_out_length, a 1-D tensor of shape (N,) containing the + * number of valid frames in `encoder_out`. + */ + std::pair ForwardEncoder( + const torch::Tensor &features, const torch::Tensor &features_length); + + /** Run the decoder network. + * + * @param decoder_input A 2-D tensor of shape (N, U). + * @return Return a tensor of shape (N, U, decoder_dim) + */ + torch::Tensor ForwardDecoder(const torch::Tensor &decoder_input); + + /** Run the joiner network. + * + * @param projected_encoder_out A 3-D tensor of shape (N, T, C). + * @param projected_decoder_out A 3-D tensor of shape (N, U, C). + * @return Return a tensor of shape (N, T, U, vocab_size) + */ + torch::Tensor ForwardJoiner(const torch::Tensor &projected_encoder_out, + const torch::Tensor &projected_decoder_out); + + /** Run the joiner.encoder_proj network. + * + * @param encoder_out The output from the encoder, which is of shape (N,T,C). + * @return Return a tensor of shape (N, T, joiner_dim). + */ + torch::Tensor ForwardEncoderProj(const torch::Tensor &encoder_out); + + /** Run the joiner.decoder_proj network. + * + * @param decoder_out The output from the encoder, which is of shape (N,T,C). + * @return Return a tensor of shape (N, T, joiner_dim). + */ + torch::Tensor ForwardDecoderProj(const torch::Tensor &decoder_out); + + /** TODO(fangjun): Implement it + * + * Run the encoder network in a streaming fashion. + * + * @param features A 3-D tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing the number of + * valid frames in `features`. + * @param prev_state It contains the previous state from the encoder network. + * + * @return Return a tuple containing 3 entries: + * - encoder_out, a 3-D tensor of shape (N, T, C) + * - encoder_out_length, a 1-D tensor of shape (N,) containing the + * number of valid frames in encoder_out + * - next_state, the state for the encoder network. + */ + std::tuple + StreamingForwardEncoder(const torch::Tensor &features, + const torch::Tensor &feature_lengths, + torch::IValue prev_state); + + private: + torch::jit::Module model_; + + // The following modules are just aliases to modules in model_ + torch::jit::Module encoder_; + torch::jit::Module decoder_; + torch::jit::Module joiner_; + torch::jit::Module encoder_proj_; + torch::jit::Module decoder_proj_; + + torch::Device device_; + int32_t blank_id_; + int32_t unk_id_; + int32_t context_size_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_RNNT_MODEL_H_ diff --git a/sherpa/python/CMakeLists.txt b/sherpa/python/CMakeLists.txt new file mode 100644 index 000000000..86735ca28 --- /dev/null +++ b/sherpa/python/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc) diff --git a/sherpa/python/csrc/CMakeLists.txt b/sherpa/python/csrc/CMakeLists.txt new file mode 100644 index 000000000..282b1b771 --- /dev/null +++ b/sherpa/python/csrc/CMakeLists.txt @@ -0,0 +1,17 @@ +add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) + +# Please sort files alphabetically +pybind11_add_module(_sherpa + rnnt_beam_search.cc + rnnt_model.cc + sherpa.cc +) + +target_link_libraries(_sherpa PRIVATE sherpa_core) +if(UNIX AND NOT APPLE) + target_link_libraries(_sherpa PUBLIC ${TORCH_DIR}/lib/libtorch_python.so) + target_link_libraries(_sherpa PUBLIC ${PYTHON_LIBRARY}) +elseif(WIN32) + target_link_libraries(_sherpa PUBLIC ${TORCH_DIR}/lib/torch_python.lib) + target_link_libraries(_sherpa PUBLIC ${PYTHON_LIBRARIES}) +endif() diff --git a/sherpa/python/csrc/rnnt_beam_search.cc b/sherpa/python/csrc/rnnt_beam_search.cc new file mode 100644 index 000000000..c61819a19 --- /dev/null +++ b/sherpa/python/csrc/rnnt_beam_search.cc @@ -0,0 +1,31 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sherpa/csrc/rnnt_beam_search.h" + +#include "sherpa/python/csrc/rnnt_beam_search.h" +#include "torch/torch.h" + +namespace sherpa { + +void PybindRnntBeamSearch(py::module &m) { + m.def("greedy_search", &GreedySearch, py::arg("model"), + py::arg("encoder_out"), py::arg("encoder_out_length"), + py::call_guard()); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/rnnt_beam_search.h b/sherpa/python/csrc/rnnt_beam_search.h new file mode 100644 index 000000000..a35367bbf --- /dev/null +++ b/sherpa/python/csrc/rnnt_beam_search.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_PYTHON_CSRC_RNNT_BEAM_SEARCH_H_ +#define SHERPA_PYTHON_CSRC_RNNT_BEAM_SEARCH_H_ + +#include "sherpa/python/csrc/sherpa.h" + +namespace sherpa { + +void PybindRnntBeamSearch(py::module &m); + +} // namespace sherpa + +#endif // SHERPA_PYTHON_CSRC_RNNT_BEAM_SEARCH_H_ diff --git a/sherpa/python/csrc/rnnt_model.cc b/sherpa/python/csrc/rnnt_model.cc new file mode 100644 index 000000000..089d45cdb --- /dev/null +++ b/sherpa/python/csrc/rnnt_model.cc @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa/python/csrc/rnnt_model.h" + +#include + +#include "sherpa/csrc/rnnt_model.h" +#include "torch/torch.h" + +namespace sherpa { + +void PybindRnntModel(py::module &m) { + using PyClass = RnntModel; + py::class_(m, "RnntModel") + .def(py::init([](const std::string &filename, + py::object device = py::str("cpu"), + bool optimize_for_inference = + false) -> std::unique_ptr { + std::string device_str = + device.is_none() ? "cpu" : py::str(device); + return std::make_unique( + filename, torch::Device(device_str), optimize_for_inference); + }), + py::arg("filename"), py::arg("device") = py::str("cpu"), + py::arg("optimize_for_inference") = false) + .def("encoder", &PyClass::ForwardEncoder, py::arg("features"), + py::arg("features_length"), py::call_guard()) + .def_property_readonly("device", [](const PyClass &self) -> py::object { + py::object ans = py::module_::import("torch").attr("device"); + return ans(self.Device().str()); + }); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/rnnt_model.h b/sherpa/python/csrc/rnnt_model.h new file mode 100644 index 000000000..5a8ce700a --- /dev/null +++ b/sherpa/python/csrc/rnnt_model.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_PYTHON_CSRC_RNNT_MODEL_H_ +#define SHERPA_PYTHON_CSRC_RNNT_MODEL_H_ + +#include "sherpa/python/csrc/sherpa.h" + +namespace sherpa { + +void PybindRnntModel(py::module &m); + +} // namespace sherpa + +#endif // SHERPA_PYTHON_CSRC_RNNT_MODEL_H_ diff --git a/sherpa/python/csrc/sherpa.cc b/sherpa/python/csrc/sherpa.cc new file mode 100644 index 000000000..5e5c90ef6 --- /dev/null +++ b/sherpa/python/csrc/sherpa.cc @@ -0,0 +1,33 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa/python/csrc/sherpa.h" + +#include "sherpa/python/csrc/rnnt_beam_search.h" +#include "sherpa/python/csrc/rnnt_model.h" + +namespace sherpa { + +PYBIND11_MODULE(_sherpa, m) { + m.doc() = "pybind11 binding of sherpa"; + + PybindRnntModel(m); + PybindRnntBeamSearch(m); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/sherpa.h b/sherpa/python/csrc/sherpa.h new file mode 100644 index 000000000..74171da4a --- /dev/null +++ b/sherpa/python/csrc/sherpa.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_PYTHON_CSRC_SHERPA_H_ +#define SHERPA_PYTHON_CSRC_SHERPA_H_ + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +#endif // SHERPA_PYTHON_CSRC_SHERPA_H_ diff --git a/sherpa/python/sherpa/__init__.py b/sherpa/python/sherpa/__init__.py new file mode 100644 index 000000000..b15022aeb --- /dev/null +++ b/sherpa/python/sherpa/__init__.py @@ -0,0 +1,3 @@ +import torch + +from _sherpa import RnntModel, greedy_search