From 74987b5daed9aca735f26c2f97948fcf0dcaaa80 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 26 May 2022 09:56:27 +0800 Subject: [PATCH] Fix pip install (#4) * Fix installation with `pip install`. * Minor fixes. * Small fixes. --- README.md | 8 +- cmake/__init__.py | 0 ...decode_mainifest.py => decode_manifest.py} | 79 +++++++++++++++++-- 3 files changed, 78 insertions(+), 9 deletions(-) create mode 100644 cmake/__init__.py rename sherpa/bin/{decode_mainifest.py => decode_manifest.py} (58%) diff --git a/README.md b/README.md index 998708e16..c3c968f24 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ the following methods. pip install --verbose k2-sherpa ``` +or + +```bash +pip install --verbose git+https://github.com/k2-fsa/shera +``` + ### Option 2: Build from source with `setup.py` ```bash @@ -138,7 +144,7 @@ sherpa/bin/offline_client.py \ ### RTF test -We provide a demo [./sherpa/bin/decode_mainifest.py](./sherpa/bin/decode_mainifest.py) +We provide a demo [./sherpa/bin/decode_manifest.py](./sherpa/bin/decode_manifest.py) to decode the `test-clean` dataset from the LibriSpeech corpus. It creates 50 connections to the server using websockets and sends audio files diff --git a/cmake/__init__.py b/cmake/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa/bin/decode_mainifest.py b/sherpa/bin/decode_manifest.py similarity index 58% rename from sherpa/bin/decode_mainifest.py rename to sherpa/bin/decode_manifest.py index f98c93bb8..8b820622a 100755 --- a/sherpa/bin/decode_mainifest.py +++ b/sherpa/bin/decode_manifest.py @@ -26,6 +26,7 @@ (Note: You have to first start the server before starting the client) """ +import argparse import asyncio import time @@ -34,13 +35,64 @@ from icefall.utils import store_transcripts, write_error_stats from lhotse import CutSet, load_manifest +DEFAULT_MANIFEST_FILENAME = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/data/fbank/cuts_test-clean.json.gz" -async def send(cuts: CutSet, name: str): + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "--manifest-filename", + type=str, + default=DEFAULT_MANIFEST_FILENAME, + help="Path to the manifest for decoding", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=50, + help="Number of tasks to use for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + return parser.parse_args() + + +async def send( + cuts: CutSet, + name: str, + server_addr: str, + server_port: int, + log_interval: int, +): total_duration = 0.0 results = [] - async with websockets.connect("ws://localhost:6006") as websocket: + async with websockets.connect(f"ws://{server_addr}:{server_port}") as websocket: for i, c in enumerate(cuts): - if i % 5 == 0: + if i % log_interval == 0: print(f"{name}: {i}/{len(cuts)}") samples = c.load_audio().reshape(-1).astype(np.float32) @@ -64,15 +116,22 @@ async def send(cuts: CutSet, name: str): async def main(): - filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/data/fbank/cuts_test-clean.json.gz" + args = get_args() + filename = args.manifest_filename + server_addr = args.server_addr + server_port = args.server_port + num_tasks = args.num_tasks + log_interval = args.log_interval + cuts = load_manifest(filename) - num_tasks = 50 # we start this number of tasks to send the requests cuts_list = cuts.split(num_tasks) tasks = [] start_time = time.time() for i in range(num_tasks): - task = asyncio.create_task(send(cuts_list[i], f"task-{i}")) + task = asyncio.create_task( + send(cuts_list[i], f"task-{i}", server_addr, server_port, log_interval) + ) tasks.append(task) ans_list = await asyncio.gather(*tasks) @@ -90,14 +149,18 @@ async def main(): print(f"RTF: {rtf:.4f}") print( - f"total_duration: {total_duration:.2f} seconds " + f"total_duration: {total_duration:.3f} seconds " f"({total_duration/3600:.2f} hours)" ) print(f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)") store_transcripts(filename="recogs-test-clean.txt", texts=results) with open("errs-test-clean.txt", "w") as f: - wer = write_error_stats(f, "test-set", results, enable_log=True) + write_error_stats(f, "test-set", results, enable_log=True) + + with open("errs-test-clean.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors if __name__ == "__main__":