Skip to content

Commit

Permalink
Re-factor the models used
Browse files Browse the repository at this point in the history
  • Loading branch information
younesselrag committed Sep 9, 2024
1 parent c8f5bce commit 375a609
Show file tree
Hide file tree
Showing 12 changed files with 1,039 additions and 1,017 deletions.
4 changes: 2 additions & 2 deletions src/audiostream/AudioStream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import sounddevice as sd
from transformers import HfArgumentParser
from .Argument import ListenAndPlayArguments
import panads as pd
from Argument import ListenAndPlayArguments


class AudioStreamer:
"""Handles sending and receiving audio data over a network."""
Expand Down
138 changes: 138 additions & 0 deletions src/sTs/Handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import socket
import logging
from time import perf_counter

# Initialize logger
logger = logging.getLogger(__name__)

class BaseHandler:
"""
Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
"""

def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
self.stop_event = stop_event
self.queue_in = queue_in
self.queue_out = queue_out
self.setup(*setup_args, **setup_kwargs)
self._times = []

def setup(self):
pass

def process(self, input_data):
raise NotImplementedError

def run(self):
while not self.stop_event.is_set():
input_data = self.queue_in.get()
if isinstance(input_data, bytes) and input_data == b'END':
# Sentinel signal to avoid queue deadlock
logger.debug("Stopping thread")
break
start_time = perf_counter()
for output in self.process(input_data):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time:.3f} s")
self.queue_out.put(output)
start_time = perf_counter()

self.cleanup()
self.queue_out.put(b'END')

@property
def last_time(self):
return self._times[-1]

def cleanup(self):
pass


class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""

def __init__(
self,
stop_event,
queue_out,
should_listen,
host='0.0.0.0',
port=12345,
chunk_size=1024
):
self.stop_event = stop_event
self.queue_out = queue_out
self.should_listen = should_listen
self.chunk_size = chunk_size
self.host = host
self.port = port

def receive_full_chunk(self, conn, chunk_size):
data = b''
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
return None # Connection closed
data += packet
return data

def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info('Receiver waiting to be connected...')
self.conn, _ = self.socket.accept()
logger.info("Receiver connected")

self.should_listen.set()
while not self.stop_event.is_set():
audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
if audio_chunk is None:
self.queue_out.put(b'END')
break
if self.should_listen.is_set():
self.queue_out.put(audio_chunk)
self.conn.close()
logger.info("Receiver closed")


class SocketSender:
"""
Handles sending generated audio packets to the client.
"""

def __init__(
self,
stop_event,
queue_in,
host='0.0.0.0',
port=12346
):
self.stop_event = stop_event
self.queue_in = queue_in
self.host = host
self.port = port

def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info('Sender waiting to be connected...')
self.conn, _ = self.socket.accept()
logger.info("Sender connected")

while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b'END':
break
self.conn.close()
logger.info("Sender closed")
212 changes: 209 additions & 3 deletions src/sTs/SpeechArgument.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import Tuple , Dict , List
from dataclasses import dataclass, field

@dataclass
class ModuleArguments:
log_level: str = field(
Expand All @@ -9,7 +9,6 @@ class ModuleArguments:
}
)


@dataclass
class SocketReceiverArguments:
recv_host: str = field(
Expand Down Expand Up @@ -47,4 +46,211 @@ class SocketSenderArguments:
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
}
)
)

@dataclass
class VADHandlerArguments:
thresh: float = field(
default=0.3,
metadata={
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
}
)
sample_rate: int = field(
default=16000,
metadata={
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
}
)
min_silence_ms: int = field(
default=250,
metadata={
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms."
}
)
min_speech_ms: int = field(
default=500,
metadata={
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
}
)
max_speech_ms: float = field(
default=float('inf'),
metadata={
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
}
)
speech_pad_ms: int = field(
default=30,
metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 30 ms."
}
)

@dataclass
class WhisperSTTHandlerArguments:
stt_model_name: str = field(
default="distil-whisper/distil-large-v3",
metadata={
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
}
)
stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}
)
stt_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}
)
stt_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
}
)
stt_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "The maximum number of new tokens to generate. Default is 128."
}
)
stt_gen_num_beams: int = field(
default=1,
metadata={
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
}
)
stt_gen_return_timestamps: bool = field(
default=False,
metadata={
"help": "Whether to return timestamps with transcriptions. Default is False."
}
)
stt_gen_task: str = field(
default="transcribe",
metadata={
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
}
)
stt_gen_language: str = field(
default="en",
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
}
)

@dataclass
class LanguageModelHandlerArguments:
lm_model_name: str = field(
default="TinyLlama/TinyLlama_v1.1",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
}
)
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}
)
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}
)
user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
}
)
init_chat_role: str = field(
default=None,
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
}
)
init_chat_prompt: str = field(
default="You are a helpful AI assistant.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
}
)
lm_gen_max_new_tokens: int = field(
default=64,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."}
)
lm_gen_temperature: float = field(
default=0.0,
metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."}
)
lm_gen_do_sample: bool = field(
default=False,
metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."}
)
chat_size: int = field(
default=3,
metadata={"help": "Number of messages of the messages to keep for the chat. None for no limitations."}
)

@dataclass
class ParlerTTSHandlerArguments:
tts_model_name: str = field(
default="ylacombe/parler-tts-mini-jenny-30H",
metadata={
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
}
)
tts_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}
)
tts_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}
)
tts_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
}
)
tts_gen_min_new_tokens: int = field(
default=None,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"}
)
tts_gen_max_new_tokens: int = field(
default=512,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"}
)
description: str = field(
default=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
metadata={
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
}
)
play_steps_s: float = field(
default=0.2,
metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
}
)
max_prompt_pad_length: int = field(
default=8,
metadata={
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
}
)
Loading

0 comments on commit 375a609

Please sign in to comment.