diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..774d406e86f5ef792a2cb10f7e06ade556443753 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel + +ENV PYTHONUNBUFFERED 1 + +WORKDIR /usr/src/app + +# Install packages +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . diff --git a/Dockerfile.arm64 b/Dockerfile.arm64 new file mode 100644 index 0000000000000000000000000000000000000000..147fdeb2acacbfffef81c651d0bb7f3920ffda94 --- /dev/null +++ b/Dockerfile.arm64 @@ -0,0 +1,13 @@ +FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 + +ENV PYTHONUNBUFFERED 1 + +WORKDIR /usr/src/app + +# Install packages +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..48999a3db93e53341e398c38ce60775614e7972c --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2024] [The HuggingFace Inc. team] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/LLM/__pycache__/chat.cpython-311.pyc b/LLM/__pycache__/chat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b0a04cc377d1405878dfd9184593aecec965dc Binary files /dev/null and b/LLM/__pycache__/chat.cpython-311.pyc differ diff --git a/LLM/__pycache__/language_model.cpython-311.pyc b/LLM/__pycache__/language_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cdcd5908a4594f2a9b7e96cada8cb40a774979b Binary files /dev/null and b/LLM/__pycache__/language_model.cpython-311.pyc differ diff --git a/LLM/__pycache__/mlx_language_model.cpython-311.pyc b/LLM/__pycache__/mlx_language_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13280c51083ba116b6e61c174cc91e5d07785862 Binary files /dev/null and b/LLM/__pycache__/mlx_language_model.cpython-311.pyc differ diff --git a/LLM/chat.py b/LLM/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ac4fbc266d084f2ffc636564c457162a130f1 --- /dev/null +++ b/LLM/chat.py @@ -0,0 +1,25 @@ +class Chat: + """ + Handles the chat using to avoid OOM issues. + """ + + def __init__(self, size): + self.size = size + self.init_chat_message = None + # maxlen is necessary pair, since a each new step we add an prompt and assitant answer + self.buffer = [] + + def append(self, item): + self.buffer.append(item) + if len(self.buffer) == 2 * (self.size + 1): + self.buffer.pop(0) + self.buffer.pop(0) + + def init_chat(self, init_chat_message): + self.init_chat_message = init_chat_message + + def to_list(self): + if self.init_chat_message: + return [self.init_chat_message] + self.buffer + else: + return self.buffer diff --git a/LLM/language_model.py b/LLM/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a957625b43193f936898ea68e4b48d810e86101 --- /dev/null +++ b/LLM/language_model.py @@ -0,0 +1,144 @@ +from threading import Thread +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + pipeline, + TextIteratorStreamer, +) +import torch + +from LLM.chat import Chat +from baseHandler import BaseHandler +from rich.console import Console +import logging +from nltk import sent_tokenize + +logger = logging.getLogger(__name__) + +console = Console() + + +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} + +class LanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + + def setup( + self, + model_name="microsoft/Phi-3-mini-4k-instruct", + device="cuda", + torch_dtype="float16", + gen_kwargs={}, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + self.pipe = pipeline( + "text-generation", model=self.model, tokenizer=self.tokenizer, device=device + ) + self.streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + self.gen_kwargs = { + "streamer": self.streamer, + "return_full_text": False, + **gen_kwargs, + } + + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + dummy_input_text = "Repeat the word 'home'." + dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["min_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs, + } + + n_steps = 2 + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for _ in range(n_steps): + thread = Thread( + target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs + ) + thread.start() + for _ in self.streamer: + pass + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, prompt): + logger.debug("infering language model...") + language_code = None + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + + self.chat.append({"role": self.user_role, "content": prompt}) + thread = Thread( + target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs + ) + thread.start() + if self.device == "mps": + generated_text = "" + for new_text in self.streamer: + generated_text += new_text + printable_text = generated_text + torch.mps.empty_cache() + else: + generated_text, printable_text = "", "" + for new_text in self.streamer: + generated_text += new_text + printable_text += new_text + sentences = sent_tokenize(printable_text) + if len(sentences) > 1: + yield (sentences[0], language_code) + printable_text = new_text + + self.chat.append({"role": "assistant", "content": generated_text}) + + # don't forget last sentence + yield (printable_text, language_code) diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..82de10214299d4ebe68200256cec73f22974fd02 --- /dev/null +++ b/LLM/mlx_language_model.py @@ -0,0 +1,107 @@ +import logging +from LLM.chat import Chat +from baseHandler import BaseHandler +from mlx_lm import load, stream_generate, generate +from rich.console import Console +import torch + +logger = logging.getLogger(__name__) + +console = Console() + +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} + +class MLXLanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + + def setup( + self, + model_name="microsoft/Phi-3-mini-4k-instruct", + device="mps", + torch_dtype="float16", + gen_kwargs={}, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): + self.model_name = model_name + self.model, self.tokenizer = load(self.model_name) + self.gen_kwargs = gen_kwargs + + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + dummy_input_text = "Repeat the word 'home'." + dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] + + n_steps = 2 + + for _ in range(n_steps): + prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) + generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=self.gen_kwargs["max_new_tokens"], + verbose=False, + ) + + def process(self, prompt): + logger.debug("infering language model...") + language_code = None + + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + + self.chat.append({"role": self.user_role, "content": prompt}) + + # Remove system messages if using a Gemma model + if "gemma" in self.model_name.lower(): + chat_messages = [ + msg for msg in self.chat.to_list() if msg["role"] != "system" + ] + else: + chat_messages = self.chat.to_list() + + prompt = self.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + output = "" + curr_output = "" + for t in stream_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=self.gen_kwargs["max_new_tokens"], + ): + output += t + curr_output += t + if curr_output.endswith((".", "?", "!", "<|end|>")): + yield (curr_output.replace("<|end|>", ""), language_code) + curr_output = "" + generated_text = output.replace("<|end|>", "") + torch.mps.empty_cache() + + self.chat.append({"role": "assistant", "content": generated_text}) \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fde24ccfdcb3eb1f4db985f86c2da1bbfa0d39b0 --- /dev/null +++ b/README.md @@ -0,0 +1,244 @@ +
+
 
