leezhuuu commited on
Commit
e663533
1 Parent(s): 22cf3e2

Upload 7 files

Browse files
Files changed (7) hide show
  1. LICENSE +201 -0
  2. README.md +105 -10
  3. README_en.md +97 -0
  4. flow_inference.py +142 -0
  5. model_server.py +116 -0
  6. requirements.txt +36 -0
  7. web_demo.py +257 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 GLM-4-Voice Model Team @ Zhipu AI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,10 +1,105 @@
1
- ---
2
- title: GLM 4 Voice
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLM-4-Voice
2
+ Read this in [English](./README_en.md)
3
+
4
+ GLM-4-Voice 是智谱 AI 推出的端到端语音模型。GLM-4-Voice 能够直接理解和生成中英文语音,进行实时语音对话,并且能够遵循用户的指令要求改变语音的情感、语调、语速、方言等属性。
5
+
6
+ ## Model Architecture
7
+ ![Model Architecture](./resources/architecture.jpeg)
8
+
9
+ GLM-4-Voice 由三个部分组成:
10
+ * GLM-4-Voice-Tokenizer: 通过在 [Whisper](https://github.com/openai/whisper) 的 Encoder 部分增加 Vector Quantization 并在 ASR 数据上有监督训练,将连续的语音输入转化为离散的 token。每秒音频平均只需要用 12.5 个离散 token 表示。
11
+ * GLM-4-Voice-Decoder: 基于 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 的 Flow Matching 模型结构训练的支持流式推理的语音解码器,将离散化的语音 token 转化为连续的语音输出。最少只需要 10 个语音 token 即可开始生成,降低端到端对话延迟。
12
+ * GLM-4-Voice-9B: 在 [GLM-4-9B](https://github.com/THUDM/GLM-4) 的基础上进行语音模态的预训练和对齐,从而能够理解和生成离散化的语音 token。
13
+
14
+ 预训练方面,为了攻克模型在语音模态下的智商和合成表现力两个难关,我们将 Speech2Speech 任务解耦合为“根据用户音频做出文本回复”和“根据文本回复和用户语音合成回复语音”两个任务,并设计两种预训练目标,分别基于文本预训练数据和无监督音频数据合成语音-文本交错数据以适配这两种任务形式。GLM-4-Voice-9B 在 GLM-4-9B 的基座模型基础之上,经过了数百万小时音频和数千亿 token 的音频文本交错数据预训练,拥有很强的音频理解和建模能力。
15
+
16
+ 对齐方面,为了支持高质量的语音对话,我们设计了一套流式思考架构:根据用户语音,GLM-4-Voice 可以流式交替输出文本和语音两个模态的内容,其中语音模态以文本作为参照保证回复内容的高质量,并根据用户的语音指令要求做出相应的声音变化,在最大程度保留语言模型智商的情况下仍然具有端到端建模的能力,同时具备低延迟性,最低只需要输出 20 个 token 便可以合成语音。
17
+
18
+ 更详细的技术报告将在之后公布。
19
+
20
+ ## Model List
21
+
22
+ | Model | Type | Download |
23
+ |:---------------------:| :---: |:------------------------------------------------------------------------------------------------------------------------------------------------:|
24
+ | GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-tokenizer) |
25
+ | GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-9b)
26
+ | GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-decoder)
27
+
28
+ ## Usage
29
+ 我们提供了可以直接启动的 Web Demo。用户可以输入语音或文本,模型会同时给出语音和文字回复。
30
+
31
+ ![](resources/web_demo.png)
32
+
33
+ ### Preparation
34
+ 首先下载仓库
35
+ ```shell
36
+ git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
37
+ cd GLM-4-Voice
38
+ ```
39
+ 然后安装依赖。
40
+ ```shell
41
+ pip install -r requirements.txt
42
+ ```
43
+ 由于 Decoder 模型不支持通过 `transformers` 初始化,因此 checkpoint 需要单独下载。
44
+
45
+ ```shell
46
+ # git 模型下载,请确保已安装 git-lfs
47
+ git clone https://huggingface.co/THUDM/glm-4-voice-decoder
48
+ ```
49
+
50
+ ### Launch Web Demo
51
+ 首先启动模型服务
52
+ ```shell
53
+ python model_server.py --model-path glm-4-voice-9b
54
+ ```
55
+ 此命令会自动下载 `glm-4-voice-9b`。如果网络条件不好,也手动下载之后通过 `--model-path` 指定本地的路径。
56
+
57
+ 然后启动 web 服务
58
+ ```shell
59
+ python web_demo.py
60
+ ```
61
+ 即可在 http://127.0.0.1:8888 访问 web demo。此命令会自动下载 `glm-4-voice-tokenizer` 和 `glm-4-voice-9b`。如果网络条件不好,也可以手动下载之后通过 `--tokenizer-path` 和 `--model-path` 指定本地的路径。
62
+
63
+ ### Known Issues
64
+ * Gradio 的流式音频播放效果不稳定。在生成完成后点击对话框中的音频质量会更高。
65
+
66
+ ## Cases
67
+ 我们提供了 GLM-4-Voice 的部分对话案例,包括控制情绪、改变语速、生成方言等。
68
+
69
+ * 用轻���的声音引导我放松
70
+
71
+ https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
72
+
73
+ * 用激动的声音解说足球比赛
74
+
75
+ https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
76
+
77
+ * 用哀怨的声音讲一个鬼故事
78
+
79
+ https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
80
+
81
+ * 用东北话介绍一下冬天有多冷
82
+
83
+ https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
84
+
85
+ * 用重庆话念“吃葡萄不吐葡萄皮”
86
+
87
+ https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
88
+
89
+ * 用北京话念一句绕口令
90
+
91
+ https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
92
+
93
+ * 加快语速
94
+
95
+ https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
96
+
97
+ * 再快一点
98
+
99
+ https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
100
+
101
+ ## Acknowledge
102
+ 本项目的部分代码来自:
103
+ * [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
104
+ * [transformers](https://github.com/huggingface/transformers)
105
+ * [GLM-4](https://github.com/THUDM/GLM-4)
README_en.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLM-4-Voice
2
+ GLM-4-Voice is an end-to-end voice model launched by Zhipu AI. GLM-4-Voice can directly understand and generate Chinese and English speech, engage in real-time voice conversations, and change attributes such as emotion, intonation, speech rate, and dialect based on user instructions.
3
+
4
+ ## Model Architecture
5
+
6
+ ![Model Architecture](./resources/architecture.jpeg)
7
+ We provide the three components of GLM-4-Voice:
8
+ * GLM-4-Voice-Tokenizer: Trained by adding vector quantization to the encoder part of [Whisper](https://github.com/openai/whisper), converting continuous speech input into discrete tokens. Each second of audio is converted into 12.5 discrete tokens.
9
+ * GLM-4-Voice-9B: Pre-trained and aligned on speech modality based on [GLM-4-9B](https://github.com/THUDM/GLM-4), enabling understanding and generation of discretized speech.
10
+ * GLM-4-Voice-Decoder: A speech decoder supporting streaming inference, retrained based on [CosyVoice](https://github.com/FunAudioLLM/CosyVoice), converting discrete speech tokens into continuous speech output. Generation can start with as few as 10 audio tokens, reducing conversation latency.
11
+
12
+ A more detailed technical report will be published later.
13
+
14
+ ## Model List
15
+ | Model | Type | Download |
16
+ |:---------------------:| :---: |:------------------:|
17
+ | GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) |
18
+ | GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b)
19
+ | GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder)
20
+
21
+ ## Usage
22
+ We provide a Web Demo that can be launched directly. Users can input speech or text, and the model will respond with both speech and text.
23
+
24
+ ![](resources/web_demo.png)
25
+
26
+ ### Preparation
27
+ First, download the repository
28
+ ```shell
29
+ git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
30
+ cd GLM-4-Voice
31
+ ```
32
+ Then, install the dependencies.
33
+ ```shell
34
+ pip install -r requirements.txt
35
+ ```
36
+ Since the Decoder model does not support initialization via `transformers`, the checkpoint needs to be downloaded separately.
37
+
38
+ ```shell
39
+ # Git model download, please ensure git-lfs is installed
40
+ git clone https://huggingface.co/THUDM/glm-4-voice-decoder
41
+ ```
42
+
43
+ ### Launch Web Demo
44
+ First, start the model service
45
+ ```shell
46
+ python model_server.py --model-path glm-4-voice-9b
47
+ ```
48
+
49
+ Then, start the web service
50
+ ```shell
51
+ python web_demo.py
52
+ ```
53
+ You can then access the web demo at http://127.0.0.1:8888.
54
+
55
+ ### Known Issues
56
+ * Gradio’s streaming audio playback can be unstable. The audio quality will be higher when clicking on the audio in the dialogue box after generation is complete.
57
+
58
+ ## Examples
59
+ We provide some dialogue cases for GLM-4-Voice, including emotion control, speech rate alteration, dialect generation, etc. (The examples are in Chinese.)
60
+
61
+ * Use a gentle voice to guide me to relax
62
+
63
+ https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
64
+
65
+ * Use an excited voice to commentate a football match
66
+
67
+ https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
68
+
69
+ * Tell a ghost story with a mournful voice
70
+
71
+ https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
72
+
73
+ * Introduce how cold winter is with a Northeastern dialect
74
+
75
+ https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
76
+
77
+ * Say "Eat grapes without spitting out the skins" in Chongqing dialect
78
+
79
+ https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
80
+
81
+ * Recite a tongue twister with a Beijing accent
82
+
83
+ https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
84
+
85
+ * Increase the speech rate
86
+
87
+ https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
88
+
89
+ * Even faster
90
+
91
+ https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
92
+
93
+ ## Acknowledge
94
+ Some code in this project is from:
95
+ * [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
96
+ * [transformers](https://github.com/huggingface/transformers)
97
+ * [GLM-4](https://github.com/THUDM/GLM-4)
flow_inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ import re
5
+ from hyperpyyaml import load_hyperpyyaml
6
+ import uuid
7
+ from collections import defaultdict
8
+
9
+
10
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
11
+ device = fade_in_mel.device
12
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
13
+ mel_overlap_len = int(window.shape[0] / 2)
14
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
15
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
16
+ return fade_in_mel.to(device)
17
+
18
+
19
+ class AudioDecoder:
20
+ def __init__(self, config_path, flow_ckpt_path, hift_ckpt_path, device="cuda"):
21
+ self.device = device
22
+
23
+ with open(config_path, 'r') as f:
24
+ self.scratch_configs = load_hyperpyyaml(f)
25
+
26
+ # Load models
27
+ self.flow = self.scratch_configs['flow']
28
+ self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device))
29
+ self.hift = self.scratch_configs['hift']
30
+ self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
31
+
32
+ # Move models to the appropriate device
33
+ self.flow.to(self.device)
34
+ self.hift.to(self.device)
35
+ self.mel_overlap_dict = defaultdict(lambda: None)
36
+ self.hift_cache_dict = defaultdict(lambda: None)
37
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
38
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
39
+ self.token_overlap_len = 5
40
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
41
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
42
+ # hift cache
43
+ self.mel_cache_len = 1
44
+ self.source_cache_len = int(self.mel_cache_len * 256)
45
+ # speech fade in out
46
+ self.speech_window = np.hamming(2 * self.source_cache_len)
47
+
48
+ def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int32),
49
+ prompt_feat=torch.zeros(1, 0, 80), embedding=torch.zeros(1, 192), finalize=False):
50
+ tts_mel = self.flow.inference(token=token.to(self.device),
51
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
52
+ prompt_token=prompt_token.to(self.device),
53
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(
54
+ self.device),
55
+ prompt_feat=prompt_feat.to(self.device),
56
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(
57
+ self.device),
58
+ embedding=embedding.to(self.device))
59
+
60
+ # mel overlap fade in out
61
+ if self.mel_overlap_dict[uuid] is not None:
62
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
63
+ # append hift cache
64
+ if self.hift_cache_dict[uuid] is not None:
65
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
66
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
67
+
68
+ else:
69
+ hift_cache_source = torch.zeros(1, 1, 0)
70
+ # _tts_mel=tts_mel.contiguous()
71
+ # keep overlap mel and hift cache
72
+ if finalize is False:
73
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
74
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
75
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
76
+
77
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
78
+ 'source': tts_source[:, :, -self.source_cache_len:],
79
+ 'speech': tts_speech[:, -self.source_cache_len:]}
80
+ # if self.hift_cache_dict[uuid] is not None:
81
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
82
+ tts_speech = tts_speech[:, :-self.source_cache_len]
83
+
84
+ else:
85
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
86
+ del self.hift_cache_dict[uuid]
87
+ del self.mel_overlap_dict[uuid]
88
+ # if uuid in self.hift_cache_dict.keys() and self.hift_cache_dict[uuid] is not None:
89
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
90
+ return tts_speech, tts_mel
91
+
92
+ def offline_inference(self, token):
93
+ this_uuid = str(uuid.uuid1())
94
+ tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
95
+ return tts_speech.cpu()
96
+
97
+ def stream_inference(self, token):
98
+ token.to(self.device)
99
+ this_uuid = str(uuid.uuid1())
100
+
101
+ # Prepare other necessary input tensors
102
+ llm_embedding = torch.zeros(1, 192).to(self.device)
103
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
104
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
105
+
106
+ tts_speechs = []
107
+ tts_mels = []
108
+
109
+ block_size = self.flow.encoder.block_size
110
+ prev_mel = None
111
+
112
+ for idx in range(0, token.size(1), block_size):
113
+ # if idx>block_size: break
114
+ tts_token = token[:, idx:idx + block_size]
115
+
116
+ print(tts_token.size())
117
+
118
+ if prev_mel is not None:
119
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
120
+ flow_prompt_speech_token = token[:, :idx]
121
+
122
+ if idx + block_size >= token.size(-1):
123
+ is_finalize = True
124
+ else:
125
+ is_finalize = False
126
+
127
+ tts_speech, tts_mel = self.token2wav(tts_token, uuid=this_uuid,
128
+ prompt_token=flow_prompt_speech_token.to(self.device),
129
+ prompt_feat=prompt_speech_feat.to(self.device), finalize=is_finalize)
130
+
131
+ prev_mel = tts_mel
132
+ prev_speech = tts_speech
133
+ print(tts_mel.size())
134
+
135
+ tts_speechs.append(tts_speech)
136
+ tts_mels.append(tts_mel)
137
+
138
+ # Convert Mel spectrogram to audio using HiFi-GAN
139
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
140
+
141
+ return tts_speech.cpu()
142
+
model_server.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import torch
12
+ import uvicorn
13
+
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from threading import Thread
16
+ from queue import Queue
17
+
18
+
19
+ class TokenStreamer(BaseStreamer):
20
+ def __init__(self, skip_prompt: bool = False, timeout=None):
21
+ self.skip_prompt = skip_prompt
22
+
23
+ # variables used in the streaming process
24
+ self.token_queue = Queue()
25
+ self.stop_signal = None
26
+ self.next_tokens_are_prompt = True
27
+ self.timeout = timeout
28
+
29
+ def put(self, value):
30
+ if len(value.shape) > 1 and value.shape[0] > 1:
31
+ raise ValueError("TextStreamer only supports batch size 1")
32
+ elif len(value.shape) > 1:
33
+ value = value[0]
34
+
35
+ if self.skip_prompt and self.next_tokens_are_prompt:
36
+ self.next_tokens_are_prompt = False
37
+ return
38
+
39
+ for token in value.tolist():
40
+ self.token_queue.put(token)
41
+
42
+ def end(self):
43
+ self.token_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ value = self.token_queue.get(timeout=self.timeout)
50
+ if value == self.stop_signal:
51
+ raise StopIteration()
52
+ else:
53
+ return value
54
+
55
+
56
+ class ModelWorker:
57
+ def __init__(self, model_path, device='cuda'):
58
+ self.device = device
59
+ self.glm_model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
60
+ device=device).to(device).eval()
61
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
62
+
63
+ @torch.inference_mode()
64
+ def generate_stream(self, params):
65
+ tokenizer, model = self.glm_tokenizer, self.glm_model
66
+
67
+ prompt = params["prompt"]
68
+
69
+ temperature = float(params.get("temperature", 1.0))
70
+ top_p = float(params.get("top_p", 1.0))
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+
73
+ inputs = tokenizer([prompt], return_tensors="pt")
74
+ inputs = inputs.to(self.device)
75
+ streamer = TokenStreamer(skip_prompt=True)
76
+ thread = Thread(target=model.generate,
77
+ kwargs=dict(**inputs, max_new_tokens=int(max_new_tokens),
78
+ temperature=float(temperature), top_p=float(top_p),
79
+ streamer=streamer))
80
+ thread.start()
81
+ for token_id in streamer:
82
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
83
+
84
+ def generate_stream_gate(self, params):
85
+ try:
86
+ for x in self.generate_stream(params):
87
+ yield x
88
+ except Exception as e:
89
+ print("Caught Unknown Error", e)
90
+ ret = {
91
+ "text": "Server Error",
92
+ "error_code": 1,
93
+ }
94
+ yield (json.dumps(ret)+ "\n").encode()
95
+
96
+
97
+ app = FastAPI()
98
+
99
+
100
+ @app.post("/generate_stream")
101
+ async def generate_stream(request: Request):
102
+ params = await request.json()
103
+
104
+ generator = worker.generate_stream_gate(params)
105
+ return StreamingResponse(generator)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--host", type=str, default="localhost")
111
+ parser.add_argument("--port", type=int, default=10000)
112
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
113
+ args = parser.parse_args()
114
+
115
+ worker = ModelWorker(args.model_path)
116
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conformer==0.3.2
2
+ deepspeed==0.14.2; sys_platform == 'linux'
3
+ diffusers==0.27.2
4
+ fastapi==0.115.3
5
+ fastapi-cli==0.0.4
6
+ gdown==5.1.0
7
+ gradio==5.3.0
8
+ grpcio==1.57.0
9
+ grpcio-tools==1.57.0
10
+ huggingface_hub==0.25.2
11
+ hydra-core==1.3.2
12
+ HyperPyYAML==1.2.2
13
+ inflect==7.3.1
14
+ librosa==0.10.2
15
+ lightning==2.2.4
16
+ matplotlib==3.7.5
17
+ modelscope==1.15.0
18
+ networkx==3.1
19
+ numpy==1.24.4
20
+ omegaconf==2.3.0
21
+ onnxruntime-gpu==1.16.0; sys_platform == 'linux'
22
+ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
23
+ openai-whisper==20231117
24
+ protobuf==4.25
25
+ pydantic==2.7.0
26
+ rich==13.7.1
27
+ Requests==2.32.3
28
+ safetensors==0.4.5
29
+ soundfile==0.12.1
30
+ tensorboard==2.14.0
31
+ transformers==4.44.1
32
+ uvicorn==0.32.0
33
+ wget==3.2
34
+ WeTextProcessing==1.0.3
35
+ torch==2.3.0
36
+ torchaudio==2.3.0
web_demo.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import tempfile
4
+ import sys
5
+ import re
6
+ import uuid
7
+ import requests
8
+ from argparse import ArgumentParser
9
+
10
+ import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoModel
12
+ from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
+
14
+
15
+ sys.path.insert(0, "./cosyvoice")
16
+ sys.path.insert(0, "./third_party/Matcha-TTS")
17
+
18
+ from speech_tokenizer.utils import extract_speech_token
19
+
20
+ import gradio as gr
21
+ import torch
22
+
23
+ audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
24
+
25
+ from flow_inference import AudioDecoder
26
+
27
+ if __name__ == "__main__":
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--host", type=str, default="0.0.0.0")
30
+ parser.add_argument("--port", type=int, default="8888")
31
+ parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
32
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
33
+ parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer")
34
+ args = parser.parse_args()
35
+
36
+ flow_config = os.path.join(args.flow_path, "config.yaml")
37
+ flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
38
+ hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
39
+ glm_tokenizer = None
40
+ device = "cuda"
41
+ audio_decoder: AudioDecoder = None
42
+ whisper_model, feature_extractor = None, None
43
+
44
+
45
+ def initialize_fn():
46
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
47
+ if audio_decoder is not None:
48
+ return
49
+
50
+ # GLM
51
+ glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
52
+
53
+ # Flow & Hift
54
+ audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint,
55
+ hift_ckpt_path=hift_checkpoint,
56
+ device=device)
57
+
58
+ # Speech tokenizer
59
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
60
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
61
+
62
+
63
+ def clear_fn():
64
+ return [], [], '', '', '', None
65
+
66
+
67
+ def inference_fn(
68
+ temperature: float,
69
+ top_p: float,
70
+ max_new_token: int,
71
+ input_mode,
72
+ audio_path: str | None,
73
+ input_text: str | None,
74
+ history: list[dict],
75
+ previous_input_tokens: str,
76
+ previous_completion_tokens: str,
77
+ ):
78
+
79
+ if input_mode == "audio":
80
+ assert audio_path is not None
81
+ history.append({"role": "user", "content": {"path": audio_path}})
82
+ audio_tokens = extract_speech_token(
83
+ whisper_model, feature_extractor, [audio_path]
84
+ )[0]
85
+ if len(audio_tokens) == 0:
86
+ raise gr.Error("No audio tokens extracted")
87
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
88
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
89
+ user_input = audio_tokens
90
+ system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. "
91
+
92
+ else:
93
+ assert input_text is not None
94
+ history.append({"role": "user", "content": input_text})
95
+ user_input = input_text
96
+ system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."
97
+
98
+
99
+ # Gather history
100
+ inputs = previous_input_tokens + previous_completion_tokens
101
+ inputs = inputs.strip()
102
+ if "<|system|>" not in inputs:
103
+ inputs += f"<|system|>\n{system_prompt}"
104
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
105
+
106
+ with torch.no_grad():
107
+ response = requests.post(
108
+ "http://localhost:10000/generate_stream",
109
+ data=json.dumps({
110
+ "prompt": inputs,
111
+ "temperature": temperature,
112
+ "top_p": top_p,
113
+ "max_new_tokens": max_new_token,
114
+ }),
115
+ stream=True
116
+ )
117
+ text_tokens, audio_tokens = [], []
118
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
119
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
120
+ complete_tokens = []
121
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
122
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
123
+ this_uuid = str(uuid.uuid4())
124
+ tts_speechs = []
125
+ tts_mels = []
126
+ prev_mel = None
127
+ is_finalize = False
128
+ block_size = 10
129
+ for chunk in response.iter_lines():
130
+ token_id = json.loads(chunk)["token_id"]
131
+ if token_id == end_token_id:
132
+ is_finalize = True
133
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
134
+ block_size = 20
135
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
136
+
137
+ if prev_mel is not None:
138
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
139
+
140
+ tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid,
141
+ prompt_token=flow_prompt_speech_token.to(device),
142
+ prompt_feat=prompt_speech_feat.to(device),
143
+ finalize=is_finalize)
144
+ prev_mel = tts_mel
145
+
146
+ tts_speechs.append(tts_speech.squeeze())
147
+ tts_mels.append(tts_mel)
148
+ yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy())
149
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
150
+ audio_tokens = []
151
+ if not is_finalize:
152
+ complete_tokens.append(token_id)
153
+ if token_id >= audio_offset:
154
+ audio_tokens.append(token_id - audio_offset)
155
+ else:
156
+ text_tokens.append(token_id)
157
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
158
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
159
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
160
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
161
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
162
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
163
+ yield history, inputs, complete_text, '', None
164
+
165
+
166
+ def update_input_interface(input_mode):
167
+ if input_mode == "audio":
168
+ return [gr.update(visible=True), gr.update(visible=False)]
169
+ else:
170
+ return [gr.update(visible=False), gr.update(visible=True)]
171
+
172
+
173
+ # Create the Gradio interface
174
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
175
+ with gr.Row():
176
+ temperature = gr.Number(
177
+ label="Temperature",
178
+ value=0.2
179
+ )
180
+
181
+ top_p = gr.Number(
182
+ label="Top p",
183
+ value=0.8
184
+ )
185
+
186
+ max_new_token = gr.Number(
187
+ label="Max new tokens",
188
+ value=2000,
189
+ )
190
+
191
+ chatbot = gr.Chatbot(
192
+ elem_id="chatbot",
193
+ bubble_full_width=False,
194
+ type="messages",
195
+ scale=1,
196
+ )
197
+
198
+ with gr.Row():
199
+ with gr.Column():
200
+ input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
201
+ audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
202
+ text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
203
+
204
+ with gr.Column():
205
+ submit_btn = gr.Button("Submit")
206
+ reset_btn = gr.Button("Clear")
207
+ output_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True, streaming=True,
208
+ autoplay=True)
209
+
210
+
211
+
212
+ gr.Markdown("""## Debug Info""")
213
+ with gr.Row():
214
+ input_tokens = gr.Textbox(
215
+ label=f"Input Tokens",
216
+ interactive=False,
217
+ )
218
+
219
+ completion_tokens = gr.Textbox(
220
+ label=f"Completion Tokens",
221
+ interactive=False,
222
+ )
223
+
224
+ detailed_error = gr.Textbox(
225
+ label=f"Detailed Error",
226
+ interactive=False,
227
+ )
228
+
229
+ history_state = gr.State([])
230
+
231
+ respond = submit_btn.click(
232
+ inference_fn,
233
+ inputs=[
234
+ temperature,
235
+ top_p,
236
+ max_new_token,
237
+ input_mode,
238
+ audio,
239
+ text_input,
240
+ history_state,
241
+ input_tokens,
242
+ completion_tokens,
243
+ ],
244
+ outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio]
245
+ )
246
+
247
+ respond.then(lambda s: s, [history_state], chatbot)
248
+
249
+ reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio])
250
+ input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
251
+
252
+ initialize_fn()
253
+ # Launch the interface
254
+ demo.launch(
255
+ server_port=args.port,
256
+ server_name=args.host
257
+ )