+ +
+ +# Speech To Speech: an effort for an open-sourced and modular GPT4-o + + +## ๐Ÿ“– Quick Index +* [Approach](#approach) + - [Structure](#structure) + - [Modularity](#modularity) +* [Setup](#setup) +* [Usage](#usage) + - [Docker Server approach](#docker-server) + - [Server/Client approach](#serverclient-approach) + - [Local approach](#local-approach-running-on-mac) +* [Command-line usage](#command-line-usage) + - [Model parameters](#model-parameters) + - [Generation parameters](#generation-parameters) + - [Notable parameters](#notable-parameters) + +## Approach + +### Structure +This repository implements a speech-to-speech cascaded pipeline with consecutive parts: +1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad) +2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper)) +3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! ๐Ÿค— +4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)๐Ÿค— + +### Modularity +The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows: +- **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad). +- **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr). +- **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it. +- **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used. + +The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs. + +## Setup + +Clone the repository: +```bash +git clone https://github.com/huggingface/speech-to-speech.git +cd speech-to-speech +``` + +Install the required dependencies using [uv](https://github.com/astral-sh/uv): +```bash +uv pip install -r requirements.txt +``` + +For Mac users, use the `requirements_mac.txt` file instead: +```bash +uv pip install -r requirements_mac.txt +``` + +If you want to use Melo TTS, you also need to run: +```bash +python -m unidic download +``` + + +## Usage + +The pipeline can be run in two ways: +- **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. +- **Local approach**: Runs locally. + +### Docker Server + +#### Install the NVIDIA Container Toolkit + +https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html + +#### Start the docker container +```docker compose up``` + +### Server/Client Approach + +1. Run the pipeline on the server: + ```bash + python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 + ``` + +2. Run the client locally to handle microphone input and receive generated audio: + ```bash + python listen_and_play.py --host + ``` + +### Local Approach (Mac) + +1. For optimal settings on Mac: + ```bash + python s2s_pipeline.py --local_mac_optimal_settings + ``` + +This setting: + - Adds `--device mps` to use MPS for all models. + - Sets LightningWhisperMLX for STT + - Sets MLX LM for language model + - Sets MeloTTS for TTS + +### Recommended usage with Cuda + +Leverage Torch Compile for Whisper and Parler-TTS: + +```bash +python s2s_pipeline.py \ + --recv_host 0.0.0.0 \ + --send_host 0.0.0.0 \ + --lm_model_name microsoft/Phi-3-mini-4k-instruct \ + --init_chat_role system \ + --stt_compile_mode reduce-overhead \ + --tts_compile_mode default +``` + +For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`). + + +### Multi-language Support + +The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups: + +#### With the server version: + + +For automatic language detection: + +```bash +python s2s_pipeline.py \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ +``` + +Or for one language in particular, chinese in this example + +```bash +python s2s_pipeline.py \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ +``` + +#### Local Mac Setup + +For automatic language detection: + +```bash +python s2s_pipeline.py \ + --local_mac_optimal_settings \ + --device mps \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ +``` + +Or for one language in particular, chinese in this example + +```bash +python s2s_pipeline.py \ + --local_mac_optimal_settings \ + --device mps \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ +``` + + +## Command-line Usage + +### Model Parameters + +`model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix: +- `stt` (Speech to Text) +- `lm` (Language Model) +- `tts` (Text to Speech) + +For example: +```bash +--lm_model_name google/gemma-2b-it +``` + +### Generation Parameters + +Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example). + +### Notable Parameters + +#### VAD Parameters +- `--thresh`: Threshold value to trigger voice activity detection. +- `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech. +- `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction. + +#### Language Model +- `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`) +- `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`. + +#### Speech to Text +- `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."` + +- `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps. + +## Citations + +### Silero VAD +```bibtex +@misc{Silero VAD, + author = {Silero Team}, + title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, + year = {2021}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/snakers4/silero-vad}}, + commit = {insert_some_commit_here}, + email = {hello@silero.ai} +} +``` + +### Distil-Whisper +```bibtex +@misc{gandhi2023distilwhisper, + title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, + author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush}, + year={2023}, + eprint={2311.00430}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + +### Parler-TTS +```bibtex +@misc{lacombe-etal-2024-parler-tts, + author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi}, + title = {Parler-TTS}, + year = {2024}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/parler-tts}} +} +``` diff --git a/STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc b/STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a3ff32252978a327324f2bbd680644167301620 Binary files /dev/null and b/STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc differ diff --git a/STT/__pycache__/paraformer_handler.cpython-311.pyc b/STT/__pycache__/paraformer_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b051713b1646803006f7597b597fd21f12c44298 Binary files /dev/null and b/STT/__pycache__/paraformer_handler.cpython-311.pyc differ diff --git a/STT/__pycache__/whisper_stt_handler.cpython-311.pyc b/STT/__pycache__/whisper_stt_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..745b533aded28d7dd03cede57c63e9dc8784cba0 Binary files /dev/null and b/STT/__pycache__/whisper_stt_handler.cpython-311.pyc differ diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6f9fbb217bb32b227096e4594f6e648fa879d37a --- /dev/null +++ b/STT/lightning_whisper_mlx_handler.py @@ -0,0 +1,85 @@ +import logging +from time import perf_counter +from baseHandler import BaseHandler +from lightning_whisper_mlx import LightningWhisperMLX +import numpy as np +from rich.console import Console +from copy import copy +import torch + +logger = logging.getLogger(__name__) + +console = Console() + +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + + +class LightningWhisperSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + + def setup( + self, + model_name="distil-large-v3", + device="mps", + torch_dtype="float16", + compile_mode=None, + language=None, + gen_kwargs={}, + ): + if len(model_name.split("/")) > 1: + model_name = model_name.split("/")[-1] + self.device = device + self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) + self.start_language = language + self.last_language = language + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 + dummy_input = np.array([0] * 512) + + for _ in range(n_steps): + _ = self.model.transcribe(dummy_input)["text"].strip() + + def process(self, spoken_prompt): + logger.debug("infering whisper...") + + global pipeline_start + pipeline_start = perf_counter() + + if self.start_language != 'auto': + transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) + else: + transcription_dict = self.model.transcribe(spoken_prompt) + language_code = transcription_dict["language"] + if language_code not in SUPPORTED_LANGUAGES: + logger.warning(f"Whisper detected unsupported language: {language_code}") + if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language + transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) + else: + transcription_dict = {"text": "", "language": "en"} + else: + self.last_language = language_code + + pred_text = transcription_dict["text"].strip() + language_code = transcription_dict["language"] + torch.mps.empty_cache() + + logger.debug("finished whisper inference") + console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") + + yield (pred_text, language_code) diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd6ac7912cc326472a31541a4ffdd9d8d79649 --- /dev/null +++ b/STT/paraformer_handler.py @@ -0,0 +1,61 @@ +import logging +from time import perf_counter + +from baseHandler import BaseHandler +from funasr import AutoModel +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ParaformerSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Paraformer model. + The default for this model is set to Chinese. + This model was contributed by @wuhongsheng. + """ + + def setup( + self, + model_name="paraformer-zh", + device="cuda", + gen_kwargs={}, + ): + print(model_name) + if len(model_name.split("/")) > 1: + model_name = model_name.split("/")[-1] + self.device = device + self.model = AutoModel(model=model_name, device=device) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 + dummy_input = np.array([0] * 512, dtype=np.float32) + for _ in range(n_steps): + _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "") + + def process(self, spoken_prompt): + logger.debug("infering paraformer...") + + global pipeline_start + pipeline_start = perf_counter() + + pred_text = ( + self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "") + ) + torch.mps.empty_cache() + + logger.debug("finished paraformer inference") + console.print(f"[yellow]USER: {pred_text}") + + yield pred_text diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..06cf613b01156aadcf0ccee9533c0d7aff039930 --- /dev/null +++ b/STT/whisper_stt_handler.py @@ -0,0 +1,140 @@ +from time import perf_counter +from transformers import ( + AutoProcessor, + AutoModelForSpeechSeq2Seq +) +import torch +from copy import copy +from baseHandler import BaseHandler +from rich.console import Console +import logging + +logger = logging.getLogger(__name__) +console = Console() + +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + + +class WhisperSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + + def setup( + self, + model_name="distil-whisper/distil-large-v3", + device="cuda", + torch_dtype="float16", + compile_mode=None, + language=None, + gen_kwargs={}, + ): + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + self.compile_mode = compile_mode + self.gen_kwargs = gen_kwargs + if language == 'auto': + language = None + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language + + self.processor = AutoProcessor.from_pretrained(model_name) + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_name, + torch_dtype=self.torch_dtype, + ).to(device) + + # compile + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) + self.warmup() + + def prepare_model_inputs(self, spoken_prompt): + input_features = self.processor( + spoken_prompt, sampling_rate=16000, return_tensors="pt" + ).input_features + input_features = input_features.to(self.device, dtype=self.torch_dtype) + + return input_features + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + dummy_input = torch.randn( + (1, self.model.config.num_mel_bins, 3000), + dtype=self.torch_dtype, + device=self.device, + ) + if self.compile_mode not in (None, "default"): + # generating more tokens than previously will trigger CUDA graphs capture + # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs[ + "max_new_tokens" + ], # Yes, assign max_new_tokens to min_new_tokens + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs, + } + else: + warmup_gen_kwargs = self.gen_kwargs + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for _ in range(n_steps): + _ = self.model.generate(dummy_input, **warmup_gen_kwargs) + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, spoken_prompt): + logger.debug("infering whisper...") + + global pipeline_start + pipeline_start = perf_counter() + + input_features = self.prepare_model_inputs(spoken_prompt) + pred_ids = self.model.generate(input_features, **self.gen_kwargs) + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + + if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_code) + gen_kwargs = copy(self.gen_kwargs) + gen_kwargs['language'] = self.last_language + language_code = self.last_language + pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_code + + pred_text = self.processor.batch_decode( + pred_ids, skip_special_tokens=True, decode_with_timestamps=False + )[0] + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + + logger.debug("finished whisper inference") + console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") + + yield (pred_text, language_code) diff --git a/TTS/__pycache__/chatTTS_handler.cpython-311.pyc b/TTS/__pycache__/chatTTS_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a33abd8290f4feaf4ebf1229e93209583018540 Binary files /dev/null and b/TTS/__pycache__/chatTTS_handler.cpython-311.pyc differ diff --git a/TTS/__pycache__/melo_handler.cpython-311.pyc b/TTS/__pycache__/melo_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae9a0687c1d0c504c7ec996138d3b364c7f6b98 Binary files /dev/null and b/TTS/__pycache__/melo_handler.cpython-311.pyc differ diff --git a/TTS/__pycache__/parler_handler.cpython-311.pyc b/TTS/__pycache__/parler_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27cbae7a0618e089114cefb5e16a9de6b774047a Binary files /dev/null and b/TTS/__pycache__/parler_handler.cpython-311.pyc differ diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdc6bfe31929c930726df594fdf296ab3e21ce7 --- /dev/null +++ b/TTS/chatTTS_handler.py @@ -0,0 +1,82 @@ +import ChatTTS +import logging +from baseHandler import BaseHandler +import librosa +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ChatTTSHandler(BaseHandler): + def setup( + self, + should_listen, + device="cuda", + gen_kwargs={}, # Unused + stream=True, + chunk_size=512, + ): + self.should_listen = should_listen + self.device = device + self.model = ChatTTS.Chat() + self.model.load(compile=False) # Doesn't work for me with True + self.chunk_size = chunk_size + self.stream = stream + rnd_spk_emb = self.model.sample_random_speaker() + self.params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb=rnd_spk_emb, + ) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + _ = self.model.infer("text") + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + if self.device == "mps": + import time + + start = time.time() + torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. + torch.mps.empty_cache() # Frees all memory allocated by the MPS device. + _ = ( + time.time() - start + ) # Removing this line makes it fail more often. I'm looking into it. + + wavs_gen = self.model.infer( + llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream + ) + + if self.stream: + wavs = [np.array([])] + for gen in wavs_gen: + if gen[0] is None or len(gen[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] + while len(audio_chunk) > self.chunk_size: + yield audio_chunk[: self.chunk_size] # ่ฟ”ๅ›žๅ‰ chunk_size ๅญ—่Š‚็š„ๆ•ฐๆฎ + audio_chunk = audio_chunk[self.chunk_size :] # ็งป้™คๅทฒ่ฟ”ๅ›ž็š„ๆ•ฐๆฎ + yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) + else: + wavs = wavs_gen + if len(wavs[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.chunk_size): + yield np.pad( + audio_chunk[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), + ) + self.should_listen.set() diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd50f1330d49f6edc7b493bcf292b3b158304a6 --- /dev/null +++ b/TTS/melo_handler.py @@ -0,0 +1,109 @@ +from melo.api import TTS +import logging +from baseHandler import BaseHandler +import librosa +import numpy as np +from rich.console import Console +import torch + +logger = logging.getLogger(__name__) + +console = Console() + +WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { + "en": "EN", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + +WHISPER_LANGUAGE_TO_MELO_SPEAKER = { + "en": "EN-BR", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + + +class MeloTTSHandler(BaseHandler): + def setup( + self, + should_listen, + device="mps", + language="en", + speaker_to_id="en", + gen_kwargs={}, # Unused + blocksize=512, + ): + self.should_listen = should_listen + self.device = device + self.language = language + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] + ] + self.blocksize = blocksize + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) + + def process(self, llm_sentence): + language_code = None + + if isinstance(llm_sentence, tuple): + llm_sentence, language_code = llm_sentence + + console.print(f"[green]ASSISTANT: {llm_sentence}") + + if language_code is not None and self.language != language_code: + try: + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], + device=self.device, + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] + ] + self.language = language_code + except KeyError: + console.print( + f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." + ) + + if self.device == "mps": + import time + + start = time.time() + torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. + torch.mps.empty_cache() # Frees all memory allocated by the MPS device. + _ = ( + time.time() - start + ) # Removing this line makes it fail more often. I'm looking into it. + + try: + audio_chunk = self.model.tts_to_file( + llm_sentence, self.speaker_id, quiet=True + ) + except (AssertionError, RuntimeError) as e: + logger.error(f"Error in MeloTTSHandler: {e}") + audio_chunk = np.array([]) + if len(audio_chunk) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.blocksize): + yield np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), + ) + + self.should_listen.set() diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ac539c7d709cc17533c81a47f155e4a19348b822 --- /dev/null +++ b/TTS/parler_handler.py @@ -0,0 +1,191 @@ +from threading import Thread +from time import perf_counter +from baseHandler import BaseHandler +import numpy as np +import torch +from transformers import ( + AutoTokenizer, +) +from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer +import librosa +import logging +from rich.console import Console +from utils.utils import next_power_of_2 +from transformers.utils.import_utils import ( + is_flash_attn_2_available, +) + +torch._inductor.config.fx_graph_cache = True +# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS +torch._dynamo.config.cache_size_limit = 15 + +logger = logging.getLogger(__name__) + +console = Console() + + +if not is_flash_attn_2_available() and torch.cuda.is_available(): + logger.warn( + """Parler TTS works best with flash attention 2, but is not installed + Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" + ) + + +class ParlerTTSHandler(BaseHandler): + def setup( + self, + should_listen, + model_name="ylacombe/parler-tts-mini-jenny-30H", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + max_prompt_pad_length=8, + description=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + play_steps_s=1, + blocksize=512, + ): + self.should_listen = should_listen + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + self.gen_kwargs = gen_kwargs + self.compile_mode = compile_mode + self.max_prompt_pad_length = max_prompt_pad_length + self.description = description + + self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = ParlerTTSForConditionalGeneration.from_pretrained( + model_name, torch_dtype=self.torch_dtype + ).to(device) + + framerate = self.model.audio_encoder.config.frame_rate + self.play_steps = int(framerate * play_steps_s) + self.blocksize = blocksize + + if self.compile_mode not in (None, "default"): + logger.warning( + "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'" + ) + self.compile_mode = "default" + + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) + + self.warmup() + + def prepare_model_inputs( + self, + prompt, + max_length_prompt=50, + pad=False, + ): + pad_args_prompt = ( + {"padding": "max_length", "max_length": max_length_prompt} if pad else {} + ) + + tokenized_description = self.description_tokenizer( + self.description, return_tensors="pt" + ) + input_ids = tokenized_description.input_ids.to(self.device) + attention_mask = tokenized_description.attention_mask.to(self.device) + + tokenized_prompt = self.prompt_tokenizer( + prompt, return_tensors="pt", **pad_args_prompt + ) + prompt_input_ids = tokenized_prompt.input_ids.to(self.device) + prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) + + gen_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prompt_input_ids": prompt_input_ids, + "prompt_attention_mask": prompt_attention_mask, + **self.gen_kwargs, + } + + return gen_kwargs + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + + if self.device == "cuda": + torch.cuda.synchronize() + start_event.record() + if self.compile_mode: + pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] + for pad_length in pad_lengths[::-1]: + model_kwargs = self.prepare_model_inputs( + "dummy prompt", max_length_prompt=pad_length, pad=True + ) + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + logger.info(f"Warmed up length {pad_length} tokens!") + else: + model_kwargs = self.prepare_model_inputs("dummy prompt") + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, llm_sentence): + if isinstance(llm_sentence, tuple): + llm_sentence, _ = llm_sentence + + console.print(f"[green]ASSISTANT: {llm_sentence}") + nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) + + pad_args = {} + if self.compile_mode: + # pad to closest upper power of two + pad_length = next_power_of_2(nb_tokens) + logger.debug(f"padding to {pad_length}") + pad_args["pad"] = True + pad_args["max_length_prompt"] = pad_length + + tts_gen_kwargs = self.prepare_model_inputs( + llm_sentence, + **pad_args, + ) + + streamer = ParlerTTSStreamer( + self.model, device=self.device, play_steps=self.play_steps + ) + tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} + torch.manual_seed(0) + thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) + thread.start() + + for i, audio_chunk in enumerate(streamer): + global pipeline_start + if i == 0 and "pipeline_start" in globals(): + logger.info( + f"Time to first audio: {perf_counter() - pipeline_start:.3f}" + ) + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.blocksize): + yield np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), + ) + + self.should_listen.set() diff --git a/VAD/__pycache__/vad_handler.cpython-311.pyc b/VAD/__pycache__/vad_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..921d84b3d80e15338deca4321e22d1ed5d2fbdd0 Binary files /dev/null and b/VAD/__pycache__/vad_handler.cpython-311.pyc differ diff --git a/VAD/__pycache__/vad_handler.cpython-312.pyc b/VAD/__pycache__/vad_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08a5606ffedc86b01ef0c14aebbf571a8fd19e0f Binary files /dev/null and b/VAD/__pycache__/vad_handler.cpython-312.pyc differ diff --git a/VAD/__pycache__/vad_iterator.cpython-311.pyc b/VAD/__pycache__/vad_iterator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..418ab572cf73ee638d6f3b6540162fdea2312a1f Binary files /dev/null and b/VAD/__pycache__/vad_iterator.cpython-311.pyc differ diff --git a/VAD/__pycache__/vad_iterator.cpython-312.pyc b/VAD/__pycache__/vad_iterator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..643c9b89a0b94162303617de4a9f427d1f6df28d Binary files /dev/null and b/VAD/__pycache__/vad_iterator.cpython-312.pyc differ diff --git a/VAD/vad_handler.py b/VAD/vad_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5c6acbe6dde672a5259f57a019b36785e14ed3 --- /dev/null +++ b/VAD/vad_handler.py @@ -0,0 +1,92 @@ +import torchaudio +from VAD.vad_iterator import VADIterator +from baseHandler import BaseHandler +import numpy as np +import torch +from rich.console import Console + +from utils.utils import int2float +from df.enhance import enhance, init_df +import logging + +logger = logging.getLogger(__name__) + +console = Console() + + +class VADHandler(BaseHandler): + """ + Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed + to the following part. + """ + + def setup( + self, + should_listen, + thresh=0.3, + sample_rate=16000, + min_silence_ms=1000, + min_speech_ms=500, + max_speech_ms=float("inf"), + speech_pad_ms=30, + audio_enhancement=False, + ): + self.should_listen = should_listen + self.sample_rate = sample_rate + self.min_silence_ms = min_silence_ms + self.min_speech_ms = min_speech_ms + self.max_speech_ms = max_speech_ms + self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") + self.iterator = VADIterator( + self.model, + threshold=thresh, + sampling_rate=sample_rate, + min_silence_duration_ms=min_silence_ms, + speech_pad_ms=speech_pad_ms, + ) + self.audio_enhancement = audio_enhancement + if audio_enhancement: + self.enhanced_model, self.df_state, _ = init_df() + + def process(self, audio_chunk): + audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) + audio_float32 = int2float(audio_int16) + vad_output = self.iterator(torch.from_numpy(audio_float32)) + if vad_output is not None and len(vad_output) != 0: + logger.debug("VAD: end of speech detected") + array = torch.cat(vad_output).cpu().numpy() + duration_ms = len(array) / self.sample_rate * 1000 + if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: + logger.debug( + f"audio input of duration: {len(array) / self.sample_rate}s, skipping" + ) + else: + self.should_listen.clear() + logger.debug("Stop listening") + if self.audio_enhancement: + if self.sample_rate != self.df_state.sr(): + audio_float32 = torchaudio.functional.resample( + torch.from_numpy(array), + orig_freq=self.sample_rate, + new_freq=self.df_state.sr(), + ) + enhanced = enhance( + self.enhanced_model, + self.df_state, + audio_float32.unsqueeze(0), + ) + enhanced = torchaudio.functional.resample( + enhanced, + orig_freq=self.df_state.sr(), + new_freq=self.sample_rate, + ) + else: + enhanced = enhance( + self.enhanced_model, self.df_state, audio_float32 + ) + array = enhanced.numpy().squeeze() + yield array + + @property + def min_time_to_debug(self): + return 0.00001 diff --git a/VAD/vad_iterator.py b/VAD/vad_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..bd272f1dd7bfb432b5323e8906730e158ad5c2ff --- /dev/null +++ b/VAD/vad_iterator.py @@ -0,0 +1,100 @@ +import torch + + +class VADIterator: + def __init__( + self, + model, + threshold: float = 0.5, + sampling_rate: int = 16000, + min_silence_duration_ms: int = 100, + speech_pad_ms: int = 30, + ): + """ + Mainly taken from https://github.com/snakers4/silero-vad + Class for stream imitation + + Parameters + ---------- + model: preloaded .jit/.onnx silero VAD model + + threshold: float (default - 0.5) + Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. + It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + + sampling_rate: int (default - 16000) + Currently silero VAD models support 8000 and 16000 sample rates + + min_silence_duration_ms: int (default - 100 milliseconds) + In the end of each speech chunk wait for min_silence_duration_ms before separating it + + speech_pad_ms: int (default - 30 milliseconds) + Final speech chunks are padded by speech_pad_ms each side + """ + + self.model = model + self.threshold = threshold + self.sampling_rate = sampling_rate + self.is_speaking = False + self.buffer = [] + + if sampling_rate not in [8000, 16000]: + raise ValueError( + "VADIterator does not support sampling rates other than [8000, 16000]" + ) + + self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + self.reset_states() + + def reset_states(self): + self.model.reset_states() + self.triggered = False + self.temp_end = 0 + self.current_sample = 0 + + @torch.no_grad() + def __call__(self, x): + """ + x: torch.Tensor + audio chunk (see examples in repo) + + return_seconds: bool (default - False) + whether return timestamps in seconds (default - samples) + """ + + if not torch.is_tensor(x): + try: + x = torch.Tensor(x) + except Exception: + raise TypeError("Audio cannot be casted to tensor. Cast it manually") + + window_size_samples = len(x[0]) if x.dim() == 2 else len(x) + self.current_sample += window_size_samples + + speech_prob = self.model(x, self.sampling_rate).item() + + if (speech_prob >= self.threshold) and self.temp_end: + self.temp_end = 0 + + if (speech_prob >= self.threshold) and not self.triggered: + self.triggered = True + return None + + if (speech_prob < self.threshold - 0.15) and self.triggered: + if not self.temp_end: + self.temp_end = self.current_sample + if self.current_sample - self.temp_end < self.min_silence_samples: + return None + else: + # end of speak + self.temp_end = 0 + self.triggered = False + spoken_utterance = self.buffer + self.buffer = [] + return spoken_utterance + + if self.triggered: + self.buffer.append(x) + + return None diff --git a/arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc b/arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b329701739820ab1cf0e5d86c4d9ed5c0e48c7a0 Binary files /dev/null and b/arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc b/arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ac5ba09592b69cb768313d773d9bb482f167e91 Binary files /dev/null and b/arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc b/arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dddda85a91b66466a687d75d53c55d8c6535544c Binary files /dev/null and b/arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc b/arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44343f167c8356ba017270bf14bc1530fed5398e Binary files /dev/null and b/arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/module_arguments.cpython-311.pyc b/arguments_classes/__pycache__/module_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fee196ad5bfe9465484558e99f47d406d0c9650 Binary files /dev/null and b/arguments_classes/__pycache__/module_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc b/arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd9cba8240a11b3f6d4a54fec1c44a67c2836ff Binary files /dev/null and b/arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc b/arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e1ff33f417e9e67c9d86e72d4588f9f4490437 Binary files /dev/null and b/arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc b/arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d391c3baa3f78cc6255f75c8b402799f6b4ccc4 Binary files /dev/null and b/arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc b/arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b5a3ae338e0ed54dcc735406afcb2e3490776c Binary files /dev/null and b/arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/vad_arguments.cpython-311.pyc b/arguments_classes/__pycache__/vad_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7699012203d7dca3573d092a3d74c986166ea47 Binary files /dev/null and b/arguments_classes/__pycache__/vad_arguments.cpython-311.pyc differ diff --git a/arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc b/arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbcca11127d073b97155f90879dab0ddc80b8711 Binary files /dev/null and b/arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc differ diff --git a/arguments_classes/chat_tts_arguments.py b/arguments_classes/chat_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..bccce27176a4e2e818a2285ebdfa2c2cd63d69c9 --- /dev/null +++ b/arguments_classes/chat_tts_arguments.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field + + +@dataclass +class ChatTTSHandlerArguments: + chat_tts_stream: bool = field( + default=True, + metadata={"help": "The tts mode is stream Default is 'stream'."}, + ) + chat_tts_device: str = field( + default="cuda", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'cuda'." + }, + ) + chat_tts_chunk_size: int = field( + default=512, + metadata={ + "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512ใ€‚." + }, + ) diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/language_model_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..8680a78fe07ec2d7cb5f6e1a3dcbb0ed3319ac6b --- /dev/null +++ b/arguments_classes/language_model_arguments.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass, field + + +@dataclass +class LanguageModelHandlerArguments: + lm_model_name: str = field( + default="HuggingFaceTB/SmolLM-360M-Instruct", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + }, + ) + lm_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + lm_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + init_chat_prompt: str = field( + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + lm_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, + ) + lm_gen_min_new_tokens: int = field( + default=0, + metadata={ + "help": "Minimum number of new tokens to generate in a single completion. Default is 0." + }, + ) + lm_gen_temperature: float = field( + default=0.0, + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, + ) + lm_gen_do_sample: bool = field( + default=False, + metadata={ + "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." + }, + ) + chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) diff --git a/arguments_classes/melo_tts_arguments.py b/arguments_classes/melo_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..7223489318f843c2919b7d3929580040c88307aa --- /dev/null +++ b/arguments_classes/melo_tts_arguments.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field + + +@dataclass +class MeloTTSHandlerArguments: + melo_language: str = field( + default="en", + metadata={ + "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." + }, + ) + melo_device: str = field( + default="auto", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'auto'." + }, + ) + melo_speaker_to_id: str = field( + default="en", + metadata={ + "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." + }, + ) diff --git a/arguments_classes/mlx_language_model_arguments.py b/arguments_classes/mlx_language_model_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..0765ec9ce76ac2e4452f13d1d3f7574d977e823b --- /dev/null +++ b/arguments_classes/mlx_language_model_arguments.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field + + +@dataclass +class MLXLanguageModelHandlerArguments: + mlx_lm_model_name: str = field( + default="mlx-community/SmolLM-360M-Instruct", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + }, + ) + mlx_lm_device: str = field( + default="mps", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + mlx_lm_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + mlx_lm_user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + mlx_lm_init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + mlx_lm_init_chat_prompt: str = field( + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + mlx_lm_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, + ) + mlx_lm_gen_temperature: float = field( + default=0.0, + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, + ) + mlx_lm_gen_do_sample: bool = field( + default=False, + metadata={ + "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." + }, + ) + mlx_lm_chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..bdaa646b4578439e47f7e36104d168910b2fe602 --- /dev/null +++ b/arguments_classes/module_arguments.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModuleArguments: + device: Optional[str] = field( + default=None, + metadata={"help": "If specified, overrides the device for all handlers."}, + ) + mode: Optional[str] = field( + default="socket", + metadata={ + "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'." + }, + ) + local_mac_optimal_settings: bool = field( + default=False, + metadata={ + "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." + }, + ) + stt: Optional[str] = field( + default="whisper", + metadata={ + "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'." + }, + ) + llm: Optional[str] = field( + default="transformers", + metadata={ + "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" + }, + ) + tts: Optional[str] = field( + default="parler", + metadata={ + "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" + }, + ) + log_level: str = field( + default="info", + metadata={ + "help": "Provide logging level. Example --log_level debug, default=warning." + }, + ) diff --git a/arguments_classes/paraformer_stt_arguments.py b/arguments_classes/paraformer_stt_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..a57a66abfbc6eb1f95868364a69e106919299032 --- /dev/null +++ b/arguments_classes/paraformer_stt_arguments.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass, field + + +@dataclass +class ParaformerSTTHandlerArguments: + paraformer_stt_model_name: str = field( + default="paraformer-zh", + metadata={ + "help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR" + }, + ) + paraformer_stt_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) diff --git a/arguments_classes/parler_tts_arguments.py b/arguments_classes/parler_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..515943209d2068bc0e2a451eb08d1cce68cfb67d --- /dev/null +++ b/arguments_classes/parler_tts_arguments.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field + + +@dataclass +class ParlerTTSHandlerArguments: + tts_model_name: str = field( + default="ylacombe/parler-tts-mini-jenny-30H", + metadata={ + "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." + }, + ) + tts_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + tts_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + tts_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + }, + ) + tts_gen_min_new_tokens: int = field( + default=64, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" + }, + ) + tts_gen_max_new_tokens: int = field( + default=512, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" + }, + ) + description: str = field( + default=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + metadata={ + "help": "Description of the speaker's voice and speaking style to guide the TTS model." + }, + ) + play_steps_s: float = field( + default=1.0, + metadata={ + "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." + }, + ) + max_prompt_pad_length: int = field( + default=8, + metadata={ + "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." + }, + ) diff --git a/arguments_classes/socket_receiver_arguments.py b/arguments_classes/socket_receiver_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..2884edd89abdcbc2275ff7aa1f8dc08879a963af --- /dev/null +++ b/arguments_classes/socket_receiver_arguments.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass, field + + +@dataclass +class SocketReceiverArguments: + recv_host: str = field( + default="localhost", + metadata={ + "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + }, + ) + recv_port: int = field( + default=12345, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + }, + ) + chunk_size: int = field( + default=1024, + metadata={ + "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." + }, + ) diff --git a/arguments_classes/socket_sender_arguments.py b/arguments_classes/socket_sender_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..8777f6cff6d2edef6e323dfe1db76e884af53ac5 --- /dev/null +++ b/arguments_classes/socket_sender_arguments.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field + + +@dataclass +class SocketSenderArguments: + send_host: str = field( + default="localhost", + metadata={ + "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + }, + ) + send_port: int = field( + default=12346, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + }, + ) diff --git a/arguments_classes/vad_arguments.py b/arguments_classes/vad_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b523a37d0470c12c34990841c4c25208e87f34 --- /dev/null +++ b/arguments_classes/vad_arguments.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass, field + + +@dataclass +class VADHandlerArguments: + thresh: float = field( + default=0.3, + metadata={ + "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." + }, + ) + sample_rate: int = field( + default=16000, + metadata={ + "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." + }, + ) + min_silence_ms: int = field( + default=250, + metadata={ + "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." + }, + ) + min_speech_ms: int = field( + default=500, + metadata={ + "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." + }, + ) + max_speech_ms: float = field( + default=float("inf"), + metadata={ + "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." + }, + ) + speech_pad_ms: int = field( + default=500, + metadata={ + "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." + }, + ) + audio_enhancement: bool = field( + default=False, + metadata={ + "help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False." + }, + ) diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc700bf24e2320d0065ab6db40c0adbcf4782b5 --- /dev/null +++ b/arguments_classes/whisper_stt_arguments.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class WhisperSTTHandlerArguments: + stt_model_name: str = field( + default="distil-whisper/distil-large-v3", + metadata={ + "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." + }, + ) + stt_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + stt_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + stt_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + }, + ) + stt_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "The maximum number of new tokens to generate. Default is 128." + }, + ) + stt_gen_num_beams: int = field( + default=1, + metadata={ + "help": "The number of beams for beam search. Default is 1, implying greedy decoding." + }, + ) + stt_gen_return_timestamps: bool = field( + default=False, + metadata={ + "help": "Whether to return timestamps with transcriptions. Default is False." + }, + ) + stt_gen_task: str = field( + default="transcribe", + metadata={ + "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." + }, + ) + language: Optional[str] = field( + default='en', + metadata={ + "help": """The language for the conversation. + Choose between 'en' (english), 'fr' (french), 'es' (spanish), + 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'. + If using 'auto', the language is automatically detected and can + change during the conversation. Default is 'en'.""" + }, + ) \ No newline at end of file diff --git a/baseHandler.py b/baseHandler.py new file mode 100644 index 0000000000000000000000000000000000000000..61532e4705ed7efee02c4c2b4079a3c68526bfde --- /dev/null +++ b/baseHandler.py @@ -0,0 +1,56 @@ +from time import perf_counter +import logging + +logger = logging.getLogger(__name__) + + +class BaseHandler: + """ + Base class for pipeline parts. Each part of the pipeline has an input and an output queue. + The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. + To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. + Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. + The cleanup method handles stopping the handler, and b"END" is placed in the output queue. + """ + + def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): + self.stop_event = stop_event + self.queue_in = queue_in + self.queue_out = queue_out + self.setup(*setup_args, **setup_kwargs) + self._times = [] + + def setup(self): + pass + + def process(self): + raise NotImplementedError + + def run(self): + while not self.stop_event.is_set(): + input = self.queue_in.get() + if isinstance(input, bytes) and input == b"END": + # sentinelle signal to avoid queue deadlock + logger.debug("Stopping thread") + break + start_time = perf_counter() + for output in self.process(input): + self._times.append(perf_counter() - start_time) + if self.last_time > self.min_time_to_debug: + logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") + self.queue_out.put(output) + start_time = perf_counter() + + self.cleanup() + self.queue_out.put(b"END") + + @property + def last_time(self): + return self._times[-1] + + @property + def min_time_to_debug(self): + return 0.001 + + def cleanup(self): + pass diff --git a/connections/__pycache__/local_audio_streamer.cpython-311.pyc b/connections/__pycache__/local_audio_streamer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dcf546dd8348434cf729b352134db97547a62bf Binary files /dev/null and b/connections/__pycache__/local_audio_streamer.cpython-311.pyc differ diff --git a/connections/__pycache__/socket_receiver.cpython-311.pyc b/connections/__pycache__/socket_receiver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0e3a97896c257c62ed64ac5cd9a43bd9431c507 Binary files /dev/null and b/connections/__pycache__/socket_receiver.cpython-311.pyc differ diff --git a/connections/__pycache__/socket_sender.cpython-311.pyc b/connections/__pycache__/socket_sender.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1abf8237ad7da3884d91606c101ad77c4fe274f2 Binary files /dev/null and b/connections/__pycache__/socket_sender.cpython-311.pyc differ diff --git a/connections/local_audio_streamer.py b/connections/local_audio_streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..389dcb8e3ee17b7b27c9882004839f136f50f882 --- /dev/null +++ b/connections/local_audio_streamer.py @@ -0,0 +1,44 @@ +import threading +import sounddevice as sd +import numpy as np + +import time +import logging + +logger = logging.getLogger(__name__) + + +class LocalAudioStreamer: + def __init__( + self, + input_queue, + output_queue, + list_play_chunk_size=512, + ): + self.list_play_chunk_size = list_play_chunk_size + + self.stop_event = threading.Event() + self.input_queue = input_queue + self.output_queue = output_queue + + def run(self): + def callback(indata, outdata, frames, time, status): + if self.output_queue.empty(): + self.input_queue.put(indata.copy()) + outdata[:] = 0 * outdata + else: + outdata[:] = self.output_queue.get()[:, np.newaxis] + + logger.debug("Available devices:") + logger.debug(sd.query_devices()) + with sd.Stream( + samplerate=16000, + dtype="int16", + channels=1, + callback=callback, + blocksize=self.list_play_chunk_size, + ): + logger.info("Starting local audio stream") + while not self.stop_event.is_set(): + time.sleep(0.001) + print("Stopping recording") diff --git a/connections/socket_receiver.py b/connections/socket_receiver.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb734cf1fd7967de7ad368b4a26ea4e16880ddb --- /dev/null +++ b/connections/socket_receiver.py @@ -0,0 +1,60 @@ +import socket +from rich.console import Console +import logging + +logger = logging.getLogger(__name__) + +console = Console() + + +class SocketReceiver: + """ + Handles reception of the audio packets from the client. + """ + + def __init__( + self, + stop_event, + queue_out, + should_listen, + host="0.0.0.0", + port=12345, + chunk_size=1024, + ): + self.stop_event = stop_event + self.queue_out = queue_out + self.should_listen = should_listen + self.chunk_size = chunk_size + self.host = host + self.port = port + + def receive_full_chunk(self, conn, chunk_size): + data = b"" + while len(data) < chunk_size: + packet = conn.recv(chunk_size - len(data)) + if not packet: + # connection closed + return None + data += packet + return data + + def run(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind((self.host, self.port)) + self.socket.listen(1) + logger.info("Receiver waiting to be connected...") + self.conn, _ = self.socket.accept() + logger.info("receiver connected") + + self.should_listen.set() + while not self.stop_event.is_set(): + audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) + if audio_chunk is None: + # connection closed + self.queue_out.put(b"END") + break + if self.should_listen.is_set(): + self.queue_out.put(audio_chunk) + self.conn.close() + logger.info("Receiver closed") diff --git a/connections/socket_sender.py b/connections/socket_sender.py new file mode 100644 index 0000000000000000000000000000000000000000..11ed21083b25e9d31de382308324b675ea6d5087 --- /dev/null +++ b/connections/socket_sender.py @@ -0,0 +1,36 @@ +import socket +from rich.console import Console +import logging + +logger = logging.getLogger(__name__) + +console = Console() + + +class SocketSender: + """ + Handles sending generated audio packets to the clients. + """ + + def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): + self.stop_event = stop_event + self.queue_in = queue_in + self.host = host + self.port = port + + def run(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind((self.host, self.port)) + self.socket.listen(1) + logger.info("Sender waiting to be connected...") + self.conn, _ = self.socket.accept() + logger.info("sender connected") + + while not self.stop_event.is_set(): + audio_chunk = self.queue_in.get() + self.conn.sendall(audio_chunk) + if isinstance(audio_chunk, bytes) and audio_chunk == b"END": + break + self.conn.close() + logger.info("Sender closed") diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..a5406c0adfeee62eec680e413e9189f5d0665ada --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,40 @@ +--- +services: + + pipeline: + build: + context: . + dockerfile: ${DOCKERFILE:-Dockerfile} + command: + - python3 + - s2s_pipeline.py + - --recv_host + - 0.0.0.0 + - --send_host + - 0.0.0.0 + - --lm_model_name + - microsoft/Phi-3-mini-4k-instruct + - --init_chat_role + - system + - --init_chat_prompt + - "You are a helpful assistant" + - --stt_compile_mode + - reduce-overhead + - --tts_compile_mode + - default + expose: + - 12345/tcp + - 12346/tcp + ports: + - 12345:12345/tcp + - 12346:12346/tcp + volumes: + - ./cache/:/root/.cache/ + - ./s2s_pipeline.py:/usr/src/app/s2s_pipeline.py + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] diff --git a/handler.py b/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3de3df84d740b5cbf715ad08dc155113b6a03c --- /dev/null +++ b/handler.py @@ -0,0 +1,106 @@ +from typing import Dict, Any, List +import torch +import os +import logging +from s2s_pipeline import main, prepare_args, parse_arguments, setup_logger, initialize_queues_and_events, build_pipeline +import numpy as np +from queue import Queue +import threading + +class EndpointHandler: + def __init__(self, path=""): + ( + self.module_kwargs, + self.socket_receiver_kwargs, + self.socket_sender_kwargs, + self.vad_handler_kwargs, + self.whisper_stt_handler_kwargs, + self.paraformer_stt_handler_kwargs, + self.language_model_handler_kwargs, + self.mlx_language_model_handler_kwargs, + self.parler_tts_handler_kwargs, + self.melo_tts_handler_kwargs, + self.chat_tts_handler_kwargs, + ) = parse_arguments() + + setup_logger(self.module_kwargs.log_level) + + prepare_args(self.whisper_stt_handler_kwargs, "stt") + prepare_args(self.paraformer_stt_handler_kwargs, "paraformer_stt") + prepare_args(self.language_model_handler_kwargs, "lm") + prepare_args(self.mlx_language_model_handler_kwargs, "mlx_lm") + prepare_args(self.parler_tts_handler_kwargs, "tts") + prepare_args(self.melo_tts_handler_kwargs, "melo") + prepare_args(self.chat_tts_handler_kwargs, "chat_tts") + + self.queues_and_events = initialize_queues_and_events() + + self.pipeline_manager = build_pipeline( + self.module_kwargs, + self.socket_receiver_kwargs, + self.socket_sender_kwargs, + self.vad_handler_kwargs, + self.whisper_stt_handler_kwargs, + self.paraformer_stt_handler_kwargs, + self.language_model_handler_kwargs, + self.mlx_language_model_handler_kwargs, + self.parler_tts_handler_kwargs, + self.melo_tts_handler_kwargs, + self.chat_tts_handler_kwargs, + self.queues_and_events, + ) + + self.pipeline_manager.start() + + # Add a new queue for collecting the final output + self.final_output_queue = Queue() + + # Start a thread to collect the final output + self.output_collector_thread = threading.Thread(target=self._collect_output) + self.output_collector_thread.start() + + def _collect_output(self): + while True: + output = self.queues_and_events['send_audio_chunks_queue'].get() + if output == b"END": + self.final_output_queue.put(b"END") + break + self.final_output_queue.put(output) + + def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]: + """ + Args: + data (Dict[str, Any]): The input data containing the necessary arguments. + + Returns: + Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline. + """ + input_type = data.get("input_type", "text") + input_data = data.get("input", "") + + if input_type == "speech": + # Convert input audio data to numpy array + audio_array = np.frombuffer(input_data, dtype=np.int16) + + # Put audio data into the recv_audio_chunks_queue + self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes()) + elif input_type == "text": + # Put text data directly into the text_prompt_queue + self.queues_and_events['text_prompt_queue'].put(input_data) + else: + raise ValueError(f"Unsupported input type: {input_type}") + + # Stream the output chunks + while True: + chunk = self.final_output_queue.get() + if chunk == b"END": + break + yield {"output": chunk} + + def cleanup(self): + # Stop the pipeline + self.pipeline_manager.stop() + + # Stop the output collector thread + self.queues_and_events['send_audio_chunks_queue'].put(b"END") + self.output_collector_thread.join() \ No newline at end of file diff --git a/listen_and_play.py b/listen_and_play.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7cfc9bcacf95a957bfa11adeea63fcf1c9117e --- /dev/null +++ b/listen_and_play.py @@ -0,0 +1,142 @@ +import socket +import threading +from queue import Queue +from dataclasses import dataclass, field +import sounddevice as sd +import argparse +import numpy as np + +@dataclass +class ListenAndPlayArguments: + send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) + recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) + list_play_chunk_size: int = field( + default=1024, + metadata={"help": "The size of data chunks (in bytes). Default is 1024."}, + ) + host: str = field( + default="localhost", + metadata={ + "help": "The hostname or IP address for listening and playing. Default is 'localhost'." + }, + ) + send_port: int = field( + default=12345, + metadata={"help": "The network port for sending data. Default is 12345."}, + ) + recv_port: int = field( + default=12346, + metadata={"help": "The network port for receiving data. Default is 12346."}, + ) + +def listen_and_play( + send_rate=16000, + recv_rate=44100, + list_play_chunk_size=1024, + host="localhost", + send_port=12345, + recv_port=12346, +): + send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + send_socket.connect((host, send_port)) + + recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + recv_socket.connect((host, recv_port)) + + print("Recording and streaming...") + + stop_event = threading.Event() + recv_queue = Queue() + send_queue = Queue() + + def callback_recv(outdata, frames, time, status): + if not recv_queue.empty(): + data = recv_queue.get() + # Convert bytes to numpy array + audio_array = np.frombuffer(data, dtype=np.int16) + # Reduce volume to 30% + audio_array = (audio_array * 0.3).astype(np.int16) + # Convert back to bytes + reduced_data = audio_array.tobytes() + outdata[: len(reduced_data)] = reduced_data + outdata[len(reduced_data) :] = b"\x00" * (len(outdata) - len(reduced_data)) + else: + outdata[:] = b"\x00" * len(outdata) + + def callback_send(indata, frames, time, status): + if recv_queue.empty(): + data = bytes(indata) + send_queue.put(data) + + def send(stop_event, send_queue): + while not stop_event.is_set(): + data = send_queue.get() + send_socket.sendall(data) + + def recv(stop_event, recv_queue): + def receive_full_chunk(conn, chunk_size): + data = b"" + while len(data) < chunk_size: + packet = conn.recv(chunk_size - len(data)) + if not packet: + return None + data += packet + return data + + while not stop_event.is_set(): + data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) + if data: + recv_queue.put(data) + + try: + send_stream = sd.RawInputStream( + samplerate=send_rate, + channels=1, + dtype="int16", + blocksize=list_play_chunk_size, + callback=callback_send, + ) + + recv_stream = sd.RawOutputStream( + samplerate=recv_rate, + channels=1, + dtype="int16", + blocksize=list_play_chunk_size, + callback=callback_recv, + device=0, + ) + + threading.Thread(target=send_stream.start).start() + threading.Thread(target=recv_stream.start).start() + + send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) + send_thread.start() + recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) + recv_thread.start() + + input("Press Enter to stop...") + + except KeyboardInterrupt: + print("\nProgram interrupted by user. Exiting...") + except Exception as e: + print(f"An error occurred: {e}") + + finally: + stop_event.set() + recv_thread.join() + send_thread.join() + send_socket.close() + recv_socket.close() + print("Connection closed.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Listen and Play Audio") + parser.add_argument("--send_rate", type=int, default=16000, help="In Hz. Default is 16000.") + parser.add_argument("--recv_rate", type=int, default=16000, help="In Hz. Default is 16000.") + parser.add_argument("--list_play_chunk_size", type=int, default=1024, help="The size of data chunks (in bytes). Default is 1024.") + parser.add_argument("--host", type=str, default="localhost", help="The hostname or IP address for listening and playing. Default is 'localhost'.") + parser.add_argument("--send_port", type=int, default=12345, help="The network port for sending data. Default is 12345.") + parser.add_argument("--recv_port", type=int, default=12346, help="The network port for receiving data. Default is 12346.") + + args = parser.parse_args() + listen_and_play(**vars(args)) diff --git a/logo.png b/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..bc79ea95c12bea0c18c85e936c79e4e11cc7f452 Binary files /dev/null and b/logo.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fba30cd7f5e716797d29b3dd5890fd1a610d06a2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +nltk==3.9.1 +parler_tts @ git+https://github.com/huggingface/parler-tts.git +melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers +torch==2.4.0 +sounddevice==0.5.0 +ChatTTS>=0.1.1 +funasr>=1.1.6 +modelscope>=1.17.1 +deepfilternet>=0.5.6 diff --git a/requirements_mac.txt b/requirements_mac.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a1c5cbb4a101ce611a2b81e4d52b73259782a0c --- /dev/null +++ b/requirements_mac.txt @@ -0,0 +1,12 @@ +nltk==3.9.1 +parler_tts @ git+https://github.com/huggingface/parler-tts.git +melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers +torch==2.4.0 +sounddevice==0.5.0 +lightning-whisper-mlx>=0.0.10 +mlx-lm>=0.14.0 +ChatTTS>=0.1.1 +funasr>=1.1.6 +modelscope>=1.17.1 +deepfilternet>=0.5.6 + diff --git a/s2s_pipeline.py b/s2s_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0cac1d4658bf0b44505f44664f63b5e275e4ff --- /dev/null +++ b/s2s_pipeline.py @@ -0,0 +1,410 @@ +import logging +import os +import sys +from copy import copy +from pathlib import Path +from queue import Queue +from threading import Event +from typing import Optional +from sys import platform +from VAD.vad_handler import VADHandler +from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments +from arguments_classes.language_model_arguments import LanguageModelHandlerArguments +from arguments_classes.mlx_language_model_arguments import ( + MLXLanguageModelHandlerArguments, +) +from arguments_classes.module_arguments import ModuleArguments +from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments +from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments +from arguments_classes.socket_receiver_arguments import SocketReceiverArguments +from arguments_classes.socket_sender_arguments import SocketSenderArguments +from arguments_classes.vad_arguments import VADHandlerArguments +from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments +from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments +import torch +import nltk +from rich.console import Console +from transformers import ( + HfArgumentParser, +) + +from utils.thread_manager import ThreadManager + +# Ensure that the necessary NLTK resources are available +try: + nltk.data.find("tokenizers/punkt_tab") +except (LookupError, OSError): + nltk.download("punkt_tab") +try: + nltk.data.find("tokenizers/averaged_perceptron_tagger_eng") +except (LookupError, OSError): + nltk.download("averaged_perceptron_tagger_eng") + +# caching allows ~50% compilation time reduction +# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma +CURRENT_DIR = Path(__file__).resolve().parent +os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") + +console = Console() +logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs + + +def rename_args(args, prefix): + """ + Rename arguments by removing the prefix and prepares the gen_kwargs. + """ + gen_kwargs = {} + for key in copy(args.__dict__): + if key.startswith(prefix): + value = args.__dict__.pop(key) + new_key = key[len(prefix) + 1 :] # Remove prefix and underscore + if new_key.startswith("gen_"): + gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict + else: + args.__dict__[new_key] = value + + args.__dict__["gen_kwargs"] = gen_kwargs + + +def parse_arguments(): + parser = HfArgumentParser( + ( + ModuleArguments, + SocketReceiverArguments, + SocketSenderArguments, + VADHandlerArguments, + WhisperSTTHandlerArguments, + ParaformerSTTHandlerArguments, + LanguageModelHandlerArguments, + MLXLanguageModelHandlerArguments, + ParlerTTSHandlerArguments, + MeloTTSHandlerArguments, + ChatTTSHandlerArguments, + ) + ) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # Parse configurations from a JSON file if specified + return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + # Parse arguments from command line if no JSON file is provided + return parser.parse_args_into_dataclasses() + + +def setup_logger(log_level): + global logger + logging.basicConfig( + level=log_level.upper(), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + # torch compile logs + if log_level == "debug": + torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + + +def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): + if mac_optimal_settings: + for kwargs in handler_kwargs: + if hasattr(kwargs, "device"): + kwargs.device = "mps" + if hasattr(kwargs, "mode"): + kwargs.mode = "local" + if hasattr(kwargs, "stt"): + kwargs.stt = "whisper-mlx" + if hasattr(kwargs, "llm"): + kwargs.llm = "mlx-lm" + if hasattr(kwargs, "tts"): + kwargs.tts = "melo" + + +def check_mac_settings(module_kwargs): + if platform == "darwin": + if module_kwargs.device == "cuda": + raise ValueError( + "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'." + ) + if module_kwargs.llm != "mlx-lm": + logger.warning( + "For macOS users, it is recommended to use mlx-lm. You can activate it by passing --llm mlx-lm." + ) + if module_kwargs.tts != "melo": + logger.warning( + "If you experiences issues generating the voice, considering setting the tts to melo." + ) + + +def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): + if common_device: + for kwargs in handler_kwargs: + if hasattr(kwargs, "lm_device"): + kwargs.lm_device = common_device + if hasattr(kwargs, "tts_device"): + kwargs.tts_device = common_device + if hasattr(kwargs, "stt_device"): + kwargs.stt_device = common_device + if hasattr(kwargs, "paraformer_stt_device"): + kwargs.paraformer_stt_device = common_device + + +def prepare_module_args(module_kwargs, *handler_kwargs): + optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs) + if platform == "darwin": + check_mac_settings(module_kwargs) + overwrite_device_argument(module_kwargs.device, *handler_kwargs) + + +def prepare_all_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, +): + prepare_module_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) + + rename_args(whisper_stt_handler_kwargs, "stt") + rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") + rename_args(language_model_handler_kwargs, "lm") + rename_args(mlx_language_model_handler_kwargs, "mlx_lm") + rename_args(parler_tts_handler_kwargs, "tts") + rename_args(melo_tts_handler_kwargs, "melo") + rename_args(chat_tts_handler_kwargs, "chat_tts") + + +def initialize_queues_and_events(): + return { + "stop_event": Event(), + "should_listen": Event(), + "recv_audio_chunks_queue": Queue(), + "send_audio_chunks_queue": Queue(), + "spoken_prompt_queue": Queue(), + "text_prompt_queue": Queue(), + "lm_response_queue": Queue(), + } + + +def build_pipeline( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + queues_and_events, +): + stop_event = queues_and_events["stop_event"] + should_listen = queues_and_events["should_listen"] + recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"] + send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"] + spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] + text_prompt_queue = queues_and_events["text_prompt_queue"] + lm_response_queue = queues_and_events["lm_response_queue"] + + if module_kwargs.mode == "local": + from connections.local_audio_streamer import LocalAudioStreamer + + local_audio_streamer = LocalAudioStreamer( + input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue + ) + comms_handlers = [local_audio_streamer] + should_listen.set() + else: + from connections.socket_receiver import SocketReceiver + from connections.socket_sender import SocketSender + + comms_handlers = [ + SocketReceiver( + stop_event, + recv_audio_chunks_queue, + should_listen, + host=socket_receiver_kwargs.recv_host, + port=socket_receiver_kwargs.recv_port, + chunk_size=socket_receiver_kwargs.chunk_size, + ), + SocketSender( + stop_event, + send_audio_chunks_queue, + host=socket_sender_kwargs.send_host, + port=socket_sender_kwargs.send_port, + ), + ] + + vad = VADHandler( + stop_event, + queue_in=recv_audio_chunks_queue, + queue_out=spoken_prompt_queue, + setup_args=(should_listen,), + setup_kwargs=vars(vad_handler_kwargs), + ) + + stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) + lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) + tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + + return ThreadManager([*comms_handlers, vad, stt, lm, tts]) + + +def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): + if module_kwargs.stt == "whisper": + from STT.whisper_stt_handler import WhisperSTTHandler + return WhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + elif module_kwargs.stt == "whisper-mlx": + from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler + return LightningWhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + elif module_kwargs.stt == "paraformer": + from STT.paraformer_handler import ParaformerSTTHandler + return ParaformerSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(paraformer_stt_handler_kwargs), + ) + else: + raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") + + +def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): + if module_kwargs.llm == "transformers": + from LLM.language_model import LanguageModelHandler + return LanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(language_model_handler_kwargs), + ) + elif module_kwargs.llm == "mlx-lm": + from LLM.mlx_language_model import MLXLanguageModelHandler + return MLXLanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(mlx_language_model_handler_kwargs), + ) + else: + raise ValueError("The LLM should be either transformers or mlx-lm") + + +def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): + if module_kwargs.tts == "parler": + from TTS.parler_handler import ParlerTTSHandler + return ParlerTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(parler_tts_handler_kwargs), + ) + elif module_kwargs.tts == "melo": + try: + from TTS.melo_handler import MeloTTSHandler + except RuntimeError as e: + logger.error( + "Error importing MeloTTSHandler. You might need to run: python -m unidic download" + ) + raise e + return MeloTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(melo_tts_handler_kwargs), + ) + elif module_kwargs.tts == "chatTTS": + try: + from TTS.chatTTS_handler import ChatTTSHandler + except RuntimeError as e: + logger.error("Error importing ChatTTSHandler") + raise e + return ChatTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(chat_tts_handler_kwargs), + ) + else: + raise ValueError("The TTS should be either parler, melo or chatTTS") + + +def main(): + ( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) = parse_arguments() + + setup_logger(module_kwargs.log_level) + + prepare_all_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) + + queues_and_events = initialize_queues_and_events() + + pipeline_manager = build_pipeline( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + queues_and_events, + ) + + try: + pipeline_manager.start() + except KeyboardInterrupt: + pipeline_manager.stop() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/__pycache__/thread_manager.cpython-311.pyc b/utils/__pycache__/thread_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5357d4f91e459b1bfbba072df9e7c03bcda33f31 Binary files /dev/null and b/utils/__pycache__/thread_manager.cpython-311.pyc differ diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a573ff94a4bbb70cb4f44227fa3d26d177d6ff00 Binary files /dev/null and b/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/utils/__pycache__/utils.cpython-312.pyc b/utils/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2311ecd7402b7040c0f639347e3573b46510a0d0 Binary files /dev/null and b/utils/__pycache__/utils.cpython-312.pyc differ diff --git a/utils/thread_manager.py b/utils/thread_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1ca4a074833cb3953ddfa63e71822e26b23143 --- /dev/null +++ b/utils/thread_manager.py @@ -0,0 +1,23 @@ +import threading + + +class ThreadManager: + """ + Manages multiple threads used to execute given handler tasks. + """ + + def __init__(self, handlers): + self.handlers = handlers + self.threads = [] + + def start(self): + for handler in self.handlers: + thread = threading.Thread(target=handler.run) + self.threads.append(thread) + thread.start() + + def stop(self): + for handler in self.handlers: + handler.stop_event.set() + for thread in self.threads: + thread.join() diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac399486a1502fddbec823d2dccbb23f0558be14 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,18 @@ +import numpy as np + + +def next_power_of_2(x): + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def int2float(sound): + """ + Taken from https://github.com/snakers4/silero-vad + """ + + abs_max = np.abs(sound).max() + sound = sound.astype("float32") + if abs_max > 0: + sound *= 1 / 32768 + sound = sound.squeeze() # depends on the use case + return sound