This view is limited to 50 files because it contains too many changes.  See the raw diff here.
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/bug_report.yml DELETED
@@ -1,50 +0,0 @@
1
- name: "Bug Report"
2
- description: |
3
- Please provide as much details to help address the issue, including logs and screenshots.
4
- labels:
5
- - bug
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To ensure timely help, please confirm the following:"
11
- options:
12
- - label: This template is only for bug reports, usage problems go with 'Help Wanted'.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Environment Details
23
- description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24
- placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
25
- validations:
26
- required: true
27
- - type: textarea
28
- attributes:
29
- label: Steps to Reproduce
30
- description: |
31
- Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32
- placeholder: |
33
- 1. Create a new conda environment.
34
- 2. Clone the repository, install as local editable and properly set up.
35
- 3. Run the command: `accelerate launch src/f5_tts/train/train.py`.
36
- 4. Have following error message... (attach logs).
37
- validations:
38
- required: true
39
- - type: textarea
40
- attributes:
41
- label: ✔️ Expected Behavior
42
- placeholder: Describe what you expected to happen.
43
- validations:
44
- required: false
45
- - type: textarea
46
- attributes:
47
- label: ❌ Actual Behavior
48
- placeholder: Describe what actually happened.
49
- validations:
50
- required: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/config.yml DELETED
@@ -1 +0,0 @@
1
- blank_issues_enabled: false
 
 
.github/ISSUE_TEMPLATE/feature_request.yml DELETED
@@ -1,62 +0,0 @@
1
- name: "Feature Request"
2
- description: |
3
- Some constructive suggestions and new ideas regarding current repo.
4
- labels:
5
- - enhancement
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To help us grasp quickly, please confirm the following:"
11
- options:
12
- - label: This template is only for feature request.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and found not discussion yet.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
23
- description: |
24
- Describe the specific problem or scenario you're facing in detail. For example:
25
- *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
26
- placeholder: Please describe the situation in as much detail as possible.
27
- validations:
28
- required: true
29
-
30
- - type: textarea
31
- attributes:
32
- label: 2. What is your suggested solution?
33
- description: |
34
- Provide a clear description of the feature or enhancement you'd like to propose.
35
- How would this feature solve your issue or improve the project?
36
- placeholder: Describe your idea or proposed solution here.
37
- validations:
38
- required: true
39
-
40
- - type: textarea
41
- attributes:
42
- label: 3. Additional context or comments
43
- description: |
44
- Any other relevant information, links, documents, or screenshots that provide clarity.
45
- Use this section for anything not covered above.
46
- placeholder: Add any extra details here.
47
- validations:
48
- required: false
49
-
50
- - type: checkboxes
51
- attributes:
52
- label: 4. Can you help us with this feature?
53
- description: |
54
- Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
55
- options:
56
- - label: I am interested in contributing to this feature.
57
- required: false
58
-
59
- - type: markdown
60
- attributes:
61
- value: |
62
- **Note:** Please submit only one request per issue to keep discussions focused and manageable.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/help_wanted.yml DELETED
@@ -1,50 +0,0 @@
1
- name: "Help Wanted"
2
- description: |
3
- Please provide as much details to help address the issue, including logs and screenshots.
4
- labels:
5
- - help wanted
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To ensure timely help, please confirm the following:"
11
- options:
12
- - label: This template is only for usage issues encountered.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Environment Details
23
- description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24
- placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
25
- validations:
26
- required: true
27
- - type: textarea
28
- attributes:
29
- label: Steps to Reproduce
30
- description: |
31
- Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32
- placeholder: |
33
- 1. Create a new conda environment.
34
- 2. Clone the repository and install as pip package.
35
- 3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
36
- 4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
37
- validations:
38
- required: true
39
- - type: textarea
40
- attributes:
41
- label: ✔️ Expected Behavior
42
- placeholder: Describe what you expected to happen, e.g. output a generated audio
43
- validations:
44
- required: false
45
- - type: textarea
46
- attributes:
47
- label: ❌ Actual Behavior
48
- placeholder: Describe what actually happened, failure messages, etc.
49
- validations:
50
- required: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/question.yml DELETED
@@ -1,26 +0,0 @@
1
- name: "Question"
2
- description: |
3
- Pure question or inquiry about the project, usage issue goes with "help wanted".
4
- labels:
5
- - question
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To help us grasp quickly, please confirm the following:"
11
- options:
12
- - label: This template is only for question, not feature requests or bug reports.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation and read the related paper(s).
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, no similar questions.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Question details
23
- description: |
24
- Question details, clearly stated using proper markdown syntax.
25
- validations:
26
- required: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/pre-commit.yaml DELETED
@@ -1,14 +0,0 @@
1
- name: pre-commit
2
-
3
- on:
4
- pull_request:
5
- push:
6
- branches: [main]
7
-
8
- jobs:
9
- pre-commit:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - uses: actions/checkout@v3
13
- - uses: actions/setup-python@v3
14
- - uses: pre-commit/action@v3.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/publish-docker-image.yaml DELETED
@@ -1,60 +0,0 @@
1
- name: Create and publish a Docker image
2
-
3
- # Configures this workflow to run every time a change is pushed to the branch called `release`.
4
- on:
5
- push:
6
- branches: ['main']
7
-
8
- # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
9
- env:
10
- REGISTRY: ghcr.io
11
- IMAGE_NAME: ${{ github.repository }}
12
-
13
- # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
14
- jobs:
15
- build-and-push-image:
16
- runs-on: ubuntu-latest
17
- # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
18
- permissions:
19
- contents: read
20
- packages: write
21
- #
22
- steps:
23
- - name: Checkout repository
24
- uses: actions/checkout@v4
25
- - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
26
- uses: jlumbroso/free-disk-space@main
27
- with:
28
- # This might remove tools that are actually needed, if set to "true" but frees about 6 GB
29
- tool-cache: false
30
-
31
- # All of these default to true, but feel free to set to "false" if necessary for your workflow
32
- android: true
33
- dotnet: true
34
- haskell: true
35
- large-packages: false
36
- swap-storage: false
37
- docker-images: false
38
- # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
39
- - name: Log in to the Container registry
40
- uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
41
- with:
42
- registry: ${{ env.REGISTRY }}
43
- username: ${{ github.actor }}
44
- password: ${{ secrets.GITHUB_TOKEN }}
45
- # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
46
- - name: Extract metadata (tags, labels) for Docker
47
- id: meta
48
- uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
49
- with:
50
- images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
51
- # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
52
- # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
53
- # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
54
- - name: Build and push Docker image
55
- uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
56
- with:
57
- context: .
58
- push: true
59
- tags: ${{ steps.meta.outputs.tags }}
60
- labels: ${{ steps.meta.outputs.labels }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/sync-hf.yaml DELETED
@@ -1,18 +0,0 @@
1
- name: Sync to HF Space
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
-
8
- jobs:
9
- trigger_curl:
10
- runs-on: ubuntu-latest
11
-
12
- steps:
13
- - name: Send cURL POST request
14
- run: |
15
- curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
16
- -s \
17
- -H "Content-Type: application/json" \
18
- -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "src/third_party/BigVGAN"]
2
- path = src/third_party/BigVGAN
3
- url = https://github.com/NVIDIA/BigVGAN.git
 
 
 
 
.pre-commit-config.yaml DELETED
@@ -1,14 +0,0 @@
1
- repos:
2
- - repo: https://github.com/astral-sh/ruff-pre-commit
3
- # Ruff version.
4
- rev: v0.7.0
5
- hooks:
6
- # Run the linter.
7
- - id: ruff
8
- args: [--fix]
9
- # Run the formatter.
10
- - id: ruff-format
11
- - repo: https://github.com/pre-commit/pre-commit-hooks
12
- rev: v2.3.0
13
- hooks:
14
- - id: check-yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,27 +0,0 @@
1
- FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
-
3
- USER root
4
-
5
- ARG DEBIAN_FRONTEND=noninteractive
6
-
7
- LABEL github_repo="https://github.com/SWivid/F5-TTS"
8
-
9
- RUN set -x \
10
- && apt-get update \
11
- && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
- && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
- && apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \
14
- && rm -rf /var/lib/apt/lists/* \
15
- && apt-get clean
16
-
17
- WORKDIR /workspace
18
-
19
- RUN git clone https://github.com/SWivid/F5-TTS.git \
20
- && cd F5-TTS \
21
- && git submodule update --init --recursive \
22
- && sed -i '7iimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))' src/third_party/BigVGAN/bigvgan.py \
23
- && pip install -e . --no-cache-dir
24
-
25
- ENV SHELL=/bin/bash
26
-
27
- WORKDIR /workspace/F5-TTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: F5-TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
- short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
- sdk_version: 4.44.1
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: E2/F5 TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
+ short_description: 'E2-TTS & F5-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
+ sdk_version: 5.0.2
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_REPO.md DELETED
@@ -1,172 +0,0 @@
1
- # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
-
3
- [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
- [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
- [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
6
- [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
- [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
- [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
10
-
11
- **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
12
-
13
- **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
14
-
15
- **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
16
-
17
- ### Thanks to all the contributors !
18
-
19
- ## News
20
- - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
21
-
22
- ## Installation
23
-
24
- ```bash
25
- # Create a python 3.10 conda env (you could also use virtualenv)
26
- conda create -n f5-tts python=3.10
27
- conda activate f5-tts
28
-
29
- # Install pytorch with your CUDA version, e.g.
30
- pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
31
- ```
32
-
33
- Then you can choose from a few options below:
34
-
35
- ### 1. As a pip package (if just for inference)
36
-
37
- ```bash
38
- pip install git+https://github.com/SWivid/F5-TTS.git
39
- ```
40
-
41
- ### 2. Local editable (if also do training, finetuning)
42
-
43
- ```bash
44
- git clone https://github.com/SWivid/F5-TTS.git
45
- cd F5-TTS
46
- # git submodule update --init --recursive # (optional, if need bigvgan)
47
- pip install -e .
48
- ```
49
- If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
50
- ```python
51
- import os
52
- import sys
53
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
54
- ```
55
-
56
- ### 3. Docker usage
57
- ```bash
58
- # Build from Dockerfile
59
- docker build -t f5tts:v1 .
60
-
61
- # Or pull from GitHub Container Registry
62
- docker pull ghcr.io/swivid/f5-tts:main
63
- ```
64
-
65
-
66
- ## Inference
67
-
68
- ### 1. Gradio App
69
-
70
- Currently supported features:
71
-
72
- - Basic TTS with Chunk Inference
73
- - Multi-Style / Multi-Speaker Generation
74
- - Voice Chat powered by Qwen2.5-3B-Instruct
75
- - [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
76
-
77
- ```bash
78
- # Launch a Gradio app (web interface)
79
- f5-tts_infer-gradio
80
-
81
- # Specify the port/host
82
- f5-tts_infer-gradio --port 7860 --host 0.0.0.0
83
-
84
- # Launch a share link
85
- f5-tts_infer-gradio --share
86
- ```
87
-
88
- ### 2. CLI Inference
89
-
90
- ```bash
91
- # Run with flags
92
- # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
93
- f5-tts_infer-cli \
94
- --model "F5-TTS" \
95
- --ref_audio "ref_audio.wav" \
96
- --ref_text "The content, subtitle or transcription of reference audio." \
97
- --gen_text "Some text you want TTS model generate for you."
98
-
99
- # Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
100
- f5-tts_infer-cli
101
- # Or with your own .toml file
102
- f5-tts_infer-cli -c custom.toml
103
-
104
- # Multi voice. See src/f5_tts/infer/README.md
105
- f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
106
- ```
107
-
108
- ### 3. More instructions
109
-
110
- - In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
111
- - The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
112
-
113
-
114
- ## Training
115
-
116
- ### 1. Gradio App
117
-
118
- Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
119
-
120
- ```bash
121
- # Quick start with Gradio web interface
122
- f5-tts_finetune-gradio
123
- ```
124
-
125
-
126
- ## [Evaluation](src/f5_tts/eval)
127
-
128
-
129
- ## Development
130
-
131
- Use pre-commit to ensure code quality (will run linters and formatters automatically)
132
-
133
- ```bash
134
- pip install pre-commit
135
- pre-commit install
136
- ```
137
-
138
- When making a pull request, before each commit, run:
139
-
140
- ```bash
141
- pre-commit run --all-files
142
- ```
143
-
144
- Note: Some model components have linting exceptions for E722 to accommodate tensor notation
145
-
146
-
147
- ## Acknowledgements
148
-
149
- - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
- - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
- - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
- - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
- - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
154
- - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
- - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
- - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
- - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
158
- - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
159
-
160
- ## Citation
161
- If our work and codebase is useful for you, please cite as:
162
- ```
163
- @article{chen-etal-2024-f5tts,
164
- title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
165
- author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
166
- journal={arXiv preprint arXiv:2410.06885},
167
- year={2024},
168
- }
169
- ```
170
- ## License
171
-
172
- Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py DELETED
@@ -1,132 +0,0 @@
1
- import soundfile as sf
2
- import torch
3
- import tqdm
4
- from cached_path import cached_path
5
-
6
- from model import DiT, UNetT
7
- from model.utils import save_spectrogram
8
-
9
- from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
- from model.utils import seed_everything
11
- import random
12
- import sys
13
-
14
-
15
- class F5TTS:
16
- def __init__(
17
- self,
18
- model_type="F5-TTS",
19
- ckpt_file="",
20
- vocab_file="",
21
- ode_method="euler",
22
- use_ema=True,
23
- local_path=None,
24
- device=None,
25
- ):
26
- # Initialize parameters
27
- self.final_wave = None
28
- self.target_sample_rate = 24000
29
- self.n_mel_channels = 100
30
- self.hop_length = 256
31
- self.target_rms = 0.1
32
- self.seed = -1
33
-
34
- # Set device
35
- self.device = device or (
36
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
37
- )
38
-
39
- # Load models
40
- self.load_vocoder_model(local_path)
41
- self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
42
-
43
- def load_vocoder_model(self, local_path):
44
- self.vocos = load_vocoder(local_path is not None, local_path, self.device)
45
-
46
- def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
47
- if model_type == "F5-TTS":
48
- if not ckpt_file:
49
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
50
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
- model_cls = DiT
52
- elif model_type == "E2-TTS":
53
- if not ckpt_file:
54
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
55
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
56
- model_cls = UNetT
57
- else:
58
- raise ValueError(f"Unknown model type: {model_type}")
59
-
60
- self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
61
-
62
- def export_wav(self, wav, file_wave, remove_silence=False):
63
- sf.write(file_wave, wav, self.target_sample_rate)
64
-
65
- if remove_silence:
66
- remove_silence_for_generated_wav(file_wave)
67
-
68
- def export_spectrogram(self, spect, file_spect):
69
- save_spectrogram(spect, file_spect)
70
-
71
- def infer(
72
- self,
73
- ref_file,
74
- ref_text,
75
- gen_text,
76
- show_info=print,
77
- progress=tqdm,
78
- target_rms=0.1,
79
- cross_fade_duration=0.15,
80
- sway_sampling_coef=-1,
81
- cfg_strength=2,
82
- nfe_step=32,
83
- speed=1.0,
84
- fix_duration=None,
85
- remove_silence=False,
86
- file_wave=None,
87
- file_spect=None,
88
- seed=-1,
89
- ):
90
- if seed == -1:
91
- seed = random.randint(0, sys.maxsize)
92
- seed_everything(seed)
93
- self.seed = seed
94
- wav, sr, spect = infer_process(
95
- ref_file,
96
- ref_text,
97
- gen_text,
98
- self.ema_model,
99
- show_info=show_info,
100
- progress=progress,
101
- target_rms=target_rms,
102
- cross_fade_duration=cross_fade_duration,
103
- nfe_step=nfe_step,
104
- cfg_strength=cfg_strength,
105
- sway_sampling_coef=sway_sampling_coef,
106
- speed=speed,
107
- fix_duration=fix_duration,
108
- device=self.device,
109
- )
110
-
111
- if file_wave is not None:
112
- self.export_wav(wav, file_wave, remove_silence)
113
-
114
- if file_spect is not None:
115
- self.export_spectrogram(spect, file_spect)
116
-
117
- return wav, sr, spect
118
-
119
-
120
- if __name__ == "__main__":
121
- f5tts = F5TTS()
122
-
123
- wav, sr, spect = f5tts.infer(
124
- ref_file="tests/ref_audio/test_en_1_ref_short.wav",
125
- ref_text="some call me nature, others call me mother nature.",
126
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
127
- file_wave="tests/out.wav",
128
- file_spect="tests/out.png",
129
- seed=-1, # random seed = -1
130
- )
131
-
132
- print("seed :", f5tts.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,888 +1,242 @@
1
- # ruff: noqa: E402
2
- # Above allows ruff to ignore E402: module level import not at top of file
3
-
4
- import json
5
  import re
6
- import tempfile
7
- from collections import OrderedDict
8
- from importlib.resources import files
9
-
10
- import click
11
  import gradio as gr
12
  import numpy as np
13
- import soundfile as sf
14
- import torchaudio
 
 
 
 
15
  from cached_path import cached_path
16
- from transformers import AutoModelForCausalLM, AutoTokenizer
17
-
18
- try:
19
- import spaces
20
-
21
- USING_SPACES = True
22
- except ImportError:
23
- USING_SPACES = False
24
-
25
-
26
- def gpu_decorator(func):
27
- if USING_SPACES:
28
- return spaces.GPU(func)
29
- else:
30
- return func
31
-
32
-
33
- from f5_tts.model import DiT, UNetT
34
- from f5_tts.infer.utils_infer import (
35
- load_vocoder,
36
- load_model,
37
- preprocess_ref_audio_text,
38
- infer_process,
39
- remove_silence_for_generated_wav,
40
  save_spectrogram,
41
  )
 
 
 
 
 
42
 
43
 
44
- DEFAULT_TTS_MODEL = "F5-TTS"
45
- tts_model_choice = DEFAULT_TTS_MODEL
46
-
47
- DEFAULT_TTS_MODEL_CFG = [
48
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
- json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
- ]
52
 
 
53
 
54
- # load models
55
-
56
- vocoder = load_vocoder()
57
-
58
-
59
- def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
60
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
61
- return load_model(DiT, F5TTS_model_cfg, ckpt_path)
62
-
63
-
64
- def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
65
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
66
- return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
67
-
68
-
69
- def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
70
- ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
71
- if ckpt_path.startswith("hf://"):
72
- ckpt_path = str(cached_path(ckpt_path))
73
- if vocab_path.startswith("hf://"):
74
- vocab_path = str(cached_path(vocab_path))
75
- if model_cfg is None:
76
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
77
- return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
78
-
79
-
80
- F5TTS_ema_model = load_f5tts()
81
- E2TTS_ema_model = load_e2tts() if USING_SPACES else None
82
- custom_ema_model, pre_custom_path = None, ""
83
-
84
- chat_model_state = None
85
- chat_tokenizer_state = None
86
 
 
 
 
 
 
 
87
 
88
- @gpu_decorator
89
- def generate_response(messages, model, tokenizer):
90
- """Generate response using Qwen"""
91
- text = tokenizer.apply_chat_template(
92
- messages,
93
- tokenize=False,
94
- add_generation_prompt=True,
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
98
- generated_ids = model.generate(
99
- **model_inputs,
100
- max_new_tokens=512,
101
- temperature=0.7,
102
- top_p=0.95,
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- generated_ids = [
106
- output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
107
- ]
108
- return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
 
 
 
 
 
110
 
111
- @gpu_decorator
112
- def infer(
113
- ref_audio_orig,
114
- ref_text,
115
- gen_text,
116
- model,
117
- remove_silence,
118
- cross_fade_duration=0.15,
119
- nfe_step=32,
120
- speed=1,
121
- show_info=gr.Info,
122
- ):
123
- if not ref_audio_orig:
124
- gr.Warning("Please provide reference audio.")
125
- return gr.update(), gr.update(), ref_text
126
 
127
- if not gen_text.strip():
128
- gr.Warning("Please enter text to generate.")
129
- return gr.update(), gr.update(), ref_text
130
 
131
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
132
 
133
- if model == "F5-TTS":
134
- ema_model = F5TTS_ema_model
135
- elif model == "E2-TTS":
136
- global E2TTS_ema_model
137
- if E2TTS_ema_model is None:
138
- show_info("Loading E2-TTS model...")
139
- E2TTS_ema_model = load_e2tts()
140
- ema_model = E2TTS_ema_model
141
- elif isinstance(model, list) and model[0] == "Custom":
142
- assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
143
- global custom_ema_model, pre_custom_path
144
- if pre_custom_path != model[1]:
145
- show_info("Loading Custom TTS model...")
146
- custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
147
- pre_custom_path = model[1]
148
- ema_model = custom_ema_model
149
-
150
- final_wave, final_sample_rate, combined_spectrogram = infer_process(
151
- ref_audio,
152
- ref_text,
153
- gen_text,
154
- ema_model,
155
- vocoder,
156
- cross_fade_duration=cross_fade_duration,
157
- nfe_step=nfe_step,
158
- speed=speed,
159
- show_info=show_info,
160
- progress=gr.Progress(),
161
- )
162
-
163
- # Remove silence
164
- if remove_silence:
165
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
166
- sf.write(f.name, final_wave, final_sample_rate)
167
- remove_silence_for_generated_wav(f.name)
168
- final_wave, _ = torchaudio.load(f.name)
169
- final_wave = final_wave.squeeze().cpu().numpy()
170
 
171
- # Save the spectrogram
172
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
173
- spectrogram_path = tmp_spectrogram.name
174
- save_spectrogram(combined_spectrogram, spectrogram_path)
175
 
176
- return (final_sample_rate, final_wave), spectrogram_path, ref_text
177
 
 
178
 
179
- with gr.Blocks() as app_credits:
180
- gr.Markdown("""
181
- # Credits
182
 
183
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
184
- * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
185
- * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
186
  """)
187
- with gr.Blocks() as app_tts:
188
- gr.Markdown("# Batched TTS")
189
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
190
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
 
191
  generate_btn = gr.Button("Synthesize", variant="primary")
192
  with gr.Accordion("Advanced Settings", open=False):
193
- ref_text_input = gr.Textbox(
194
- label="Reference Text",
195
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
196
- lines=2,
197
- )
198
- remove_silence = gr.Checkbox(
199
- label="Remove Silences",
200
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
201
- value=False,
202
- )
203
- speed_slider = gr.Slider(
204
- label="Speed",
205
- minimum=0.3,
206
- maximum=2.0,
207
- value=1.0,
208
- step=0.1,
209
- info="Adjust the speed of the audio.",
210
- )
211
- nfe_slider = gr.Slider(
212
- label="NFE Steps",
213
- minimum=4,
214
- maximum=64,
215
- value=32,
216
- step=2,
217
- info="Set the number of denoising steps.",
218
- )
219
- cross_fade_duration_slider = gr.Slider(
220
- label="Cross-Fade Duration (s)",
221
- minimum=0.0,
222
- maximum=1.0,
223
- value=0.15,
224
- step=0.01,
225
- info="Set the duration of the cross-fade between audio clips.",
226
- )
227
-
228
  audio_output = gr.Audio(label="Synthesized Audio")
229
- spectrogram_output = gr.Image(label="Spectrogram")
230
-
231
- @gpu_decorator
232
- def basic_tts(
233
- ref_audio_input,
234
- ref_text_input,
235
- gen_text_input,
236
- remove_silence,
237
- cross_fade_duration_slider,
238
- nfe_slider,
239
- speed_slider,
240
- ):
241
- audio_out, spectrogram_path, ref_text_out = infer(
242
- ref_audio_input,
243
- ref_text_input,
244
- gen_text_input,
245
- tts_model_choice,
246
- remove_silence,
247
- cross_fade_duration=cross_fade_duration_slider,
248
- nfe_step=nfe_slider,
249
- speed=speed_slider,
250
- )
251
- return audio_out, spectrogram_path, ref_text_out
252
-
253
- generate_btn.click(
254
- basic_tts,
255
- inputs=[
256
- ref_audio_input,
257
- ref_text_input,
258
- gen_text_input,
259
- remove_silence,
260
- cross_fade_duration_slider,
261
- nfe_slider,
262
- speed_slider,
263
- ],
264
- outputs=[audio_output, spectrogram_output, ref_text_input],
265
- )
266
-
267
-
268
- def parse_speechtypes_text(gen_text):
269
- # Pattern to find {speechtype}
270
- pattern = r"\{(.*?)\}"
271
-
272
- # Split the text by the pattern
273
- tokens = re.split(pattern, gen_text)
274
-
275
- segments = []
276
-
277
- current_style = "Regular"
278
-
279
- for i in range(len(tokens)):
280
- if i % 2 == 0:
281
- # This is text
282
- text = tokens[i].strip()
283
- if text:
284
- segments.append({"style": current_style, "text": text})
285
- else:
286
- # This is style
287
- style = tokens[i].strip()
288
- current_style = style
289
-
290
- return segments
291
-
292
-
293
- with gr.Blocks() as app_multistyle:
294
- # New section for multistyle generation
295
- gr.Markdown(
296
- """
297
- # Multiple Speech-Type Generation
298
-
299
- This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
300
- """
301
- )
302
-
303
- with gr.Row():
304
- gr.Markdown(
305
- """
306
- **Example Input:**
307
- {Regular} Hello, I'd like to order a sandwich please.
308
- {Surprised} What do you mean you're out of bread?
309
- {Sad} I really wanted a sandwich though...
310
- {Angry} You know what, darn you and your little shop!
311
- {Whisper} I'll just go back home and cry now.
312
- {Shouting} Why me?!
313
- """
314
- )
315
-
316
- gr.Markdown(
317
- """
318
- **Example Input 2:**
319
- {Speaker1_Happy} Hello, I'd like to order a sandwich please.
320
- {Speaker2_Regular} Sorry, we're out of bread.
321
- {Speaker1_Sad} I really wanted a sandwich though...
322
- {Speaker2_Whisper} I'll give you the last one I was hiding.
323
- """
324
- )
325
-
326
- gr.Markdown(
327
- "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
328
- )
329
-
330
- # Regular speech type (mandatory)
331
- with gr.Row() as regular_row:
332
- with gr.Column():
333
- regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
334
- regular_insert = gr.Button("Insert Label", variant="secondary")
335
- regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
336
- regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
337
-
338
- # Regular speech type (max 100)
339
- max_speech_types = 100
340
- speech_type_rows = [regular_row]
341
- speech_type_names = [regular_name]
342
- speech_type_audios = [regular_audio]
343
- speech_type_ref_texts = [regular_ref_text]
344
- speech_type_delete_btns = [None]
345
- speech_type_insert_btns = [regular_insert]
346
-
347
- # Additional speech types (99 more)
348
- for i in range(max_speech_types - 1):
349
- with gr.Row(visible=False) as row:
350
- with gr.Column():
351
- name_input = gr.Textbox(label="Speech Type Name")
352
- delete_btn = gr.Button("Delete Type", variant="secondary")
353
- insert_btn = gr.Button("Insert Label", variant="secondary")
354
- audio_input = gr.Audio(label="Reference Audio", type="filepath")
355
- ref_text_input = gr.Textbox(label="Reference Text", lines=2)
356
- speech_type_rows.append(row)
357
- speech_type_names.append(name_input)
358
- speech_type_audios.append(audio_input)
359
- speech_type_ref_texts.append(ref_text_input)
360
- speech_type_delete_btns.append(delete_btn)
361
- speech_type_insert_btns.append(insert_btn)
362
-
363
- # Button to add speech type
364
- add_speech_type_btn = gr.Button("Add Speech Type")
365
-
366
- # Keep track of autoincrement of speech types, no roll back
367
- speech_type_count = 1
368
-
369
- # Function to add a speech type
370
- def add_speech_type_fn():
371
- row_updates = [gr.update() for _ in range(max_speech_types)]
372
- global speech_type_count
373
- if speech_type_count < max_speech_types:
374
- row_updates[speech_type_count] = gr.update(visible=True)
375
- speech_type_count += 1
376
- else:
377
- gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
378
- return row_updates
379
-
380
- add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
381
-
382
- # Function to delete a speech type
383
- def delete_speech_type_fn():
384
- return gr.update(visible=False), None, None, None
385
-
386
- # Update delete button clicks
387
- for i in range(1, len(speech_type_delete_btns)):
388
- speech_type_delete_btns[i].click(
389
- delete_speech_type_fn,
390
- outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
391
- )
392
-
393
- # Text input for the prompt
394
- gen_text_input_multistyle = gr.Textbox(
395
- label="Text to Generate",
396
- lines=10,
397
- placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
398
- )
399
-
400
- def make_insert_speech_type_fn(index):
401
- def insert_speech_type_fn(current_text, speech_type_name):
402
- current_text = current_text or ""
403
- speech_type_name = speech_type_name or "None"
404
- updated_text = current_text + f"{{{speech_type_name}}} "
405
- return updated_text
406
-
407
- return insert_speech_type_fn
408
-
409
- for i, insert_btn in enumerate(speech_type_insert_btns):
410
- insert_fn = make_insert_speech_type_fn(i)
411
- insert_btn.click(
412
- insert_fn,
413
- inputs=[gen_text_input_multistyle, speech_type_names[i]],
414
- outputs=gen_text_input_multistyle,
415
- )
416
 
417
- with gr.Accordion("Advanced Settings", open=False):
418
- remove_silence_multistyle = gr.Checkbox(
419
- label="Remove Silences",
420
- value=True,
421
- )
422
-
423
- # Generate button
424
- generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
425
-
426
- # Output audio
427
- audio_output_multistyle = gr.Audio(label="Synthesized Audio")
428
-
429
- @gpu_decorator
430
- def generate_multistyle_speech(
431
- gen_text,
432
- *args,
433
- ):
434
- speech_type_names_list = args[:max_speech_types]
435
- speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
436
- speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
437
- remove_silence = args[3 * max_speech_types]
438
- # Collect the speech types and their audios into a dict
439
- speech_types = OrderedDict()
440
-
441
- ref_text_idx = 0
442
- for name_input, audio_input, ref_text_input in zip(
443
- speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
444
- ):
445
- if name_input and audio_input:
446
- speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
447
- else:
448
- speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
449
- ref_text_idx += 1
450
-
451
- # Parse the gen_text into segments
452
- segments = parse_speechtypes_text(gen_text)
453
-
454
- # For each segment, generate speech
455
- generated_audio_segments = []
456
- current_style = "Regular"
457
-
458
- for segment in segments:
459
- style = segment["style"]
460
- text = segment["text"]
461
-
462
- if style in speech_types:
463
- current_style = style
464
- else:
465
- gr.Warning(f"Type {style} is not available, will use Regular as default.")
466
- current_style = "Regular"
467
-
468
- try:
469
- ref_audio = speech_types[current_style]["audio"]
470
- except KeyError:
471
- gr.Warning(f"Please provide reference audio for type {current_style}.")
472
- return [None] + [speech_types[style]["ref_text"] for style in speech_types]
473
- ref_text = speech_types[current_style].get("ref_text", "")
474
-
475
- # Generate speech for this segment
476
- audio_out, _, ref_text_out = infer(
477
- ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
478
- ) # show_info=print no pull to top when generating
479
- sr, audio_data = audio_out
480
-
481
- generated_audio_segments.append(audio_data)
482
- speech_types[current_style]["ref_text"] = ref_text_out
483
-
484
- # Concatenate all audio segments
485
- if generated_audio_segments:
486
- final_audio_data = np.concatenate(generated_audio_segments)
487
- return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
488
- else:
489
- gr.Warning("No audio generated.")
490
- return [None] + [speech_types[style]["ref_text"] for style in speech_types]
491
-
492
- generate_multistyle_btn.click(
493
- generate_multistyle_speech,
494
- inputs=[
495
- gen_text_input_multistyle,
496
- ]
497
- + speech_type_names
498
- + speech_type_audios
499
- + speech_type_ref_texts
500
- + [
501
- remove_silence_multistyle,
502
- ],
503
- outputs=[audio_output_multistyle] + speech_type_ref_texts,
504
- )
505
-
506
- # Validation function to disable Generate button if speech types are missing
507
- def validate_speech_types(gen_text, regular_name, *args):
508
- speech_type_names_list = args
509
-
510
- # Collect the speech types names
511
- speech_types_available = set()
512
- if regular_name:
513
- speech_types_available.add(regular_name)
514
- for name_input in speech_type_names_list:
515
- if name_input:
516
- speech_types_available.add(name_input)
517
-
518
- # Parse the gen_text to get the speech types used
519
- segments = parse_speechtypes_text(gen_text)
520
- speech_types_in_text = set(segment["style"] for segment in segments)
521
-
522
- # Check if all speech types in text are available
523
- missing_speech_types = speech_types_in_text - speech_types_available
524
-
525
- if missing_speech_types:
526
- # Disable the generate button
527
- return gr.update(interactive=False)
528
- else:
529
- # Enable the generate button
530
- return gr.update(interactive=True)
531
-
532
- gen_text_input_multistyle.change(
533
- validate_speech_types,
534
- inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
535
- outputs=generate_multistyle_btn,
536
- )
537
-
538
-
539
- with gr.Blocks() as app_chat:
540
- gr.Markdown(
541
- """
542
- # Voice Chat
543
- Have a conversation with an AI using your reference voice!
544
- 1. Upload a reference audio clip and optionally its transcript.
545
- 2. Load the chat model.
546
- 3. Record your message through your microphone.
547
- 4. The AI will respond using the reference voice.
548
- """
549
- )
550
-
551
- if not USING_SPACES:
552
- load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
553
-
554
- chat_interface_container = gr.Column(visible=False)
555
-
556
- @gpu_decorator
557
- def load_chat_model():
558
- global chat_model_state, chat_tokenizer_state
559
- if chat_model_state is None:
560
- show_info = gr.Info
561
- show_info("Loading chat model...")
562
- model_name = "Qwen/Qwen2.5-3B-Instruct"
563
- chat_model_state = AutoModelForCausalLM.from_pretrained(
564
- model_name, torch_dtype="auto", device_map="auto"
565
- )
566
- chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
567
- show_info("Chat model loaded.")
568
-
569
- return gr.update(visible=False), gr.update(visible=True)
570
-
571
- load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
572
-
573
- else:
574
- chat_interface_container = gr.Column()
575
-
576
- if chat_model_state is None:
577
- model_name = "Qwen/Qwen2.5-3B-Instruct"
578
- chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
579
- chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
580
-
581
- with chat_interface_container:
582
- with gr.Row():
583
- with gr.Column():
584
- ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
585
- with gr.Column():
586
- with gr.Accordion("Advanced Settings", open=False):
587
- remove_silence_chat = gr.Checkbox(
588
- label="Remove Silences",
589
- value=True,
590
- )
591
- ref_text_chat = gr.Textbox(
592
- label="Reference Text",
593
- info="Optional: Leave blank to auto-transcribe",
594
- lines=2,
595
- )
596
- system_prompt_chat = gr.Textbox(
597
- label="System Prompt",
598
- value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
599
- lines=2,
600
- )
601
-
602
- chatbot_interface = gr.Chatbot(label="Conversation")
603
-
604
- with gr.Row():
605
- with gr.Column():
606
- audio_input_chat = gr.Microphone(
607
- label="Speak your message",
608
- type="filepath",
609
- )
610
- audio_output_chat = gr.Audio(autoplay=True)
611
- with gr.Column():
612
- text_input_chat = gr.Textbox(
613
- label="Type your message",
614
- lines=1,
615
- )
616
- send_btn_chat = gr.Button("Send Message")
617
- clear_btn_chat = gr.Button("Clear Conversation")
618
-
619
- conversation_state = gr.State(
620
- value=[
621
- {
622
- "role": "system",
623
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
624
- }
625
- ]
626
- )
627
-
628
- # Modify process_audio_input to use model and tokenizer from state
629
- @gpu_decorator
630
- def process_audio_input(audio_path, text, history, conv_state):
631
- """Handle audio or text input from user"""
632
-
633
- if not audio_path and not text.strip():
634
- return history, conv_state, ""
635
-
636
- if audio_path:
637
- text = preprocess_ref_audio_text(audio_path, text)[1]
638
-
639
- if not text.strip():
640
- return history, conv_state, ""
641
-
642
- conv_state.append({"role": "user", "content": text})
643
- history.append((text, None))
644
-
645
- response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
646
-
647
- conv_state.append({"role": "assistant", "content": response})
648
- history[-1] = (text, response)
649
-
650
- return history, conv_state, ""
651
-
652
- @gpu_decorator
653
- def generate_audio_response(history, ref_audio, ref_text, remove_silence):
654
- """Generate TTS audio for AI response"""
655
- if not history or not ref_audio:
656
- return None
657
-
658
- last_user_message, last_ai_response = history[-1]
659
- if not last_ai_response:
660
- return None
661
-
662
- audio_result, _, ref_text_out = infer(
663
- ref_audio,
664
- ref_text,
665
- last_ai_response,
666
- tts_model_choice,
667
- remove_silence,
668
- cross_fade_duration=0.15,
669
- speed=1.0,
670
- show_info=print, # show_info=print no pull to top when generating
671
- )
672
- return audio_result, ref_text_out
673
-
674
- def clear_conversation():
675
- """Reset the conversation"""
676
- return [], [
677
- {
678
- "role": "system",
679
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
680
- }
681
- ]
682
-
683
- def update_system_prompt(new_prompt):
684
- """Update the system prompt and reset the conversation"""
685
- new_conv_state = [{"role": "system", "content": new_prompt}]
686
- return [], new_conv_state
687
-
688
- # Handle audio input
689
- audio_input_chat.stop_recording(
690
- process_audio_input,
691
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
692
- outputs=[chatbot_interface, conversation_state],
693
- ).then(
694
- generate_audio_response,
695
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
696
- outputs=[audio_output_chat, ref_text_chat],
697
- ).then(
698
- lambda: None,
699
- None,
700
- audio_input_chat,
701
- )
702
-
703
- # Handle text input
704
- text_input_chat.submit(
705
- process_audio_input,
706
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
707
- outputs=[chatbot_interface, conversation_state],
708
- ).then(
709
- generate_audio_response,
710
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
711
- outputs=[audio_output_chat, ref_text_chat],
712
- ).then(
713
- lambda: None,
714
- None,
715
- text_input_chat,
716
- )
717
-
718
- # Handle send button
719
- send_btn_chat.click(
720
- process_audio_input,
721
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
722
- outputs=[chatbot_interface, conversation_state],
723
- ).then(
724
- generate_audio_response,
725
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
726
- outputs=[audio_output_chat, ref_text_chat],
727
- ).then(
728
- lambda: None,
729
- None,
730
- text_input_chat,
731
- )
732
-
733
- # Handle clear button
734
- clear_btn_chat.click(
735
- clear_conversation,
736
- outputs=[chatbot_interface, conversation_state],
737
- )
738
-
739
- # Handle system prompt change and reset conversation
740
- system_prompt_chat.change(
741
- update_system_prompt,
742
- inputs=system_prompt_chat,
743
- outputs=[chatbot_interface, conversation_state],
744
- )
745
-
746
-
747
- with gr.Blocks() as app:
748
- gr.Markdown(
749
- """
750
- # E2/F5 TTS
751
-
752
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
753
 
754
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
755
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
756
 
757
- The checkpoints currently support English and Chinese.
758
 
759
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
 
 
 
 
 
760
 
761
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
762
- """
763
- )
764
-
765
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
766
-
767
- def load_last_used_custom():
768
- try:
769
- custom = []
770
- with open(last_used_custom, "r", encoding="utf-8") as f:
771
- for line in f:
772
- custom.append(line.strip())
773
- return custom
774
- except FileNotFoundError:
775
- last_used_custom.parent.mkdir(parents=True, exist_ok=True)
776
- return DEFAULT_TTS_MODEL_CFG
777
-
778
- def switch_tts_model(new_choice):
779
- global tts_model_choice
780
- if new_choice == "Custom": # override in case webpage is refreshed
781
- custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
782
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
783
- return (
784
- gr.update(visible=True, value=custom_ckpt_path),
785
- gr.update(visible=True, value=custom_vocab_path),
786
- gr.update(visible=True, value=custom_model_cfg),
787
- )
788
- else:
789
- tts_model_choice = new_choice
790
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
791
-
792
- def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
793
- global tts_model_choice
794
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
795
- with open(last_used_custom, "w", encoding="utf-8") as f:
796
- f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
797
-
798
- with gr.Row():
799
- if not USING_SPACES:
800
- choose_tts_model = gr.Radio(
801
- choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
802
- )
803
- else:
804
- choose_tts_model = gr.Radio(
805
- choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
806
- )
807
- custom_ckpt_path = gr.Dropdown(
808
- choices=[DEFAULT_TTS_MODEL_CFG[0]],
809
- value=load_last_used_custom()[0],
810
- allow_custom_value=True,
811
- label="Model: local_path | hf://user_id/repo_id/model_ckpt",
812
- visible=False,
813
- )
814
- custom_vocab_path = gr.Dropdown(
815
- choices=[DEFAULT_TTS_MODEL_CFG[1]],
816
- value=load_last_used_custom()[1],
817
- allow_custom_value=True,
818
- label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
819
- visible=False,
820
- )
821
- custom_model_cfg = gr.Dropdown(
822
- choices=[
823
- DEFAULT_TTS_MODEL_CFG[2],
824
- json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
825
- ],
826
- value=load_last_used_custom()[2],
827
- allow_custom_value=True,
828
- label="Config: in a dictionary form",
829
- visible=False,
830
- )
831
-
832
- choose_tts_model.change(
833
- switch_tts_model,
834
- inputs=[choose_tts_model],
835
- outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
836
- show_progress="hidden",
837
- )
838
- custom_ckpt_path.change(
839
- set_custom_model,
840
- inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
841
- show_progress="hidden",
842
- )
843
- custom_vocab_path.change(
844
- set_custom_model,
845
- inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
846
- show_progress="hidden",
847
- )
848
- custom_model_cfg.change(
849
- set_custom_model,
850
- inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
851
- show_progress="hidden",
852
- )
853
-
854
- gr.TabbedInterface(
855
- [app_tts, app_multistyle, app_chat, app_credits],
856
- ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
857
- )
858
-
859
-
860
- @click.command()
861
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
862
- @click.option("--host", "-H", default=None, help="Host to run the app on")
863
- @click.option(
864
- "--share",
865
- "-s",
866
- default=False,
867
- is_flag=True,
868
- help="Share the app via Gradio share link",
869
- )
870
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
871
- @click.option(
872
- "--root_path",
873
- "-r",
874
- default=None,
875
- type=str,
876
- help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
877
- )
878
- def main(port, host, share, api, root_path):
879
- global app
880
- print("Starting app...")
881
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path)
882
 
883
 
884
- if __name__ == "__main__":
885
- if not USING_SPACES:
886
- main()
887
- else:
888
- app.queue().launch()
 
1
+ import os
 
 
 
2
  import re
3
+ import torch
4
+ import torchaudio
 
 
 
5
  import gradio as gr
6
  import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from ema_pytorch import EMA
10
+ from vocos import Vocos
11
+ from pydub import AudioSegment
12
+ from model import CFM, UNetT, DiT, MMDiT
13
  from cached_path import cached_path
14
+ from model.utils import (
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  save_spectrogram,
18
  )
19
+ from transformers import pipeline
20
+ import spaces
21
+ import librosa
22
+ from txtsplit import txtsplit
23
+ from detoxify import Detoxify
24
 
25
 
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
 
27
 
28
+ model = Detoxify('original', device=device)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ pipe = pipeline(
32
+ "automatic-speech-recognition",
33
+ model="openai/whisper-large-v3-turbo",
34
+ torch_dtype=torch.float16,
35
+ device=device,
36
+ )
37
 
38
+ # --------------------- Settings -------------------- #
39
+
40
+ target_sample_rate = 24000
41
+ n_mel_channels = 100
42
+ hop_length = 256
43
+ target_rms = 0.1
44
+ nfe_step = 32 # 16, 32
45
+ cfg_strength = 2.0
46
+ ode_method = 'euler'
47
+ sway_sampling_coef = -1.0
48
+ speed = 1.0
49
+ # fix_duration = 27 # None or float (duration in seconds)
50
+ fix_duration = None
51
+
52
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
+ model = CFM(
56
+ transformer=model_cls(
57
+ **model_cfg,
58
+ text_num_embeds=vocab_size,
59
+ mel_dim=n_mel_channels
60
+ ),
61
+ mel_spec_kwargs=dict(
62
+ target_sample_rate=target_sample_rate,
63
+ n_mel_channels=n_mel_channels,
64
+ hop_length=hop_length,
65
+ ),
66
+ odeint_kwargs=dict(
67
+ method=ode_method,
68
+ ),
69
+ vocab_char_map=vocab_char_map,
70
+ ).to(device)
71
+
72
+ ema_model = EMA(model, include_online_model=False).to(device)
73
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
74
+ ema_model.copy_params_from_ema_to_model()
75
+
76
+ return ema_model, model
77
 
78
+ # load models
79
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
81
+
82
+ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
83
+ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
84
+
85
+ @spaces.GPU
86
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
+ print(gen_text)
88
+ if model.predict(gen_text)['toxicity'] > 0.8:
89
+ print("Flagged for toxicity:", gen_text)
90
+ raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
+ gr.Info("Converting audio...")
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
93
+ aseg = AudioSegment.from_file(ref_audio_orig)
94
+ # Convert to mono
95
+ aseg = aseg.set_channels(1)
96
+ audio_duration = len(aseg)
97
+ if audio_duration > 15000:
98
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
99
+ aseg = aseg[:15000]
100
+ aseg.export(f.name, format="wav")
101
+ ref_audio = f.name
102
+ if exp_name == "F5-TTS":
103
+ ema_model = F5TTS_ema_model
104
+ base_model = F5TTS_base_model
105
+ elif exp_name == "E2-TTS":
106
+ ema_model = E2TTS_ema_model
107
+ base_model = E2TTS_base_model
108
+
109
+ if not ref_text.strip():
110
+ gr.Info("No reference text provided, transcribing reference audio...")
111
+ ref_text = outputs = pipe(
112
+ ref_audio,
113
+ chunk_length_s=30,
114
+ batch_size=128,
115
+ generate_kwargs={"task": "transcribe"},
116
+ return_timestamps=False,
117
+ )['text'].strip()
118
+ gr.Info("Finished transcription")
119
+ else:
120
+ gr.Info("Using custom reference text...")
121
+ audio, sr = torchaudio.load(ref_audio)
122
+ # Audio
123
+ if audio.shape[0] > 1:
124
+ audio = torch.mean(audio, dim=0, keepdim=True)
125
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
126
+ if rms < target_rms:
127
+ audio = audio * target_rms / rms
128
+ if sr != target_sample_rate:
129
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
130
+ audio = resampler(audio)
131
+ audio = audio.to(device)
132
+ # Chunk
133
+ chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
134
+ results = []
135
+ generated_mel_specs = []
136
+ for chunk in progress.tqdm(chunks):
137
+ # Prepare the text
138
+ text_list = [ref_text + chunk]
139
+ final_text_list = convert_char_to_pinyin(text_list)
140
+
141
+ # Calculate duration
142
+ ref_audio_len = audio.shape[-1] // hop_length
143
+ # if fix_duration is not None:
144
+ # duration = int(fix_duration * target_sample_rate / hop_length)
145
+ # else:
146
+ zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
148
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
150
+
151
+ # inference
152
+ gr.Info(f"Generating audio using {exp_name}")
153
+ with torch.inference_mode():
154
+ generated, _ = base_model.sample(
155
+ cond=audio,
156
+ text=final_text_list,
157
+ duration=duration,
158
+ steps=nfe_step,
159
+ cfg_strength=cfg_strength,
160
+ sway_sampling_coef=sway_sampling_coef,
161
+ )
162
+
163
+ generated = generated[:, ref_audio_len:, :]
164
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
165
+ gr.Info("Running vocoder")
166
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
167
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
168
+ if rms < target_rms:
169
+ generated_wave = generated_wave * rms / target_rms
170
+
171
+ # wav -> numpy
172
+ generated_wave = generated_wave.squeeze().cpu().numpy()
173
+ results.append(generated_wave)
174
+ generated_wave = np.concatenate(results)
175
+ if remove_silence:
176
+ gr.Info("Removing audio silences... This may take a moment")
177
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
178
+ non_silent_wave = np.array([])
179
+ for interval in non_silent_intervals:
180
+ start, end = interval
181
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
182
+ generated_wave = non_silent_wave
183
 
 
 
 
 
184
 
185
+ # spectogram
186
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
187
+ # spectrogram_path = tmp_spectrogram.name
188
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
189
 
190
+ return (target_sample_rate, generated_wave)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ with gr.Blocks() as app:
193
+ gr.Markdown("""
194
+ # E2/F5 TTS
195
 
196
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
197
 
198
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
199
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
 
 
 
202
 
203
+ The checkpoints support English and Chinese.
204
 
205
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
206
 
207
+ The model is licensed under the CC-BY-NC license, this demo cannot be used for commercial purposes.
 
 
208
 
209
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
 
 
210
  """)
211
+
 
212
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
213
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
214
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
215
  generate_btn = gr.Button("Synthesize", variant="primary")
216
  with gr.Accordion("Advanced Settings", open=False):
217
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
218
+ remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  audio_output = gr.Audio(label="Synthesized Audio")
221
+ # spectrogram_output = gr.Image(label="Spectrogram")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
224
+ gr.Markdown("""
225
+ ## Run Locally
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ Run this demo locally on CPU, CUDA, or MPS/Apple Silicon (requires macOS >= 14):
 
228
 
229
+ First, ensure `ffmpeg` is installed.
230
 
231
+ ```bash
232
+ git clone https://huggingface.co/spaces/mrfakename/E2-F5-TTS
233
+ cd E2-F5-TTS
234
+ python -m pip install -r requirements.txt
235
+ python app_local.py
236
+ ```
237
 
238
+ """)
239
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
+ app.queue().launch()
 
 
 
 
app_local.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("WARNING: You are running this unofficial E2/F5 TTS demo locally, it may not be as up-to-date as the hosted version (https://huggingface.co/spaces/mrfakename/E2-F5-TTS)")
2
+
3
+ import os
4
+ import re
5
+ import torch
6
+ import torchaudio
7
+ import gradio as gr
8
+ import numpy as np
9
+ import tempfile
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
+ from vocos import Vocos
13
+ from pydub import AudioSegment, silence
14
+ from model import CFM, UNetT, DiT, MMDiT
15
+ from cached_path import cached_path
16
+ from model.utils import (
17
+ get_tokenizer,
18
+ convert_char_to_pinyin,
19
+ save_spectrogram,
20
+ )
21
+ from transformers import pipeline
22
+ import librosa
23
+ import soundfile as sf
24
+ from txtsplit import txtsplit
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model="openai/whisper-large-v3-turbo",
31
+ torch_dtype=torch.float16,
32
+ device=device,
33
+ )
34
+
35
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
36
+
37
+ # --------------------- Settings -------------------- #
38
+
39
+ target_sample_rate = 24000
40
+ n_mel_channels = 100
41
+ hop_length = 256
42
+ target_rms = 0.1
43
+ nfe_step = 32 # 16, 32
44
+ cfg_strength = 2.0
45
+ ode_method = 'euler'
46
+ sway_sampling_coef = -1.0
47
+ speed = 1.0
48
+ # fix_duration = 27 # None or float (duration in seconds)
49
+ fix_duration = None
50
+
51
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
52
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
53
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
54
+ model = CFM(
55
+ transformer=model_cls(
56
+ **model_cfg,
57
+ text_num_embeds=vocab_size,
58
+ mel_dim=n_mel_channels
59
+ ),
60
+ mel_spec_kwargs=dict(
61
+ target_sample_rate=target_sample_rate,
62
+ n_mel_channels=n_mel_channels,
63
+ hop_length=hop_length,
64
+ ),
65
+ odeint_kwargs=dict(
66
+ method=ode_method,
67
+ ),
68
+ vocab_char_map=vocab_char_map,
69
+ ).to(device)
70
+
71
+ ema_model = EMA(model, include_online_model=False).to(device)
72
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
73
+ ema_model.copy_params_from_ema_to_model()
74
+
75
+ return model
76
+
77
+ # load models
78
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
79
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
+
81
+ F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
82
+ E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
83
+
84
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
85
+ print(gen_text)
86
+ gr.Info("Converting audio...")
87
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
88
+ aseg = AudioSegment.from_file(ref_audio_orig)
89
+ # remove long silence in reference audio
90
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
91
+ non_silent_wave = AudioSegment.silent(duration=0)
92
+ for non_silent_seg in non_silent_segs:
93
+ non_silent_wave += non_silent_seg
94
+ aseg = non_silent_wave
95
+ # Convert to mono
96
+ aseg = aseg.set_channels(1)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ if exp_name == "F5-TTS":
104
+ ema_model = F5TTS_ema_model
105
+ elif exp_name == "E2-TTS":
106
+ ema_model = E2TTS_ema_model
107
+
108
+ if not ref_text.strip():
109
+ gr.Info("No reference text provided, transcribing reference audio...")
110
+ ref_text = outputs = pipe(
111
+ ref_audio,
112
+ chunk_length_s=30,
113
+ batch_size=128,
114
+ generate_kwargs={"task": "transcribe"},
115
+ return_timestamps=False,
116
+ )['text'].strip()
117
+ gr.Info("Finished transcription")
118
+ else:
119
+ gr.Info("Using custom reference text...")
120
+ audio, sr = torchaudio.load(ref_audio)
121
+ max_chars = int(len(ref_text) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
122
+ # Audio
123
+ if audio.shape[0] > 1:
124
+ audio = torch.mean(audio, dim=0, keepdim=True)
125
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
126
+ if rms < target_rms:
127
+ audio = audio * target_rms / rms
128
+ if sr != target_sample_rate:
129
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
130
+ audio = resampler(audio)
131
+ audio = audio.to(device)
132
+ # Chunk
133
+ chunks = txtsplit(gen_text, 0.7*max_chars, 0.9*max_chars) # 100 chars preferred, 150 max
134
+ results = []
135
+ generated_mel_specs = []
136
+ for chunk in progress.tqdm(chunks):
137
+ # Prepare the text
138
+ text_list = [ref_text + chunk]
139
+ final_text_list = convert_char_to_pinyin(text_list)
140
+
141
+ # Calculate duration
142
+ ref_audio_len = audio.shape[-1] // hop_length
143
+ # if fix_duration is not None:
144
+ # duration = int(fix_duration * target_sample_rate / hop_length)
145
+ # else:
146
+ zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
148
+ chunk = len(chunk.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * chunk / speed)
150
+
151
+ # inference
152
+ gr.Info(f"Generating audio using {exp_name}")
153
+ with torch.inference_mode():
154
+ generated, _ = ema_model.sample(
155
+ cond=audio,
156
+ text=final_text_list,
157
+ duration=duration,
158
+ steps=nfe_step,
159
+ cfg_strength=cfg_strength,
160
+ sway_sampling_coef=sway_sampling_coef,
161
+ )
162
+
163
+ generated = generated[:, ref_audio_len:, :]
164
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
165
+ gr.Info("Running vocoder")
166
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
167
+ if rms < target_rms:
168
+ generated_wave = generated_wave * rms / target_rms
169
+
170
+ # wav -> numpy
171
+ generated_wave = generated_wave.squeeze().cpu().numpy()
172
+ results.append(generated_wave)
173
+ generated_wave = np.concatenate(results)
174
+ if remove_silence:
175
+ gr.Info("Removing audio silences... This may take a moment")
176
+ # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
177
+ # non_silent_wave = np.array([])
178
+ # for interval in non_silent_intervals:
179
+ # start, end = interval
180
+ # non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
181
+ # generated_wave = non_silent_wave
182
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
183
+ sf.write(f.name, generated_wave, target_sample_rate)
184
+ aseg = AudioSegment.from_file(f.name)
185
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
186
+ non_silent_wave = AudioSegment.silent(duration=0)
187
+ for non_silent_seg in non_silent_segs:
188
+ non_silent_wave += non_silent_seg
189
+ aseg = non_silent_wave
190
+ aseg.export(f.name, format="wav")
191
+ generated_wave, _ = torchaudio.load(f.name)
192
+ generated_wave = generated_wave.squeeze().cpu().numpy()
193
+
194
+ # spectogram
195
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
196
+ # spectrogram_path = tmp_spectrogram.name
197
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
198
+
199
+ return (target_sample_rate, generated_wave)
200
+
201
+ with gr.Blocks() as app:
202
+ gr.Markdown("""
203
+ # E2/F5 TTS
204
+
205
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
206
+
207
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
208
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
209
+
210
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
211
+
212
+ The checkpoints support English and Chinese.
213
+
214
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
215
+
216
+ Long-form/batched inference + speech editing is coming soon!
217
+
218
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
219
+ """)
220
+
221
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
222
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
223
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
224
+ generate_btn = gr.Button("Synthesize", variant="primary")
225
+ with gr.Accordion("Advanced Settings", open=False):
226
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
227
+ remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
228
+
229
+ audio_output = gr.Audio(label="Synthesized Audio")
230
+ # spectrogram_output = gr.Image(label="Spectrogram")
231
+
232
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
233
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
234
+
235
+
236
+ app.queue().launch()
cog.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+
6
+ import os
7
+ import re
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ import tempfile
12
+ from einops import rearrange
13
+ from ema_pytorch import EMA
14
+ from vocos import Vocos
15
+ from pydub import AudioSegment
16
+ from model import CFM, UNetT, DiT, MMDiT
17
+ from cached_path import cached_path
18
+ from model.utils import (
19
+ get_tokenizer,
20
+ convert_char_to_pinyin,
21
+ save_spectrogram,
22
+ )
23
+ from transformers import pipeline
24
+ import librosa
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+ nfe_step = 32 # 16, 32
33
+ cfg_strength = 2.0
34
+ ode_method = 'euler'
35
+ sway_sampling_coef = -1.0
36
+ speed = 1.0
37
+ # fix_duration = 27 # None or float (duration in seconds)
38
+ fix_duration = None
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
43
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
44
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
45
+ model = CFM(
46
+ transformer=model_cls(
47
+ **model_cfg,
48
+ text_num_embeds=vocab_size,
49
+ mel_dim=n_mel_channels
50
+ ),
51
+ mel_spec_kwargs=dict(
52
+ target_sample_rate=target_sample_rate,
53
+ n_mel_channels=n_mel_channels,
54
+ hop_length=hop_length,
55
+ ),
56
+ odeint_kwargs=dict(
57
+ method=ode_method,
58
+ ),
59
+ vocab_char_map=vocab_char_map,
60
+ ).to(device)
61
+
62
+ ema_model = EMA(model, include_online_model=False).to(device)
63
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
64
+ ema_model.copy_params_from_ema_to_model()
65
+
66
+ return ema_model, model
67
+ def setup(self) -> None:
68
+ """Load the model into memory to make running multiple predictions efficient"""
69
+ # self.model = torch.load("./weights.pth")
70
+ print("Loading Whisper model...")
71
+ self.pipe = pipeline(
72
+ "automatic-speech-recognition",
73
+ model="openai/whisper-large-v3-turbo",
74
+ torch_dtype=torch.float16,
75
+ device=device,
76
+ )
77
+ print("Loading F5-TTS model...")
78
+
79
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
+ self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
+
82
+
83
+ def predict(
84
+ self,
85
+ gen_text: str = Input(description="Text to generate"),
86
+ ref_audio_orig: Path = Input(description="Reference audio"),
87
+ remove_silence: bool = Input(description="Remove silences", default=True),
88
+ ) -> Path:
89
+ """Run a single prediction on the model"""
90
+ model_choice = "F5-TTS"
91
+ print(gen_text)
92
+ if len(gen_text) > 200:
93
+ raise gr.Error("Please keep your text under 200 chars.")
94
+ gr.Info("Converting audio...")
95
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
96
+ aseg = AudioSegment.from_file(ref_audio_orig)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ ema_model = self.F5TTS_ema_model
104
+ base_model = self.F5TTS_base_model
105
+
106
+ if not ref_text.strip():
107
+ gr.Info("No reference text provided, transcribing reference audio...")
108
+ ref_text = outputs = self.pipe(
109
+ ref_audio,
110
+ chunk_length_s=30,
111
+ batch_size=128,
112
+ generate_kwargs={"task": "transcribe"},
113
+ return_timestamps=False,
114
+ )['text'].strip()
115
+ gr.Info("Finished transcription")
116
+ else:
117
+ gr.Info("Using custom reference text...")
118
+ audio, sr = torchaudio.load(ref_audio)
119
+
120
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
121
+ if rms < target_rms:
122
+ audio = audio * target_rms / rms
123
+ if sr != target_sample_rate:
124
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
125
+ audio = resampler(audio)
126
+ audio = audio.to(device)
127
+
128
+ # Prepare the text
129
+ text_list = [ref_text + gen_text]
130
+ final_text_list = convert_char_to_pinyin(text_list)
131
+
132
+ # Calculate duration
133
+ ref_audio_len = audio.shape[-1] // hop_length
134
+ # if fix_duration is not None:
135
+ # duration = int(fix_duration * target_sample_rate / hop_length)
136
+ # else:
137
+ zh_pause_punc = r"。,、;:?!"
138
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
139
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
140
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
141
+
142
+ # inference
143
+ gr.Info(f"Generating audio using F5-TTS")
144
+ with torch.inference_mode():
145
+ generated, _ = base_model.sample(
146
+ cond=audio,
147
+ text=final_text_list,
148
+ duration=duration,
149
+ steps=nfe_step,
150
+ cfg_strength=cfg_strength,
151
+ sway_sampling_coef=sway_sampling_coef,
152
+ )
153
+
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ gr.Info("Running vocoder")
157
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
158
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
159
+ if rms < target_rms:
160
+ generated_wave = generated_wave * rms / target_rms
161
+
162
+ # wav -> numpy
163
+ generated_wave = generated_wave.squeeze().cpu().numpy()
164
+
165
+ if remove_silence:
166
+ gr.Info("Removing audio silences... This may take a moment")
167
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
168
+ non_silent_wave = np.array([])
169
+ for interval in non_silent_intervals:
170
+ start, end = interval
171
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
172
+ generated_wave = non_silent_wave
173
+
174
+
175
+ # spectogram
176
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
177
+ wav_path = tmp_wav.name
178
+ torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
179
+
180
+ return wav_path
data/.DS_Store DELETED
Binary file (6.15 kB)
 
data/Emilia_ZH_EN_pinyin/vocab.txt CHANGED
@@ -1,2545 +1,2545 @@
1
-
2
- !
3
- "
4
- #
5
- $
6
- %
7
- &
8
- '
9
- (
10
- )
11
- *
12
- +
13
- ,
14
- -
15
- .
16
- /
17
- 0
18
- 1
19
- 2
20
- 3
21
- 4
22
- 5
23
- 6
24
- 7
25
- 8
26
- 9
27
- :
28
- ;
29
- =
30
- >
31
- ?
32
- @
33
- A
34
- B
35
- C
36
- D
37
- E
38
- F
39
- G
40
- H
41
- I
42
- J
43
- K
44
- L
45
- M
46
- N
47
- O
48
- P
49
- Q
50
- R
51
- S
52
- T
53
- U
54
- V
55
- W
56
- X
57
- Y
58
- Z
59
- [
60
- \
61
- ]
62
- _
63
- a
64
- a1
65
- ai1
66
- ai2
67
- ai3
68
- ai4
69
- an1
70
- an3
71
- an4
72
- ang1
73
- ang2
74
- ang4
75
- ao1
76
- ao2
77
- ao3
78
- ao4
79
- b
80
- ba
81
- ba1
82
- ba2
83
- ba3
84
- ba4
85
- bai1
86
- bai2
87
- bai3
88
- bai4
89
- ban1
90
- ban2
91
- ban3
92
- ban4
93
- bang1
94
- bang2
95
- bang3
96
- bang4
97
- bao1
98
- bao2
99
- bao3
100
- bao4
101
- bei
102
- bei1
103
- bei2
104
- bei3
105
- bei4
106
- ben1
107
- ben2
108
- ben3
109
- ben4
110
- beng
111
- beng1
112
- beng2
113
- beng3
114
- beng4
115
- bi1
116
- bi2
117
- bi3
118
- bi4
119
- bian1
120
- bian2
121
- bian3
122
- bian4
123
- biao1
124
- biao2
125
- biao3
126
- bie1
127
- bie2
128
- bie3
129
- bie4
130
- bin1
131
- bin4
132
- bing1
133
- bing2
134
- bing3
135
- bing4
136
- bo
137
- bo1
138
- bo2
139
- bo3
140
- bo4
141
- bu2
142
- bu3
143
- bu4
144
- c
145
- ca1
146
- cai1
147
- cai2
148
- cai3
149
- cai4
150
- can1
151
- can2
152
- can3
153
- can4
154
- cang1
155
- cang2
156
- cao1
157
- cao2
158
- cao3
159
- ce4
160
- cen1
161
- cen2
162
- ceng1
163
- ceng2
164
- ceng4
165
- cha1
166
- cha2
167
- cha3
168
- cha4
169
- chai1
170
- chai2
171
- chan1
172
- chan2
173
- chan3
174
- chan4
175
- chang1
176
- chang2
177
- chang3
178
- chang4
179
- chao1
180
- chao2
181
- chao3
182
- che1
183
- che2
184
- che3
185
- che4
186
- chen1
187
- chen2
188
- chen3
189
- chen4
190
- cheng1
191
- cheng2
192
- cheng3
193
- cheng4
194
- chi1
195
- chi2
196
- chi3
197
- chi4
198
- chong1
199
- chong2
200
- chong3
201
- chong4
202
- chou1
203
- chou2
204
- chou3
205
- chou4
206
- chu1
207
- chu2
208
- chu3
209
- chu4
210
- chua1
211
- chuai1
212
- chuai2
213
- chuai3
214
- chuai4
215
- chuan1
216
- chuan2
217
- chuan3
218
- chuan4
219
- chuang1
220
- chuang2
221
- chuang3
222
- chuang4
223
- chui1
224
- chui2
225
- chun1
226
- chun2
227
- chun3
228
- chuo1
229
- chuo4
230
- ci1
231
- ci2
232
- ci3
233
- ci4
234
- cong1
235
- cong2
236
- cou4
237
- cu1
238
- cu4
239
- cuan1
240
- cuan2
241
- cuan4
242
- cui1
243
- cui3
244
- cui4
245
- cun1
246
- cun2
247
- cun4
248
- cuo1
249
- cuo2
250
- cuo4
251
- d
252
- da
253
- da1
254
- da2
255
- da3
256
- da4
257
- dai1
258
- dai2
259
- dai3
260
- dai4
261
- dan1
262
- dan2
263
- dan3
264
- dan4
265
- dang1
266
- dang2
267
- dang3
268
- dang4
269
- dao1
270
- dao2
271
- dao3
272
- dao4
273
- de
274
- de1
275
- de2
276
- dei3
277
- den4
278
- deng1
279
- deng2
280
- deng3
281
- deng4
282
- di1
283
- di2
284
- di3
285
- di4
286
- dia3
287
- dian1
288
- dian2
289
- dian3
290
- dian4
291
- diao1
292
- diao3
293
- diao4
294
- die1
295
- die2
296
- die4
297
- ding1
298
- ding2
299
- ding3
300
- ding4
301
- diu1
302
- dong1
303
- dong3
304
- dong4
305
- dou1
306
- dou2
307
- dou3
308
- dou4
309
- du1
310
- du2
311
- du3
312
- du4
313
- duan1
314
- duan2
315
- duan3
316
- duan4
317
- dui1
318
- dui4
319
- dun1
320
- dun3
321
- dun4
322
- duo1
323
- duo2
324
- duo3
325
- duo4
326
- e
327
- e1
328
- e2
329
- e3
330
- e4
331
- ei2
332
- en1
333
- en4
334
- er
335
- er2
336
- er3
337
- er4
338
- f
339
- fa1
340
- fa2
341
- fa3
342
- fa4
343
- fan1
344
- fan2
345
- fan3
346
- fan4
347
- fang1
348
- fang2
349
- fang3
350
- fang4
351
- fei1
352
- fei2
353
- fei3
354
- fei4
355
- fen1
356
- fen2
357
- fen3
358
- fen4
359
- feng1
360
- feng2
361
- feng3
362
- feng4
363
- fo2
364
- fou2
365
- fou3
366
- fu1
367
- fu2
368
- fu3
369
- fu4
370
- g
371
- ga1
372
- ga2
373
- ga3
374
- ga4
375
- gai1
376
- gai2
377
- gai3
378
- gai4
379
- gan1
380
- gan2
381
- gan3
382
- gan4
383
- gang1
384
- gang2
385
- gang3
386
- gang4
387
- gao1
388
- gao2
389
- gao3
390
- gao4
391
- ge1
392
- ge2
393
- ge3
394
- ge4
395
- gei2
396
- gei3
397
- gen1
398
- gen2
399
- gen3
400
- gen4
401
- geng1
402
- geng3
403
- geng4
404
- gong1
405
- gong3
406
- gong4
407
- gou1
408
- gou2
409
- gou3
410
- gou4
411
- gu
412
- gu1
413
- gu2
414
- gu3
415
- gu4
416
- gua1
417
- gua2
418
- gua3
419
- gua4
420
- guai1
421
- guai2
422
- guai3
423
- guai4
424
- guan1
425
- guan2
426
- guan3
427
- guan4
428
- guang1
429
- guang2
430
- guang3
431
- guang4
432
- gui1
433
- gui2
434
- gui3
435
- gui4
436
- gun3
437
- gun4
438
- guo1
439
- guo2
440
- guo3
441
- guo4
442
- h
443
- ha1
444
- ha2
445
- ha3
446
- hai1
447
- hai2
448
- hai3
449
- hai4
450
- han1
451
- han2
452
- han3
453
- han4
454
- hang1
455
- hang2
456
- hang4
457
- hao1
458
- hao2
459
- hao3
460
- hao4
461
- he1
462
- he2
463
- he4
464
- hei1
465
- hen2
466
- hen3
467
- hen4
468
- heng1
469
- heng2
470
- heng4
471
- hong1
472
- hong2
473
- hong3
474
- hong4
475
- hou1
476
- hou2
477
- hou3
478
- hou4
479
- hu1
480
- hu2
481
- hu3
482
- hu4
483
- hua1
484
- hua2
485
- hua4
486
- huai2
487
- huai4
488
- huan1
489
- huan2
490
- huan3
491
- huan4
492
- huang1
493
- huang2
494
- huang3
495
- huang4
496
- hui1
497
- hui2
498
- hui3
499
- hui4
500
- hun1
501
- hun2
502
- hun4
503
- huo
504
- huo1
505
- huo2
506
- huo3
507
- huo4
508
- i
509
- j
510
- ji1
511
- ji2
512
- ji3
513
- ji4
514
- jia
515
- jia1
516
- jia2
517
- jia3
518
- jia4
519
- jian1
520
- jian2
521
- jian3
522
- jian4
523
- jiang1
524
- jiang2
525
- jiang3
526
- jiang4
527
- jiao1
528
- jiao2
529
- jiao3
530
- jiao4
531
- jie1
532
- jie2
533
- jie3
534
- jie4
535
- jin1
536
- jin2
537
- jin3
538
- jin4
539
- jing1
540
- jing2
541
- jing3
542
- jing4
543
- jiong3
544
- jiu1
545
- jiu2
546
- jiu3
547
- jiu4
548
- ju1
549
- ju2
550
- ju3
551
- ju4
552
- juan1
553
- juan2
554
- juan3
555
- juan4
556
- jue1
557
- jue2
558
- jue4
559
- jun1
560
- jun4
561
- k
562
- ka1
563
- ka2
564
- ka3
565
- kai1
566
- kai2
567
- kai3
568
- kai4
569
- kan1
570
- kan2
571
- kan3
572
- kan4
573
- kang1
574
- kang2
575
- kang4
576
- kao1
577
- kao2
578
- kao3
579
- kao4
580
- ke1
581
- ke2
582
- ke3
583
- ke4
584
- ken3
585
- keng1
586
- kong1
587
- kong3
588
- kong4
589
- kou1
590
- kou2
591
- kou3
592
- kou4
593
- ku1
594
- ku2
595
- ku3
596
- ku4
597
- kua1
598
- kua3
599
- kua4
600
- kuai3
601
- kuai4
602
- kuan1
603
- kuan2
604
- kuan3
605
- kuang1
606
- kuang2
607
- kuang4
608
- kui1
609
- kui2
610
- kui3
611
- kui4
612
- kun1
613
- kun3
614
- kun4
615
- kuo4
616
- l
617
- la
618
- la1
619
- la2
620
- la3
621
- la4
622
- lai2
623
- lai4
624
- lan2
625
- lan3
626
- lan4
627
- lang1
628
- lang2
629
- lang3
630
- lang4
631
- lao1
632
- lao2
633
- lao3
634
- lao4
635
- le
636
- le1
637
- le4
638
- lei
639
- lei1
640
- lei2
641
- lei3
642
- lei4
643
- leng1
644
- leng2
645
- leng3
646
- leng4
647
- li
648
- li1
649
- li2
650
- li3
651
- li4
652
- lia3
653
- lian2
654
- lian3
655
- lian4
656
- liang2
657
- liang3
658
- liang4
659
- liao1
660
- liao2
661
- liao3
662
- liao4
663
- lie1
664
- lie2
665
- lie3
666
- lie4
667
- lin1
668
- lin2
669
- lin3
670
- lin4
671
- ling2
672
- ling3
673
- ling4
674
- liu1
675
- liu2
676
- liu3
677
- liu4
678
- long1
679
- long2
680
- long3
681
- long4
682
- lou1
683
- lou2
684
- lou3
685
- lou4
686
- lu1
687
- lu2
688
- lu3
689
- lu4
690
- luan2
691
- luan3
692
- luan4
693
- lun1
694
- lun2
695
- lun4
696
- luo1
697
- luo2
698
- luo3
699
- luo4
700
- lv2
701
- lv3
702
- lv4
703
- lve3
704
- lve4
705
- m
706
- ma
707
- ma1
708
- ma2
709
- ma3
710
- ma4
711
- mai2
712
- mai3
713
- mai4
714
- man1
715
- man2
716
- man3
717
- man4
718
- mang2
719
- mang3
720
- mao1
721
- mao2
722
- mao3
723
- mao4
724
- me
725
- mei2
726
- mei3
727
- mei4
728
- men
729
- men1
730
- men2
731
- men4
732
- meng
733
- meng1
734
- meng2
735
- meng3
736
- meng4
737
- mi1
738
- mi2
739
- mi3
740
- mi4
741
- mian2
742
- mian3
743
- mian4
744
- miao1
745
- miao2
746
- miao3
747
- miao4
748
- mie1
749
- mie4
750
- min2
751
- min3
752
- ming2
753
- ming3
754
- ming4
755
- miu4
756
- mo1
757
- mo2
758
- mo3
759
- mo4
760
- mou1
761
- mou2
762
- mou3
763
- mu2
764
- mu3
765
- mu4
766
- n
767
- n2
768
- na1
769
- na2
770
- na3
771
- na4
772
- nai2
773
- nai3
774
- nai4
775
- nan1
776
- nan2
777
- nan3
778
- nan4
779
- nang1
780
- nang2
781
- nang3
782
- nao1
783
- nao2
784
- nao3
785
- nao4
786
- ne
787
- ne2
788
- ne4
789
- nei3
790
- nei4
791
- nen4
792
- neng2
793
- ni1
794
- ni2
795
- ni3
796
- ni4
797
- nian1
798
- nian2
799
- nian3
800
- nian4
801
- niang2
802
- niang4
803
- niao2
804
- niao3
805
- niao4
806
- nie1
807
- nie4
808
- nin2
809
- ning2
810
- ning3
811
- ning4
812
- niu1
813
- niu2
814
- niu3
815
- niu4
816
- nong2
817
- nong4
818
- nou4
819
- nu2
820
- nu3
821
- nu4
822
- nuan3
823
- nuo2
824
- nuo4
825
- nv2
826
- nv3
827
- nve4
828
- o
829
- o1
830
- o2
831
- ou1
832
- ou2
833
- ou3
834
- ou4
835
- p
836
- pa1
837
- pa2
838
- pa4
839
- pai1
840
- pai2
841
- pai3
842
- pai4
843
- pan1
844
- pan2
845
- pan4
846
- pang1
847
- pang2
848
- pang4
849
- pao1
850
- pao2
851
- pao3
852
- pao4
853
- pei1
854
- pei2
855
- pei4
856
- pen1
857
- pen2
858
- pen4
859
- peng1
860
- peng2
861
- peng3
862
- peng4
863
- pi1
864
- pi2
865
- pi3
866
- pi4
867
- pian1
868
- pian2
869
- pian4
870
- piao1
871
- piao2
872
- piao3
873
- piao4
874
- pie1
875
- pie2
876
- pie3
877
- pin1
878
- pin2
879
- pin3
880
- pin4
881
- ping1
882
- ping2
883
- po1
884
- po2
885
- po3
886
- po4
887
- pou1
888
- pu1
889
- pu2
890
- pu3
891
- pu4
892
- q
893
- qi1
894
- qi2
895
- qi3
896
- qi4
897
- qia1
898
- qia3
899
- qia4
900
- qian1
901
- qian2
902
- qian3
903
- qian4
904
- qiang1
905
- qiang2
906
- qiang3
907
- qiang4
908
- qiao1
909
- qiao2
910
- qiao3
911
- qiao4
912
- qie1
913
- qie2
914
- qie3
915
- qie4
916
- qin1
917
- qin2
918
- qin3
919
- qin4
920
- qing1
921
- qing2
922
- qing3
923
- qing4
924
- qiong1
925
- qiong2
926
- qiu1
927
- qiu2
928
- qiu3
929
- qu1
930
- qu2
931
- qu3
932
- qu4
933
- quan1
934
- quan2
935
- quan3
936
- quan4
937
- que1
938
- que2
939
- que4
940
- qun2
941
- r
942
- ran2
943
- ran3
944
- rang1
945
- rang2
946
- rang3
947
- rang4
948
- rao2
949
- rao3
950
- rao4
951
- re2
952
- re3
953
- re4
954
- ren2
955
- ren3
956
- ren4
957
- reng1
958
- reng2
959
- ri4
960
- rong1
961
- rong2
962
- rong3
963
- rou2
964
- rou4
965
- ru2
966
- ru3
967
- ru4
968
- ruan2
969
- ruan3
970
- rui3
971
- rui4
972
- run4
973
- ruo4
974
- s
975
- sa1
976
- sa2
977
- sa3
978
- sa4
979
- sai1
980
- sai4
981
- san1
982
- san2
983
- san3
984
- san4
985
- sang1
986
- sang3
987
- sang4
988
- sao1
989
- sao2
990
- sao3
991
- sao4
992
- se4
993
- sen1
994
- seng1
995
- sha1
996
- sha2
997
- sha3
998
- sha4
999
- shai1
1000
- shai2
1001
- shai3
1002
- shai4
1003
- shan1
1004
- shan3
1005
- shan4
1006
- shang
1007
- shang1
1008
- shang3
1009
- shang4
1010
- shao1
1011
- shao2
1012
- shao3
1013
- shao4
1014
- she1
1015
- she2
1016
- she3
1017
- she4
1018
- shei2
1019
- shen1
1020
- shen2
1021
- shen3
1022
- shen4
1023
- sheng1
1024
- sheng2
1025
- sheng3
1026
- sheng4
1027
- shi
1028
- shi1
1029
- shi2
1030
- shi3
1031
- shi4
1032
- shou1
1033
- shou2
1034
- shou3
1035
- shou4
1036
- shu1
1037
- shu2
1038
- shu3
1039
- shu4
1040
- shua1
1041
- shua2
1042
- shua3
1043
- shua4
1044
- shuai1
1045
- shuai3
1046
- shuai4
1047
- shuan1
1048
- shuan4
1049
- shuang1
1050
- shuang3
1051
- shui2
1052
- shui3
1053
- shui4
1054
- shun3
1055
- shun4
1056
- shuo1
1057
- shuo4
1058
- si1
1059
- si2
1060
- si3
1061
- si4
1062
- song1
1063
- song3
1064
- song4
1065
- sou1
1066
- sou3
1067
- sou4
1068
- su1
1069
- su2
1070
- su4
1071
- suan1
1072
- suan4
1073
- sui1
1074
- sui2
1075
- sui3
1076
- sui4
1077
- sun1
1078
- sun3
1079
- suo
1080
- suo1
1081
- suo2
1082
- suo3
1083
- t
1084
- ta1
1085
- ta2
1086
- ta3
1087
- ta4
1088
- tai1
1089
- tai2
1090
- tai4
1091
- tan1
1092
- tan2
1093
- tan3
1094
- tan4
1095
- tang1
1096
- tang2
1097
- tang3
1098
- tang4
1099
- tao1
1100
- tao2
1101
- tao3
1102
- tao4
1103
- te4
1104
- teng2
1105
- ti1
1106
- ti2
1107
- ti3
1108
- ti4
1109
- tian1
1110
- tian2
1111
- tian3
1112
- tiao1
1113
- tiao2
1114
- tiao3
1115
- tiao4
1116
- tie1
1117
- tie2
1118
- tie3
1119
- tie4
1120
- ting1
1121
- ting2
1122
- ting3
1123
- tong1
1124
- tong2
1125
- tong3
1126
- tong4
1127
- tou
1128
- tou1
1129
- tou2
1130
- tou4
1131
- tu1
1132
- tu2
1133
- tu3
1134
- tu4
1135
- tuan1
1136
- tuan2
1137
- tui1
1138
- tui2
1139
- tui3
1140
- tui4
1141
- tun1
1142
- tun2
1143
- tun4
1144
- tuo1
1145
- tuo2
1146
- tuo3
1147
- tuo4
1148
- u
1149
- v
1150
- w
1151
- wa
1152
- wa1
1153
- wa2
1154
- wa3
1155
- wa4
1156
- wai1
1157
- wai3
1158
- wai4
1159
- wan1
1160
- wan2
1161
- wan3
1162
- wan4
1163
- wang1
1164
- wang2
1165
- wang3
1166
- wang4
1167
- wei1
1168
- wei2
1169
- wei3
1170
- wei4
1171
- wen1
1172
- wen2
1173
- wen3
1174
- wen4
1175
- weng1
1176
- weng4
1177
- wo1
1178
- wo2
1179
- wo3
1180
- wo4
1181
- wu1
1182
- wu2
1183
- wu3
1184
- wu4
1185
- x
1186
- xi1
1187
- xi2
1188
- xi3
1189
- xi4
1190
- xia1
1191
- xia2
1192
- xia4
1193
- xian1
1194
- xian2
1195
- xian3
1196
- xian4
1197
- xiang1
1198
- xiang2
1199
- xiang3
1200
- xiang4
1201
- xiao1
1202
- xiao2
1203
- xiao3
1204
- xiao4
1205
- xie1
1206
- xie2
1207
- xie3
1208
- xie4
1209
- xin1
1210
- xin2
1211
- xin4
1212
- xing1
1213
- xing2
1214
- xing3
1215
- xing4
1216
- xiong1
1217
- xiong2
1218
- xiu1
1219
- xiu3
1220
- xiu4
1221
- xu
1222
- xu1
1223
- xu2
1224
- xu3
1225
- xu4
1226
- xuan1
1227
- xuan2
1228
- xuan3
1229
- xuan4
1230
- xue1
1231
- xue2
1232
- xue3
1233
- xue4
1234
- xun1
1235
- xun2
1236
- xun4
1237
- y
1238
- ya
1239
- ya1
1240
- ya2
1241
- ya3
1242
- ya4
1243
- yan1
1244
- yan2
1245
- yan3
1246
- yan4
1247
- yang1
1248
- yang2
1249
- yang3
1250
- yang4
1251
- yao1
1252
- yao2
1253
- yao3
1254
- yao4
1255
- ye1
1256
- ye2
1257
- ye3
1258
- ye4
1259
- yi
1260
- yi1
1261
- yi2
1262
- yi3
1263
- yi4
1264
- yin1
1265
- yin2
1266
- yin3
1267
- yin4
1268
- ying1
1269
- ying2
1270
- ying3
1271
- ying4
1272
- yo1
1273
- yong1
1274
- yong2
1275
- yong3
1276
- yong4
1277
- you1
1278
- you2
1279
- you3
1280
- you4
1281
- yu1
1282
- yu2
1283
- yu3
1284
- yu4
1285
- yuan1
1286
- yuan2
1287
- yuan3
1288
- yuan4
1289
- yue1
1290
- yue4
1291
- yun1
1292
- yun2
1293
- yun3
1294
- yun4
1295
- z
1296
- za1
1297
- za2
1298
- za3
1299
- zai1
1300
- zai3
1301
- zai4
1302
- zan1
1303
- zan2
1304
- zan3
1305
- zan4
1306
- zang1
1307
- zang4
1308
- zao1
1309
- zao2
1310
- zao3
1311
- zao4
1312
- ze2
1313
- ze4
1314
- zei2
1315
- zen3
1316
- zeng1
1317
- zeng4
1318
- zha1
1319
- zha2
1320
- zha3
1321
- zha4
1322
- zhai1
1323
- zhai2
1324
- zhai3
1325
- zhai4
1326
- zhan1
1327
- zhan2
1328
- zhan3
1329
- zhan4
1330
- zhang1
1331
- zhang2
1332
- zhang3
1333
- zhang4
1334
- zhao1
1335
- zhao2
1336
- zhao3
1337
- zhao4
1338
- zhe
1339
- zhe1
1340
- zhe2
1341
- zhe3
1342
- zhe4
1343
- zhen1
1344
- zhen2
1345
- zhen3
1346
- zhen4
1347
- zheng1
1348
- zheng2
1349
- zheng3
1350
- zheng4
1351
- zhi1
1352
- zhi2
1353
- zhi3
1354
- zhi4
1355
- zhong1
1356
- zhong2
1357
- zhong3
1358
- zhong4
1359
- zhou1
1360
- zhou2
1361
- zhou3
1362
- zhou4
1363
- zhu1
1364
- zhu2
1365
- zhu3
1366
- zhu4
1367
- zhua1
1368
- zhua2
1369
- zhua3
1370
- zhuai1
1371
- zhuai3
1372
- zhuai4
1373
- zhuan1
1374
- zhuan2
1375
- zhuan3
1376
- zhuan4
1377
- zhuang1
1378
- zhuang4
1379
- zhui1
1380
- zhui4
1381
- zhun1
1382
- zhun2
1383
- zhun3
1384
- zhuo1
1385
- zhuo2
1386
- zi
1387
- zi1
1388
- zi2
1389
- zi3
1390
- zi4
1391
- zong1
1392
- zong2
1393
- zong3
1394
- zong4
1395
- zou1
1396
- zou2
1397
- zou3
1398
- zou4
1399
- zu1
1400
- zu2
1401
- zu3
1402
- zuan1
1403
- zuan3
1404
- zuan4
1405
- zui2
1406
- zui3
1407
- zui4
1408
- zun1
1409
- zuo
1410
- zuo1
1411
- zuo2
1412
- zuo3
1413
- zuo4
1414
- {
1415
- ~
1416
- ¡
1417
- ¢
1418
- £
1419
- ¥
1420
- §
1421
- ¨
1422
- ©
1423
- «
1424
- ®
1425
- ¯
1426
- °
1427
- ±
1428
- ²
1429
- ³
1430
- ´
1431
- µ
1432
- ·
1433
- ¹
1434
- º
1435
- »
1436
- ¼
1437
- ½
1438
- ¾
1439
- ¿
1440
- À
1441
- Á
1442
- Â
1443
- Ã
1444
- Ä
1445
- Å
1446
- Æ
1447
- Ç
1448
- È
1449
- É
1450
- Ê
1451
- Í
1452
- Î
1453
- Ñ
1454
- Ó
1455
- Ö
1456
- ×
1457
- Ø
1458
- Ú
1459
- Ü
1460
- Ý
1461
- Þ
1462
- ß
1463
- à
1464
- á
1465
- â
1466
- ã
1467
- ä
1468
- å
1469
- æ
1470
- ç
1471
- è
1472
- é
1473
- ê
1474
- ë
1475
- ì
1476
- í
1477
- î
1478
- ï
1479
- ð
1480
- ñ
1481
- ò
1482
- ó
1483
- ô
1484
- õ
1485
- ö
1486
- ø
1487
- ù
1488
- ú
1489
- û
1490
- ü
1491
- ý
1492
- Ā
1493
- ā
1494
- ă
1495
- ą
1496
- ć
1497
- Č
1498
- č
1499
- Đ
1500
- đ
1501
- ē
1502
- ė
1503
- ę
1504
- ě
1505
- ĝ
1506
- ğ
1507
- ħ
1508
- ī
1509
- į
1510
- İ
1511
- ı
1512
- Ł
1513
- ł
1514
- ń
1515
- ņ
1516
- ň
1517
- ŋ
1518
- Ō
1519
- ō
1520
- ő
1521
- œ
1522
- ř
1523
- Ś
1524
- ś
1525
- Ş
1526
- ş
1527
- Š
1528
- š
1529
- Ť
1530
- ť
1531
- ũ
1532
- ū
1533
- ź
1534
- Ż
1535
- ż
1536
- Ž
1537
- ž
1538
- ơ
1539
- ư
1540
- ǎ
1541
- ǐ
1542
- ǒ
1543
- ǔ
1544
- ǚ
1545
- ș
1546
- ț
1547
- ɑ
1548
- ɔ
1549
- ɕ
1550
- ə
1551
- ɛ
1552
- ɜ
1553
- ɡ
1554
- ɣ
1555
- ɪ
1556
- ɫ
1557
- ɴ
1558
- ɹ
1559
- ɾ
1560
- ʃ
1561
- ʊ
1562
- ʌ
1563
- ʒ
1564
- ʔ
1565
- ʰ
1566
- ʷ
1567
- ʻ
1568
- ʾ
1569
- ʿ
1570
- ˈ
1571
- ː
1572
- ˙
1573
- ˜
1574
- ˢ
1575
- ́
1576
- ̅
1577
- Α
1578
- Β
1579
- Δ
1580
- Ε
1581
- Θ
1582
- Κ
1583
- Λ
1584
- Μ
1585
- Ξ
1586
- Π
1587
- Σ
1588
- Τ
1589
- Φ
1590
- Χ
1591
- Ψ
1592
- Ω
1593
- ά
1594
- έ
1595
- ή
1596
- ί
1597
- α
1598
- β
1599
- γ
1600
- δ
1601
- ε
1602
- ζ
1603
- η
1604
- θ
1605
- ι
1606
- κ
1607
- λ
1608
- μ
1609
- ν
1610
- ξ
1611
- ο
1612
- π
1613
- ρ
1614
- ς
1615
- σ
1616
- τ
1617
- υ
1618
- φ
1619
- χ
1620
- ψ
1621
- ω
1622
- ϊ
1623
- ό
1624
- ύ
1625
- ώ
1626
- ϕ
1627
- ϵ
1628
- Ё
1629
- А
1630
- Б
1631
- В
1632
- Г
1633
- Д
1634
- Е
1635
- Ж
1636
- З
1637
- И
1638
- Й
1639
- К
1640
- Л
1641
- М
1642
- Н
1643
- О
1644
- П
1645
- Р
1646
- С
1647
- Т
1648
- У
1649
- Ф
1650
- Х
1651
- Ц
1652
- Ч
1653
- Ш
1654
- Щ
1655
- Ы
1656
- Ь
1657
- Э
1658
- Ю
1659
- Я
1660
- а
1661
- б
1662
- в
1663
- г
1664
- д
1665
- е
1666
- ж
1667
- з
1668
- и
1669
- й
1670
- к
1671
- л
1672
- м
1673
- н
1674
- о
1675
- п
1676
- р
1677
- с
1678
- т
1679
- у
1680
- ф
1681
- х
1682
- ц
1683
- ч
1684
- ш
1685
- щ
1686
- ъ
1687
- ы
1688
- ь
1689
- э
1690
- ю
1691
- я
1692
- ё
1693
- і
1694
- ְ
1695
- ִ
1696
- ֵ
1697
- ֶ
1698
- ַ
1699
- ָ
1700
- ֹ
1701
- ּ
1702
- ־
1703
- ׁ
1704
- א
1705
- ב
1706
- ג
1707
- ד
1708
- ה
1709
- ו
1710
- ז
1711
- ח
1712
- ט
1713
- י
1714
- כ
1715
- ל
1716
- ם
1717
- מ
1718
- ן
1719
- נ
1720
- ס
1721
- ע
1722
- פ
1723
- ק
1724
- ר
1725
- ש
1726
- ת
1727
- أ
1728
- ب
1729
- ة
1730
- ت
1731
- ج
1732
- ح
1733
- د
1734
- ر
1735
- ز
1736
- س
1737
- ص
1738
- ط
1739
- ع
1740
- ق
1741
- ك
1742
- ل
1743
- م
1744
- ن
1745
- ه
1746
- و
1747
- ي
1748
- َ
1749
- ُ
1750
- ِ
1751
- ْ
1752
-
1753
-
1754
-
1755
-
1756
-
1757
-
1758
-
1759
-
1760
-
1761
-
1762
-
1763
-
1764
-
1765
-
1766
-
1767
-
1768
-
1769
-
1770
-
1771
-
1772
-
1773
-
1774
-
1775
-
1776
-
1777
-
1778
-
1779
-
1780
-
1781
-
1782
-
1783
-
1784
-
1785
-
1786
-
1787
-
1788
-
1789
-
1790
-
1791
-
1792
-
1793
-
1794
-
1795
-
1796
-
1797
-
1798
-
1799
-
1800
- ế
1801
-
1802
-
1803
-
1804
-
1805
-
1806
-
1807
-
1808
-
1809
-
1810
-
1811
-
1812
-
1813
-
1814
-
1815
-
1816
-
1817
-
1818
-
1819
-
1820
-
1821
-
1822
-
1823
-
1824
-
1825
-
1826
-
1827
-
1828
-
1829
-
1830
-
1831
-
1832
-
1833
-
1834
-
1835
-
1836
-
1837
-
1838
-
1839
-
1840
-
1841
-
1842
-
1843
-
1844
-
1845
-
1846
-
1847
-
1848
-
1849
-
1850
-
1851
-
1852
-
1853
-
1854
-
1855
-
1856
-
1857
-
1858
-
1859
-
1860
-
1861
-
1862
-
1863
-
1864
-
1865
-
1866
-
1867
-
1868
-
1869
-
1870
-
1871
-
1872
-
1873
-
1874
-
1875
-
1876
-
1877
-
1878
-
1879
-
1880
-
1881
-
1882
-
1883
-
1884
-
1885
-
1886
-
1887
-
1888
-
1889
-
1890
-
1891
-
1892
-
1893
-
1894
-
1895
-
1896
-
1897
-
1898
-
1899
-
1900
-
1901
-
1902
-
1903
-
1904
-
1905
-
1906
-
1907
-
1908
-
1909
-
1910
-
1911
-
1912
-
1913
-
1914
-
1915
-
1916
-
1917
-
1918
-
1919
-
1920
-
1921
-
1922
-
1923
-
1924
-
1925
-
1926
-
1927
-
1928
-
1929
-
1930
-
1931
-
1932
-
1933
-
1934
-
1935
-
1936
-
1937
-
1938
-
1939
-
1940
-
1941
-
1942
-
1943
-
1944
-
1945
-
1946
-
1947
-
1948
-
1949
-
1950
-
1951
-
1952
-
1953
-
1954
-
1955
-
1956
-
1957
-
1958
-
1959
-
1960
-
1961
-
1962
-
1963
-
1964
-
1965
-
1966
-
1967
-
1968
-
1969
-
1970
-
1971
-
1972
-
1973
-
1974
-
1975
-
1976
-
1977
-
1978
-
1979
-
1980
-
1981
-
1982
-
1983
-
1984
-
1985
-
1986
-
1987
-
1988
-
1989
-
1990
-
1991
-
1992
-
1993
-
1994
-
1995
-
1996
-
1997
-
1998
-
1999
-
2000
-
2001
-
2002
-
2003
-
2004
-
2005
-
2006
-
2007
-
2008
-
2009
-
2010
-
2011
-
2012
-
2013
-
2014
-
2015
-
2016
-
2017
-
2018
-
2019
-
2020
-
2021
-
2022
-
2023
-
2024
-
2025
-
2026
-
2027
-
2028
-
2029
-
2030
-
2031
-
2032
-
2033
-
2034
-
2035
-
2036
-
2037
-
2038
-
2039
-
2040
-
2041
-
2042
-
2043
-
2044
-
2045
-
2046
-
2047
-
2048
-
2049
-
2050
-
2051
-
2052
-
2053
-
2054
-
2055
-
2056
-
2057
-
2058
-
2059
-
2060
-
2061
-
2062
-
2063
-
2064
-
2065
-
2066
-
2067
-
2068
-
2069
-
2070
-
2071
-
2072
-
2073
-
2074
-
2075
-
2076
-
2077
-
2078
-
2079
-
2080
-
2081
-
2082
-
2083
-
2084
-
2085
-
2086
-
2087
-
2088
-
2089
-
2090
-
2091
-
2092
-
2093
-
2094
-
2095
-
2096
-
2097
-
2098
-
2099
-
2100
-
2101
-
2102
-
2103
-
2104
-
2105
-
2106
-
2107
-
2108
-
2109
-
2110
-
2111
-
2112
-
2113
-
2114
-
2115
-
2116
-
2117
-
2118
-
2119
-
2120
-
2121
-
2122
-
2123
-
2124
-
2125
-
2126
-
2127
-
2128
-
2129
-
2130
-
2131
-
2132
-
2133
-
2134
-
2135
-
2136
-
2137
-
2138
-
2139
-
2140
-
2141
-
2142
-
2143
-
2144
-
2145
-
2146
-
2147
-
2148
-
2149
-
2150
-
2151
-
2152
-
2153
-
2154
-
2155
-
2156
-
2157
-
2158
-
2159
-
2160
-
2161
-
2162
-
2163
-
2164
-
2165
-
2166
-
2167
-
2168
-
2169
-
2170
-
2171
-
2172
-
2173
-
2174
-
2175
-
2176
-
2177
-
2178
-
2179
-
2180
-
2181
-
2182
-
2183
-
2184
-
2185
-
2186
-
2187
-
2188
-
2189
-
2190
-
2191
-
2192
-
2193
-
2194
-
2195
-
2196
-
2197
-
2198
-
2199
-
2200
-
2201
-
2202
-
2203
-
2204
-
2205
-
2206
-
2207
-
2208
-
2209
-
2210
-
2211
-
2212
-
2213
-
2214
-
2215
-
2216
-
2217
-
2218
-
2219
-
2220
-
2221
-
2222
-
2223
-
2224
-
2225
-
2226
-
2227
-
2228
-
2229
-
2230
-
2231
-
2232
-
2233
-
2234
-
2235
-
2236
-
2237
-
2238
-
2239
-
2240
-
2241
-
2242
-
2243
-
2244
-
2245
-
2246
-
2247
-
2248
-
2249
-
2250
-
2251
-
2252
-
2253
-
2254
-
2255
-
2256
-
2257
-
2258
-
2259
-
2260
-
2261
-
2262
-
2263
-
2264
-
2265
-
2266
-
2267
-
2268
-
2269
-
2270
-
2271
-
2272
-
2273
-
2274
-
2275
-
2276
-
2277
-
2278
-
2279
-
2280
-
2281
-
2282
-
2283
-
2284
-
2285
-
2286
-
2287
-
2288
-
2289
-
2290
-
2291
-
2292
-
2293
-
2294
-
2295
-
2296
-
2297
-
2298
-
2299
-
2300
-
2301
-
2302
-
2303
-
2304
-
2305
-
2306
-
2307
-
2308
-
2309
-
2310
-
2311
-
2312
-
2313
-
2314
-
2315
-
2316
-
2317
-
2318
-
2319
-
2320
-
2321
-
2322
-
2323
-
2324
-
2325
-
2326
-
2327
-
2328
-
2329
-
2330
-
2331
-
2332
-
2333
-
2334
-
2335
-
2336
-
2337
-
2338
-
2339
-
2340
-
2341
-
2342
-
2343
-
2344
-
2345
-
2346
-
2347
-
2348
-
2349
-
2350
-
2351
-
2352
-
2353
-
2354
-
2355
-
2356
-
2357
-
2358
-
2359
-
2360
-
2361
-
2362
-
2363
-
2364
-
2365
-
2366
-
2367
-
2368
-
2369
-
2370
-
2371
-
2372
-
2373
-
2374
-
2375
-
2376
-
2377
-
2378
- ���
2379
-
2380
-
2381
-
2382
-
2383
-
2384
-
2385
-
2386
-
2387
-
2388
-
2389
-
2390
-
2391
-
2392
-
2393
-
2394
-
2395
-
2396
-
2397
-
2398
-
2399
-
2400
-
2401
-
2402
-
2403
-
2404
-
2405
-
2406
-
2407
-
2408
-
2409
-
2410
-
2411
-
2412
-
2413
-
2414
-
2415
-
2416
-
2417
-
2418
-
2419
-
2420
-
2421
-
2422
-
2423
-
2424
-
2425
-
2426
-
2427
-
2428
-
2429
-
2430
-
2431
-
2432
-
2433
-
2434
-
2435
-
2436
-
2437
-
2438
-
2439
-
2440
-
2441
-
2442
-
2443
-
2444
-
2445
-
2446
-
2447
-
2448
-
2449
-
2450
-
2451
-
2452
-
2453
-
2454
-
2455
-
2456
-
2457
-
2458
-
2459
-
2460
-
2461
-
2462
-
2463
-
2464
-
2465
-
2466
-
2467
-
2468
-
2469
-
2470
-
2471
-
2472
-
2473
-
2474
-
2475
-
2476
-
2477
-
2478
-
2479
-
2480
-
2481
-
2482
-
2483
-
2484
-
2485
-
2486
-
2487
-
2488
-
2489
-
2490
-
2491
-
2492
-
2493
-
2494
-
2495
-
2496
-
2497
-
2498
-
2499
-
2500
-
2501
-
2502
-
2503
-
2504
-
2505
-
2506
-
2507
-
2508
-
2509
-
2510
-
2511
-
2512
-
2513
-
2514
-
2515
-
2516
-
2517
-
2518
-
2519
-
2520
-
2521
-
2522
-
2523
-
2524
-
2525
-
2526
-
2527
-
2528
-
2529
-
2530
-
2531
-
2532
-
2533
-
2534
-
2535
-
2536
-
2537
-
2538
-
2539
-
2540
-
2541
-
2542
-
2543
-
2544
-
2545
- 𠮶
 
1
+
2
+ !
3
+ "
4
+ #
5
+ $
6
+ %
7
+ &
8
+ '
9
+ (
10
+ )
11
+ *
12
+ +
13
+ ,
14
+ -
15
+ .
16
+ /
17
+ 0
18
+ 1
19
+ 2
20
+ 3
21
+ 4
22
+ 5
23
+ 6
24
+ 7
25
+ 8
26
+ 9
27
+ :
28
+ ;
29
+ =
30
+ >
31
+ ?
32
+ @
33
+ A
34
+ B
35
+ C
36
+ D
37
+ E
38
+ F
39
+ G
40
+ H
41
+ I
42
+ J
43
+ K
44
+ L
45
+ M
46
+ N
47
+ O
48
+ P
49
+ Q
50
+ R
51
+ S
52
+ T
53
+ U
54
+ V
55
+ W
56
+ X
57
+ Y
58
+ Z
59
+ [
60
+ \
61
+ ]
62
+ _
63
+ a
64
+ a1
65
+ ai1
66
+ ai2
67
+ ai3
68
+ ai4
69
+ an1
70
+ an3
71
+ an4
72
+ ang1
73
+ ang2
74
+ ang4
75
+ ao1
76
+ ao2
77
+ ao3
78
+ ao4
79
+ b
80
+ ba
81
+ ba1
82
+ ba2
83
+ ba3
84
+ ba4
85
+ bai1
86
+ bai2
87
+ bai3
88
+ bai4
89
+ ban1
90
+ ban2
91
+ ban3
92
+ ban4
93
+ bang1
94
+ bang2
95
+ bang3
96
+ bang4
97
+ bao1
98
+ bao2
99
+ bao3
100
+ bao4
101
+ bei
102
+ bei1
103
+ bei2
104
+ bei3
105
+ bei4
106
+ ben1
107
+ ben2
108
+ ben3
109
+ ben4
110
+ beng
111
+ beng1
112
+ beng2
113
+ beng3
114
+ beng4
115
+ bi1
116
+ bi2
117
+ bi3
118
+ bi4
119
+ bian1
120
+ bian2
121
+ bian3
122
+ bian4
123
+ biao1
124
+ biao2
125
+ biao3
126
+ bie1
127
+ bie2
128
+ bie3
129
+ bie4
130
+ bin1
131
+ bin4
132
+ bing1
133
+ bing2
134
+ bing3
135
+ bing4
136
+ bo
137
+ bo1
138
+ bo2
139
+ bo3
140
+ bo4
141
+ bu2
142
+ bu3
143
+ bu4
144
+ c
145
+ ca1
146
+ cai1
147
+ cai2
148
+ cai3
149
+ cai4
150
+ can1
151
+ can2
152
+ can3
153
+ can4
154
+ cang1
155
+ cang2
156
+ cao1
157
+ cao2
158
+ cao3
159
+ ce4
160
+ cen1
161
+ cen2
162
+ ceng1
163
+ ceng2
164
+ ceng4
165
+ cha1
166
+ cha2
167
+ cha3
168
+ cha4
169
+ chai1
170
+ chai2
171
+ chan1
172
+ chan2
173
+ chan3
174
+ chan4
175
+ chang1
176
+ chang2
177
+ chang3
178
+ chang4
179
+ chao1
180
+ chao2
181
+ chao3
182
+ che1
183
+ che2
184
+ che3
185
+ che4
186
+ chen1
187
+ chen2
188
+ chen3
189
+ chen4
190
+ cheng1
191
+ cheng2
192
+ cheng3
193
+ cheng4
194
+ chi1
195
+ chi2
196
+ chi3
197
+ chi4
198
+ chong1
199
+ chong2
200
+ chong3
201
+ chong4
202
+ chou1
203
+ chou2
204
+ chou3
205
+ chou4
206
+ chu1
207
+ chu2
208
+ chu3
209
+ chu4
210
+ chua1
211
+ chuai1
212
+ chuai2
213
+ chuai3
214
+ chuai4
215
+ chuan1
216
+ chuan2
217
+ chuan3
218
+ chuan4
219
+ chuang1
220
+ chuang2
221
+ chuang3
222
+ chuang4
223
+ chui1
224
+ chui2
225
+ chun1
226
+ chun2
227
+ chun3
228
+ chuo1
229
+ chuo4
230
+ ci1
231
+ ci2
232
+ ci3
233
+ ci4
234
+ cong1
235
+ cong2
236
+ cou4
237
+ cu1
238
+ cu4
239
+ cuan1
240
+ cuan2
241
+ cuan4
242
+ cui1
243
+ cui3
244
+ cui4
245
+ cun1
246
+ cun2
247
+ cun4
248
+ cuo1
249
+ cuo2
250
+ cuo4
251
+ d
252
+ da
253
+ da1
254
+ da2
255
+ da3
256
+ da4
257
+ dai1
258
+ dai2
259
+ dai3
260
+ dai4
261
+ dan1
262
+ dan2
263
+ dan3
264
+ dan4
265
+ dang1
266
+ dang2
267
+ dang3
268
+ dang4
269
+ dao1
270
+ dao2
271
+ dao3
272
+ dao4
273
+ de
274
+ de1
275
+ de2
276
+ dei3
277
+ den4
278
+ deng1
279
+ deng2
280
+ deng3
281
+ deng4
282
+ di1
283
+ di2
284
+ di3
285
+ di4
286
+ dia3
287
+ dian1
288
+ dian2
289
+ dian3
290
+ dian4
291
+ diao1
292
+ diao3
293
+ diao4
294
+ die1
295
+ die2
296
+ die4
297
+ ding1
298
+ ding2
299
+ ding3
300
+ ding4
301
+ diu1
302
+ dong1
303
+ dong3
304
+ dong4
305
+ dou1
306
+ dou2
307
+ dou3
308
+ dou4
309
+ du1
310
+ du2
311
+ du3
312
+ du4
313
+ duan1
314
+ duan2
315
+ duan3
316
+ duan4
317
+ dui1
318
+ dui4
319
+ dun1
320
+ dun3
321
+ dun4
322
+ duo1
323
+ duo2
324
+ duo3
325
+ duo4
326
+ e
327
+ e1
328
+ e2
329
+ e3
330
+ e4
331
+ ei2
332
+ en1
333
+ en4
334
+ er
335
+ er2
336
+ er3
337
+ er4
338
+ f
339
+ fa1
340
+ fa2
341
+ fa3
342
+ fa4
343
+ fan1
344
+ fan2
345
+ fan3
346
+ fan4
347
+ fang1
348
+ fang2
349
+ fang3
350
+ fang4
351
+ fei1
352
+ fei2
353
+ fei3
354
+ fei4
355
+ fen1
356
+ fen2
357
+ fen3
358
+ fen4
359
+ feng1
360
+ feng2
361
+ feng3
362
+ feng4
363
+ fo2
364
+ fou2
365
+ fou3
366
+ fu1
367
+ fu2
368
+ fu3
369
+ fu4
370
+ g
371
+ ga1
372
+ ga2
373
+ ga3
374
+ ga4
375
+ gai1
376
+ gai2
377
+ gai3
378
+ gai4
379
+ gan1
380
+ gan2
381
+ gan3
382
+ gan4
383
+ gang1
384
+ gang2
385
+ gang3
386
+ gang4
387
+ gao1
388
+ gao2
389
+ gao3
390
+ gao4
391
+ ge1
392
+ ge2
393
+ ge3
394
+ ge4
395
+ gei2
396
+ gei3
397
+ gen1
398
+ gen2
399
+ gen3
400
+ gen4
401
+ geng1
402
+ geng3
403
+ geng4
404
+ gong1
405
+ gong3
406
+ gong4
407
+ gou1
408
+ gou2
409
+ gou3
410
+ gou4
411
+ gu
412
+ gu1
413
+ gu2
414
+ gu3
415
+ gu4
416
+ gua1
417
+ gua2
418
+ gua3
419
+ gua4
420
+ guai1
421
+ guai2
422
+ guai3
423
+ guai4
424
+ guan1
425
+ guan2
426
+ guan3
427
+ guan4
428
+ guang1
429
+ guang2
430
+ guang3
431
+ guang4
432
+ gui1
433
+ gui2
434
+ gui3
435
+ gui4
436
+ gun3
437
+ gun4
438
+ guo1
439
+ guo2
440
+ guo3
441
+ guo4
442
+ h
443
+ ha1
444
+ ha2
445
+ ha3
446
+ hai1
447
+ hai2
448
+ hai3
449
+ hai4
450
+ han1
451
+ han2
452
+ han3
453
+ han4
454
+ hang1
455
+ hang2
456
+ hang4
457
+ hao1
458
+ hao2
459
+ hao3
460
+ hao4
461
+ he1
462
+ he2
463
+ he4
464
+ hei1
465
+ hen2
466
+ hen3
467
+ hen4
468
+ heng1
469
+ heng2
470
+ heng4
471
+ hong1
472
+ hong2
473
+ hong3
474
+ hong4
475
+ hou1
476
+ hou2
477
+ hou3
478
+ hou4
479
+ hu1
480
+ hu2
481
+ hu3
482
+ hu4
483
+ hua1
484
+ hua2
485
+ hua4
486
+ huai2
487
+ huai4
488
+ huan1
489
+ huan2
490
+ huan3
491
+ huan4
492
+ huang1
493
+ huang2
494
+ huang3
495
+ huang4
496
+ hui1
497
+ hui2
498
+ hui3
499
+ hui4
500
+ hun1
501
+ hun2
502
+ hun4
503
+ huo
504
+ huo1
505
+ huo2
506
+ huo3
507
+ huo4
508
+ i
509
+ j
510
+ ji1
511
+ ji2
512
+ ji3
513
+ ji4
514
+ jia
515
+ jia1
516
+ jia2
517
+ jia3
518
+ jia4
519
+ jian1
520
+ jian2
521
+ jian3
522
+ jian4
523
+ jiang1
524
+ jiang2
525
+ jiang3
526
+ jiang4
527
+ jiao1
528
+ jiao2
529
+ jiao3
530
+ jiao4
531
+ jie1
532
+ jie2
533
+ jie3
534
+ jie4
535
+ jin1
536
+ jin2
537
+ jin3
538
+ jin4
539
+ jing1
540
+ jing2
541
+ jing3
542
+ jing4
543
+ jiong3
544
+ jiu1
545
+ jiu2
546
+ jiu3
547
+ jiu4
548
+ ju1
549
+ ju2
550
+ ju3
551
+ ju4
552
+ juan1
553
+ juan2
554
+ juan3
555
+ juan4
556
+ jue1
557
+ jue2
558
+ jue4
559
+ jun1
560
+ jun4
561
+ k
562
+ ka1
563
+ ka2
564
+ ka3
565
+ kai1
566
+ kai2
567
+ kai3
568
+ kai4
569
+ kan1
570
+ kan2
571
+ kan3
572
+ kan4
573
+ kang1
574
+ kang2
575
+ kang4
576
+ kao1
577
+ kao2
578
+ kao3
579
+ kao4
580
+ ke1
581
+ ke2
582
+ ke3
583
+ ke4
584
+ ken3
585
+ keng1
586
+ kong1
587
+ kong3
588
+ kong4
589
+ kou1
590
+ kou2
591
+ kou3
592
+ kou4
593
+ ku1
594
+ ku2
595
+ ku3
596
+ ku4
597
+ kua1
598
+ kua3
599
+ kua4
600
+ kuai3
601
+ kuai4
602
+ kuan1
603
+ kuan2
604
+ kuan3
605
+ kuang1
606
+ kuang2
607
+ kuang4
608
+ kui1
609
+ kui2
610
+ kui3
611
+ kui4
612
+ kun1
613
+ kun3
614
+ kun4
615
+ kuo4
616
+ l
617
+ la
618
+ la1
619
+ la2
620
+ la3
621
+ la4
622
+ lai2
623
+ lai4
624
+ lan2
625
+ lan3
626
+ lan4
627
+ lang1
628
+ lang2
629
+ lang3
630
+ lang4
631
+ lao1
632
+ lao2
633
+ lao3
634
+ lao4
635
+ le
636
+ le1
637
+ le4
638
+ lei
639
+ lei1
640
+ lei2
641
+ lei3
642
+ lei4
643
+ leng1
644
+ leng2
645
+ leng3
646
+ leng4
647
+ li
648
+ li1
649
+ li2
650
+ li3
651
+ li4
652
+ lia3
653
+ lian2
654
+ lian3
655
+ lian4
656
+ liang2
657
+ liang3
658
+ liang4
659
+ liao1
660
+ liao2
661
+ liao3
662
+ liao4
663
+ lie1
664
+ lie2
665
+ lie3
666
+ lie4
667
+ lin1
668
+ lin2
669
+ lin3
670
+ lin4
671
+ ling2
672
+ ling3
673
+ ling4
674
+ liu1
675
+ liu2
676
+ liu3
677
+ liu4
678
+ long1
679
+ long2
680
+ long3
681
+ long4
682
+ lou1
683
+ lou2
684
+ lou3
685
+ lou4
686
+ lu1
687
+ lu2
688
+ lu3
689
+ lu4
690
+ luan2
691
+ luan3
692
+ luan4
693
+ lun1
694
+ lun2
695
+ lun4
696
+ luo1
697
+ luo2
698
+ luo3
699
+ luo4
700
+ lv2
701
+ lv3
702
+ lv4
703
+ lve3
704
+ lve4
705
+ m
706
+ ma
707
+ ma1
708
+ ma2
709
+ ma3
710
+ ma4
711
+ mai2
712
+ mai3
713
+ mai4
714
+ man1
715
+ man2
716
+ man3
717
+ man4
718
+ mang2
719
+ mang3
720
+ mao1
721
+ mao2
722
+ mao3
723
+ mao4
724
+ me
725
+ mei2
726
+ mei3
727
+ mei4
728
+ men
729
+ men1
730
+ men2
731
+ men4
732
+ meng
733
+ meng1
734
+ meng2
735
+ meng3
736
+ meng4
737
+ mi1
738
+ mi2
739
+ mi3
740
+ mi4
741
+ mian2
742
+ mian3
743
+ mian4
744
+ miao1
745
+ miao2
746
+ miao3
747
+ miao4
748
+ mie1
749
+ mie4
750
+ min2
751
+ min3
752
+ ming2
753
+ ming3
754
+ ming4
755
+ miu4
756
+ mo1
757
+ mo2
758
+ mo3
759
+ mo4
760
+ mou1
761
+ mou2
762
+ mou3
763
+ mu2
764
+ mu3
765
+ mu4
766
+ n
767
+ n2
768
+ na1
769
+ na2
770
+ na3
771
+ na4
772
+ nai2
773
+ nai3
774
+ nai4
775
+ nan1
776
+ nan2
777
+ nan3
778
+ nan4
779
+ nang1
780
+ nang2
781
+ nang3
782
+ nao1
783
+ nao2
784
+ nao3
785
+ nao4
786
+ ne
787
+ ne2
788
+ ne4
789
+ nei3
790
+ nei4
791
+ nen4
792
+ neng2
793
+ ni1
794
+ ni2
795
+ ni3
796
+ ni4
797
+ nian1
798
+ nian2
799
+ nian3
800
+ nian4
801
+ niang2
802
+ niang4
803
+ niao2
804
+ niao3
805
+ niao4
806
+ nie1
807
+ nie4
808
+ nin2
809
+ ning2
810
+ ning3
811
+ ning4
812
+ niu1
813
+ niu2
814
+ niu3
815
+ niu4
816
+ nong2
817
+ nong4
818
+ nou4
819
+ nu2
820
+ nu3
821
+ nu4
822
+ nuan3
823
+ nuo2
824
+ nuo4
825
+ nv2
826
+ nv3
827
+ nve4
828
+ o
829
+ o1
830
+ o2
831
+ ou1
832
+ ou2
833
+ ou3
834
+ ou4
835
+ p
836
+ pa1
837
+ pa2
838
+ pa4
839
+ pai1
840
+ pai2
841
+ pai3
842
+ pai4
843
+ pan1
844
+ pan2
845
+ pan4
846
+ pang1
847
+ pang2
848
+ pang4
849
+ pao1
850
+ pao2
851
+ pao3
852
+ pao4
853
+ pei1
854
+ pei2
855
+ pei4
856
+ pen1
857
+ pen2
858
+ pen4
859
+ peng1
860
+ peng2
861
+ peng3
862
+ peng4
863
+ pi1
864
+ pi2
865
+ pi3
866
+ pi4
867
+ pian1
868
+ pian2
869
+ pian4
870
+ piao1
871
+ piao2
872
+ piao3
873
+ piao4
874
+ pie1
875
+ pie2
876
+ pie3
877
+ pin1
878
+ pin2
879
+ pin3
880
+ pin4
881
+ ping1
882
+ ping2
883
+ po1
884
+ po2
885
+ po3
886
+ po4
887
+ pou1
888
+ pu1
889
+ pu2
890
+ pu3
891
+ pu4
892
+ q
893
+ qi1
894
+ qi2
895
+ qi3
896
+ qi4
897
+ qia1
898
+ qia3
899
+ qia4
900
+ qian1
901
+ qian2
902
+ qian3
903
+ qian4
904
+ qiang1
905
+ qiang2
906
+ qiang3
907
+ qiang4
908
+ qiao1
909
+ qiao2
910
+ qiao3
911
+ qiao4
912
+ qie1
913
+ qie2
914
+ qie3
915
+ qie4
916
+ qin1
917
+ qin2
918
+ qin3
919
+ qin4
920
+ qing1
921
+ qing2
922
+ qing3
923
+ qing4
924
+ qiong1
925
+ qiong2
926
+ qiu1
927
+ qiu2
928
+ qiu3
929
+ qu1
930
+ qu2
931
+ qu3
932
+ qu4
933
+ quan1
934
+ quan2
935
+ quan3
936
+ quan4
937
+ que1
938
+ que2
939
+ que4
940
+ qun2
941
+ r
942
+ ran2
943
+ ran3
944
+ rang1
945
+ rang2
946
+ rang3
947
+ rang4
948
+ rao2
949
+ rao3
950
+ rao4
951
+ re2
952
+ re3
953
+ re4
954
+ ren2
955
+ ren3
956
+ ren4
957
+ reng1
958
+ reng2
959
+ ri4
960
+ rong1
961
+ rong2
962
+ rong3
963
+ rou2
964
+ rou4
965
+ ru2
966
+ ru3
967
+ ru4
968
+ ruan2
969
+ ruan3
970
+ rui3
971
+ rui4
972
+ run4
973
+ ruo4
974
+ s
975
+ sa1
976
+ sa2
977
+ sa3
978
+ sa4
979
+ sai1
980
+ sai4
981
+ san1
982
+ san2
983
+ san3
984
+ san4
985
+ sang1
986
+ sang3
987
+ sang4
988
+ sao1
989
+ sao2
990
+ sao3
991
+ sao4
992
+ se4
993
+ sen1
994
+ seng1
995
+ sha1
996
+ sha2
997
+ sha3
998
+ sha4
999
+ shai1
1000
+ shai2
1001
+ shai3
1002
+ shai4
1003
+ shan1
1004
+ shan3
1005
+ shan4
1006
+ shang
1007
+ shang1
1008
+ shang3
1009
+ shang4
1010
+ shao1
1011
+ shao2
1012
+ shao3
1013
+ shao4
1014
+ she1
1015
+ she2
1016
+ she3
1017
+ she4
1018
+ shei2
1019
+ shen1
1020
+ shen2
1021
+ shen3
1022
+ shen4
1023
+ sheng1
1024
+ sheng2
1025
+ sheng3
1026
+ sheng4
1027
+ shi
1028
+ shi1
1029
+ shi2
1030
+ shi3
1031
+ shi4
1032
+ shou1
1033
+ shou2
1034
+ shou3
1035
+ shou4
1036
+ shu1
1037
+ shu2
1038
+ shu3
1039
+ shu4
1040
+ shua1
1041
+ shua2
1042
+ shua3
1043
+ shua4
1044
+ shuai1
1045
+ shuai3
1046
+ shuai4
1047
+ shuan1
1048
+ shuan4
1049
+ shuang1
1050
+ shuang3
1051
+ shui2
1052
+ shui3
1053
+ shui4
1054
+ shun3
1055
+ shun4
1056
+ shuo1
1057
+ shuo4
1058
+ si1
1059
+ si2
1060
+ si3
1061
+ si4
1062
+ song1
1063
+ song3
1064
+ song4
1065
+ sou1
1066
+ sou3
1067
+ sou4
1068
+ su1
1069
+ su2
1070
+ su4
1071
+ suan1
1072
+ suan4
1073
+ sui1
1074
+ sui2
1075
+ sui3
1076
+ sui4
1077
+ sun1
1078
+ sun3
1079
+ suo
1080
+ suo1
1081
+ suo2
1082
+ suo3
1083
+ t
1084
+ ta1
1085
+ ta2
1086
+ ta3
1087
+ ta4
1088
+ tai1
1089
+ tai2
1090
+ tai4
1091
+ tan1
1092
+ tan2
1093
+ tan3
1094
+ tan4
1095
+ tang1
1096
+ tang2
1097
+ tang3
1098
+ tang4
1099
+ tao1
1100
+ tao2
1101
+ tao3
1102
+ tao4
1103
+ te4
1104
+ teng2
1105
+ ti1
1106
+ ti2
1107
+ ti3
1108
+ ti4
1109
+ tian1
1110
+ tian2
1111
+ tian3
1112
+ tiao1
1113
+ tiao2
1114
+ tiao3
1115
+ tiao4
1116
+ tie1
1117
+ tie2
1118
+ tie3
1119
+ tie4
1120
+ ting1
1121
+ ting2
1122
+ ting3
1123
+ tong1
1124
+ tong2
1125
+ tong3
1126
+ tong4
1127
+ tou
1128
+ tou1
1129
+ tou2
1130
+ tou4
1131
+ tu1
1132
+ tu2
1133
+ tu3
1134
+ tu4
1135
+ tuan1
1136
+ tuan2
1137
+ tui1
1138
+ tui2
1139
+ tui3
1140
+ tui4
1141
+ tun1
1142
+ tun2
1143
+ tun4
1144
+ tuo1
1145
+ tuo2
1146
+ tuo3
1147
+ tuo4
1148
+ u
1149
+ v
1150
+ w
1151
+ wa
1152
+ wa1
1153
+ wa2
1154
+ wa3
1155
+ wa4
1156
+ wai1
1157
+ wai3
1158
+ wai4
1159
+ wan1
1160
+ wan2
1161
+ wan3
1162
+ wan4
1163
+ wang1
1164
+ wang2
1165
+ wang3
1166
+ wang4
1167
+ wei1
1168
+ wei2
1169
+ wei3
1170
+ wei4
1171
+ wen1
1172
+ wen2
1173
+ wen3
1174
+ wen4
1175
+ weng1
1176
+ weng4
1177
+ wo1
1178
+ wo2
1179
+ wo3
1180
+ wo4
1181
+ wu1
1182
+ wu2
1183
+ wu3
1184
+ wu4
1185
+ x
1186
+ xi1
1187
+ xi2
1188
+ xi3
1189
+ xi4
1190
+ xia1
1191
+ xia2
1192
+ xia4
1193
+ xian1
1194
+ xian2
1195
+ xian3
1196
+ xian4
1197
+ xiang1
1198
+ xiang2
1199
+ xiang3
1200
+ xiang4
1201
+ xiao1
1202
+ xiao2
1203
+ xiao3
1204
+ xiao4
1205
+ xie1
1206
+ xie2
1207
+ xie3
1208
+ xie4
1209
+ xin1
1210
+ xin2
1211
+ xin4
1212
+ xing1
1213
+ xing2
1214
+ xing3
1215
+ xing4
1216
+ xiong1
1217
+ xiong2
1218
+ xiu1
1219
+ xiu3
1220
+ xiu4
1221
+ xu
1222
+ xu1
1223
+ xu2
1224
+ xu3
1225
+ xu4
1226
+ xuan1
1227
+ xuan2
1228
+ xuan3
1229
+ xuan4
1230
+ xue1
1231
+ xue2
1232
+ xue3
1233
+ xue4
1234
+ xun1
1235
+ xun2
1236
+ xun4
1237
+ y
1238
+ ya
1239
+ ya1
1240
+ ya2
1241
+ ya3
1242
+ ya4
1243
+ yan1
1244
+ yan2
1245
+ yan3
1246
+ yan4
1247
+ yang1
1248
+ yang2
1249
+ yang3
1250
+ yang4
1251
+ yao1
1252
+ yao2
1253
+ yao3
1254
+ yao4
1255
+ ye1
1256
+ ye2
1257
+ ye3
1258
+ ye4
1259
+ yi
1260
+ yi1
1261
+ yi2
1262
+ yi3
1263
+ yi4
1264
+ yin1
1265
+ yin2
1266
+ yin3
1267
+ yin4
1268
+ ying1
1269
+ ying2
1270
+ ying3
1271
+ ying4
1272
+ yo1
1273
+ yong1
1274
+ yong2
1275
+ yong3
1276
+ yong4
1277
+ you1
1278
+ you2
1279
+ you3
1280
+ you4
1281
+ yu1
1282
+ yu2
1283
+ yu3
1284
+ yu4
1285
+ yuan1
1286
+ yuan2
1287
+ yuan3
1288
+ yuan4
1289
+ yue1
1290
+ yue4
1291
+ yun1
1292
+ yun2
1293
+ yun3
1294
+ yun4
1295
+ z
1296
+ za1
1297
+ za2
1298
+ za3
1299
+ zai1
1300
+ zai3
1301
+ zai4
1302
+ zan1
1303
+ zan2
1304
+ zan3
1305
+ zan4
1306
+ zang1
1307
+ zang4
1308
+ zao1
1309
+ zao2
1310
+ zao3
1311
+ zao4
1312
+ ze2
1313
+ ze4
1314
+ zei2
1315
+ zen3
1316
+ zeng1
1317
+ zeng4
1318
+ zha1
1319
+ zha2
1320
+ zha3
1321
+ zha4
1322
+ zhai1
1323
+ zhai2
1324
+ zhai3
1325
+ zhai4
1326
+ zhan1
1327
+ zhan2
1328
+ zhan3
1329
+ zhan4
1330
+ zhang1
1331
+ zhang2
1332
+ zhang3
1333
+ zhang4
1334
+ zhao1
1335
+ zhao2
1336
+ zhao3
1337
+ zhao4
1338
+ zhe
1339
+ zhe1
1340
+ zhe2
1341
+ zhe3
1342
+ zhe4
1343
+ zhen1
1344
+ zhen2
1345
+ zhen3
1346
+ zhen4
1347
+ zheng1
1348
+ zheng2
1349
+ zheng3
1350
+ zheng4
1351
+ zhi1
1352
+ zhi2
1353
+ zhi3
1354
+ zhi4
1355
+ zhong1
1356
+ zhong2
1357
+ zhong3
1358
+ zhong4
1359
+ zhou1
1360
+ zhou2
1361
+ zhou3
1362
+ zhou4
1363
+ zhu1
1364
+ zhu2
1365
+ zhu3
1366
+ zhu4
1367
+ zhua1
1368
+ zhua2
1369
+ zhua3
1370
+ zhuai1
1371
+ zhuai3
1372
+ zhuai4
1373
+ zhuan1
1374
+ zhuan2
1375
+ zhuan3
1376
+ zhuan4
1377
+ zhuang1
1378
+ zhuang4
1379
+ zhui1
1380
+ zhui4
1381
+ zhun1
1382
+ zhun2
1383
+ zhun3
1384
+ zhuo1
1385
+ zhuo2
1386
+ zi
1387
+ zi1
1388
+ zi2
1389
+ zi3
1390
+ zi4
1391
+ zong1
1392
+ zong2
1393
+ zong3
1394
+ zong4
1395
+ zou1
1396
+ zou2
1397
+ zou3
1398
+ zou4
1399
+ zu1
1400
+ zu2
1401
+ zu3
1402
+ zuan1
1403
+ zuan3
1404
+ zuan4
1405
+ zui2
1406
+ zui3
1407
+ zui4
1408
+ zun1
1409
+ zuo
1410
+ zuo1
1411
+ zuo2
1412
+ zuo3
1413
+ zuo4
1414
+ {
1415
+ ~
1416
+ ¡
1417
+ ¢
1418
+ £
1419
+ ¥
1420
+ §
1421
+ ¨
1422
+ ©
1423
+ «
1424
+ ®
1425
+ ¯
1426
+ °
1427
+ ±
1428
+ ²
1429
+ ³
1430
+ ´
1431
+ µ
1432
+ ·
1433
+ ¹
1434
+ º
1435
+ »
1436
+ ¼
1437
+ ½
1438
+ ¾
1439
+ ¿
1440
+ À
1441
+ Á
1442
+ Â
1443
+ Ã
1444
+ Ä
1445
+ Å
1446
+ Æ
1447
+ Ç
1448
+ È
1449
+ É
1450
+ Ê
1451
+ Í
1452
+ Î
1453
+ Ñ
1454
+ Ó
1455
+ Ö
1456
+ ×
1457
+ Ø
1458
+ Ú
1459
+ Ü
1460
+ Ý
1461
+ Þ
1462
+ ß
1463
+ à
1464
+ á
1465
+ â
1466
+ ã
1467
+ ä
1468
+ å
1469
+ æ
1470
+ ç
1471
+ è
1472
+ é
1473
+ ê
1474
+ ë
1475
+ ì
1476
+ í
1477
+ î
1478
+ ï
1479
+ ð
1480
+ ñ
1481
+ ò
1482
+ ó
1483
+ ô
1484
+ õ
1485
+ ö
1486
+ ø
1487
+ ù
1488
+ ú
1489
+ û
1490
+ ü
1491
+ ý
1492
+ Ā
1493
+ ā
1494
+ ă
1495
+ ą
1496
+ ć
1497
+ Č
1498
+ č
1499
+ Đ
1500
+ đ
1501
+ ē
1502
+ ė
1503
+ ę
1504
+ ě
1505
+ ĝ
1506
+ ğ
1507
+ ħ
1508
+ ī
1509
+ į
1510
+ İ
1511
+ ı
1512
+ Ł
1513
+ ł
1514
+ ń
1515
+ ņ
1516
+ ň
1517
+ ŋ
1518
+ Ō
1519
+ ō
1520
+ ő
1521
+ œ
1522
+ ř
1523
+ Ś
1524
+ ś
1525
+ Ş
1526
+ ş
1527
+ Š
1528
+ š
1529
+ Ť
1530
+ ť
1531
+ ũ
1532
+ ū
1533
+ ź
1534
+ Ż
1535
+ ż
1536
+ Ž
1537
+ ž
1538
+ ơ
1539
+ ư
1540
+ ǎ
1541
+ ǐ
1542
+ ǒ
1543
+ ǔ
1544
+ ǚ
1545
+ ș
1546
+ ț
1547
+ ɑ
1548
+ ɔ
1549
+ ɕ
1550
+ ə
1551
+ ɛ
1552
+ ɜ
1553
+ ɡ
1554
+ ɣ
1555
+ ɪ
1556
+ ɫ
1557
+ ɴ
1558
+ ɹ
1559
+ ɾ
1560
+ ʃ
1561
+ ʊ
1562
+ ʌ
1563
+ ʒ
1564
+ ʔ
1565
+ ʰ
1566
+ ʷ
1567
+ ʻ
1568
+ ʾ
1569
+ ʿ
1570
+ ˈ
1571
+ ː
1572
+ ˙
1573
+ ˜
1574
+ ˢ
1575
+ ́
1576
+ ̅
1577
+ Α
1578
+ Β
1579
+ Δ
1580
+ Ε
1581
+ Θ
1582
+ Κ
1583
+ Λ
1584
+ Μ
1585
+ Ξ
1586
+ Π
1587
+ Σ
1588
+ Τ
1589
+ Φ
1590
+ Χ
1591
+ Ψ
1592
+ Ω
1593
+ ά
1594
+ έ
1595
+ ή
1596
+ ί
1597
+ α
1598
+ β
1599
+ γ
1600
+ δ
1601
+ ε
1602
+ ζ
1603
+ η
1604
+ θ
1605
+ ι
1606
+ κ
1607
+ λ
1608
+ μ
1609
+ ν
1610
+ ξ
1611
+ ο
1612
+ π
1613
+ ρ
1614
+ ς
1615
+ σ
1616
+ τ
1617
+ υ
1618
+ φ
1619
+ χ
1620
+ ψ
1621
+ ω
1622
+ ϊ
1623
+ ό
1624
+ ύ
1625
+ ώ
1626
+ ϕ
1627
+ ϵ
1628
+ Ё
1629
+ А
1630
+ Б
1631
+ В
1632
+ Г
1633
+ Д
1634
+ Е
1635
+ Ж
1636
+ З
1637
+ И
1638
+ Й
1639
+ К
1640
+ Л
1641
+ М
1642
+ Н
1643
+ О
1644
+ П
1645
+ Р
1646
+ С
1647
+ Т
1648
+ У
1649
+ Ф
1650
+ Х
1651
+ Ц
1652
+ Ч
1653
+ Ш
1654
+ Щ
1655
+ Ы
1656
+ Ь
1657
+ Э
1658
+ Ю
1659
+ Я
1660
+ а
1661
+ б
1662
+ в
1663
+ г
1664
+ д
1665
+ е
1666
+ ж
1667
+ з
1668
+ и
1669
+ й
1670
+ к
1671
+ л
1672
+ м
1673
+ н
1674
+ о
1675
+ п
1676
+ р
1677
+ с
1678
+ т
1679
+ у
1680
+ ф
1681
+ х
1682
+ ц
1683
+ ч
1684
+ ш
1685
+ щ
1686
+ ъ
1687
+ ы
1688
+ ь
1689
+ э
1690
+ ю
1691
+ я
1692
+ ё
1693
+ і
1694
+ ְ
1695
+ ִ
1696
+ ֵ
1697
+ ֶ
1698
+ ַ
1699
+ ָ
1700
+ ֹ
1701
+ ּ
1702
+ ־
1703
+ ׁ
1704
+ א
1705
+ ב
1706
+ ג
1707
+ ד
1708
+ ה
1709
+ ו
1710
+ ז
1711
+ ח
1712
+ ט
1713
+ י
1714
+ כ
1715
+ ל
1716
+ ם
1717
+ מ
1718
+ ן
1719
+ נ
1720
+ ס
1721
+ ע
1722
+ פ
1723
+ ק
1724
+ ר
1725
+ ש
1726
+ ת
1727
+ أ
1728
+ ب
1729
+ ة
1730
+ ت
1731
+ ج
1732
+ ح
1733
+ د
1734
+ ر
1735
+ ز
1736
+ س
1737
+ ص
1738
+ ط
1739
+ ع
1740
+ ق
1741
+ ك
1742
+ ل
1743
+ م
1744
+ ن
1745
+ ه
1746
+ و
1747
+ ي
1748
+ َ
1749
+ ُ
1750
+ ِ
1751
+ ْ
1752
+
1753
+
1754
+
1755
+
1756
+
1757
+
1758
+
1759
+
1760
+
1761
+
1762
+
1763
+
1764
+
1765
+
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+
1774
+
1775
+
1776
+
1777
+
1778
+
1779
+
1780
+
1781
+
1782
+
1783
+
1784
+
1785
+
1786
+
1787
+
1788
+
1789
+
1790
+
1791
+
1792
+
1793
+
1794
+
1795
+
1796
+
1797
+
1798
+
1799
+
1800
+ ế
1801
+
1802
+
1803
+
1804
+
1805
+
1806
+
1807
+
1808
+
1809
+
1810
+
1811
+
1812
+
1813
+
1814
+
1815
+
1816
+
1817
+
1818
+
1819
+
1820
+
1821
+
1822
+
1823
+
1824
+
1825
+
1826
+
1827
+
1828
+
1829
+
1830
+
1831
+
1832
+
1833
+
1834
+
1835
+
1836
+
1837
+
1838
+
1839
+
1840
+
1841
+
1842
+
1843
+
1844
+
1845
+
1846
+
1847
+
1848
+
1849
+
1850
+
1851
+
1852
+
1853
+
1854
+
1855
+
1856
+
1857
+
1858
+
1859
+
1860
+
1861
+
1862
+
1863
+
1864
+
1865
+
1866
+
1867
+
1868
+
1869
+
1870
+
1871
+
1872
+
1873
+
1874
+
1875
+
1876
+
1877
+
1878
+
1879
+
1880
+
1881
+
1882
+
1883
+
1884
+
1885
+
1886
+
1887
+
1888
+
1889
+
1890
+
1891
+
1892
+
1893
+
1894
+
1895
+
1896
+
1897
+
1898
+
1899
+
1900
+
1901
+
1902
+
1903
+
1904
+
1905
+
1906
+
1907
+
1908
+
1909
+
1910
+
1911
+
1912
+
1913
+
1914
+
1915
+
1916
+
1917
+
1918
+
1919
+
1920
+
1921
+
1922
+
1923
+
1924
+
1925
+
1926
+
1927
+
1928
+
1929
+
1930
+
1931
+
1932
+
1933
+
1934
+
1935
+
1936
+
1937
+
1938
+
1939
+
1940
+
1941
+
1942
+
1943
+
1944
+
1945
+
1946
+
1947
+
1948
+
1949
+
1950
+
1951
+
1952
+
1953
+
1954
+
1955
+
1956
+
1957
+
1958
+
1959
+
1960
+
1961
+
1962
+
1963
+
1964
+
1965
+
1966
+
1967
+
1968
+
1969
+
1970
+
1971
+
1972
+
1973
+
1974
+
1975
+
1976
+
1977
+
1978
+
1979
+
1980
+
1981
+
1982
+
1983
+
1984
+
1985
+
1986
+
1987
+
1988
+
1989
+
1990
+
1991
+
1992
+
1993
+
1994
+
1995
+
1996
+
1997
+
1998
+
1999
+
2000
+
2001
+
2002
+
2003
+
2004
+
2005
+
2006
+
2007
+
2008
+
2009
+
2010
+
2011
+
2012
+
2013
+
2014
+
2015
+
2016
+
2017
+
2018
+
2019
+
2020
+
2021
+
2022
+
2023
+
2024
+
2025
+
2026
+
2027
+
2028
+
2029
+
2030
+
2031
+
2032
+
2033
+
2034
+
2035
+
2036
+
2037
+
2038
+
2039
+
2040
+
2041
+
2042
+
2043
+
2044
+
2045
+
2046
+
2047
+
2048
+
2049
+
2050
+
2051
+
2052
+
2053
+
2054
+
2055
+
2056
+
2057
+
2058
+
2059
+
2060
+
2061
+
2062
+
2063
+
2064
+
2065
+
2066
+
2067
+
2068
+
2069
+
2070
+
2071
+
2072
+
2073
+
2074
+
2075
+
2076
+
2077
+
2078
+
2079
+
2080
+
2081
+
2082
+
2083
+
2084
+
2085
+
2086
+
2087
+
2088
+
2089
+
2090
+
2091
+
2092
+
2093
+
2094
+
2095
+
2096
+
2097
+
2098
+
2099
+
2100
+
2101
+
2102
+
2103
+
2104
+
2105
+
2106
+
2107
+
2108
+
2109
+
2110
+
2111
+
2112
+
2113
+
2114
+
2115
+
2116
+
2117
+
2118
+
2119
+
2120
+
2121
+
2122
+
2123
+
2124
+
2125
+
2126
+
2127
+
2128
+
2129
+
2130
+
2131
+
2132
+
2133
+
2134
+
2135
+
2136
+
2137
+
2138
+
2139
+
2140
+
2141
+
2142
+
2143
+
2144
+
2145
+
2146
+
2147
+
2148
+
2149
+
2150
+
2151
+
2152
+
2153
+
2154
+
2155
+
2156
+
2157
+
2158
+
2159
+
2160
+
2161
+
2162
+
2163
+
2164
+
2165
+
2166
+
2167
+
2168
+
2169
+
2170
+
2171
+
2172
+
2173
+
2174
+
2175
+
2176
+
2177
+
2178
+
2179
+
2180
+
2181
+
2182
+
2183
+
2184
+
2185
+
2186
+
2187
+
2188
+
2189
+
2190
+
2191
+
2192
+
2193
+
2194
+
2195
+
2196
+
2197
+
2198
+
2199
+
2200
+
2201
+
2202
+
2203
+
2204
+
2205
+
2206
+
2207
+
2208
+
2209
+
2210
+
2211
+
2212
+
2213
+
2214
+
2215
+
2216
+
2217
+
2218
+
2219
+
2220
+
2221
+
2222
+
2223
+
2224
+
2225
+
2226
+
2227
+
2228
+
2229
+
2230
+
2231
+
2232
+
2233
+
2234
+
2235
+
2236
+
2237
+
2238
+
2239
+
2240
+
2241
+
2242
+
2243
+
2244
+
2245
+
2246
+
2247
+
2248
+
2249
+
2250
+
2251
+
2252
+
2253
+
2254
+
2255
+
2256
+
2257
+
2258
+
2259
+
2260
+
2261
+
2262
+
2263
+
2264
+
2265
+
2266
+
2267
+
2268
+
2269
+
2270
+
2271
+
2272
+
2273
+
2274
+
2275
+
2276
+
2277
+
2278
+
2279
+
2280
+
2281
+
2282
+
2283
+
2284
+
2285
+
2286
+
2287
+
2288
+
2289
+
2290
+
2291
+
2292
+
2293
+
2294
+
2295
+
2296
+
2297
+
2298
+
2299
+
2300
+
2301
+
2302
+
2303
+
2304
+
2305
+
2306
+
2307
+
2308
+
2309
+
2310
+
2311
+
2312
+
2313
+
2314
+
2315
+
2316
+
2317
+
2318
+
2319
+
2320
+
2321
+
2322
+
2323
+
2324
+
2325
+
2326
+
2327
+
2328
+
2329
+
2330
+
2331
+
2332
+
2333
+
2334
+
2335
+
2336
+
2337
+
2338
+
2339
+
2340
+
2341
+
2342
+
2343
+
2344
+
2345
+
2346
+
2347
+
2348
+
2349
+
2350
+
2351
+
2352
+
2353
+
2354
+
2355
+
2356
+
2357
+
2358
+
2359
+
2360
+
2361
+
2362
+
2363
+
2364
+
2365
+
2366
+
2367
+
2368
+
2369
+
2370
+
2371
+
2372
+
2373
+
2374
+
2375
+
2376
+
2377
+
2378
+
2379
+
2380
+
2381
+
2382
+
2383
+
2384
+
2385
+
2386
+
2387
+
2388
+
2389
+
2390
+
2391
+
2392
+
2393
+
2394
+
2395
+
2396
+
2397
+
2398
+
2399
+
2400
+
2401
+
2402
+
2403
+
2404
+
2405
+
2406
+
2407
+
2408
+
2409
+
2410
+
2411
+
2412
+
2413
+
2414
+
2415
+
2416
+
2417
+
2418
+
2419
+
2420
+
2421
+
2422
+
2423
+
2424
+
2425
+
2426
+
2427
+
2428
+
2429
+
2430
+
2431
+
2432
+
2433
+
2434
+
2435
+
2436
+
2437
+
2438
+
2439
+
2440
+
2441
+
2442
+
2443
+
2444
+
2445
+
2446
+
2447
+
2448
+
2449
+
2450
+
2451
+
2452
+
2453
+
2454
+
2455
+
2456
+
2457
+
2458
+
2459
+
2460
+
2461
+
2462
+
2463
+
2464
+
2465
+
2466
+
2467
+
2468
+
2469
+
2470
+
2471
+
2472
+
2473
+
2474
+
2475
+
2476
+
2477
+
2478
+
2479
+
2480
+
2481
+
2482
+
2483
+
2484
+
2485
+
2486
+
2487
+
2488
+
2489
+
2490
+
2491
+
2492
+
2493
+
2494
+
2495
+
2496
+
2497
+
2498
+
2499
+
2500
+
2501
+
2502
+
2503
+
2504
+
2505
+
2506
+
2507
+
2508
+
2509
+
2510
+
2511
+
2512
+
2513
+
2514
+
2515
+
2516
+
2517
+
2518
+
2519
+
2520
+
2521
+
2522
+
2523
+
2524
+
2525
+
2526
+
2527
+
2528
+
2529
+
2530
+
2531
+
2532
+
2533
+
2534
+
2535
+
2536
+
2537
+
2538
+
2539
+
2540
+
2541
+
2542
+
2543
+
2544
+
2545
+ 𠮶
data/librispeech_pc_test_clean_cross_sentence.lst CHANGED
The diff for this file is too large to render. See raw diff
 
finetune-cli.py DELETED
@@ -1,127 +0,0 @@
1
- import argparse
2
- from model import CFM, UNetT, DiT, Trainer
3
- from model.utils import get_tokenizer
4
- from model.dataset import load_dataset
5
- from cached_path import cached_path
6
- import shutil
7
- import os
8
-
9
- # -------------------------- Dataset Settings --------------------------- #
10
- target_sample_rate = 24000
11
- n_mel_channels = 100
12
- hop_length = 256
13
-
14
-
15
- # -------------------------- Argument Parsing --------------------------- #
16
- def parse_args():
17
- parser = argparse.ArgumentParser(description="Train CFM Model")
18
-
19
- parser.add_argument(
20
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
21
- )
22
- parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
23
- parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
24
- parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
25
- parser.add_argument(
26
- "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
27
- )
28
- parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
29
- parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
30
- parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
31
- parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
32
- parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
33
- parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
34
- parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
35
- parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
36
-
37
- parser.add_argument(
38
- "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
39
- )
40
- parser.add_argument(
41
- "--tokenizer_path",
42
- type=str,
43
- default=None,
44
- help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
45
- )
46
-
47
- return parser.parse_args()
48
-
49
-
50
- # -------------------------- Training Settings -------------------------- #
51
-
52
-
53
- def main():
54
- args = parse_args()
55
-
56
- # Model parameters based on experiment name
57
- if args.exp_name == "F5TTS_Base":
58
- wandb_resume_id = None
59
- model_cls = DiT
60
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
61
- if args.finetune:
62
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
63
- elif args.exp_name == "E2TTS_Base":
64
- wandb_resume_id = None
65
- model_cls = UNetT
66
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
67
- if args.finetune:
68
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
69
-
70
- if args.finetune:
71
- path_ckpt = os.path.join("ckpts", args.dataset_name)
72
- if not os.path.isdir(path_ckpt):
73
- os.makedirs(path_ckpt, exist_ok=True)
74
- shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
75
-
76
- checkpoint_path = os.path.join("ckpts", args.dataset_name)
77
-
78
- # Use the tokenizer and tokenizer_path provided in the command line arguments
79
- tokenizer = args.tokenizer
80
- if tokenizer == "custom":
81
- if not args.tokenizer_path:
82
- raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
83
- tokenizer_path = args.tokenizer_path
84
- else:
85
- tokenizer_path = args.dataset_name
86
-
87
- vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
88
-
89
- mel_spec_kwargs = dict(
90
- target_sample_rate=target_sample_rate,
91
- n_mel_channels=n_mel_channels,
92
- hop_length=hop_length,
93
- )
94
-
95
- e2tts = CFM(
96
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
97
- mel_spec_kwargs=mel_spec_kwargs,
98
- vocab_char_map=vocab_char_map,
99
- )
100
-
101
- trainer = Trainer(
102
- e2tts,
103
- args.epochs,
104
- args.learning_rate,
105
- num_warmup_updates=args.num_warmup_updates,
106
- save_per_updates=args.save_per_updates,
107
- checkpoint_path=checkpoint_path,
108
- batch_size=args.batch_size_per_gpu,
109
- batch_size_type=args.batch_size_type,
110
- max_samples=args.max_samples,
111
- grad_accumulation_steps=args.grad_accumulation_steps,
112
- max_grad_norm=args.max_grad_norm,
113
- wandb_project="CFM-TTS",
114
- wandb_run_name=args.exp_name,
115
- wandb_resume_id=wandb_resume_id,
116
- last_per_steps=args.last_per_steps,
117
- )
118
-
119
- train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
120
- trainer.train(
121
- train_dataset,
122
- resumable_with_seed=666, # seed for shuffling dataset
123
- )
124
-
125
-
126
- if __name__ == "__main__":
127
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune_gradio.py DELETED
@@ -1,944 +0,0 @@
1
- import os
2
- import sys
3
-
4
- import tempfile
5
- import random
6
- from transformers import pipeline
7
- import gradio as gr
8
- import torch
9
- import gc
10
- import click
11
- import torchaudio
12
- from glob import glob
13
- import librosa
14
- import numpy as np
15
- from scipy.io import wavfile
16
- import shutil
17
- import time
18
-
19
- import json
20
- from model.utils import convert_char_to_pinyin
21
- import signal
22
- import psutil
23
- import platform
24
- import subprocess
25
- from datasets.arrow_writer import ArrowWriter
26
- from datasets import Dataset as Dataset_
27
- from api import F5TTS
28
-
29
-
30
- training_process = None
31
- system = platform.system()
32
- python_executable = sys.executable or "python"
33
- tts_api = None
34
- last_checkpoint = ""
35
- last_device = ""
36
-
37
- path_data = "data"
38
-
39
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
40
-
41
- pipe = None
42
-
43
-
44
- # Load metadata
45
- def get_audio_duration(audio_path):
46
- """Calculate the duration of an audio file."""
47
- audio, sample_rate = torchaudio.load(audio_path)
48
- num_channels = audio.shape[0]
49
- return audio.shape[1] / (sample_rate * num_channels)
50
-
51
-
52
- def clear_text(text):
53
- """Clean and prepare text by lowering the case and stripping whitespace."""
54
- return text.lower().strip()
55
-
56
-
57
- def get_rms(
58
- y,
59
- frame_length=2048,
60
- hop_length=512,
61
- pad_mode="constant",
62
- ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
63
- padding = (int(frame_length // 2), int(frame_length // 2))
64
- y = np.pad(y, padding, mode=pad_mode)
65
-
66
- axis = -1
67
- # put our new within-frame axis at the end for now
68
- out_strides = y.strides + tuple([y.strides[axis]])
69
- # Reduce the shape on the framing axis
70
- x_shape_trimmed = list(y.shape)
71
- x_shape_trimmed[axis] -= frame_length - 1
72
- out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
73
- xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
74
- if axis < 0:
75
- target_axis = axis - 1
76
- else:
77
- target_axis = axis + 1
78
- xw = np.moveaxis(xw, -1, target_axis)
79
- # Downsample along the target axis
80
- slices = [slice(None)] * xw.ndim
81
- slices[axis] = slice(0, None, hop_length)
82
- x = xw[tuple(slices)]
83
-
84
- # Calculate power
85
- power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
86
-
87
- return np.sqrt(power)
88
-
89
-
90
- class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
91
- def __init__(
92
- self,
93
- sr: int,
94
- threshold: float = -40.0,
95
- min_length: int = 2000,
96
- min_interval: int = 300,
97
- hop_size: int = 20,
98
- max_sil_kept: int = 2000,
99
- ):
100
- if not min_length >= min_interval >= hop_size:
101
- raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
102
- if not max_sil_kept >= hop_size:
103
- raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
104
- min_interval = sr * min_interval / 1000
105
- self.threshold = 10 ** (threshold / 20.0)
106
- self.hop_size = round(sr * hop_size / 1000)
107
- self.win_size = min(round(min_interval), 4 * self.hop_size)
108
- self.min_length = round(sr * min_length / 1000 / self.hop_size)
109
- self.min_interval = round(min_interval / self.hop_size)
110
- self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
111
-
112
- def _apply_slice(self, waveform, begin, end):
113
- if len(waveform.shape) > 1:
114
- return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
115
- else:
116
- return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
117
-
118
- # @timeit
119
- def slice(self, waveform):
120
- if len(waveform.shape) > 1:
121
- samples = waveform.mean(axis=0)
122
- else:
123
- samples = waveform
124
- if samples.shape[0] <= self.min_length:
125
- return [waveform]
126
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
127
- sil_tags = []
128
- silence_start = None
129
- clip_start = 0
130
- for i, rms in enumerate(rms_list):
131
- # Keep looping while frame is silent.
132
- if rms < self.threshold:
133
- # Record start of silent frames.
134
- if silence_start is None:
135
- silence_start = i
136
- continue
137
- # Keep looping while frame is not silent and silence start has not been recorded.
138
- if silence_start is None:
139
- continue
140
- # Clear recorded silence start if interval is not enough or clip is too short
141
- is_leading_silence = silence_start == 0 and i > self.max_sil_kept
142
- need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
143
- if not is_leading_silence and not need_slice_middle:
144
- silence_start = None
145
- continue
146
- # Need slicing. Record the range of silent frames to be removed.
147
- if i - silence_start <= self.max_sil_kept:
148
- pos = rms_list[silence_start : i + 1].argmin() + silence_start
149
- if silence_start == 0:
150
- sil_tags.append((0, pos))
151
- else:
152
- sil_tags.append((pos, pos))
153
- clip_start = pos
154
- elif i - silence_start <= self.max_sil_kept * 2:
155
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
156
- pos += i - self.max_sil_kept
157
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
158
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
159
- if silence_start == 0:
160
- sil_tags.append((0, pos_r))
161
- clip_start = pos_r
162
- else:
163
- sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
164
- clip_start = max(pos_r, pos)
165
- else:
166
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
167
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
168
- if silence_start == 0:
169
- sil_tags.append((0, pos_r))
170
- else:
171
- sil_tags.append((pos_l, pos_r))
172
- clip_start = pos_r
173
- silence_start = None
174
- # Deal with trailing silence.
175
- total_frames = rms_list.shape[0]
176
- if silence_start is not None and total_frames - silence_start >= self.min_interval:
177
- silence_end = min(total_frames, silence_start + self.max_sil_kept)
178
- pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
179
- sil_tags.append((pos, total_frames + 1))
180
- # Apply and return slices.
181
- ####音频+起始时间+终止时间
182
- if len(sil_tags) == 0:
183
- return [[waveform, 0, int(total_frames * self.hop_size)]]
184
- else:
185
- chunks = []
186
- if sil_tags[0][0] > 0:
187
- chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
188
- for i in range(len(sil_tags) - 1):
189
- chunks.append(
190
- [
191
- self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
192
- int(sil_tags[i][1] * self.hop_size),
193
- int(sil_tags[i + 1][0] * self.hop_size),
194
- ]
195
- )
196
- if sil_tags[-1][1] < total_frames:
197
- chunks.append(
198
- [
199
- self._apply_slice(waveform, sil_tags[-1][1], total_frames),
200
- int(sil_tags[-1][1] * self.hop_size),
201
- int(total_frames * self.hop_size),
202
- ]
203
- )
204
- return chunks
205
-
206
-
207
- # terminal
208
- def terminate_process_tree(pid, including_parent=True):
209
- try:
210
- parent = psutil.Process(pid)
211
- except psutil.NoSuchProcess:
212
- # Process already terminated
213
- return
214
-
215
- children = parent.children(recursive=True)
216
- for child in children:
217
- try:
218
- os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
219
- except OSError:
220
- pass
221
- if including_parent:
222
- try:
223
- os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
224
- except OSError:
225
- pass
226
-
227
-
228
- def terminate_process(pid):
229
- if system == "Windows":
230
- cmd = f"taskkill /t /f /pid {pid}"
231
- os.system(cmd)
232
- else:
233
- terminate_process_tree(pid)
234
-
235
-
236
- def start_training(
237
- dataset_name="",
238
- exp_name="F5TTS_Base",
239
- learning_rate=1e-4,
240
- batch_size_per_gpu=400,
241
- batch_size_type="frame",
242
- max_samples=64,
243
- grad_accumulation_steps=1,
244
- max_grad_norm=1.0,
245
- epochs=11,
246
- num_warmup_updates=200,
247
- save_per_updates=400,
248
- last_per_steps=800,
249
- finetune=True,
250
- ):
251
- global training_process, tts_api
252
-
253
- if tts_api is not None:
254
- del tts_api
255
- gc.collect()
256
- torch.cuda.empty_cache()
257
- tts_api = None
258
-
259
- path_project = os.path.join(path_data, dataset_name + "_pinyin")
260
-
261
- if not os.path.isdir(path_project):
262
- yield (
263
- f"There is not project with name {dataset_name}",
264
- gr.update(interactive=True),
265
- gr.update(interactive=False),
266
- )
267
- return
268
-
269
- file_raw = os.path.join(path_project, "raw.arrow")
270
- if not os.path.isfile(file_raw):
271
- yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
272
- return
273
-
274
- # Check if a training process is already running
275
- if training_process is not None:
276
- return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
277
-
278
- yield "start train", gr.update(interactive=False), gr.update(interactive=False)
279
-
280
- # Command to run the training script with the specified arguments
281
- cmd = (
282
- f"accelerate launch finetune-cli.py --exp_name {exp_name} "
283
- f"--learning_rate {learning_rate} "
284
- f"--batch_size_per_gpu {batch_size_per_gpu} "
285
- f"--batch_size_type {batch_size_type} "
286
- f"--max_samples {max_samples} "
287
- f"--grad_accumulation_steps {grad_accumulation_steps} "
288
- f"--max_grad_norm {max_grad_norm} "
289
- f"--epochs {epochs} "
290
- f"--num_warmup_updates {num_warmup_updates} "
291
- f"--save_per_updates {save_per_updates} "
292
- f"--last_per_steps {last_per_steps} "
293
- f"--dataset_name {dataset_name}"
294
- )
295
- if finetune:
296
- cmd += f" --finetune {finetune}"
297
-
298
- print(cmd)
299
-
300
- try:
301
- # Start the training process
302
- training_process = subprocess.Popen(cmd, shell=True)
303
-
304
- time.sleep(5)
305
- yield "train start", gr.update(interactive=False), gr.update(interactive=True)
306
-
307
- # Wait for the training process to finish
308
- training_process.wait()
309
- time.sleep(1)
310
-
311
- if training_process is None:
312
- text_info = "train stop"
313
- else:
314
- text_info = "train complete !"
315
-
316
- except Exception as e: # Catch all exceptions
317
- # Ensure that we reset the training process variable in case of an error
318
- text_info = f"An error occurred: {str(e)}"
319
-
320
- training_process = None
321
-
322
- yield text_info, gr.update(interactive=True), gr.update(interactive=False)
323
-
324
-
325
- def stop_training():
326
- global training_process
327
- if training_process is None:
328
- return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
329
- terminate_process_tree(training_process.pid)
330
- training_process = None
331
- return "train stop", gr.update(interactive=True), gr.update(interactive=False)
332
-
333
-
334
- def create_data_project(name):
335
- name += "_pinyin"
336
- os.makedirs(os.path.join(path_data, name), exist_ok=True)
337
- os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
338
-
339
-
340
- def transcribe(file_audio, language="english"):
341
- global pipe
342
-
343
- if pipe is None:
344
- pipe = pipeline(
345
- "automatic-speech-recognition",
346
- model="openai/whisper-large-v3-turbo",
347
- torch_dtype=torch.float16,
348
- device=device,
349
- )
350
-
351
- text_transcribe = pipe(
352
- file_audio,
353
- chunk_length_s=30,
354
- batch_size=128,
355
- generate_kwargs={"task": "transcribe", "language": language},
356
- return_timestamps=False,
357
- )["text"].strip()
358
- return text_transcribe
359
-
360
-
361
- def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
362
- name_project += "_pinyin"
363
- path_project = os.path.join(path_data, name_project)
364
- path_dataset = os.path.join(path_project, "dataset")
365
- path_project_wavs = os.path.join(path_project, "wavs")
366
- file_metadata = os.path.join(path_project, "metadata.csv")
367
-
368
- if audio_files is None:
369
- return "You need to load an audio file."
370
-
371
- if os.path.isdir(path_project_wavs):
372
- shutil.rmtree(path_project_wavs)
373
-
374
- if os.path.isfile(file_metadata):
375
- os.remove(file_metadata)
376
-
377
- os.makedirs(path_project_wavs, exist_ok=True)
378
-
379
- if user:
380
- file_audios = [
381
- file
382
- for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
383
- for file in glob(os.path.join(path_dataset, format))
384
- ]
385
- if file_audios == []:
386
- return "No audio file was found in the dataset."
387
- else:
388
- file_audios = audio_files
389
-
390
- alpha = 0.5
391
- _max = 1.0
392
- slicer = Slicer(24000)
393
-
394
- num = 0
395
- error_num = 0
396
- data = ""
397
- for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
398
- audio, _ = librosa.load(file_audio, sr=24000, mono=True)
399
-
400
- list_slicer = slicer.slice(audio)
401
- for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
402
- name_segment = os.path.join(f"segment_{num}")
403
- file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
404
-
405
- tmp_max = np.abs(chunk).max()
406
- if tmp_max > 1:
407
- chunk /= tmp_max
408
- chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
409
- wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
410
-
411
- try:
412
- text = transcribe(file_segment, language)
413
- text = text.lower().strip().replace('"', "")
414
-
415
- data += f"{name_segment}|{text}\n"
416
-
417
- num += 1
418
- except: # noqa: E722
419
- error_num += 1
420
-
421
- with open(file_metadata, "w", encoding="utf-8") as f:
422
- f.write(data)
423
-
424
- if error_num != []:
425
- error_text = f"\nerror files : {error_num}"
426
- else:
427
- error_text = ""
428
-
429
- return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
430
-
431
-
432
- def format_seconds_to_hms(seconds):
433
- hours = int(seconds / 3600)
434
- minutes = int((seconds % 3600) / 60)
435
- seconds = seconds % 60
436
- return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
437
-
438
-
439
- def create_metadata(name_project, progress=gr.Progress()):
440
- name_project += "_pinyin"
441
- path_project = os.path.join(path_data, name_project)
442
- path_project_wavs = os.path.join(path_project, "wavs")
443
- file_metadata = os.path.join(path_project, "metadata.csv")
444
- file_raw = os.path.join(path_project, "raw.arrow")
445
- file_duration = os.path.join(path_project, "duration.json")
446
- file_vocab = os.path.join(path_project, "vocab.txt")
447
-
448
- if not os.path.isfile(file_metadata):
449
- return "The file was not found in " + file_metadata
450
-
451
- with open(file_metadata, "r", encoding="utf-8") as f:
452
- data = f.read()
453
-
454
- audio_path_list = []
455
- text_list = []
456
- duration_list = []
457
-
458
- count = data.split("\n")
459
- lenght = 0
460
- result = []
461
- error_files = []
462
- for line in progress.tqdm(data.split("\n"), total=count):
463
- sp_line = line.split("|")
464
- if len(sp_line) != 2:
465
- continue
466
- name_audio, text = sp_line[:2]
467
-
468
- file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
469
-
470
- if not os.path.isfile(file_audio):
471
- error_files.append(file_audio)
472
- continue
473
-
474
- duraction = get_audio_duration(file_audio)
475
- if duraction < 2 and duraction > 15:
476
- continue
477
- if len(text) < 4:
478
- continue
479
-
480
- text = clear_text(text)
481
- text = convert_char_to_pinyin([text], polyphone=True)[0]
482
-
483
- audio_path_list.append(file_audio)
484
- duration_list.append(duraction)
485
- text_list.append(text)
486
-
487
- result.append({"audio_path": file_audio, "text": text, "duration": duraction})
488
-
489
- lenght += duraction
490
-
491
- if duration_list == []:
492
- error_files_text = "\n".join(error_files)
493
- return f"Error: No audio files found in the specified path : \n{error_files_text}"
494
-
495
- min_second = round(min(duration_list), 2)
496
- max_second = round(max(duration_list), 2)
497
-
498
- with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
499
- for line in progress.tqdm(result, total=len(result), desc="prepare data"):
500
- writer.write(line)
501
-
502
- with open(file_duration, "w", encoding="utf-8") as f:
503
- json.dump({"duration": duration_list}, f, ensure_ascii=False)
504
-
505
- file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
506
- if not os.path.isfile(file_vocab_finetune):
507
- return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
508
- shutil.copy2(file_vocab_finetune, file_vocab)
509
-
510
- if error_files != []:
511
- error_text = "error files\n" + "\n".join(error_files)
512
- else:
513
- error_text = ""
514
-
515
- return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
516
-
517
-
518
- def check_user(value):
519
- return gr.update(visible=not value), gr.update(visible=value)
520
-
521
-
522
- def calculate_train(
523
- name_project,
524
- batch_size_type,
525
- max_samples,
526
- learning_rate,
527
- num_warmup_updates,
528
- save_per_updates,
529
- last_per_steps,
530
- finetune,
531
- ):
532
- name_project += "_pinyin"
533
- path_project = os.path.join(path_data, name_project)
534
- file_duraction = os.path.join(path_project, "duration.json")
535
-
536
- if not os.path.isfile(file_duraction):
537
- return (
538
- 1000,
539
- max_samples,
540
- num_warmup_updates,
541
- save_per_updates,
542
- last_per_steps,
543
- "project not found !",
544
- learning_rate,
545
- )
546
-
547
- with open(file_duraction, "r") as file:
548
- data = json.load(file)
549
-
550
- duration_list = data["duration"]
551
-
552
- samples = len(duration_list)
553
-
554
- if torch.cuda.is_available():
555
- gpu_properties = torch.cuda.get_device_properties(0)
556
- total_memory = gpu_properties.total_memory / (1024**3)
557
- elif torch.backends.mps.is_available():
558
- total_memory = psutil.virtual_memory().available / (1024**3)
559
-
560
- if batch_size_type == "frame":
561
- batch = int(total_memory * 0.5)
562
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
563
- batch_size_per_gpu = int(38400 / batch)
564
- else:
565
- batch_size_per_gpu = int(total_memory / 8)
566
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
567
- batch = batch_size_per_gpu
568
-
569
- if batch_size_per_gpu <= 0:
570
- batch_size_per_gpu = 1
571
-
572
- if samples < 64:
573
- max_samples = int(samples * 0.25)
574
- else:
575
- max_samples = 64
576
-
577
- num_warmup_updates = int(samples * 0.05)
578
- save_per_updates = int(samples * 0.10)
579
- last_per_steps = int(save_per_updates * 5)
580
-
581
- max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
582
- num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
583
- save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
584
- last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
585
-
586
- if finetune:
587
- learning_rate = 1e-5
588
- else:
589
- learning_rate = 7.5e-5
590
-
591
- return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
592
-
593
-
594
- def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
595
- try:
596
- checkpoint = torch.load(checkpoint_path)
597
- print("Original Checkpoint Keys:", checkpoint.keys())
598
-
599
- ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
600
-
601
- if ema_model_state_dict is not None:
602
- new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
603
- torch.save(new_checkpoint, new_checkpoint_path)
604
- return f"New checkpoint saved at: {new_checkpoint_path}"
605
- else:
606
- return "No 'ema_model_state_dict' found in the checkpoint."
607
-
608
- except Exception as e:
609
- return f"An error occurred: {e}"
610
-
611
-
612
- def vocab_check(project_name):
613
- name_project = project_name + "_pinyin"
614
- path_project = os.path.join(path_data, name_project)
615
-
616
- file_metadata = os.path.join(path_project, "metadata.csv")
617
-
618
- file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
619
- if not os.path.isfile(file_vocab):
620
- return f"the file {file_vocab} not found !"
621
-
622
- with open(file_vocab, "r", encoding="utf-8") as f:
623
- data = f.read()
624
-
625
- vocab = data.split("\n")
626
-
627
- if not os.path.isfile(file_metadata):
628
- return f"the file {file_metadata} not found !"
629
-
630
- with open(file_metadata, "r", encoding="utf-8") as f:
631
- data = f.read()
632
-
633
- miss_symbols = []
634
- miss_symbols_keep = {}
635
- for item in data.split("\n"):
636
- sp = item.split("|")
637
- if len(sp) != 2:
638
- continue
639
-
640
- text = sp[1].lower().strip()
641
-
642
- for t in text:
643
- if t not in vocab and t not in miss_symbols_keep:
644
- miss_symbols.append(t)
645
- miss_symbols_keep[t] = t
646
- if miss_symbols == []:
647
- info = "You can train using your language !"
648
- else:
649
- info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
650
-
651
- return info
652
-
653
-
654
- def get_random_sample_prepare(project_name):
655
- name_project = project_name + "_pinyin"
656
- path_project = os.path.join(path_data, name_project)
657
- file_arrow = os.path.join(path_project, "raw.arrow")
658
- if not os.path.isfile(file_arrow):
659
- return "", None
660
- dataset = Dataset_.from_file(file_arrow)
661
- random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
662
- text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
663
- audio_path = random_sample["audio_path"][0]
664
- return text, audio_path
665
-
666
-
667
- def get_random_sample_transcribe(project_name):
668
- name_project = project_name + "_pinyin"
669
- path_project = os.path.join(path_data, name_project)
670
- file_metadata = os.path.join(path_project, "metadata.csv")
671
- if not os.path.isfile(file_metadata):
672
- return "", None
673
-
674
- data = ""
675
- with open(file_metadata, "r", encoding="utf-8") as f:
676
- data = f.read()
677
-
678
- list_data = []
679
- for item in data.split("\n"):
680
- sp = item.split("|")
681
- if len(sp) != 2:
682
- continue
683
- list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
684
-
685
- if list_data == []:
686
- return "", None
687
-
688
- random_item = random.choice(list_data)
689
-
690
- return random_item[1], random_item[0]
691
-
692
-
693
- def get_random_sample_infer(project_name):
694
- text, audio = get_random_sample_transcribe(project_name)
695
- return (
696
- text,
697
- text,
698
- audio,
699
- )
700
-
701
-
702
- def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
703
- global last_checkpoint, last_device, tts_api
704
-
705
- if not os.path.isfile(file_checkpoint):
706
- return None
707
-
708
- if training_process is not None:
709
- device_test = "cpu"
710
- else:
711
- device_test = None
712
-
713
- if last_checkpoint != file_checkpoint or last_device != device_test:
714
- if last_checkpoint != file_checkpoint:
715
- last_checkpoint = file_checkpoint
716
- if last_device != device_test:
717
- last_device = device_test
718
-
719
- tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
720
-
721
- print("update", device_test, file_checkpoint)
722
-
723
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
724
- tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
725
- return f.name
726
-
727
-
728
- with gr.Blocks() as app:
729
- with gr.Row():
730
- project_name = gr.Textbox(label="project name", value="my_speak")
731
- bt_create = gr.Button("create new project")
732
-
733
- bt_create.click(fn=create_data_project, inputs=[project_name])
734
-
735
- with gr.Tabs():
736
- with gr.TabItem("transcribe Data"):
737
- ch_manual = gr.Checkbox(label="user", value=False)
738
-
739
- mark_info_transcribe = gr.Markdown(
740
- """```plaintext
741
- Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
742
-
743
- my_speak/
744
-
745
- └── dataset/
746
- ├── audio1.wav
747
- └── audio2.wav
748
- ...
749
- ```""",
750
- visible=False,
751
- )
752
-
753
- audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
754
- txt_lang = gr.Text(label="Language", value="english")
755
- bt_transcribe = bt_create = gr.Button("transcribe")
756
- txt_info_transcribe = gr.Text(label="info", value="")
757
- bt_transcribe.click(
758
- fn=transcribe_all,
759
- inputs=[project_name, audio_speaker, txt_lang, ch_manual],
760
- outputs=[txt_info_transcribe],
761
- )
762
- ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
763
-
764
- random_sample_transcribe = gr.Button("random sample")
765
-
766
- with gr.Row():
767
- random_text_transcribe = gr.Text(label="Text")
768
- random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
769
-
770
- random_sample_transcribe.click(
771
- fn=get_random_sample_transcribe,
772
- inputs=[project_name],
773
- outputs=[random_text_transcribe, random_audio_transcribe],
774
- )
775
-
776
- with gr.TabItem("prepare Data"):
777
- gr.Markdown(
778
- """```plaintext
779
- place all your wavs folder and your metadata.csv file in {your name project}
780
- my_speak/
781
-
782
- ├── wavs/
783
- │ ├── audio1.wav
784
- │ └── audio2.wav
785
- | ...
786
-
787
- └── metadata.csv
788
-
789
- file format metadata.csv
790
-
791
- audio1|text1
792
- audio2|text1
793
- ...
794
-
795
- ```"""
796
- )
797
-
798
- bt_prepare = bt_create = gr.Button("prepare")
799
- txt_info_prepare = gr.Text(label="info", value="")
800
- bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
801
-
802
- random_sample_prepare = gr.Button("random sample")
803
-
804
- with gr.Row():
805
- random_text_prepare = gr.Text(label="Pinyin")
806
- random_audio_prepare = gr.Audio(label="Audio", type="filepath")
807
-
808
- random_sample_prepare.click(
809
- fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
810
- )
811
-
812
- with gr.TabItem("train Data"):
813
- with gr.Row():
814
- bt_calculate = bt_create = gr.Button("Auto Settings")
815
- ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
816
- lb_samples = gr.Label(label="samples")
817
- batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
818
-
819
- with gr.Row():
820
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
821
- learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
822
-
823
- with gr.Row():
824
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
825
- max_samples = gr.Number(label="Max Samples", value=64)
826
-
827
- with gr.Row():
828
- grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
829
- max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
830
-
831
- with gr.Row():
832
- epochs = gr.Number(label="Epochs", value=10)
833
- num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
834
-
835
- with gr.Row():
836
- save_per_updates = gr.Number(label="Save per Updates", value=10)
837
- last_per_steps = gr.Number(label="Last per Steps", value=50)
838
-
839
- with gr.Row():
840
- start_button = gr.Button("Start Training")
841
- stop_button = gr.Button("Stop Training", interactive=False)
842
-
843
- txt_info_train = gr.Text(label="info", value="")
844
- start_button.click(
845
- fn=start_training,
846
- inputs=[
847
- project_name,
848
- exp_name,
849
- learning_rate,
850
- batch_size_per_gpu,
851
- batch_size_type,
852
- max_samples,
853
- grad_accumulation_steps,
854
- max_grad_norm,
855
- epochs,
856
- num_warmup_updates,
857
- save_per_updates,
858
- last_per_steps,
859
- ch_finetune,
860
- ],
861
- outputs=[txt_info_train, start_button, stop_button],
862
- )
863
- stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
864
- bt_calculate.click(
865
- fn=calculate_train,
866
- inputs=[
867
- project_name,
868
- batch_size_type,
869
- max_samples,
870
- learning_rate,
871
- num_warmup_updates,
872
- save_per_updates,
873
- last_per_steps,
874
- ch_finetune,
875
- ],
876
- outputs=[
877
- batch_size_per_gpu,
878
- max_samples,
879
- num_warmup_updates,
880
- save_per_updates,
881
- last_per_steps,
882
- lb_samples,
883
- learning_rate,
884
- ],
885
- )
886
-
887
- with gr.TabItem("reduse checkpoint"):
888
- txt_path_checkpoint = gr.Text(label="path checkpoint :")
889
- txt_path_checkpoint_small = gr.Text(label="path output :")
890
- txt_info_reduse = gr.Text(label="info", value="")
891
- reduse_button = gr.Button("reduse")
892
- reduse_button.click(
893
- fn=extract_and_save_ema_model,
894
- inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
895
- outputs=[txt_info_reduse],
896
- )
897
-
898
- with gr.TabItem("vocab check experiment"):
899
- check_button = gr.Button("check vocab")
900
- txt_info_check = gr.Text(label="info", value="")
901
- check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
902
-
903
- with gr.TabItem("test model"):
904
- exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
905
- nfe_step = gr.Number(label="n_step", value=32)
906
- file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
907
-
908
- random_sample_infer = gr.Button("random sample")
909
-
910
- ref_text = gr.Textbox(label="ref text")
911
- ref_audio = gr.Audio(label="audio ref", type="filepath")
912
- gen_text = gr.Textbox(label="gen text")
913
- random_sample_infer.click(
914
- fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
915
- )
916
- check_button_infer = gr.Button("infer")
917
- gen_audio = gr.Audio(label="audio gen", type="filepath")
918
-
919
- check_button_infer.click(
920
- fn=infer,
921
- inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
922
- outputs=[gen_audio],
923
- )
924
-
925
-
926
- @click.command()
927
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
928
- @click.option("--host", "-H", default=None, help="Host to run the app on")
929
- @click.option(
930
- "--share",
931
- "-s",
932
- default=False,
933
- is_flag=True,
934
- help="Share the app via Gradio share link",
935
- )
936
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
937
- def main(port, host, share, api):
938
- global app
939
- print("Starting app...")
940
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
941
-
942
-
943
- if __name__ == "__main__":
944
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_app.py DELETED
@@ -1,824 +0,0 @@
1
- import os
2
- import re
3
- import torch
4
- import torchaudio
5
- import gradio as gr
6
- import numpy as np
7
- import tempfile
8
- from einops import rearrange
9
- from vocos import Vocos
10
- from pydub import AudioSegment, silence
11
- from model import CFM, UNetT, DiT, MMDiT
12
- from cached_path import cached_path
13
- from model.utils import (
14
- load_checkpoint,
15
- get_tokenizer,
16
- convert_char_to_pinyin,
17
- save_spectrogram,
18
- )
19
- from transformers import pipeline
20
- import librosa
21
- import click
22
- import soundfile as sf
23
-
24
- try:
25
- import spaces
26
- USING_SPACES = True
27
- except ImportError:
28
- USING_SPACES = False
29
-
30
- def gpu_decorator(func):
31
- if USING_SPACES:
32
- return spaces.GPU(func)
33
- else:
34
- return func
35
-
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
- device = (
50
- "cuda"
51
- if torch.cuda.is_available()
52
- else "mps" if torch.backends.mps.is_available() else "cpu"
53
- )
54
-
55
- print(f"Using {device} device")
56
-
57
- pipe = pipeline(
58
- "automatic-speech-recognition",
59
- model="openai/whisper-large-v3-turbo",
60
- torch_dtype=torch.float16,
61
- device=device,
62
- )
63
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
-
65
- # --------------------- Settings -------------------- #
66
-
67
- target_sample_rate = 24000
68
- n_mel_channels = 100
69
- hop_length = 256
70
- target_rms = 0.1
71
- nfe_step = 32 # 16, 32
72
- cfg_strength = 2.0
73
- ode_method = "euler"
74
- sway_sampling_coef = -1.0
75
- speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
- fix_duration = None
78
-
79
-
80
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
- model = CFM(
85
- transformer=model_cls(
86
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
- ),
88
- mel_spec_kwargs=dict(
89
- target_sample_rate=target_sample_rate,
90
- n_mel_channels=n_mel_channels,
91
- hop_length=hop_length,
92
- ),
93
- odeint_kwargs=dict(
94
- method=ode_method,
95
- ),
96
- vocab_char_map=vocab_char_map,
97
- ).to(device)
98
-
99
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
-
101
- return model
102
-
103
-
104
- # load models
105
- F5TTS_model_cfg = dict(
106
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
- )
108
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
-
110
- F5TTS_ema_model = load_model(
111
- "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
- )
113
- E2TTS_ema_model = load_model(
114
- "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
- )
116
-
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
-
154
- for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
- else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
-
213
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
- if exp_name == "F5-TTS":
215
- ema_model = F5TTS_ema_model
216
- elif exp_name == "E2-TTS":
217
- ema_model = E2TTS_ema_model
218
-
219
- audio, sr = ref_audio
220
- if audio.shape[0] > 1:
221
- audio = torch.mean(audio, dim=0, keepdim=True)
222
-
223
- rms = torch.sqrt(torch.mean(torch.square(audio)))
224
- if rms < target_rms:
225
- audio = audio * target_rms / rms
226
- if sr != target_sample_rate:
227
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
- audio = resampler(audio)
229
- audio = audio.to(device)
230
-
231
- generated_waves = []
232
- spectrograms = []
233
-
234
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
- # Prepare the text
236
- if len(ref_text[-1].encode('utf-8')) == 1:
237
- ref_text = ref_text + " "
238
- text_list = [ref_text + gen_text]
239
- final_text_list = convert_char_to_pinyin(text_list)
240
-
241
- # Calculate duration
242
- ref_audio_len = audio.shape[-1] // hop_length
243
- zh_pause_punc = r"。,、;:?!"
244
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
-
248
- # inference
249
- with torch.inference_mode():
250
- generated, _ = ema_model.sample(
251
- cond=audio,
252
- text=final_text_list,
253
- duration=duration,
254
- steps=nfe_step,
255
- cfg_strength=cfg_strength,
256
- sway_sampling_coef=sway_sampling_coef,
257
- )
258
-
259
- generated = generated[:, ref_audio_len:, :]
260
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
- generated_wave = vocos.decode(generated_mel_spec.cpu())
262
- if rms < target_rms:
263
- generated_wave = generated_wave * rms / target_rms
264
-
265
- # wav -> numpy
266
- generated_wave = generated_wave.squeeze().cpu().numpy()
267
-
268
- generated_waves.append(generated_wave)
269
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
-
271
- # Combine all generated waves
272
- final_wave = np.concatenate(generated_waves)
273
-
274
- # Remove silence
275
- if remove_silence:
276
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
- sf.write(f.name, final_wave, target_sample_rate)
278
- aseg = AudioSegment.from_file(f.name)
279
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
- non_silent_wave = AudioSegment.silent(duration=0)
281
- for non_silent_seg in non_silent_segs:
282
- non_silent_wave += non_silent_seg
283
- aseg = non_silent_wave
284
- aseg.export(f.name, format="wav")
285
- final_wave, _ = torchaudio.load(f.name)
286
- final_wave = final_wave.squeeze().cpu().numpy()
287
-
288
- # Create a combined spectrogram
289
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
-
291
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
- spectrogram_path = tmp_spectrogram.name
293
- save_spectrogram(combined_spectrogram, spectrogram_path)
294
-
295
- return (target_sample_rate, final_wave), spectrogram_path
296
-
297
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
- if not custom_split_words.strip():
299
- custom_words = [word.strip() for word in custom_split_words.split(',')]
300
- global SPLIT_WORDS
301
- SPLIT_WORDS = custom_words
302
-
303
- print(gen_text)
304
-
305
- gr.Info("Converting audio...")
306
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
- aseg = AudioSegment.from_file(ref_audio_orig)
308
-
309
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
- non_silent_wave = AudioSegment.silent(duration=0)
311
- for non_silent_seg in non_silent_segs:
312
- non_silent_wave += non_silent_seg
313
- aseg = non_silent_wave
314
-
315
- audio_duration = len(aseg)
316
- if audio_duration > 15000:
317
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
- aseg = aseg[:15000]
319
- aseg.export(f.name, format="wav")
320
- ref_audio = f.name
321
-
322
- if not ref_text.strip():
323
- gr.Info("No reference text provided, transcribing reference audio...")
324
- ref_text = pipe(
325
- ref_audio,
326
- chunk_length_s=30,
327
- batch_size=128,
328
- generate_kwargs={"task": "transcribe"},
329
- return_timestamps=False,
330
- )["text"].strip()
331
- gr.Info("Finished transcription")
332
- else:
333
- gr.Info("Using custom reference text...")
334
-
335
- # Split the input text into batches
336
- audio, sr = torchaudio.load(ref_audio)
337
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
- print('ref_text', ref_text)
340
- for i, gen_text in enumerate(gen_text_batches):
341
- print(f'gen_text {i}', gen_text)
342
-
343
- gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
-
346
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
- # Split the script into speaker blocks
348
- speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
- speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
-
351
- generated_audio_segments = []
352
-
353
- for i in range(0, len(speaker_blocks), 2):
354
- speaker = speaker_blocks[i]
355
- text = speaker_blocks[i+1].strip()
356
-
357
- # Determine which speaker is talking
358
- if speaker == speaker1_name:
359
- ref_audio = ref_audio1
360
- ref_text = ref_text1
361
- elif speaker == speaker2_name:
362
- ref_audio = ref_audio2
363
- ref_text = ref_text2
364
- else:
365
- continue # Skip if the speaker is neither speaker1 nor speaker2
366
-
367
- # Generate audio for this block
368
- audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
-
370
- # Convert the generated audio to a numpy array
371
- sr, audio_data = audio
372
-
373
- # Save the audio data as a WAV file
374
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
- sf.write(temp_file.name, audio_data, sr)
376
- audio_segment = AudioSegment.from_wav(temp_file.name)
377
-
378
- generated_audio_segments.append(audio_segment)
379
-
380
- # Add a short pause between speakers
381
- pause = AudioSegment.silent(duration=500) # 500ms pause
382
- generated_audio_segments.append(pause)
383
-
384
- # Concatenate all audio segments
385
- final_podcast = sum(generated_audio_segments)
386
-
387
- # Export the final podcast
388
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
- podcast_path = temp_file.name
390
- final_podcast.export(podcast_path, format="wav")
391
-
392
- return podcast_path
393
-
394
- def parse_speechtypes_text(gen_text):
395
- # Pattern to find (Emotion)
396
- pattern = r'\((.*?)\)'
397
-
398
- # Split the text by the pattern
399
- tokens = re.split(pattern, gen_text)
400
-
401
- segments = []
402
-
403
- current_emotion = 'Regular'
404
-
405
- for i in range(len(tokens)):
406
- if i % 2 == 0:
407
- # This is text
408
- text = tokens[i].strip()
409
- if text:
410
- segments.append({'emotion': current_emotion, 'text': text})
411
- else:
412
- # This is emotion
413
- emotion = tokens[i].strip()
414
- current_emotion = emotion
415
-
416
- return segments
417
-
418
- def update_speed(new_speed):
419
- global speed
420
- speed = new_speed
421
- return f"Speed set to: {speed}"
422
-
423
- with gr.Blocks() as app_credits:
424
- gr.Markdown("""
425
- # Credits
426
-
427
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
- * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
- """)
430
- with gr.Blocks() as app_tts:
431
- gr.Markdown("# Batched TTS")
432
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
- model_choice = gr.Radio(
435
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
- )
437
- generate_btn = gr.Button("Synthesize", variant="primary")
438
- with gr.Accordion("Advanced Settings", open=False):
439
- ref_text_input = gr.Textbox(
440
- label="Reference Text",
441
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
- lines=2,
443
- )
444
- remove_silence = gr.Checkbox(
445
- label="Remove Silences",
446
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
- value=True,
448
- )
449
- split_words_input = gr.Textbox(
450
- label="Custom Split Words",
451
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
- lines=2,
453
- )
454
- speed_slider = gr.Slider(
455
- label="Speed",
456
- minimum=0.3,
457
- maximum=2.0,
458
- value=speed,
459
- step=0.1,
460
- info="Adjust the speed of the audio.",
461
- )
462
- speed_slider.change(update_speed, inputs=speed_slider)
463
-
464
- audio_output = gr.Audio(label="Synthesized Audio")
465
- spectrogram_output = gr.Image(label="Spectrogram")
466
-
467
- generate_btn.click(
468
- infer,
469
- inputs=[
470
- ref_audio_input,
471
- ref_text_input,
472
- gen_text_input,
473
- model_choice,
474
- remove_silence,
475
- split_words_input,
476
- ],
477
- outputs=[audio_output, spectrogram_output],
478
- )
479
-
480
- with gr.Blocks() as app_podcast:
481
- gr.Markdown("# Podcast Generation")
482
- speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
- ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
- ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
-
486
- speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
- ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
- ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
-
490
- script_input = gr.Textbox(label="Podcast Script", lines=10,
491
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
-
493
- podcast_model_choice = gr.Radio(
494
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
- )
496
- podcast_remove_silence = gr.Checkbox(
497
- label="Remove Silences",
498
- value=True,
499
- )
500
- generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
- podcast_output = gr.Audio(label="Generated Podcast")
502
-
503
- def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
- return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
-
506
- generate_podcast_btn.click(
507
- podcast_generation,
508
- inputs=[
509
- script_input,
510
- speaker1_name,
511
- ref_audio_input1,
512
- ref_text_input1,
513
- speaker2_name,
514
- ref_audio_input2,
515
- ref_text_input2,
516
- podcast_model_choice,
517
- podcast_remove_silence,
518
- ],
519
- outputs=podcast_output,
520
- )
521
-
522
- def parse_emotional_text(gen_text):
523
- # Pattern to find (Emotion)
524
- pattern = r'\((.*?)\)'
525
-
526
- # Split the text by the pattern
527
- tokens = re.split(pattern, gen_text)
528
-
529
- segments = []
530
-
531
- current_emotion = 'Regular'
532
-
533
- for i in range(len(tokens)):
534
- if i % 2 == 0:
535
- # This is text
536
- text = tokens[i].strip()
537
- if text:
538
- segments.append({'emotion': current_emotion, 'text': text})
539
- else:
540
- # This is emotion
541
- emotion = tokens[i].strip()
542
- current_emotion = emotion
543
-
544
- return segments
545
-
546
- with gr.Blocks() as app_emotional:
547
- # New section for emotional generation
548
- gr.Markdown(
549
- """
550
- # Multiple Speech-Type Generation
551
-
552
- This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
-
554
- **Example Input:**
555
-
556
- (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
- """
558
- )
559
-
560
- gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
-
562
- # Regular speech type (mandatory)
563
- with gr.Row():
564
- regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
- regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
- regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
-
568
- # Additional speech types (up to 9 more)
569
- max_speech_types = 10
570
- speech_type_names = []
571
- speech_type_audios = []
572
- speech_type_ref_texts = []
573
- speech_type_delete_btns = []
574
-
575
- for i in range(max_speech_types - 1):
576
- with gr.Row():
577
- name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
- audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
- ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
- delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
- speech_type_names.append(name_input)
582
- speech_type_audios.append(audio_input)
583
- speech_type_ref_texts.append(ref_text_input)
584
- speech_type_delete_btns.append(delete_btn)
585
-
586
- # Button to add speech type
587
- add_speech_type_btn = gr.Button("Add Speech Type")
588
-
589
- # Keep track of current number of speech types
590
- speech_type_count = gr.State(value=0)
591
-
592
- # Function to add a speech type
593
- def add_speech_type_fn(speech_type_count):
594
- if speech_type_count < max_speech_types - 1:
595
- speech_type_count += 1
596
- # Prepare updates for the components
597
- name_updates = []
598
- audio_updates = []
599
- ref_text_updates = []
600
- delete_btn_updates = []
601
- for i in range(max_speech_types - 1):
602
- if i < speech_type_count:
603
- name_updates.append(gr.update(visible=True))
604
- audio_updates.append(gr.update(visible=True))
605
- ref_text_updates.append(gr.update(visible=True))
606
- delete_btn_updates.append(gr.update(visible=True))
607
- else:
608
- name_updates.append(gr.update())
609
- audio_updates.append(gr.update())
610
- ref_text_updates.append(gr.update())
611
- delete_btn_updates.append(gr.update())
612
- else:
613
- # Optionally, show a warning
614
- # gr.Warning("Maximum number of speech types reached.")
615
- name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
- audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
- ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
- delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
-
621
- add_speech_type_btn.click(
622
- add_speech_type_fn,
623
- inputs=speech_type_count,
624
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
- )
626
-
627
- # Function to delete a speech type
628
- def make_delete_speech_type_fn(index):
629
- def delete_speech_type_fn(speech_type_count):
630
- # Prepare updates
631
- name_updates = []
632
- audio_updates = []
633
- ref_text_updates = []
634
- delete_btn_updates = []
635
-
636
- for i in range(max_speech_types - 1):
637
- if i == index:
638
- name_updates.append(gr.update(visible=False, value=''))
639
- audio_updates.append(gr.update(visible=False, value=None))
640
- ref_text_updates.append(gr.update(visible=False, value=''))
641
- delete_btn_updates.append(gr.update(visible=False))
642
- else:
643
- name_updates.append(gr.update())
644
- audio_updates.append(gr.update())
645
- ref_text_updates.append(gr.update())
646
- delete_btn_updates.append(gr.update())
647
-
648
- speech_type_count = max(0, speech_type_count - 1)
649
-
650
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
-
652
- return delete_speech_type_fn
653
-
654
- for i, delete_btn in enumerate(speech_type_delete_btns):
655
- delete_fn = make_delete_speech_type_fn(i)
656
- delete_btn.click(
657
- delete_fn,
658
- inputs=speech_type_count,
659
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
- )
661
-
662
- # Text input for the prompt
663
- gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
-
665
- # Model choice
666
- model_choice_emotional = gr.Radio(
667
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
- )
669
-
670
- with gr.Accordion("Advanced Settings", open=False):
671
- remove_silence_emotional = gr.Checkbox(
672
- label="Remove Silences",
673
- value=True,
674
- )
675
-
676
- # Generate button
677
- generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
-
679
- # Output audio
680
- audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
-
682
- def generate_emotional_speech(
683
- regular_audio,
684
- regular_ref_text,
685
- gen_text,
686
- *args,
687
- ):
688
- num_additional_speech_types = max_speech_types - 1
689
- speech_type_names_list = args[:num_additional_speech_types]
690
- speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
- speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
- model_choice = args[3 * num_additional_speech_types]
693
- remove_silence = args[3 * num_additional_speech_types + 1]
694
-
695
- # Collect the speech types and their audios into a dict
696
- speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
-
698
- for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
- if name_input and audio_input:
700
- speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
-
702
- # Parse the gen_text into segments
703
- segments = parse_speechtypes_text(gen_text)
704
-
705
- # For each segment, generate speech
706
- generated_audio_segments = []
707
- current_emotion = 'Regular'
708
-
709
- for segment in segments:
710
- emotion = segment['emotion']
711
- text = segment['text']
712
-
713
- if emotion in speech_types:
714
- current_emotion = emotion
715
- else:
716
- # If emotion not available, default to Regular
717
- current_emotion = 'Regular'
718
-
719
- ref_audio = speech_types[current_emotion]['audio']
720
- ref_text = speech_types[current_emotion].get('ref_text', '')
721
-
722
- # Generate speech for this segment
723
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
- sr, audio_data = audio
725
-
726
- generated_audio_segments.append(audio_data)
727
-
728
- # Concatenate all audio segments
729
- if generated_audio_segments:
730
- final_audio_data = np.concatenate(generated_audio_segments)
731
- return (sr, final_audio_data)
732
- else:
733
- gr.Warning("No audio generated.")
734
- return None
735
-
736
- generate_emotional_btn.click(
737
- generate_emotional_speech,
738
- inputs=[
739
- regular_audio,
740
- regular_ref_text,
741
- gen_text_input_emotional,
742
- ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
- model_choice_emotional,
744
- remove_silence_emotional,
745
- ],
746
- outputs=audio_output_emotional,
747
- )
748
-
749
- # Validation function to disable Generate button if speech types are missing
750
- def validate_speech_types(
751
- gen_text,
752
- regular_name,
753
- *args
754
- ):
755
- num_additional_speech_types = max_speech_types - 1
756
- speech_type_names_list = args[:num_additional_speech_types]
757
-
758
- # Collect the speech types names
759
- speech_types_available = set()
760
- if regular_name:
761
- speech_types_available.add(regular_name)
762
- for name_input in speech_type_names_list:
763
- if name_input:
764
- speech_types_available.add(name_input)
765
-
766
- # Parse the gen_text to get the speech types used
767
- segments = parse_emotional_text(gen_text)
768
- speech_types_in_text = set(segment['emotion'] for segment in segments)
769
-
770
- # Check if all speech types in text are available
771
- missing_speech_types = speech_types_in_text - speech_types_available
772
-
773
- if missing_speech_types:
774
- # Disable the generate button
775
- return gr.update(interactive=False)
776
- else:
777
- # Enable the generate button
778
- return gr.update(interactive=True)
779
-
780
- gen_text_input_emotional.change(
781
- validate_speech_types,
782
- inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
- outputs=generate_emotional_btn
784
- )
785
- with gr.Blocks() as app:
786
- gr.Markdown(
787
- """
788
- # E2/F5 TTS
789
-
790
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
-
792
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
-
795
- The checkpoints support English and Chinese.
796
-
797
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
-
799
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
- """
801
- )
802
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
-
804
- @click.command()
805
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
- @click.option("--host", "-H", default=None, help="Host to run the app on")
807
- @click.option(
808
- "--share",
809
- "-s",
810
- default=False,
811
- is_flag=True,
812
- help="Share the app via Gradio share link",
813
- )
814
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
- def main(port, host, share, api):
816
- global app
817
- print(f"Starting app...")
818
- app.queue(api_open=api).launch(
819
- server_name=host, server_port=port, share=share, show_api=api
820
- )
821
-
822
-
823
- if __name__ == "__main__":
824
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference-cli.py DELETED
@@ -1,170 +0,0 @@
1
- import argparse
2
- import codecs
3
- import re
4
- from pathlib import Path
5
-
6
- import numpy as np
7
- import soundfile as sf
8
- import tomli
9
- from cached_path import cached_path
10
-
11
- from model import DiT, UNetT
12
- from model.utils_infer import (
13
- load_vocoder,
14
- load_model,
15
- preprocess_ref_audio_text,
16
- infer_process,
17
- remove_silence_for_generated_wav,
18
- )
19
-
20
-
21
- parser = argparse.ArgumentParser(
22
- prog="python3 inference-cli.py",
23
- description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
24
- epilog="Specify options above to override one or more settings from config.",
25
- )
26
- parser.add_argument(
27
- "-c",
28
- "--config",
29
- help="Configuration file. Default=cli-config.toml",
30
- default="inference-cli.toml",
31
- )
32
- parser.add_argument(
33
- "-m",
34
- "--model",
35
- help="F5-TTS | E2-TTS",
36
- )
37
- parser.add_argument(
38
- "-p",
39
- "--ckpt_file",
40
- help="The Checkpoint .pt",
41
- )
42
- parser.add_argument(
43
- "-v",
44
- "--vocab_file",
45
- help="The vocab .txt",
46
- )
47
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
48
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
49
- parser.add_argument(
50
- "-t",
51
- "--gen_text",
52
- type=str,
53
- help="Text to generate.",
54
- )
55
- parser.add_argument(
56
- "-f",
57
- "--gen_file",
58
- type=str,
59
- help="File with text to generate. Ignores --text",
60
- )
61
- parser.add_argument(
62
- "-o",
63
- "--output_dir",
64
- type=str,
65
- help="Path to output folder..",
66
- )
67
- parser.add_argument(
68
- "--remove_silence",
69
- help="Remove silence.",
70
- )
71
- parser.add_argument(
72
- "--load_vocoder_from_local",
73
- action="store_true",
74
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
75
- )
76
- args = parser.parse_args()
77
-
78
- config = tomli.load(open(args.config, "rb"))
79
-
80
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
81
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
82
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
83
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
84
- if gen_file:
85
- gen_text = codecs.open(gen_file, "r", "utf-8").read()
86
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
87
- model = args.model if args.model else config["model"]
88
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
89
- vocab_file = args.vocab_file if args.vocab_file else ""
90
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
91
- wave_path = Path(output_dir) / "out.wav"
92
- spectrogram_path = Path(output_dir) / "out.png"
93
- vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
94
-
95
- vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
96
-
97
-
98
- # load models
99
- if model == "F5-TTS":
100
- model_cls = DiT
101
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
102
- if ckpt_file == "":
103
- repo_name = "F5-TTS"
104
- exp_name = "F5TTS_Base"
105
- ckpt_step = 1200000
106
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
-
109
- elif model == "E2-TTS":
110
- model_cls = UNetT
111
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
112
- if ckpt_file == "":
113
- repo_name = "E2-TTS"
114
- exp_name = "E2TTS_Base"
115
- ckpt_step = 1200000
116
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
-
119
- print(f"Using {model}...")
120
- ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
121
-
122
-
123
- def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
124
- main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
125
- if "voices" not in config:
126
- voices = {"main": main_voice}
127
- else:
128
- voices = config["voices"]
129
- voices["main"] = main_voice
130
- for voice in voices:
131
- voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
132
- voices[voice]["ref_audio"], voices[voice]["ref_text"]
133
- )
134
- print("Voice:", voice)
135
- print("Ref_audio:", voices[voice]["ref_audio"])
136
- print("Ref_text:", voices[voice]["ref_text"])
137
-
138
- generated_audio_segments = []
139
- reg1 = r"(?=\[\w+\])"
140
- chunks = re.split(reg1, text_gen)
141
- reg2 = r"\[(\w+)\]"
142
- for text in chunks:
143
- match = re.match(reg2, text)
144
- if match:
145
- voice = match[1]
146
- else:
147
- print("No voice tag found, using main.")
148
- voice = "main"
149
- if voice not in voices:
150
- print(f"Voice {voice} not found, using main.")
151
- voice = "main"
152
- text = re.sub(reg2, "", text)
153
- gen_text = text.strip()
154
- ref_audio = voices[voice]["ref_audio"]
155
- ref_text = voices[voice]["ref_text"]
156
- print(f"Voice: {voice}")
157
- audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
158
- generated_audio_segments.append(audio)
159
-
160
- if generated_audio_segments:
161
- final_wave = np.concatenate(generated_audio_segments)
162
- with open(wave_path, "wb") as f:
163
- sf.write(f.name, final_wave, final_sample_rate)
164
- # Remove silence
165
- if remove_silence:
166
- remove_silence_for_generated_wav(f.name)
167
- print(f.name)
168
-
169
-
170
- main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference-cli.toml DELETED
@@ -1,10 +0,0 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
- ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
- # If an empty "", transcribes the reference audio automatically.
5
- ref_text = "Some call me nature, others call me mother nature."
6
- gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
- # File with text to generate. Ignores the text above.
8
- gen_file = ""
9
- remove_silence = false
10
- output_dir = "tests"
 
 
 
 
 
 
 
 
 
 
 
model/__init__.py CHANGED
@@ -5,6 +5,3 @@ from model.backbones.dit import DiT
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
8
-
9
-
10
- __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
 
 
 
model/backbones/dit.py CHANGED
@@ -13,6 +13,8 @@ import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
 
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from model.modules import (
@@ -21,16 +23,14 @@ from model.modules import (
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
  AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
  )
27
 
28
 
29
  # Text embedding
30
 
31
-
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
@@ -38,22 +38,20 @@ class TextEmbedding(nn.Module):
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
- self.text_blocks = nn.Sequential(
42
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
- )
44
  else:
45
  self.extra_modeling = False
46
 
47
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
- batch, text_len = text.shape[0], text.shape[1]
51
- text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
55
 
56
- text = self.text_embed(text) # b n -> b n d
57
 
58
  # possible extra modeling
59
  if self.extra_modeling:
@@ -71,91 +69,88 @@ class TextEmbedding(nn.Module):
71
 
72
  # noised input audio and context mixing embedding
73
 
74
-
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
 
81
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
-
89
 
90
  # Transformer backbone using DiT blocks
91
 
92
-
93
  class DiT(nn.Module):
94
- def __init__(
95
- self,
96
- *,
97
- dim,
98
- depth=8,
99
- heads=8,
100
- dim_head=64,
101
- dropout=0.1,
102
- ff_mult=4,
103
- mel_dim=100,
104
- text_num_embeds=256,
105
- text_dim=None,
106
- conv_layers=0,
107
- long_skip_connection=False,
108
  ):
109
  super().__init__()
110
 
111
  self.time_embed = TimestepEmbedding(dim)
112
  if text_dim is None:
113
  text_dim = mel_dim
114
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
 
117
  self.rotary_embed = RotaryEmbedding(dim_head)
118
 
119
  self.dim = dim
120
  self.depth = depth
121
-
122
  self.transformer_blocks = nn.ModuleList(
123
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
124
  )
125
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
-
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
130
  def forward(
131
  self,
132
- x: float["b n d"], # nosied input audio # noqa: F722
133
- cond: float["b n d"], # masked cond audio # noqa: F722
134
- text: int["b nt"], # text # noqa: F722
135
- time: float["b"] | float[""], # time step # noqa: F821 F722
136
  drop_audio_cond, # cfg for cond audio
137
- drop_text, # cfg for text
138
- mask: bool["b n"] | None = None, # noqa: F722
139
  ):
140
  batch, seq_len = x.shape[0], x.shape[1]
141
  if time.ndim == 0:
142
- time = time.repeat(batch)
143
-
144
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
  t = self.time_embed(time)
146
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
-
149
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
 
151
  if self.long_skip_connection is not None:
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
156
 
157
  if self.long_skip_connection is not None:
158
- x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
 
160
  x = self.norm_out(x, t)
161
  output = self.proj_out(x)
 
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
16
+ from einops import repeat
17
+
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from model.modules import (
 
23
  ConvPositionEmbedding,
24
  DiTBlock,
25
  AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
 
27
  )
28
 
29
 
30
  # Text embedding
31
 
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
42
  else:
43
  self.extra_modeling = False
44
 
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
50
 
51
  if drop_text: # cfg for text
52
  text = torch.zeros_like(text)
53
 
54
+ text = self.text_embed(text) # b n -> b n d
55
 
56
  # possible extra modeling
57
  if self.extra_modeling:
 
69
 
70
  # noised input audio and context mixing embedding
71
 
 
72
  class InputEmbedding(nn.Module):
73
  def __init__(self, mel_dim, text_dim, out_dim):
74
  super().__init__()
75
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
 
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
  if drop_audio_cond: # cfg for cond audio
80
  cond = torch.zeros_like(cond)
81
 
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
  x = self.conv_pos_embed(x) + x
84
  return x
85
+
86
 
87
  # Transformer backbone using DiT blocks
88
 
 
89
  class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
 
 
 
 
 
 
 
 
 
 
94
  ):
95
  super().__init__()
96
 
97
  self.time_embed = TimestepEmbedding(dim)
98
  if text_dim is None:
99
  text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
 
103
  self.rotary_embed = RotaryEmbedding(dim_head)
104
 
105
  self.dim = dim
106
  self.depth = depth
107
+
108
  self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
  )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
  self.proj_out = nn.Linear(dim, mel_dim)
124
 
125
  def forward(
126
  self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
  drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
  ):
135
  batch, seq_len = x.shape[0], x.shape[1]
136
  if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
  t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
 
146
  if self.long_skip_connection is not None:
147
  residual = x
148
 
149
  for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
 
152
  if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
 
155
  x = self.norm_out(x, t)
156
  output = self.proj_out(x)
model/backbones/mmdit.py CHANGED
@@ -12,6 +12,8 @@ from __future__ import annotations
12
  import torch
13
  from torch import nn
14
 
 
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from model.modules import (
@@ -19,14 +21,12 @@ from model.modules import (
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
  AdaLayerNormZero_Final,
22
- precompute_freqs_cis,
23
- get_pos_embed_indices,
24
  )
25
 
26
 
27
  # text embedding
28
 
29
-
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
@@ -35,7 +35,7 @@ class TextEmbedding(nn.Module):
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
@@ -54,37 +54,27 @@ class TextEmbedding(nn.Module):
54
 
55
  # noised input & masked cond audio embedding
56
 
57
-
58
  class AudioEmbedding(nn.Module):
59
  def __init__(self, in_dim, out_dim):
60
  super().__init__()
61
  self.linear = nn.Linear(2 * in_dim, out_dim)
62
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
 
64
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
  if drop_audio_cond:
66
  cond = torch.zeros_like(cond)
67
- x = torch.cat((x, cond), dim=-1)
68
  x = self.linear(x)
69
  x = self.conv_pos_embed(x) + x
70
  return x
71
-
72
 
73
  # Transformer backbone using MM-DiT blocks
74
 
75
-
76
  class MMDiT(nn.Module):
77
- def __init__(
78
- self,
79
- *,
80
- dim,
81
- depth=8,
82
- heads=8,
83
- dim_head=64,
84
- dropout=0.1,
85
- ff_mult=4,
86
- text_num_embeds=256,
87
- mel_dim=100,
88
  ):
89
  super().__init__()
90
 
@@ -96,16 +86,16 @@ class MMDiT(nn.Module):
96
 
97
  self.dim = dim
98
  self.depth = depth
99
-
100
  self.transformer_blocks = nn.ModuleList(
101
  [
102
  MMDiTBlock(
103
- dim=dim,
104
- heads=heads,
105
- dim_head=dim_head,
106
- dropout=dropout,
107
- ff_mult=ff_mult,
108
- context_pre_only=i == depth - 1,
109
  )
110
  for i in range(depth)
111
  ]
@@ -115,30 +105,30 @@ class MMDiT(nn.Module):
115
 
116
  def forward(
117
  self,
118
- x: float["b n d"], # nosied input audio # noqa: F722
119
- cond: float["b n d"], # masked cond audio # noqa: F722
120
- text: int["b nt"], # text # noqa: F722
121
- time: float["b"] | float[""], # time step # noqa: F821 F722
122
  drop_audio_cond, # cfg for cond audio
123
- drop_text, # cfg for text
124
- mask: bool["b n"] | None = None, # noqa: F722
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
128
- time = time.repeat(batch)
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
133
- x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
136
  text_len = text.shape[1]
137
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
-
140
  for block in self.transformer_blocks:
141
- c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
 
143
  x = self.norm_out(x, t)
144
  output = self.proj_out(x)
 
12
  import torch
13
  from torch import nn
14
 
15
+ from einops import repeat
16
+
17
  from x_transformers.x_transformers import RotaryEmbedding
18
 
19
  from model.modules import (
 
21
  ConvPositionEmbedding,
22
  MMDiTBlock,
23
  AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
 
25
  )
26
 
27
 
28
  # text embedding
29
 
 
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
 
54
 
55
  # noised input & masked cond audio embedding
56
 
 
57
  class AudioEmbedding(nn.Module):
58
  def __init__(self, in_dim, out_dim):
59
  super().__init__()
60
  self.linear = nn.Linear(2 * in_dim, out_dim)
61
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
 
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
  if drop_audio_cond:
65
  cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
  x = self.linear(x)
68
  x = self.conv_pos_embed(x) + x
69
  return x
70
+
71
 
72
  # Transformer backbone using MM-DiT blocks
73
 
 
74
  class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
 
 
 
 
 
 
 
 
78
  ):
79
  super().__init__()
80
 
 
86
 
87
  self.dim = dim
88
  self.depth = depth
89
+
90
  self.transformer_blocks = nn.ModuleList(
91
  [
92
  MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
  )
100
  for i in range(depth)
101
  ]
 
105
 
106
  def forward(
107
  self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
  drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
  ):
116
  batch = x.shape[0]
117
  if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
 
120
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
  t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
 
125
  seq_len = x.shape[1]
126
  text_len = text.shape[1]
127
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
  for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
 
133
  x = self.norm_out(x, t)
134
  output = self.proj_out(x)
model/backbones/unett.py CHANGED
@@ -14,6 +14,8 @@ import torch
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
 
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
@@ -24,16 +26,14 @@ from model.modules import (
24
  Attention,
25
  AttnProcessor,
26
  FeedForward,
27
- precompute_freqs_cis,
28
- get_pos_embed_indices,
29
  )
30
 
31
 
32
  # Text embedding
33
 
34
-
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
@@ -41,22 +41,20 @@ class TextEmbedding(nn.Module):
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
- self.text_blocks = nn.Sequential(
45
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
- )
47
  else:
48
  self.extra_modeling = False
49
 
50
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
51
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
- batch, text_len = text.shape[0], text.shape[1]
54
- text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
58
 
59
- text = self.text_embed(text) # b n -> b n d
60
 
61
  # possible extra modeling
62
  if self.extra_modeling:
@@ -74,40 +72,28 @@ class TextEmbedding(nn.Module):
74
 
75
  # noised input audio and context mixing embedding
76
 
77
-
78
  class InputEmbedding(nn.Module):
79
  def __init__(self, mel_dim, text_dim, out_dim):
80
  super().__init__()
81
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
 
84
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
  if drop_audio_cond: # cfg for cond audio
86
  cond = torch.zeros_like(cond)
87
 
88
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
  x = self.conv_pos_embed(x) + x
90
  return x
91
 
92
 
93
  # Flat UNet Transformer backbone
94
 
95
-
96
  class UNetT(nn.Module):
97
- def __init__(
98
- self,
99
- *,
100
- dim,
101
- depth=8,
102
- heads=8,
103
- dim_head=64,
104
- dropout=0.1,
105
- ff_mult=4,
106
- mel_dim=100,
107
- text_num_embeds=256,
108
- text_dim=None,
109
- conv_layers=0,
110
- skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
113
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
@@ -115,7 +101,7 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -124,7 +110,7 @@ class UNetT(nn.Module):
124
 
125
  self.dim = dim
126
  self.skip_connect_type = skip_connect_type
127
- needs_skip_proj = skip_connect_type == "concat"
128
 
129
  self.depth = depth
130
  self.layers = nn.ModuleList([])
@@ -134,57 +120,53 @@ class UNetT(nn.Module):
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
- processor=AttnProcessor(),
138
- dim=dim,
139
- heads=heads,
140
- dim_head=dim_head,
141
- dropout=dropout,
142
- )
143
 
144
  ff_norm = RMSNorm(dim)
145
- ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
-
147
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
-
149
- self.layers.append(
150
- nn.ModuleList(
151
- [
152
- skip_proj,
153
- attn_norm,
154
- attn,
155
- ff_norm,
156
- ff,
157
- ]
158
- )
159
- )
160
 
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
164
  def forward(
165
  self,
166
- x: float["b n d"], # nosied input audio # noqa: F722
167
- cond: float["b n d"], # masked cond audio # noqa: F722
168
- text: int["b nt"], # text # noqa: F722
169
- time: float["b"] | float[""], # time step # noqa: F821 F722
170
  drop_audio_cond, # cfg for cond audio
171
- drop_text, # cfg for text
172
- mask: bool["b n"] | None = None, # noqa: F722
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
176
- time = time.repeat(batch)
177
-
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
184
- x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
  if mask is not None:
186
  mask = F.pad(mask, (1, 0), value=1)
187
-
188
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
 
190
  # flat unet transformer
@@ -202,18 +184,18 @@ class UNetT(nn.Module):
202
 
203
  if is_later_half:
204
  skip = skips.pop()
205
- if skip_connect_type == "concat":
206
- x = torch.cat((x, skip), dim=-1)
207
  x = maybe_skip_proj(x)
208
- elif skip_connect_type == "add":
209
  x = x + skip
210
 
211
  # attention and feedforward blocks
212
- x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
  x = ff(ff_norm(x)) + x
214
 
215
  assert len(skips) == 0
216
 
217
- x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
 
219
  return self.proj_out(x)
 
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
17
+ from einops import repeat, pack, unpack
18
+
19
  from x_transformers import RMSNorm
20
  from x_transformers.x_transformers import RotaryEmbedding
21
 
 
26
  Attention,
27
  AttnProcessor,
28
  FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
 
30
  )
31
 
32
 
33
  # Text embedding
34
 
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
45
  else:
46
  self.extra_modeling = False
47
 
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
53
 
54
  if drop_text: # cfg for text
55
  text = torch.zeros_like(text)
56
 
57
+ text = self.text_embed(text) # b n -> b n d
58
 
59
  # possible extra modeling
60
  if self.extra_modeling:
 
72
 
73
  # noised input audio and context mixing embedding
74
 
 
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
 
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
 
89
 
90
  # Flat UNet Transformer backbone
91
 
 
92
  class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
 
 
 
 
 
 
 
 
 
 
97
  ):
98
  super().__init__()
99
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
 
101
  self.time_embed = TimestepEmbedding(dim)
102
  if text_dim is None:
103
  text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
 
110
 
111
  self.dim = dim
112
  self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
 
115
  self.depth = depth
116
  self.layers = nn.ModuleList([])
 
120
 
121
  attn_norm = RMSNorm(dim)
122
  attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
 
130
  ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
 
 
 
 
142
 
143
  self.norm_out = RMSNorm(dim)
144
  self.proj_out = nn.Linear(dim, mel_dim)
145
 
146
  def forward(
147
  self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
  drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
  ):
156
  batch, seq_len = x.shape[0], x.shape[1]
157
  if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
  t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
 
165
  # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
  if mask is not None:
168
  mask = F.pad(mask, (1, 0), value=1)
169
+
170
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
 
172
  # flat unet transformer
 
184
 
185
  if is_later_half:
186
  skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
  x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
  x = x + skip
192
 
193
  # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
  x = ff(ff_norm(x)) + x
196
 
197
  assert len(skips) == 0
198
 
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
 
201
  return self.proj_out(x)
model/cfm.py CHANGED
@@ -18,34 +18,34 @@ from torch.nn.utils.rnn import pad_sequence
18
 
19
  from torchdiffeq import odeint
20
 
 
 
21
  from model.modules import MelSpec
 
22
  from model.utils import (
23
- default,
24
- exists,
25
- list_str_to_idx,
26
- list_str_to_tensor,
27
- lens_to_mask,
28
- mask_from_frac_lengths,
29
- )
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
- sigma=0.0,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
- method="euler" # 'midpoint'
41
  ),
42
- audio_drop_prob=0.3,
43
- cond_drop_prob=0.2,
44
- num_channels=None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
- frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
- vocab_char_map: dict[str:int] | None = None,
49
  ):
50
  super().__init__()
51
 
@@ -81,37 +81,33 @@ class CFM(nn.Module):
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
- cond: float["b n d"] | float["b nw"], # noqa: F722
85
- text: int["b nt"] | list[str], # noqa: F722
86
- duration: int | int["b"], # noqa: F821
87
  *,
88
- lens: int["b"] | None = None, # noqa: F821
89
- steps=32,
90
- cfg_strength=1.0,
91
- sway_sampling_coef=None,
92
  seed: int | None = None,
93
- max_duration=4096,
94
- vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
- no_ref_audio=False,
96
- duplicate_test=False,
97
- t_inter=0.1,
98
- edit_mask=None,
99
  ):
100
  self.eval()
101
 
102
- if next(self.parameters()).dtype == torch.float16:
103
- cond = cond.half()
104
-
105
  # raw wave
106
 
107
  if cond.ndim == 2:
108
  cond = self.mel_spec(cond)
109
- cond = cond.permute(0, 2, 1)
110
  assert cond.shape[-1] == self.num_channels
111
 
112
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
113
  if not exists(lens):
114
- lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
115
 
116
  # text
117
 
@@ -123,37 +119,30 @@ class CFM(nn.Module):
123
  assert text.shape[0] == batch
124
 
125
  if exists(text):
126
- text_lens = (text != -1).sum(dim=-1)
127
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
128
 
129
  # duration
130
 
131
  cond_mask = lens_to_mask(lens)
132
- if edit_mask is not None:
133
- cond_mask = cond_mask & edit_mask
134
 
135
  if isinstance(duration, int):
136
- duration = torch.full((batch,), duration, device=device, dtype=torch.long)
137
 
138
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
139
- duration = duration.clamp(max=max_duration)
140
  max_duration = duration.amax()
141
-
142
  # duplicate test corner for inner time step oberservation
143
  if duplicate_test:
144
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
145
-
146
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
147
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
148
- cond_mask = cond_mask.unsqueeze(-1)
149
- step_cond = torch.where(
150
- cond_mask, cond, torch.zeros_like(cond)
151
- ) # allow direct control (cut cond audio) with lens passed in
152
 
153
- if batch > 1:
154
- mask = lens_to_mask(duration)
155
- else: # save memory and speed up, as single inference need no mask currently
156
- mask = None
157
 
158
  # test for no ref audio
159
  if no_ref_audio:
@@ -166,15 +155,11 @@ class CFM(nn.Module):
166
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
167
 
168
  # predict flow
169
- pred = self.transformer(
170
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
171
- )
172
  if cfg_strength < 1e-5:
173
  return pred
174
-
175
- null_pred = self.transformer(
176
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
177
- )
178
  return pred + (pred - null_pred) * cfg_strength
179
 
180
  # noise input
@@ -184,8 +169,8 @@ class CFM(nn.Module):
184
  for dur in duration:
185
  if exists(seed):
186
  torch.manual_seed(seed)
187
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
188
- y0 = pad_sequence(y0, padding_value=0, batch_first=True)
189
 
190
  t_start = 0
191
 
@@ -195,37 +180,37 @@ class CFM(nn.Module):
195
  y0 = (1 - t_start) * y0 + t_start * test_cond
196
  steps = int(steps * (1 - t_start))
197
 
198
- t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
199
  if sway_sampling_coef is not None:
200
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
201
 
202
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
203
-
204
  sampled = trajectory[-1]
205
  out = sampled
206
  out = torch.where(cond_mask, cond, out)
207
 
208
  if exists(vocoder):
209
- out = out.permute(0, 2, 1)
210
  out = vocoder(out)
211
 
212
  return out, trajectory
213
 
214
  def forward(
215
  self,
216
- inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
217
- text: int["b nt"] | list[str], # noqa: F722
218
  *,
219
- lens: int["b"] | None = None, # noqa: F821
220
  noise_scheduler: str | None = None,
221
  ):
222
  # handle raw wave
223
  if inp.ndim == 2:
224
  inp = self.mel_spec(inp)
225
- inp = inp.permute(0, 2, 1)
226
  assert inp.shape[-1] == self.num_channels
227
 
228
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
229
 
230
  # handle text as string
231
  if isinstance(text, list):
@@ -237,12 +222,12 @@ class CFM(nn.Module):
237
 
238
  # lens and mask
239
  if not exists(lens):
240
- lens = torch.full((batch,), seq_len, device=device)
241
-
242
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
243
 
244
  # get a random span to mask out for training conditionally
245
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
246
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
247
 
248
  if exists(mask):
@@ -255,16 +240,19 @@ class CFM(nn.Module):
255
  x0 = torch.randn_like(x1)
256
 
257
  # time step
258
- time = torch.rand((batch,), dtype=dtype, device=self.device)
259
  # TODO. noise_scheduler
260
 
261
  # sample xt (φ_t(x) in the paper)
262
- t = time.unsqueeze(-1).unsqueeze(-1)
263
  φ = (1 - t) * x0 + t * x1
264
  flow = x1 - x0
265
 
266
  # only predict what is within the random mask span for infilling
267
- cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
 
 
 
268
 
269
  # transformer and cfg training with a drop rate
270
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
@@ -273,15 +261,13 @@ class CFM(nn.Module):
273
  drop_text = True
274
  else:
275
  drop_text = False
276
-
277
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
278
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
279
- pred = self.transformer(
280
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
281
- )
282
 
283
  # flow matching loss
284
- loss = F.mse_loss(pred, flow, reduction="none")
285
  loss = loss[rand_span_mask]
286
 
287
  return loss.mean(), cond, pred
 
18
 
19
  from torchdiffeq import odeint
20
 
21
+ from einops import rearrange
22
+
23
  from model.modules import MelSpec
24
+
25
  from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
 
 
 
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
+ sigma = 0.,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
  ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
  ):
50
  super().__init__()
51
 
 
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
  *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
  seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
 
98
  ):
99
  self.eval()
100
 
 
 
 
101
  # raw wave
102
 
103
  if cond.ndim == 2:
104
  cond = self.mel_spec(cond)
105
+ cond = rearrange(cond, 'b d n -> b n d')
106
  assert cond.shape[-1] == self.num_channels
107
 
108
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
109
  if not exists(lens):
110
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
111
 
112
  # text
113
 
 
119
  assert text.shape[0] == batch
120
 
121
  if exists(text):
122
+ text_lens = (text != -1).sum(dim = -1)
123
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
124
 
125
  # duration
126
 
127
  cond_mask = lens_to_mask(lens)
 
 
128
 
129
  if isinstance(duration, int):
130
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
131
 
132
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
133
+ duration = duration.clamp(max = max_duration)
134
  max_duration = duration.amax()
135
+
136
  # duplicate test corner for inner time step oberservation
137
  if duplicate_test:
138
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
139
+
140
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
141
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
142
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
143
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
 
 
144
 
145
+ mask = lens_to_mask(duration)
 
 
 
146
 
147
  # test for no ref audio
148
  if no_ref_audio:
 
155
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
156
 
157
  # predict flow
158
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
 
 
159
  if cfg_strength < 1e-5:
160
  return pred
161
+
162
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
 
 
163
  return pred + (pred - null_pred) * cfg_strength
164
 
165
  # noise input
 
169
  for dur in duration:
170
  if exists(seed):
171
  torch.manual_seed(seed)
172
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
173
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
174
 
175
  t_start = 0
176
 
 
180
  y0 = (1 - t_start) * y0 + t_start * test_cond
181
  steps = int(steps * (1 - t_start))
182
 
183
+ t = torch.linspace(t_start, 1, steps, device = self.device)
184
  if sway_sampling_coef is not None:
185
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
186
 
187
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
188
+
189
  sampled = trajectory[-1]
190
  out = sampled
191
  out = torch.where(cond_mask, cond, out)
192
 
193
  if exists(vocoder):
194
+ out = rearrange(out, 'b n d -> b d n')
195
  out = vocoder(out)
196
 
197
  return out, trajectory
198
 
199
  def forward(
200
  self,
201
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
202
+ text: int['b nt'] | list[str],
203
  *,
204
+ lens: int['b'] | None = None,
205
  noise_scheduler: str | None = None,
206
  ):
207
  # handle raw wave
208
  if inp.ndim == 2:
209
  inp = self.mel_spec(inp)
210
+ inp = rearrange(inp, 'b d n -> b n d')
211
  assert inp.shape[-1] == self.num_channels
212
 
213
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
214
 
215
  # handle text as string
216
  if isinstance(text, list):
 
222
 
223
  # lens and mask
224
  if not exists(lens):
225
+ lens = torch.full((batch,), seq_len, device = device)
226
+
227
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
228
 
229
  # get a random span to mask out for training conditionally
230
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
231
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
232
 
233
  if exists(mask):
 
240
  x0 = torch.randn_like(x1)
241
 
242
  # time step
243
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
244
  # TODO. noise_scheduler
245
 
246
  # sample xt (φ_t(x) in the paper)
247
+ t = rearrange(time, 'b -> b 1 1')
248
  φ = (1 - t) * x0 + t * x1
249
  flow = x1 - x0
250
 
251
  # only predict what is within the random mask span for infilling
252
+ cond = torch.where(
253
+ rand_span_mask[..., None],
254
+ torch.zeros_like(x1), x1
255
+ )
256
 
257
  # transformer and cfg training with a drop rate
258
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
 
261
  drop_text = True
262
  else:
263
  drop_text = False
264
+
265
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
266
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
267
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
 
 
268
 
269
  # flow matching loss
270
+ loss = F.mse_loss(pred, flow, reduction = 'none')
271
  loss = loss[rand_span_mask]
272
 
273
  return loss.mean(), cond, pred
model/dataset.py CHANGED
@@ -6,67 +6,65 @@ import torch
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
- from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
- from torch import nn
 
12
 
13
  from model.modules import MelSpec
14
- from model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
- target_sample_rate=24_000,
22
- n_mel_channels=100,
23
- hop_length=256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
- self.mel_spectrogram = MelSpec(
29
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
30
- )
31
-
32
  def get_frame_len(self, index):
33
  row = self.data[index]
34
- audio = row["audio"]["array"]
35
- sample_rate = row["audio"]["sampling_rate"]
36
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
37
 
38
  def __len__(self):
39
  return len(self.data)
40
-
41
  def __getitem__(self, index):
42
  row = self.data[index]
43
- audio = row["audio"]["array"]
44
 
45
  # logger.info(f"Audio shape: {audio.shape}")
46
 
47
- sample_rate = row["audio"]["sampling_rate"]
48
  duration = audio.shape[-1] / sample_rate
49
 
50
  if duration > 30 or duration < 0.3:
51
  return self.__getitem__((index + 1) % len(self.data))
52
-
53
  audio_tensor = torch.from_numpy(audio).float()
54
-
55
  if sample_rate != self.target_sample_rate:
56
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
57
  audio_tensor = resampler(audio_tensor)
58
-
59
- audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
60
-
61
  mel_spec = self.mel_spectrogram(audio_tensor)
62
-
63
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
64
-
65
- text = row["text"]
66
-
67
  return dict(
68
- mel_spec=mel_spec,
69
- text=text,
70
  )
71
 
72
 
@@ -74,39 +72,28 @@ class CustomDataset(Dataset):
74
  def __init__(
75
  self,
76
  custom_dataset: Dataset,
77
- durations=None,
78
- target_sample_rate=24_000,
79
- hop_length=256,
80
- n_mel_channels=100,
81
- preprocessed_mel=False,
82
- mel_spec_module: nn.Module | None = None,
83
  ):
84
  self.data = custom_dataset
85
  self.durations = durations
86
  self.target_sample_rate = target_sample_rate
87
  self.hop_length = hop_length
88
  self.preprocessed_mel = preprocessed_mel
89
-
90
  if not preprocessed_mel:
91
- self.mel_spectrogram = default(
92
- mel_spec_module,
93
- MelSpec(
94
- target_sample_rate=target_sample_rate,
95
- hop_length=hop_length,
96
- n_mel_channels=n_mel_channels,
97
- ),
98
- )
99
 
100
  def get_frame_len(self, index):
101
- if (
102
- self.durations is not None
103
- ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
104
  return self.durations[index] * self.target_sample_rate / self.hop_length
105
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
106
-
107
  def __len__(self):
108
  return len(self.data)
109
-
110
  def __getitem__(self, index):
111
  row = self.data[index]
112
  audio_path = row["audio_path"]
@@ -118,57 +105,48 @@ class CustomDataset(Dataset):
118
 
119
  else:
120
  audio, source_sample_rate = torchaudio.load(audio_path)
121
- if audio.shape[0] > 1:
122
- audio = torch.mean(audio, dim=0, keepdim=True)
123
 
124
  if duration > 30 or duration < 0.3:
125
  return self.__getitem__((index + 1) % len(self.data))
126
-
127
  if source_sample_rate != self.target_sample_rate:
128
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
129
  audio = resampler(audio)
130
-
131
  mel_spec = self.mel_spectrogram(audio)
132
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
133
-
134
  return dict(
135
- mel_spec=mel_spec,
136
- text=text,
137
  )
138
-
139
 
140
  # Dynamic Batch Sampler
141
 
142
-
143
  class DynamicBatchSampler(Sampler[list[int]]):
144
- """Extension of Sampler that will do the following:
145
- 1. Change the batch size (essentially number of sequences)
146
- in a batch to ensure that the total number of frames are less
147
- than a certain threshold.
148
- 2. Make sure the padding efficiency in the batch is high.
149
  """
150
 
151
- def __init__(
152
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
153
- ):
154
  self.sampler = sampler
155
  self.frames_threshold = frames_threshold
156
  self.max_samples = max_samples
157
 
158
  indices, batches = [], []
159
  data_source = self.sampler.data_source
160
-
161
- for idx in tqdm(
162
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
163
- ):
164
  indices.append((idx, data_source.get_frame_len(idx)))
165
- indices.sort(key=lambda elem: elem[1])
166
 
167
  batch = []
168
  batch_frames = 0
169
- for idx, frame_len in tqdm(
170
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
171
- ):
172
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
173
  batch.append(idx)
174
  batch_frames += frame_len
@@ -204,91 +182,61 @@ class DynamicBatchSampler(Sampler[list[int]]):
204
 
205
  # Load dataset
206
 
207
-
208
  def load_dataset(
209
- dataset_name: str,
210
- tokenizer: str = "pinyin",
211
- dataset_type: str = "CustomDataset",
212
- audio_type: str = "raw",
213
- mel_spec_module: nn.Module | None = None,
214
- mel_spec_kwargs: dict = dict(),
215
- ) -> CustomDataset | HFDataset:
216
- """
217
- dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
218
- - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
219
- """
220
-
221
  print("Loading dataset ...")
222
 
223
  if dataset_type == "CustomDataset":
224
  if audio_type == "raw":
225
  try:
226
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
227
- except: # noqa: E722
228
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
229
  preprocessed_mel = False
230
  elif audio_type == "mel":
231
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
232
  preprocessed_mel = True
233
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
234
- data_dict = json.load(f)
235
- durations = data_dict["duration"]
236
- train_dataset = CustomDataset(
237
- train_dataset,
238
- durations=durations,
239
- preprocessed_mel=preprocessed_mel,
240
- mel_spec_module=mel_spec_module,
241
- **mel_spec_kwargs,
242
- )
243
-
244
- elif dataset_type == "CustomDatasetPath":
245
- try:
246
- train_dataset = load_from_disk(f"{dataset_name}/raw")
247
- except: # noqa: E722
248
- train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
249
-
250
- with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
251
  data_dict = json.load(f)
252
  durations = data_dict["duration"]
253
- train_dataset = CustomDataset(
254
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
255
- )
256
 
257
  elif dataset_type == "HFDataset":
258
- print(
259
- "Should manually modify the path of huggingface dataset to your need.\n"
260
- + "May also the corresponding script cuz different dataset may have different format."
261
- )
262
  pre, post = dataset_name.split("_")
263
- train_dataset = HFDataset(
264
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
265
- )
266
 
267
  return train_dataset
268
 
269
 
270
  # collation
271
 
272
-
273
  def collate_fn(batch):
274
- mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
275
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
276
  max_mel_length = mel_lengths.amax()
277
 
278
  padded_mel_specs = []
279
  for spec in mel_specs: # TODO. maybe records mask for attention here
280
  padding = (0, max_mel_length - spec.size(-1))
281
- padded_spec = F.pad(spec, padding, value=0)
282
  padded_mel_specs.append(padded_spec)
283
-
284
  mel_specs = torch.stack(padded_mel_specs)
285
 
286
- text = [item["text"] for item in batch]
287
  text_lengths = torch.LongTensor([len(item) for item in text])
288
 
289
  return dict(
290
- mel=mel_specs,
291
- mel_lengths=mel_lengths,
292
- text=text,
293
- text_lengths=text_lengths,
294
  )
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
 
14
  from model.modules import MelSpec
 
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
 
 
30
  def get_frame_len(self, index):
31
  row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
 
36
  def __len__(self):
37
  return len(self.data)
38
+
39
  def __getitem__(self, index):
40
  row = self.data[index]
41
+ audio = row['audio']['array']
42
 
43
  # logger.info(f"Audio shape: {audio.shape}")
44
 
45
+ sample_rate = row['audio']['sampling_rate']
46
  duration = audio.shape[-1] / sample_rate
47
 
48
  if duration > 30 or duration < 0.3:
49
  return self.__getitem__((index + 1) % len(self.data))
50
+
51
  audio_tensor = torch.from_numpy(audio).float()
52
+
53
  if sample_rate != self.target_sample_rate:
54
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
  audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
  mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
  return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
  )
69
 
70
 
 
72
  def __init__(
73
  self,
74
  custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
 
80
  ):
81
  self.data = custom_dataset
82
  self.durations = durations
83
  self.target_sample_rate = target_sample_rate
84
  self.hop_length = hop_length
85
  self.preprocessed_mel = preprocessed_mel
 
86
  if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
 
 
 
 
 
 
 
88
 
89
  def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
 
 
91
  return self.durations[index] * self.target_sample_rate / self.hop_length
92
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
  def __len__(self):
95
  return len(self.data)
96
+
97
  def __getitem__(self, index):
98
  row = self.data[index]
99
  audio_path = row["audio_path"]
 
105
 
106
  else:
107
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
108
 
109
  if duration > 30 or duration < 0.3:
110
  return self.__getitem__((index + 1) % len(self.data))
111
+
112
  if source_sample_rate != self.target_sample_rate:
113
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
  audio = resampler(audio)
115
+
116
  mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
  return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
  )
123
+
124
 
125
  # Dynamic Batch Sampler
126
 
 
127
  class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
  """
134
 
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
 
 
136
  self.sampler = sampler
137
  self.frames_threshold = frames_threshold
138
  self.max_samples = max_samples
139
 
140
  indices, batches = [], []
141
  data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
 
 
144
  indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
 
147
  batch = []
148
  batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
 
 
150
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
  batch.append(idx)
152
  batch_frames += frame_len
 
182
 
183
  # Load dataset
184
 
 
185
  def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str,
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+
 
 
 
 
 
193
  print("Loading dataset ...")
194
 
195
  if dataset_type == "CustomDataset":
196
  if audio_type == "raw":
197
  try:
198
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
199
+ except:
200
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
201
  preprocessed_mel = False
202
  elif audio_type == "mel":
203
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
204
  preprocessed_mel = True
205
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  data_dict = json.load(f)
207
  durations = data_dict["duration"]
208
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
 
 
209
 
210
  elif dataset_type == "HFDataset":
211
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
212
+ "May also the corresponding script cuz different dataset may have different format.")
 
 
213
  pre, post = dataset_name.split("_")
214
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
 
 
215
 
216
  return train_dataset
217
 
218
 
219
  # collation
220
 
 
221
  def collate_fn(batch):
222
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
223
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
224
  max_mel_length = mel_lengths.amax()
225
 
226
  padded_mel_specs = []
227
  for spec in mel_specs: # TODO. maybe records mask for attention here
228
  padding = (0, max_mel_length - spec.size(-1))
229
+ padded_spec = F.pad(spec, padding, value = 0)
230
  padded_mel_specs.append(padded_spec)
231
+
232
  mel_specs = torch.stack(padded_mel_specs)
233
 
234
+ text = [item['text'] for item in batch]
235
  text_lengths = torch.LongTensor([len(item) for item in text])
236
 
237
  return dict(
238
+ mel = mel_specs,
239
+ mel_lengths = mel_lengths,
240
+ text = text,
241
+ text_lengths = text_lengths,
242
  )
model/ecapa_tdnn.py CHANGED
@@ -9,14 +9,13 @@ import torch.nn as nn
9
  import torch.nn.functional as F
10
 
11
 
12
- """ Res2Conv1d + BatchNorm1d + ReLU
13
- """
14
-
15
 
16
  class Res2Conv1dReluBn(nn.Module):
17
- """
18
  in_channels == out_channels == channels
19
- """
20
 
21
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
  super().__init__()
@@ -52,9 +51,8 @@ class Res2Conv1dReluBn(nn.Module):
52
  return out
53
 
54
 
55
- """ Conv1d + BatchNorm1d + ReLU
56
- """
57
-
58
 
59
  class Conv1dReluBn(nn.Module):
60
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
@@ -66,9 +64,8 @@ class Conv1dReluBn(nn.Module):
66
  return self.bn(F.relu(self.conv(x)))
67
 
68
 
69
- """ The SE connection of 1D case.
70
- """
71
-
72
 
73
  class SE_Connect(nn.Module):
74
  def __init__(self, channels, se_bottleneck_dim=128):
@@ -85,8 +82,8 @@ class SE_Connect(nn.Module):
85
  return out
86
 
87
 
88
- """ SE-Res2Block of the ECAPA-TDNN architecture.
89
- """
90
 
91
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
  # return nn.Sequential(
@@ -96,7 +93,6 @@ class SE_Connect(nn.Module):
96
  # SE_Connect(channels)
97
  # )
98
 
99
-
100
  class SE_Res2Block(nn.Module):
101
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
  super().__init__()
@@ -126,9 +122,8 @@ class SE_Res2Block(nn.Module):
126
  return x + residual
127
 
128
 
129
- """ Attentive weighted mean and standard deviation pooling.
130
- """
131
-
132
 
133
  class AttentiveStatsPool(nn.Module):
134
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -143,6 +138,7 @@ class AttentiveStatsPool(nn.Module):
143
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
 
145
  def forward(self, x):
 
146
  if self.global_context_att:
147
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
@@ -155,52 +151,38 @@ class AttentiveStatsPool(nn.Module):
155
  # alpha = F.relu(self.linear1(x_in))
156
  alpha = torch.softmax(self.linear2(alpha), dim=2)
157
  mean = torch.sum(alpha * x, dim=2)
158
- residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
  std = torch.sqrt(residuals.clamp(min=1e-9))
160
  return torch.cat([mean, std], dim=1)
161
 
162
 
163
  class ECAPA_TDNN(nn.Module):
164
- def __init__(
165
- self,
166
- feat_dim=80,
167
- channels=512,
168
- emb_dim=192,
169
- global_context_att=False,
170
- feat_type="wavlm_large",
171
- sr=16000,
172
- feature_selection="hidden_states",
173
- update_extract=False,
174
- config_path=None,
175
- ):
176
  super().__init__()
177
 
178
  self.feat_type = feat_type
179
  self.feature_selection = feature_selection
180
  self.update_extract = update_extract
181
  self.sr = sr
182
-
183
- torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
  try:
185
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
- except: # noqa: E722
188
- self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
 
190
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
- self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
- ):
193
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
- self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
- ):
197
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
 
199
  self.feat_num = self.get_feat_num()
200
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
 
202
- if feat_type != "fbank" and feat_type != "mfcc":
203
- freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
  for name, param in self.feature_extract.named_parameters():
205
  for freeze_val in freeze_list:
206
  if freeze_val in name:
@@ -216,46 +198,18 @@ class ECAPA_TDNN(nn.Module):
216
  self.channels = [channels] * 4 + [1536]
217
 
218
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
- self.layer2 = SE_Res2Block(
220
- self.channels[0],
221
- self.channels[1],
222
- kernel_size=3,
223
- stride=1,
224
- padding=2,
225
- dilation=2,
226
- scale=8,
227
- se_bottleneck_dim=128,
228
- )
229
- self.layer3 = SE_Res2Block(
230
- self.channels[1],
231
- self.channels[2],
232
- kernel_size=3,
233
- stride=1,
234
- padding=3,
235
- dilation=3,
236
- scale=8,
237
- se_bottleneck_dim=128,
238
- )
239
- self.layer4 = SE_Res2Block(
240
- self.channels[2],
241
- self.channels[3],
242
- kernel_size=3,
243
- stride=1,
244
- padding=4,
245
- dilation=4,
246
- scale=8,
247
- se_bottleneck_dim=128,
248
- )
249
 
250
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
  cat_channels = channels * 3
252
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
- self.pooling = AttentiveStatsPool(
254
- self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
- )
256
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
 
 
259
  def get_feat_num(self):
260
  self.feature_extract.eval()
261
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
@@ -272,12 +226,12 @@ class ECAPA_TDNN(nn.Module):
272
  x = self.feature_extract([sample for sample in x])
273
  else:
274
  with torch.no_grad():
275
- if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
  else:
278
  x = self.feature_extract([sample for sample in x])
279
 
280
- if self.feat_type == "fbank":
281
  x = x.log()
282
 
283
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
@@ -309,22 +263,6 @@ class ECAPA_TDNN(nn.Module):
309
  return out
310
 
311
 
312
- def ECAPA_TDNN_SMALL(
313
- feat_dim,
314
- emb_dim=256,
315
- feat_type="wavlm_large",
316
- sr=16000,
317
- feature_selection="hidden_states",
318
- update_extract=False,
319
- config_path=None,
320
- ):
321
- return ECAPA_TDNN(
322
- feat_dim=feat_dim,
323
- channels=512,
324
- emb_dim=emb_dim,
325
- feat_type=feat_type,
326
- sr=sr,
327
- feature_selection=feature_selection,
328
- update_extract=update_extract,
329
- config_path=config_path,
330
- )
 
9
  import torch.nn.functional as F
10
 
11
 
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
 
14
 
15
  class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
  in_channels == out_channels == channels
18
+ '''
19
 
20
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
  super().__init__()
 
51
  return out
52
 
53
 
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
 
56
 
57
  class Conv1dReluBn(nn.Module):
58
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
64
  return self.bn(F.relu(self.conv(x)))
65
 
66
 
67
+ ''' The SE connection of 1D case.
68
+ '''
 
69
 
70
  class SE_Connect(nn.Module):
71
  def __init__(self, channels, se_bottleneck_dim=128):
 
82
  return out
83
 
84
 
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
 
88
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
  # return nn.Sequential(
 
93
  # SE_Connect(channels)
94
  # )
95
 
 
96
  class SE_Res2Block(nn.Module):
97
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
  super().__init__()
 
122
  return x + residual
123
 
124
 
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
 
127
 
128
  class AttentiveStatsPool(nn.Module):
129
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
 
138
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
 
140
  def forward(self, x):
141
+
142
  if self.global_context_att:
143
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
151
  # alpha = F.relu(self.linear1(x_in))
152
  alpha = torch.softmax(self.linear2(alpha), dim=2)
153
  mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
  std = torch.sqrt(residuals.clamp(min=1e-9))
156
  return torch.cat([mean, std], dim=1)
157
 
158
 
159
  class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
 
 
 
 
 
 
 
 
 
 
162
  super().__init__()
163
 
164
  self.feat_type = feat_type
165
  self.feature_selection = feature_selection
166
  self.update_extract = update_extract
167
  self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
  try:
171
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
 
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
 
 
177
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
 
 
179
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
 
181
  self.feat_num = self.get_feat_num()
182
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
 
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
  for name, param in self.feature_extract.named_parameters():
187
  for freeze_val in freeze_list:
188
  if freeze_val in name:
 
198
  self.channels = [channels] * 4 + [1536]
199
 
200
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
  cat_channels = channels * 3
207
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
 
 
209
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
 
212
+
213
  def get_feat_num(self):
214
  self.feature_extract.eval()
215
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
 
226
  x = self.feature_extract([sample for sample in x])
227
  else:
228
  with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
  else:
232
  x = self.feature_extract([sample for sample in x])
233
 
234
+ if self.feat_type == 'fbank':
235
  x = x.log()
236
 
237
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
 
263
  return out
264
 
265
 
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/modules.py CHANGED
@@ -16,45 +16,45 @@ from torch import nn
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
 
19
  from x_transformers.x_transformers import apply_rotary_pos_emb
20
 
21
 
22
  # raw wav to mel spec
23
 
24
-
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
- filter_length=1024,
29
- hop_length=256,
30
- win_length=1024,
31
- n_mel_channels=100,
32
- target_sample_rate=24_000,
33
- normalize=False,
34
- power=1,
35
- norm=None,
36
- center=True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
- sample_rate=target_sample_rate,
43
- n_fft=filter_length,
44
- win_length=win_length,
45
- hop_length=hop_length,
46
- n_mels=n_mel_channels,
47
- power=power,
48
- center=center,
49
- normalized=normalize,
50
- norm=norm,
51
  )
52
 
53
- self.register_buffer("dummy", torch.tensor(0), persistent=False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
- inp = inp.squeeze(1) # 'b 1 nw -> b nw'
58
 
59
  assert len(inp.shape) == 2
60
 
@@ -62,13 +62,12 @@ class MelSpec(nn.Module):
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
- mel = mel.clamp(min=1e-5).log()
66
  return mel
67
-
68
 
69
  # sinusoidal position embedding
70
 
71
-
72
  class SinusPositionEmbedding(nn.Module):
73
  def __init__(self, dim):
74
  super().__init__()
@@ -86,37 +85,35 @@ class SinusPositionEmbedding(nn.Module):
86
 
87
  # convolutional position embedding
88
 
89
-
90
  class ConvPositionEmbedding(nn.Module):
91
- def __init__(self, dim, kernel_size=31, groups=16):
92
  super().__init__()
93
  assert kernel_size % 2 != 0
94
  self.conv1d = nn.Sequential(
95
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
96
  nn.Mish(),
97
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
98
  nn.Mish(),
99
  )
100
 
101
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
102
  if mask is not None:
103
  mask = mask[..., None]
104
- x = x.masked_fill(~mask, 0.0)
105
 
106
- x = x.permute(0, 2, 1)
107
  x = self.conv1d(x)
108
- out = x.permute(0, 2, 1)
109
 
110
  if mask is not None:
111
- out = out.masked_fill(~mask, 0.0)
112
 
113
  return out
114
 
115
 
116
  # rotary positional embedding related
117
 
118
-
119
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
120
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
121
  # has some connection to NTK literature
122
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -129,14 +126,12 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
129
  freqs_sin = torch.sin(freqs) # imaginary part
130
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
131
 
132
-
133
- def get_pos_embed_indices(start, length, max_pos, scale=1.0):
134
  # length = length if isinstance(length, int) else length.max()
135
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
136
- pos = (
137
- start.unsqueeze(1)
138
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
139
- )
140
  # avoid extra long error.
141
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
142
  return pos
@@ -144,7 +139,6 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
144
 
145
  # Global Response Normalization layer (Instance Normalization ?)
146
 
147
-
148
  class GRN(nn.Module):
149
  def __init__(self, dim):
150
  super().__init__()
@@ -160,7 +154,6 @@ class GRN(nn.Module):
160
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
161
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
162
 
163
-
164
  class ConvNeXtV2Block(nn.Module):
165
  def __init__(
166
  self,
@@ -170,9 +163,7 @@ class ConvNeXtV2Block(nn.Module):
170
  ):
171
  super().__init__()
172
  padding = (dilation * (7 - 1)) // 2
173
- self.dwconv = nn.Conv1d(
174
- dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
175
- ) # depthwise conv
176
  self.norm = nn.LayerNorm(dim, eps=1e-6)
177
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
178
  self.act = nn.GELU()
@@ -195,7 +186,6 @@ class ConvNeXtV2Block(nn.Module):
195
  # AdaLayerNormZero
196
  # return with modulated x for attn input, and params for later mlp modulation
197
 
198
-
199
  class AdaLayerNormZero(nn.Module):
200
  def __init__(self, dim):
201
  super().__init__()
@@ -205,7 +195,7 @@ class AdaLayerNormZero(nn.Module):
205
 
206
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
207
 
208
- def forward(self, x, emb=None):
209
  emb = self.linear(self.silu(emb))
210
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
211
 
@@ -216,7 +206,6 @@ class AdaLayerNormZero(nn.Module):
216
  # AdaLayerNormZero for final layer
217
  # return only with modulated x for attn input, cuz no more mlp modulation
218
 
219
-
220
  class AdaLayerNormZero_Final(nn.Module):
221
  def __init__(self, dim):
222
  super().__init__()
@@ -236,16 +225,22 @@ class AdaLayerNormZero_Final(nn.Module):
236
 
237
  # FeedForward
238
 
239
-
240
  class FeedForward(nn.Module):
241
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
242
  super().__init__()
243
  inner_dim = int(dim * mult)
244
  dim_out = dim_out if dim_out is not None else dim
245
 
246
  activation = nn.GELU(approximate=approximate)
247
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
248
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
 
 
 
 
 
249
 
250
  def forward(self, x):
251
  return self.ff(x)
@@ -254,7 +249,6 @@ class FeedForward(nn.Module):
254
  # Attention with possible joint part
255
  # modified from diffusers/src/diffusers/models/attention_processor.py
256
 
257
-
258
  class Attention(nn.Module):
259
  def __init__(
260
  self,
@@ -263,8 +257,8 @@ class Attention(nn.Module):
263
  heads: int = 8,
264
  dim_head: int = 64,
265
  dropout: float = 0.0,
266
- context_dim: Optional[int] = None, # if not None -> joint attention
267
- context_pre_only=None,
268
  ):
269
  super().__init__()
270
 
@@ -300,21 +294,20 @@ class Attention(nn.Module):
300
 
301
  def forward(
302
  self,
303
- x: float["b n d"], # noised input x # noqa: F722
304
- c: float["b n d"] = None, # context c # noqa: F722
305
- mask: bool["b n"] | None = None, # noqa: F722
306
- rope=None, # rotary position embedding for x
307
- c_rope=None, # rotary position embedding for c
308
  ) -> torch.Tensor:
309
  if c is not None:
310
- return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
311
  else:
312
- return self.processor(self, x, mask=mask, rope=rope)
313
 
314
 
315
  # Attention processor
316
 
317
-
318
  class AttnProcessor:
319
  def __init__(self):
320
  pass
@@ -322,10 +315,11 @@ class AttnProcessor:
322
  def __call__(
323
  self,
324
  attn: Attention,
325
- x: float["b n d"], # noised input x # noqa: F722
326
- mask: bool["b n"] | None = None, # noqa: F722
327
- rope=None, # rotary position embedding
328
  ) -> torch.FloatTensor:
 
329
  batch_size = x.shape[0]
330
 
331
  # `sample` projections.
@@ -336,7 +330,7 @@ class AttnProcessor:
336
  # apply rotary position embedding
337
  if rope is not None:
338
  freqs, xpos_scale = rope
339
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
340
 
341
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
342
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -351,7 +345,7 @@ class AttnProcessor:
351
  # mask. e.g. inference got a batch with different target durations, mask out the padding
352
  if mask is not None:
353
  attn_mask = mask
354
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
355
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
356
  else:
357
  attn_mask = None
@@ -366,16 +360,15 @@ class AttnProcessor:
366
  x = attn.to_out[1](x)
367
 
368
  if mask is not None:
369
- mask = mask.unsqueeze(-1)
370
- x = x.masked_fill(~mask, 0.0)
371
 
372
  return x
373
-
374
 
375
  # Joint Attention processor for MM-DiT
376
  # modified from diffusers/src/diffusers/models/attention_processor.py
377
 
378
-
379
  class JointAttnProcessor:
380
  def __init__(self):
381
  pass
@@ -383,11 +376,11 @@ class JointAttnProcessor:
383
  def __call__(
384
  self,
385
  attn: Attention,
386
- x: float["b n d"], # noised input x # noqa: F722
387
- c: float["b nt d"] = None, # context c, here text # noqa: F722
388
- mask: bool["b n"] | None = None, # noqa: F722
389
- rope=None, # rotary position embedding for x
390
- c_rope=None, # rotary position embedding for c
391
  ) -> torch.FloatTensor:
392
  residual = x
393
 
@@ -406,12 +399,12 @@ class JointAttnProcessor:
406
  # apply rope for context and noised input independently
407
  if rope is not None:
408
  freqs, xpos_scale = rope
409
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
410
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
411
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
412
  if c_rope is not None:
413
  freqs, xpos_scale = c_rope
414
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
415
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
416
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
417
 
@@ -428,8 +421,8 @@ class JointAttnProcessor:
428
 
429
  # mask. e.g. inference got a batch with different target durations, mask out the padding
430
  if mask is not None:
431
- attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
432
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
434
  else:
435
  attn_mask = None
@@ -440,8 +433,8 @@ class JointAttnProcessor:
440
 
441
  # Split the attention outputs.
442
  x, c = (
443
- x[:, : residual.shape[1]],
444
- x[:, residual.shape[1] :],
445
  )
446
 
447
  # linear proj
@@ -452,8 +445,8 @@ class JointAttnProcessor:
452
  c = attn.to_out_c(c)
453
 
454
  if mask is not None:
455
- mask = mask.unsqueeze(-1)
456
- x = x.masked_fill(~mask, 0.0)
457
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
458
 
459
  return x, c
@@ -461,24 +454,24 @@ class JointAttnProcessor:
461
 
462
  # DiT Block
463
 
464
-
465
  class DiTBlock(nn.Module):
466
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
467
- super().__init__()
468
 
 
 
 
469
  self.attn_norm = AdaLayerNormZero(dim)
470
  self.attn = Attention(
471
- processor=AttnProcessor(),
472
- dim=dim,
473
- heads=heads,
474
- dim_head=dim_head,
475
- dropout=dropout,
476
- )
477
-
478
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
480
 
481
- def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
482
  # pre-norm & modulation for attention input
483
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
484
 
@@ -487,7 +480,7 @@ class DiTBlock(nn.Module):
487
 
488
  # process attention output for input x
489
  x = x + gate_msa.unsqueeze(1) * attn_output
490
-
491
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
492
  ff_output = self.ff(norm)
493
  x = x + gate_mlp.unsqueeze(1) * ff_output
@@ -497,9 +490,8 @@ class DiTBlock(nn.Module):
497
 
498
  # MMDiT Block https://arxiv.org/abs/2403.03206
499
 
500
-
501
  class MMDiTBlock(nn.Module):
502
- r"""
503
  modified from diffusers/src/diffusers/models/attention.py
504
 
505
  notes.
@@ -508,33 +500,33 @@ class MMDiTBlock(nn.Module):
508
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
509
  """
510
 
511
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
512
  super().__init__()
513
 
514
  self.context_pre_only = context_pre_only
515
-
516
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
517
  self.attn_norm_x = AdaLayerNormZero(dim)
518
  self.attn = Attention(
519
- processor=JointAttnProcessor(),
520
- dim=dim,
521
- heads=heads,
522
- dim_head=dim_head,
523
- dropout=dropout,
524
- context_dim=dim,
525
- context_pre_only=context_pre_only,
526
- )
527
 
528
  if not context_pre_only:
529
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
530
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
531
  else:
532
  self.ff_norm_c = None
533
  self.ff_c = None
534
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
535
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
536
 
537
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
538
  # pre-norm & modulation for attention input
539
  if self.context_pre_only:
540
  norm_c = self.attn_norm_c(c, t)
@@ -548,7 +540,7 @@ class MMDiTBlock(nn.Module):
548
  # process attention output for context c
549
  if self.context_pre_only:
550
  c = None
551
- else: # if not last layer
552
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
553
 
554
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
@@ -557,7 +549,7 @@ class MMDiTBlock(nn.Module):
557
 
558
  # process attention output for input x
559
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
560
-
561
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
562
  x_ff_output = self.ff_x(norm_x)
563
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
@@ -567,15 +559,17 @@ class MMDiTBlock(nn.Module):
567
 
568
  # time step conditioning embedding
569
 
570
-
571
  class TimestepEmbedding(nn.Module):
572
  def __init__(self, dim, freq_embed_dim=256):
573
  super().__init__()
574
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
575
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
 
 
576
 
577
- def forward(self, timestep: float["b"]): # noqa: F821
578
  time_hidden = self.time_embed(timestep)
579
- time_hidden = time_hidden.to(timestep.dtype)
580
  time = self.time_mlp(time_hidden) # b d
581
  return time
 
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
19
+ from einops import rearrange
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
 
23
  # raw wav to mel spec
24
 
 
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
  )
52
 
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
 
59
  assert len(inp.shape) == 2
60
 
 
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
  return mel
67
+
68
 
69
  # sinusoidal position embedding
70
 
 
71
  class SinusPositionEmbedding(nn.Module):
72
  def __init__(self, dim):
73
  super().__init__()
 
85
 
86
  # convolutional position embedding
87
 
 
88
  class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
  super().__init__()
91
  assert kernel_size % 2 != 0
92
  self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
  nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
  nn.Mish(),
97
  )
98
 
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
  if mask is not None:
101
  mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
 
104
+ x = rearrange(x, 'b n d -> b d n')
105
  x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
 
108
  if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
 
111
  return out
112
 
113
 
114
  # rotary positional embedding related
115
 
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
 
117
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
  # has some connection to NTK literature
119
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
126
  freqs_sin = torch.sin(freqs) # imaginary part
127
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
 
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
 
130
  # length = length if isinstance(length, int) else length.max()
131
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
 
135
  # avoid extra long error.
136
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
  return pos
 
139
 
140
  # Global Response Normalization layer (Instance Normalization ?)
141
 
 
142
  class GRN(nn.Module):
143
  def __init__(self, dim):
144
  super().__init__()
 
154
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
 
 
157
  class ConvNeXtV2Block(nn.Module):
158
  def __init__(
159
  self,
 
163
  ):
164
  super().__init__()
165
  padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
 
 
167
  self.norm = nn.LayerNorm(dim, eps=1e-6)
168
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
  self.act = nn.GELU()
 
186
  # AdaLayerNormZero
187
  # return with modulated x for attn input, and params for later mlp modulation
188
 
 
189
  class AdaLayerNormZero(nn.Module):
190
  def __init__(self, dim):
191
  super().__init__()
 
195
 
196
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
 
198
+ def forward(self, x, emb = None):
199
  emb = self.linear(self.silu(emb))
200
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
 
 
206
  # AdaLayerNormZero for final layer
207
  # return only with modulated x for attn input, cuz no more mlp modulation
208
 
 
209
  class AdaLayerNormZero_Final(nn.Module):
210
  def __init__(self, dim):
211
  super().__init__()
 
225
 
226
  # FeedForward
227
 
 
228
  class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
  super().__init__()
231
  inner_dim = int(dim * mult)
232
  dim_out = dim_out if dim_out is not None else dim
233
 
234
  activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
 
245
  def forward(self, x):
246
  return self.ff(x)
 
249
  # Attention with possible joint part
250
  # modified from diffusers/src/diffusers/models/attention_processor.py
251
 
 
252
  class Attention(nn.Module):
253
  def __init__(
254
  self,
 
257
  heads: int = 8,
258
  dim_head: int = 64,
259
  dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
  ):
263
  super().__init__()
264
 
 
294
 
295
  def forward(
296
  self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
  ) -> torch.Tensor:
303
  if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
  else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
 
308
 
309
  # Attention processor
310
 
 
311
  class AttnProcessor:
312
  def __init__(self):
313
  pass
 
315
  def __call__(
316
  self,
317
  attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
  ) -> torch.FloatTensor:
322
+
323
  batch_size = x.shape[0]
324
 
325
  # `sample` projections.
 
330
  # apply rotary position embedding
331
  if rope is not None:
332
  freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
 
335
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
345
  # mask. e.g. inference got a batch with different target durations, mask out the padding
346
  if mask is not None:
347
  attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
  else:
351
  attn_mask = None
 
360
  x = attn.to_out[1](x)
361
 
362
  if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
 
366
  return x
367
+
368
 
369
  # Joint Attention processor for MM-DiT
370
  # modified from diffusers/src/diffusers/models/attention_processor.py
371
 
 
372
  class JointAttnProcessor:
373
  def __init__(self):
374
  pass
 
376
  def __call__(
377
  self,
378
  attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
  ) -> torch.FloatTensor:
385
  residual = x
386
 
 
399
  # apply rope for context and noised input independently
400
  if rope is not None:
401
  freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
  if c_rope is not None:
406
  freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
 
 
421
 
422
  # mask. e.g. inference got a batch with different target durations, mask out the padding
423
  if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
  else:
428
  attn_mask = None
 
433
 
434
  # Split the attention outputs.
435
  x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
  )
439
 
440
  # linear proj
 
445
  c = attn.to_out_c(c)
446
 
447
  if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
 
452
  return x, c
 
454
 
455
  # DiT Block
456
 
 
457
  class DiTBlock(nn.Module):
 
 
458
 
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
  self.attn_norm = AdaLayerNormZero(dim)
463
  self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
 
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
  # pre-norm & modulation for attention input
476
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
 
 
480
 
481
  # process attention output for input x
482
  x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
  ff_output = self.ff(norm)
486
  x = x + gate_mlp.unsqueeze(1) * ff_output
 
490
 
491
  # MMDiT Block https://arxiv.org/abs/2403.03206
492
 
 
493
  class MMDiTBlock(nn.Module):
494
+ r"""
495
  modified from diffusers/src/diffusers/models/attention.py
496
 
497
  notes.
 
500
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
  """
502
 
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
  super().__init__()
505
 
506
  self.context_pre_only = context_pre_only
507
+
508
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
  self.attn_norm_x = AdaLayerNormZero(dim)
510
  self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
 
520
  if not context_pre_only:
521
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
  else:
524
  self.ff_norm_c = None
525
  self.ff_c = None
526
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
 
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
  # pre-norm & modulation for attention input
531
  if self.context_pre_only:
532
  norm_c = self.attn_norm_c(c, t)
 
540
  # process attention output for context c
541
  if self.context_pre_only:
542
  c = None
543
+ else: # if not last layer
544
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
 
546
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
549
 
550
  # process attention output for input x
551
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
  x_ff_output = self.ff_x(norm_x)
555
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
 
559
 
560
  # time step conditioning embedding
561
 
 
562
  class TimestepEmbedding(nn.Module):
563
  def __init__(self, dim, freq_embed_dim=256):
564
  super().__init__()
565
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
 
572
+ def forward(self, timestep: float['b']):
573
  time_hidden = self.time_embed(timestep)
 
574
  time = self.time_mlp(time_hidden) # b d
575
  return time
model/trainer.py CHANGED
@@ -10,6 +10,8 @@ from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
 
 
13
  from accelerate import Accelerator
14
  from accelerate.utils import DistributedDataParallelKwargs
15
 
@@ -22,69 +24,66 @@ from model.dataset import DynamicBatchSampler, collate_fn
22
 
23
  # trainer
24
 
25
-
26
  class Trainer:
27
  def __init__(
28
  self,
29
  model: CFM,
30
  epochs,
31
  learning_rate,
32
- num_warmup_updates=20000,
33
- save_per_updates=1000,
34
- checkpoint_path=None,
35
- batch_size=32,
36
  batch_size_type: str = "sample",
37
- max_samples=32,
38
- grad_accumulation_steps=1,
39
- max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
- wandb_project="test_e2-tts",
43
- wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
45
- last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
- ema_kwargs: dict = dict(),
48
- bnb_optimizer: bool = False,
49
  ):
50
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
-
52
- logger = "wandb" if wandb.api.api_key else None
53
- print(f"Using logger: {logger}")
54
 
55
  self.accelerator = Accelerator(
56
- log_with=logger,
57
- kwargs_handlers=[ddp_kwargs],
58
- gradient_accumulation_steps=grad_accumulation_steps,
59
- **accelerate_kwargs,
60
  )
61
-
62
- if logger == "wandb":
63
- if exists(wandb_resume_id):
64
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
- else:
66
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
67
- self.accelerator.init_trackers(
68
- project_name=wandb_project,
69
- init_kwargs=init_kwargs,
70
- config={
71
- "epochs": epochs,
72
  "learning_rate": learning_rate,
73
- "num_warmup_updates": num_warmup_updates,
74
  "batch_size": batch_size,
75
  "batch_size_type": batch_size_type,
76
  "max_samples": max_samples,
77
  "grad_accumulation_steps": grad_accumulation_steps,
78
  "max_grad_norm": max_grad_norm,
79
  "gpus": self.accelerator.num_processes,
80
- "noise_scheduler": noise_scheduler,
81
- },
82
  )
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
- self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
 
 
 
88
 
89
  self.ema_model.to(self.accelerator.device)
90
 
@@ -92,7 +91,7 @@ class Trainer:
92
  self.num_warmup_updates = num_warmup_updates
93
  self.save_per_updates = save_per_updates
94
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
95
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
96
 
97
  self.batch_size = batch_size
98
  self.batch_size_type = batch_size_type
@@ -104,13 +103,10 @@ class Trainer:
104
 
105
  self.duration_predictor = duration_predictor
106
 
107
- if bnb_optimizer:
108
- import bitsandbytes as bnb
109
-
110
- self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
111
- else:
112
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
113
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
114
 
115
  @property
116
  def is_main(self):
@@ -120,112 +116,76 @@ class Trainer:
120
  self.accelerator.wait_for_everyone()
121
  if self.is_main:
122
  checkpoint = dict(
123
- model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
124
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
125
- ema_model_state_dict=self.ema_model.state_dict(),
126
- scheduler_state_dict=self.scheduler.state_dict(),
127
- step=step,
128
  )
129
  if not os.path.exists(self.checkpoint_path):
130
  os.makedirs(self.checkpoint_path)
131
- if last:
132
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
133
  print(f"Saved last checkpoint at step {step}")
134
  else:
135
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
136
 
137
  def load_checkpoint(self):
138
- if (
139
- not exists(self.checkpoint_path)
140
- or not os.path.exists(self.checkpoint_path)
141
- or not os.listdir(self.checkpoint_path)
142
- ):
143
  return 0
144
-
145
  self.accelerator.wait_for_everyone()
146
  if "model_last.pt" in os.listdir(self.checkpoint_path):
147
  latest_checkpoint = "model_last.pt"
148
  else:
149
- latest_checkpoint = sorted(
150
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
151
- key=lambda x: int("".join(filter(str.isdigit, x))),
152
- )[-1]
153
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
154
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
155
 
156
  if self.is_main:
157
- self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
158
-
159
- if "step" in checkpoint:
160
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
161
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
162
- if self.scheduler:
163
- self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
164
- step = checkpoint["step"]
165
- else:
166
- checkpoint["model_state_dict"] = {
167
- k.replace("ema_model.", ""): v
168
- for k, v in checkpoint["ema_model_state_dict"].items()
169
- if k not in ["initted", "step"]
170
- }
171
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
172
- step = 0
173
-
174
- del checkpoint
175
- gc.collect()
176
  return step
177
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
182
- else:
183
  generator = None
184
 
185
  if self.batch_size_type == "sample":
186
- train_dataloader = DataLoader(
187
- train_dataset,
188
- collate_fn=collate_fn,
189
- num_workers=num_workers,
190
- pin_memory=True,
191
- persistent_workers=True,
192
- batch_size=self.batch_size,
193
- shuffle=True,
194
- generator=generator,
195
- )
196
  elif self.batch_size_type == "frame":
197
  self.accelerator.even_batches = False
198
  sampler = SequentialSampler(train_dataset)
199
- batch_sampler = DynamicBatchSampler(
200
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
201
- )
202
- train_dataloader = DataLoader(
203
- train_dataset,
204
- collate_fn=collate_fn,
205
- num_workers=num_workers,
206
- pin_memory=True,
207
- persistent_workers=True,
208
- batch_sampler=batch_sampler,
209
- )
210
  else:
211
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
212
-
213
  # accelerator.prepare() dispatches batches to devices;
214
  # which means the length of dataloader calculated before, should consider the number of devices
215
- warmup_steps = (
216
- self.num_warmup_updates * self.accelerator.num_processes
217
- ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
218
- # otherwise by default with split_batches=False, warmup steps change with num_processes
219
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
220
  decay_steps = total_steps - warmup_steps
221
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
222
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
223
- self.scheduler = SequentialLR(
224
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
225
- )
226
- train_dataloader, self.scheduler = self.accelerator.prepare(
227
- train_dataloader, self.scheduler
228
- ) # actual steps = 1 gpu steps / gpus
229
  start_step = self.load_checkpoint()
230
  global_step = start_step
231
 
@@ -240,36 +200,23 @@ class Trainer:
240
  for epoch in range(skipped_epoch, self.epochs):
241
  self.model.train()
242
  if exists(resumable_with_seed) and epoch == skipped_epoch:
243
- progress_bar = tqdm(
244
- skipped_dataloader,
245
- desc=f"Epoch {epoch+1}/{self.epochs}",
246
- unit="step",
247
- disable=not self.accelerator.is_local_main_process,
248
- initial=skipped_batch,
249
- total=orig_epoch_step,
250
- )
251
  else:
252
- progress_bar = tqdm(
253
- train_dataloader,
254
- desc=f"Epoch {epoch+1}/{self.epochs}",
255
- unit="step",
256
- disable=not self.accelerator.is_local_main_process,
257
- )
258
 
259
  for batch in progress_bar:
260
  with self.accelerator.accumulate(self.model):
261
- text_inputs = batch["text"]
262
- mel_spec = batch["mel"].permute(0, 2, 1)
263
  mel_lengths = batch["mel_lengths"]
264
 
265
  # TODO. add duration predictor training
266
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
267
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
268
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
269
 
270
- loss, cond, pred = self.model(
271
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
- )
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -286,15 +233,13 @@ class Trainer:
286
 
287
  if self.accelerator.is_local_main_process:
288
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
-
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
-
292
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
293
  self.save_checkpoint(global_step)
294
-
295
  if global_step % self.last_per_steps == 0:
296
  self.save_checkpoint(global_step, last=True)
297
-
298
- self.save_checkpoint(global_step, last=True)
299
-
300
  self.accelerator.end_training()
 
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
13
+ from einops import rearrange
14
+
15
  from accelerate import Accelerator
16
  from accelerate.utils import DistributedDataParallelKwargs
17
 
 
24
 
25
  # trainer
26
 
 
27
  class Trainer:
28
  def __init__(
29
  self,
30
  model: CFM,
31
  epochs,
32
  learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
  batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
  noise_scheduler: str | None = None,
42
  duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
  wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
  accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
 
49
  ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
 
 
52
 
53
  self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
  )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
 
 
68
  "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
  "batch_size": batch_size,
71
  "batch_size_type": batch_size_type,
72
  "max_samples": max_samples,
73
  "grad_accumulation_steps": grad_accumulation_steps,
74
  "max_grad_norm": max_grad_norm,
75
  "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
 
77
  )
78
 
79
  self.model = model
80
 
81
  if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
 
88
  self.ema_model.to(self.accelerator.device)
89
 
 
91
  self.num_warmup_updates = num_warmup_updates
92
  self.save_per_updates = save_per_updates
93
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
 
96
  self.batch_size = batch_size
97
  self.batch_size_type = batch_size_type
 
103
 
104
  self.duration_predictor = duration_predictor
105
 
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
 
 
 
110
 
111
  @property
112
  def is_main(self):
 
116
  self.accelerator.wait_for_everyone()
117
  if self.is_main:
118
  checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
  )
125
  if not os.path.exists(self.checkpoint_path):
126
  os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
  print(f"Saved last checkpoint at step {step}")
130
  else:
131
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
 
133
  def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
 
 
 
 
135
  return 0
136
+
137
  self.accelerator.wait_for_everyone()
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
+ latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
 
 
 
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
145
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
146
 
147
  if self.is_main:
148
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
149
+
150
+ if self.scheduler:
151
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
152
+
153
+ step = checkpoint['step']
154
+ del checkpoint; gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
155
  return step
156
 
157
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
158
+
159
  if exists(resumable_with_seed):
160
  generator = torch.Generator()
161
  generator.manual_seed(resumable_with_seed)
162
+ else:
163
  generator = None
164
 
165
  if self.batch_size_type == "sample":
166
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
167
+ batch_size=self.batch_size, shuffle=True, generator=generator)
 
 
 
 
 
 
 
 
168
  elif self.batch_size_type == "frame":
169
  self.accelerator.even_batches = False
170
  sampler = SequentialSampler(train_dataset)
171
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
172
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
173
+ batch_sampler=batch_sampler)
 
 
 
 
 
 
 
 
174
  else:
175
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
176
+
177
  # accelerator.prepare() dispatches batches to devices;
178
  # which means the length of dataloader calculated before, should consider the number of devices
179
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
180
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
 
 
181
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
182
  decay_steps = total_steps - warmup_steps
183
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
184
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
185
+ self.scheduler = SequentialLR(self.optimizer,
186
+ schedulers=[warmup_scheduler, decay_scheduler],
187
+ milestones=[warmup_steps])
188
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
 
 
189
  start_step = self.load_checkpoint()
190
  global_step = start_step
191
 
 
200
  for epoch in range(skipped_epoch, self.epochs):
201
  self.model.train()
202
  if exists(resumable_with_seed) and epoch == skipped_epoch:
203
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
204
+ initial=skipped_batch, total=orig_epoch_step)
 
 
 
 
 
 
205
  else:
206
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
 
 
 
 
 
207
 
208
  for batch in progress_bar:
209
  with self.accelerator.accumulate(self.model):
210
+ text_inputs = batch['text']
211
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
212
  mel_lengths = batch["mel_lengths"]
213
 
214
  # TODO. add duration predictor training
215
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
216
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
217
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
218
 
219
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
 
 
220
  self.accelerator.backward(loss)
221
 
222
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
233
 
234
  if self.accelerator.is_local_main_process:
235
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
236
+
237
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
238
+
239
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
240
  self.save_checkpoint(global_step)
241
+
242
  if global_step % self.last_per_steps == 0:
243
  self.save_checkpoint(global_step, last=True)
244
+
 
 
245
  self.accelerator.end_training()
model/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import math
5
  import random
6
  import string
@@ -8,7 +9,6 @@ from tqdm import tqdm
8
  from collections import defaultdict
9
 
10
  import matplotlib
11
-
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
@@ -17,8 +17,17 @@ import torch.nn.functional as F
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
 
 
 
20
  import jieba
21
  from pypinyin import lazy_pinyin, Style
 
 
 
 
 
 
22
 
23
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
24
  from model.modules import MelSpec
@@ -26,102 +35,106 @@ from model.modules import MelSpec
26
 
27
  # seed everything
28
 
29
-
30
- def seed_everything(seed=0):
31
  random.seed(seed)
32
- os.environ["PYTHONHASHSEED"] = str(seed)
33
  torch.manual_seed(seed)
34
  torch.cuda.manual_seed(seed)
35
  torch.cuda.manual_seed_all(seed)
36
  torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
-
40
  # helpers
41
 
42
-
43
  def exists(v):
44
  return v is not None
45
 
46
-
47
  def default(v, d):
48
  return v if exists(v) else d
49
 
50
-
51
  # tensor helpers
52
 
 
 
 
 
53
 
54
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
55
  if not exists(length):
56
  length = t.amax()
57
 
58
- seq = torch.arange(length, device=t.device)
59
- return seq[None, :] < t[:, None]
60
-
61
-
62
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
63
- max_seq_len = seq_len.max().item()
64
- seq = torch.arange(max_seq_len, device=start.device).long()
65
- start_mask = seq[None, :] >= start[:, None]
66
- end_mask = seq[None, :] < end[:, None]
67
- return start_mask & end_mask
68
 
 
 
 
 
 
 
 
 
69
 
70
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
 
71
  lengths = (frac_lengths * seq_len).long()
72
  max_start = seq_len - lengths
73
 
74
  rand = torch.rand_like(frac_lengths)
75
- start = (max_start * rand).long().clamp(min=0)
76
  end = start + lengths
77
 
78
  return mask_from_start_end_indices(seq_len, start, end)
79
 
 
 
 
 
80
 
81
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
82
  if not exists(mask):
83
- return t.mean(dim=1)
84
 
85
- t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
86
- num = t.sum(dim=1)
87
- den = mask.float().sum(dim=1)
88
 
89
- return num / den.clamp(min=1.0)
90
 
91
 
92
  # simple utf-8 tokenizer, since paper went character based
93
- def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
94
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
95
- text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
 
 
 
96
  return text
97
 
98
-
99
  # char tokenizer, based on custom dataset's extracted .txt file
100
  def list_str_to_idx(
101
  text: list[str] | list[list[str]],
102
  vocab_char_map: dict[str, int], # {char: idx}
103
- padding_value=-1,
104
- ) -> int["b nt"]: # noqa: F722
105
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
106
- text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
107
  return text
108
 
109
 
110
  # Get tokenizer
111
 
112
-
113
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
114
- """
115
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
116
  - "char" for char-wise tokenizer, need .txt vocab_file
117
  - "byte" for utf-8 tokenizer
118
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
119
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
120
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
121
- - if use "byte", set to 256 (unicode byte range)
122
- """
123
  if tokenizer in ["pinyin", "char"]:
124
- with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
125
  vocab_char_map = {}
126
  for i, char in enumerate(f):
127
  vocab_char_map[char[:-1]] = i
@@ -131,31 +144,20 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
131
  elif tokenizer == "byte":
132
  vocab_char_map = None
133
  vocab_size = 256
134
- elif tokenizer == "custom":
135
- with open(dataset_name, "r", encoding="utf-8") as f:
136
- vocab_char_map = {}
137
- for i, char in enumerate(f):
138
- vocab_char_map[char[:-1]] = i
139
- vocab_size = len(vocab_char_map)
140
 
141
  return vocab_char_map, vocab_size
142
 
143
 
144
  # convert char to pinyin
145
 
146
-
147
- def convert_char_to_pinyin(text_list, polyphone=True):
148
  final_text_list = []
149
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
150
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
151
- ) # in case librispeech (orig no-pc) test-clean
152
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
153
  for text in text_list:
154
  char_list = []
155
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
156
- text = text.translate(custom_trans)
157
  for seg in jieba.cut(text):
158
- seg_byte_len = len(bytes(seg, "UTF-8"))
159
  if seg_byte_len == len(seg): # if pure alphabets and symbols
160
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
161
  char_list.append(" ")
@@ -184,7 +186,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
184
  # save spectrogram
185
  def save_spectrogram(spectrogram, path):
186
  plt.figure(figsize=(12, 4))
187
- plt.imshow(spectrogram, origin="lower", aspect="auto")
188
  plt.colorbar()
189
  plt.savefig(path)
190
  plt.close()
@@ -192,15 +194,13 @@ def save_spectrogram(spectrogram, path):
192
 
193
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
194
  def get_seedtts_testset_metainfo(metalst):
195
- f = open(metalst)
196
- lines = f.readlines()
197
- f.close()
198
  metainfo = []
199
  for line in lines:
200
- if len(line.strip().split("|")) == 5:
201
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
202
- elif len(line.strip().split("|")) == 4:
203
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -210,20 +210,18 @@ def get_seedtts_testset_metainfo(metalst):
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
- f = open(metalst)
214
- lines = f.readlines()
215
- f.close()
216
  metainfo = []
217
  for line in lines:
218
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
219
 
220
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
222
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
223
 
224
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
226
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
227
 
228
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
 
@@ -235,30 +233,21 @@ def padded_mel_batch(ref_mels):
235
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
  padded_ref_mels = []
237
  for mel in ref_mels:
238
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
239
  padded_ref_mels.append(padded_ref_mel)
240
  padded_ref_mels = torch.stack(padded_ref_mels)
241
- padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
242
  return padded_ref_mels
243
 
244
 
245
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
 
247
-
248
  def get_inference_prompt(
249
- metainfo,
250
- speed=1.0,
251
- tokenizer="pinyin",
252
- polyphone=True,
253
- target_sample_rate=24000,
254
- n_mel_channels=100,
255
- hop_length=256,
256
- target_rms=0.1,
257
- use_truth_duration=False,
258
- infer_batch_size=1,
259
- num_buckets=200,
260
- min_secs=3,
261
- max_secs=40,
262
  ):
263
  prompts_all = []
264
 
@@ -266,15 +255,13 @@ def get_inference_prompt(
266
  max_tokens = max_secs * target_sample_rate // hop_length
267
 
268
  batch_accum = [0] * num_buckets
269
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
270
- [[] for _ in range(num_buckets)] for _ in range(6)
271
- )
272
 
273
- mel_spectrogram = MelSpec(
274
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
275
- )
276
 
277
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
 
278
  # Audio
279
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
280
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
@@ -286,11 +273,9 @@ def get_inference_prompt(
286
  ref_audio = resampler(ref_audio)
287
 
288
  # Text
289
- if len(prompt_text[-1].encode("utf-8")) == 1:
290
- prompt_text = prompt_text + " "
291
  text = [prompt_text + gt_text]
292
  if tokenizer == "pinyin":
293
- text_list = convert_char_to_pinyin(text, polyphone=polyphone)
294
  else:
295
  text_list = text
296
 
@@ -306,19 +291,19 @@ def get_inference_prompt(
306
  # # test vocoder resynthesis
307
  # ref_audio = gt_audio
308
  else:
309
- ref_text_len = len(prompt_text.encode("utf-8"))
310
- gen_text_len = len(gt_text.encode("utf-8"))
 
311
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
312
 
313
  # to mel spectrogram
314
  ref_mel = mel_spectrogram(ref_audio)
315
- ref_mel = ref_mel.squeeze(0)
316
 
317
  # deal with batch
318
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
319
- assert (
320
- min_tokens <= total_mel_len <= max_tokens
321
- ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
322
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
323
 
324
  utts[bucket_i].append(utt)
@@ -332,39 +317,28 @@ def get_inference_prompt(
332
 
333
  if batch_accum[bucket_i] >= infer_batch_size:
334
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
335
- prompts_all.append(
336
- (
337
- utts[bucket_i],
338
- ref_rms_list[bucket_i],
339
- padded_mel_batch(ref_mels[bucket_i]),
340
- ref_mel_lens[bucket_i],
341
- total_mel_lens[bucket_i],
342
- final_text_list[bucket_i],
343
- )
344
- )
345
  batch_accum[bucket_i] = 0
346
- (
347
- utts[bucket_i],
348
- ref_rms_list[bucket_i],
349
- ref_mels[bucket_i],
350
- ref_mel_lens[bucket_i],
351
- total_mel_lens[bucket_i],
352
- final_text_list[bucket_i],
353
- ) = [], [], [], [], [], []
354
 
355
  # add residual
356
  for bucket_i, bucket_frames in enumerate(batch_accum):
357
  if bucket_frames > 0:
358
- prompts_all.append(
359
- (
360
- utts[bucket_i],
361
- ref_rms_list[bucket_i],
362
- padded_mel_batch(ref_mels[bucket_i]),
363
- ref_mel_lens[bucket_i],
364
- total_mel_lens[bucket_i],
365
- final_text_list[bucket_i],
366
- )
367
- )
368
  # not only leave easy work for last workers
369
  random.seed(666)
370
  random.shuffle(prompts_all)
@@ -375,7 +349,6 @@ def get_inference_prompt(
375
  # get wav_res_ref_text of seed-tts test metalst
376
  # https://github.com/BytedanceSpeech/seed-tts-eval
377
 
378
-
379
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
380
  f = open(metalst)
381
  lines = f.readlines()
@@ -383,14 +356,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
383
 
384
  test_set_ = []
385
  for line in tqdm(lines):
386
- if len(line.strip().split("|")) == 5:
387
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
388
- elif len(line.strip().split("|")) == 4:
389
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
390
 
391
- if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
392
  continue
393
- gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
394
  if not os.path.isabs(prompt_wav):
395
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
396
 
@@ -399,69 +372,63 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
399
  num_jobs = len(gpus)
400
  if num_jobs == 1:
401
  return [(gpus[0], test_set_)]
402
-
403
  wav_per_job = len(test_set_) // num_jobs + 1
404
  test_set = []
405
  for i in range(num_jobs):
406
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
407
 
408
  return test_set
409
 
410
 
411
  # get librispeech test-clean cross sentence test
412
 
413
-
414
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
415
  f = open(metalst)
416
  lines = f.readlines()
417
  f.close()
418
 
419
  test_set_ = []
420
  for line in tqdm(lines):
421
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
422
 
423
  if eval_ground_truth:
424
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
425
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
426
  else:
427
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
428
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
429
- gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
430
 
431
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
432
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
433
 
434
  test_set_.append((gen_wav, ref_wav, gen_txt))
435
 
436
  num_jobs = len(gpus)
437
  if num_jobs == 1:
438
  return [(gpus[0], test_set_)]
439
-
440
  wav_per_job = len(test_set_) // num_jobs + 1
441
  test_set = []
442
  for i in range(num_jobs):
443
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
444
 
445
  return test_set
446
 
447
 
448
  # load asr model
449
 
450
-
451
- def load_asr_model(lang, ckpt_dir=""):
452
  if lang == "zh":
453
- from funasr import AutoModel
454
-
455
  model = AutoModel(
456
- model=os.path.join(ckpt_dir, "paraformer-zh"),
457
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
458
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
459
- # spk_model = os.path.join(ckpt_dir, "cam++"),
460
  disable_update=True,
461
- ) # following seed-tts setting
462
  elif lang == "en":
463
- from faster_whisper import WhisperModel
464
-
465
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
466
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
467
  return model
@@ -469,50 +436,41 @@ def load_asr_model(lang, ckpt_dir=""):
469
 
470
  # WER Evaluation, the way Seed-TTS does
471
 
472
-
473
  def run_asr_wer(args):
474
  rank, lang, test_set, ckpt_dir = args
475
 
476
  if lang == "zh":
477
- import zhconv
478
-
479
  torch.cuda.set_device(rank)
480
  elif lang == "en":
481
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
482
  else:
483
- raise NotImplementedError(
484
- "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
485
- )
486
 
487
- asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
488
-
489
- from zhon.hanzi import punctuation
490
 
491
  punctuation_all = punctuation + string.punctuation
492
  wers = []
493
 
494
- from jiwer import compute_measures
495
-
496
  for gen_wav, prompt_wav, truth in tqdm(test_set):
497
  if lang == "zh":
498
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
499
  hypo = res[0]["text"]
500
- hypo = zhconv.convert(hypo, "zh-cn")
501
  elif lang == "en":
502
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
503
- hypo = ""
504
  for segment in segments:
505
- hypo = hypo + " " + segment.text
506
 
507
  # raw_truth = truth
508
  # raw_hypo = hypo
509
 
510
  for x in punctuation_all:
511
- truth = truth.replace(x, "")
512
- hypo = hypo.replace(x, "")
513
 
514
- truth = truth.replace(" ", " ")
515
- hypo = hypo.replace(" ", " ")
516
 
517
  if lang == "zh":
518
  truth = " ".join([x for x in truth])
@@ -536,22 +494,22 @@ def run_asr_wer(args):
536
 
537
  # SIM Evaluation
538
 
539
-
540
  def run_sim(args):
541
  rank, test_set, ckpt_dir = args
542
  device = f"cuda:{rank}"
543
 
544
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
545
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
546
- model.load_state_dict(state_dict["model"], strict=False)
547
 
548
- use_gpu = True if torch.cuda.is_available() else False
549
  if use_gpu:
550
  model = model.cuda(device)
551
  model.eval()
552
 
553
  sim_list = []
554
  for wav1, wav2, truth in tqdm(test_set):
 
555
  wav1, sr1 = torchaudio.load(wav1)
556
  wav2, sr2 = torchaudio.load(wav2)
557
 
@@ -566,55 +524,22 @@ def run_sim(args):
566
  with torch.no_grad():
567
  emb1 = model(wav1)
568
  emb2 = model(wav2)
569
-
570
  sim = F.cosine_similarity(emb1, emb2)[0].item()
571
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
572
  sim_list.append(sim)
573
-
574
  return sim_list
575
 
576
 
577
  # filter func for dirty data with many repetitions
578
 
579
-
580
- def repetition_found(text, length=2, tolerance=10):
581
  pattern_count = defaultdict(int)
582
  for i in range(len(text) - length + 1):
583
- pattern = text[i : i + length]
584
  pattern_count[pattern] += 1
585
  for pattern, count in pattern_count.items():
586
  if count > tolerance:
587
  return True
588
  return False
589
-
590
-
591
- # load model checkpoint for inference
592
-
593
-
594
- def load_checkpoint(model, ckpt_path, device, use_ema=True):
595
- if device == "cuda":
596
- model = model.half()
597
-
598
- ckpt_type = ckpt_path.split(".")[-1]
599
- if ckpt_type == "safetensors":
600
- from safetensors.torch import load_file
601
-
602
- checkpoint = load_file(ckpt_path)
603
- else:
604
- checkpoint = torch.load(ckpt_path, weights_only=True)
605
-
606
- if use_ema:
607
- if ckpt_type == "safetensors":
608
- checkpoint = {"ema_model_state_dict": checkpoint}
609
- checkpoint["model_state_dict"] = {
610
- k.replace("ema_model.", ""): v
611
- for k, v in checkpoint["ema_model_state_dict"].items()
612
- if k not in ["initted", "step"]
613
- }
614
- model.load_state_dict(checkpoint["model_state_dict"])
615
- else:
616
- if ckpt_type == "safetensors":
617
- checkpoint = {"model_state_dict": checkpoint}
618
- model.load_state_dict(checkpoint["model_state_dict"])
619
-
620
- return model.to(device)
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import re
5
  import math
6
  import random
7
  import string
 
9
  from collections import defaultdict
10
 
11
  import matplotlib
 
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
 
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
+ import zhconv
26
+ from zhon.hanzi import punctuation
27
+ from jiwer import compute_measures
28
+
29
+ from funasr import AutoModel
30
+ from faster_whisper import WhisperModel
31
 
32
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
  from model.modules import MelSpec
 
35
 
36
  # seed everything
37
 
38
+ def seed_everything(seed = 0):
 
39
  random.seed(seed)
40
+ os.environ['PYTHONHASHSEED'] = str(seed)
41
  torch.manual_seed(seed)
42
  torch.cuda.manual_seed(seed)
43
  torch.cuda.manual_seed_all(seed)
44
  torch.backends.cudnn.deterministic = True
45
  torch.backends.cudnn.benchmark = False
46
 
 
47
  # helpers
48
 
 
49
  def exists(v):
50
  return v is not None
51
 
 
52
  def default(v, d):
53
  return v if exists(v) else d
54
 
 
55
  # tensor helpers
56
 
57
+ def lens_to_mask(
58
+ t: int['b'],
59
+ length: int | None = None
60
+ ) -> bool['b n']:
61
 
 
62
  if not exists(length):
63
  length = t.amax()
64
 
65
+ seq = torch.arange(length, device = t.device)
66
+ return einx.less('n, b -> b n', seq, t)
 
 
 
 
 
 
 
 
67
 
68
+ def mask_from_start_end_indices(
69
+ seq_len: int['b'],
70
+ start: int['b'],
71
+ end: int['b']
72
+ ):
73
+ max_seq_len = seq_len.max().item()
74
+ seq = torch.arange(max_seq_len, device = start.device).long()
75
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
76
 
77
+ def mask_from_frac_lengths(
78
+ seq_len: int['b'],
79
+ frac_lengths: float['b']
80
+ ):
81
  lengths = (frac_lengths * seq_len).long()
82
  max_start = seq_len - lengths
83
 
84
  rand = torch.rand_like(frac_lengths)
85
+ start = (max_start * rand).long().clamp(min = 0)
86
  end = start + lengths
87
 
88
  return mask_from_start_end_indices(seq_len, start, end)
89
 
90
+ def maybe_masked_mean(
91
+ t: float['b n d'],
92
+ mask: bool['b n'] = None
93
+ ) -> float['b d']:
94
 
 
95
  if not exists(mask):
96
+ return t.mean(dim = 1)
97
 
98
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
99
+ num = reduce(t, 'b n d -> b d', 'sum')
100
+ den = reduce(mask.float(), 'b n -> b', 'sum')
101
 
102
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
103
 
104
 
105
  # simple utf-8 tokenizer, since paper went character based
106
+ def list_str_to_tensor(
107
+ text: list[str],
108
+ padding_value = -1
109
+ ) -> int['b nt']:
110
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
111
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
112
  return text
113
 
 
114
  # char tokenizer, based on custom dataset's extracted .txt file
115
  def list_str_to_idx(
116
  text: list[str] | list[list[str]],
117
  vocab_char_map: dict[str, int], # {char: idx}
118
+ padding_value = -1
119
+ ) -> int['b nt']:
120
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
121
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
122
  return text
123
 
124
 
125
  # Get tokenizer
126
 
 
127
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
128
+ '''
129
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
  - "char" for char-wise tokenizer, need .txt vocab_file
131
  - "byte" for utf-8 tokenizer
 
132
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
+ - if use "byte", set to 256 (unicode byte range)
135
+ '''
136
  if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
138
  vocab_char_map = {}
139
  for i, char in enumerate(f):
140
  vocab_char_map[char[:-1]] = i
 
144
  elif tokenizer == "byte":
145
  vocab_char_map = None
146
  vocab_size = 256
 
 
 
 
 
 
147
 
148
  return vocab_char_map, vocab_size
149
 
150
 
151
  # convert char to pinyin
152
 
153
+ def convert_char_to_pinyin(text_list, polyphone = True):
 
154
  final_text_list = []
155
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
 
 
 
156
  for text in text_list:
157
  char_list = []
158
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
 
159
  for seg in jieba.cut(text):
160
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
161
  if seg_byte_len == len(seg): # if pure alphabets and symbols
162
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
163
  char_list.append(" ")
 
186
  # save spectrogram
187
  def save_spectrogram(spectrogram, path):
188
  plt.figure(figsize=(12, 4))
189
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
190
  plt.colorbar()
191
  plt.savefig(path)
192
  plt.close()
 
194
 
195
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
196
  def get_seedtts_testset_metainfo(metalst):
197
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
198
  metainfo = []
199
  for line in lines:
200
+ if len(line.strip().split('|')) == 5:
201
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
202
+ elif len(line.strip().split('|')) == 4:
203
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
 
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
214
  metainfo = []
215
  for line in lines:
216
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
217
 
218
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
219
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
220
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
221
 
222
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
223
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
224
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
225
 
226
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
227
 
 
233
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
234
  padded_ref_mels = []
235
  for mel in ref_mels:
236
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
237
  padded_ref_mels.append(padded_ref_mel)
238
  padded_ref_mels = torch.stack(padded_ref_mels)
239
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
240
  return padded_ref_mels
241
 
242
 
243
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
244
 
 
245
  def get_inference_prompt(
246
+ metainfo,
247
+ speed = 1., tokenizer = "pinyin", polyphone = True,
248
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
249
+ use_truth_duration = False,
250
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
 
 
 
 
 
 
 
 
251
  ):
252
  prompts_all = []
253
 
 
255
  max_tokens = max_secs * target_sample_rate // hop_length
256
 
257
  batch_accum = [0] * num_buckets
258
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
259
+ ([[] for _ in range(num_buckets)] for _ in range(6))
 
260
 
261
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
 
 
262
 
263
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
264
+
265
  # Audio
266
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
267
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
 
273
  ref_audio = resampler(ref_audio)
274
 
275
  # Text
 
 
276
  text = [prompt_text + gt_text]
277
  if tokenizer == "pinyin":
278
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
279
  else:
280
  text_list = text
281
 
 
291
  # # test vocoder resynthesis
292
  # ref_audio = gt_audio
293
  else:
294
+ zh_pause_punc = r"。,、;:?!"
295
+ ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
296
+ gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
297
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
298
 
299
  # to mel spectrogram
300
  ref_mel = mel_spectrogram(ref_audio)
301
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
302
 
303
  # deal with batch
304
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
305
+ assert min_tokens <= total_mel_len <= max_tokens, \
306
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
 
307
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
308
 
309
  utts[bucket_i].append(utt)
 
317
 
318
  if batch_accum[bucket_i] >= infer_batch_size:
319
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
320
+ prompts_all.append((
321
+ utts[bucket_i],
322
+ ref_rms_list[bucket_i],
323
+ padded_mel_batch(ref_mels[bucket_i]),
324
+ ref_mel_lens[bucket_i],
325
+ total_mel_lens[bucket_i],
326
+ final_text_list[bucket_i]
327
+ ))
 
 
328
  batch_accum[bucket_i] = 0
329
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
 
 
 
 
 
 
 
330
 
331
  # add residual
332
  for bucket_i, bucket_frames in enumerate(batch_accum):
333
  if bucket_frames > 0:
334
+ prompts_all.append((
335
+ utts[bucket_i],
336
+ ref_rms_list[bucket_i],
337
+ padded_mel_batch(ref_mels[bucket_i]),
338
+ ref_mel_lens[bucket_i],
339
+ total_mel_lens[bucket_i],
340
+ final_text_list[bucket_i]
341
+ ))
 
 
342
  # not only leave easy work for last workers
343
  random.seed(666)
344
  random.shuffle(prompts_all)
 
349
  # get wav_res_ref_text of seed-tts test metalst
350
  # https://github.com/BytedanceSpeech/seed-tts-eval
351
 
 
352
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
353
  f = open(metalst)
354
  lines = f.readlines()
 
356
 
357
  test_set_ = []
358
  for line in tqdm(lines):
359
+ if len(line.strip().split('|')) == 5:
360
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
361
+ elif len(line.strip().split('|')) == 4:
362
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
363
 
364
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
365
  continue
366
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
367
  if not os.path.isabs(prompt_wav):
368
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
369
 
 
372
  num_jobs = len(gpus)
373
  if num_jobs == 1:
374
  return [(gpus[0], test_set_)]
375
+
376
  wav_per_job = len(test_set_) // num_jobs + 1
377
  test_set = []
378
  for i in range(num_jobs):
379
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
380
 
381
  return test_set
382
 
383
 
384
  # get librispeech test-clean cross sentence test
385
 
386
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
 
387
  f = open(metalst)
388
  lines = f.readlines()
389
  f.close()
390
 
391
  test_set_ = []
392
  for line in tqdm(lines):
393
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
394
 
395
  if eval_ground_truth:
396
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
397
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
398
  else:
399
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
400
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
401
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
402
 
403
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
404
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
405
 
406
  test_set_.append((gen_wav, ref_wav, gen_txt))
407
 
408
  num_jobs = len(gpus)
409
  if num_jobs == 1:
410
  return [(gpus[0], test_set_)]
411
+
412
  wav_per_job = len(test_set_) // num_jobs + 1
413
  test_set = []
414
  for i in range(num_jobs):
415
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
416
 
417
  return test_set
418
 
419
 
420
  # load asr model
421
 
422
+ def load_asr_model(lang, ckpt_dir = ""):
 
423
  if lang == "zh":
 
 
424
  model = AutoModel(
425
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
426
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
427
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
428
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
429
  disable_update=True,
430
+ ) # following seed-tts setting
431
  elif lang == "en":
 
 
432
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
433
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
434
  return model
 
436
 
437
  # WER Evaluation, the way Seed-TTS does
438
 
 
439
  def run_asr_wer(args):
440
  rank, lang, test_set, ckpt_dir = args
441
 
442
  if lang == "zh":
 
 
443
  torch.cuda.set_device(rank)
444
  elif lang == "en":
445
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
446
  else:
447
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
 
 
448
 
449
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
 
 
450
 
451
  punctuation_all = punctuation + string.punctuation
452
  wers = []
453
 
 
 
454
  for gen_wav, prompt_wav, truth in tqdm(test_set):
455
  if lang == "zh":
456
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
457
  hypo = res[0]["text"]
458
+ hypo = zhconv.convert(hypo, 'zh-cn')
459
  elif lang == "en":
460
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
461
+ hypo = ''
462
  for segment in segments:
463
+ hypo = hypo + ' ' + segment.text
464
 
465
  # raw_truth = truth
466
  # raw_hypo = hypo
467
 
468
  for x in punctuation_all:
469
+ truth = truth.replace(x, '')
470
+ hypo = hypo.replace(x, '')
471
 
472
+ truth = truth.replace(' ', ' ')
473
+ hypo = hypo.replace(' ', ' ')
474
 
475
  if lang == "zh":
476
  truth = " ".join([x for x in truth])
 
494
 
495
  # SIM Evaluation
496
 
 
497
  def run_sim(args):
498
  rank, test_set, ckpt_dir = args
499
  device = f"cuda:{rank}"
500
 
501
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
502
+ state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
503
+ model.load_state_dict(state_dict['model'], strict=False)
504
 
505
+ use_gpu=True if torch.cuda.is_available() else False
506
  if use_gpu:
507
  model = model.cuda(device)
508
  model.eval()
509
 
510
  sim_list = []
511
  for wav1, wav2, truth in tqdm(test_set):
512
+
513
  wav1, sr1 = torchaudio.load(wav1)
514
  wav2, sr2 = torchaudio.load(wav2)
515
 
 
524
  with torch.no_grad():
525
  emb1 = model(wav1)
526
  emb2 = model(wav2)
527
+
528
  sim = F.cosine_similarity(emb1, emb2)[0].item()
529
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
530
  sim_list.append(sim)
531
+
532
  return sim_list
533
 
534
 
535
  # filter func for dirty data with many repetitions
536
 
537
+ def repetition_found(text, length = 2, tolerance = 10):
 
538
  pattern_count = defaultdict(int)
539
  for i in range(len(text) - length + 1):
540
+ pattern = text[i:i + length]
541
  pattern_count[pattern] += 1
542
  for pattern, count in pattern_count.items():
543
  if count > tolerance:
544
  return True
545
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils_infer.py DELETED
@@ -1,357 +0,0 @@
1
- # A unified script for inference process
2
- # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
-
4
- import re
5
- import tempfile
6
-
7
- import numpy as np
8
- import torch
9
- import torchaudio
10
- import tqdm
11
- from pydub import AudioSegment, silence
12
- from transformers import pipeline
13
- from vocos import Vocos
14
-
15
- from model import CFM
16
- from model.utils import (
17
- load_checkpoint,
18
- get_tokenizer,
19
- convert_char_to_pinyin,
20
- )
21
-
22
-
23
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
-
25
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
-
27
-
28
- # -----------------------------------------
29
-
30
- target_sample_rate = 24000
31
- n_mel_channels = 100
32
- hop_length = 256
33
- target_rms = 0.1
34
- cross_fade_duration = 0.15
35
- ode_method = "euler"
36
- nfe_step = 32 # 16, 32
37
- cfg_strength = 2.0
38
- sway_sampling_coef = -1.0
39
- speed = 1.0
40
- fix_duration = None
41
-
42
- # -----------------------------------------
43
-
44
-
45
- # chunk text into smaller pieces
46
-
47
-
48
- def chunk_text(text, max_chars=135):
49
- """
50
- Splits the input text into chunks, each with a maximum number of characters.
51
-
52
- Args:
53
- text (str): The text to be split.
54
- max_chars (int): The maximum number of characters per chunk.
55
-
56
- Returns:
57
- List[str]: A list of text chunks.
58
- """
59
- chunks = []
60
- current_chunk = ""
61
- # Split the text into sentences based on punctuation followed by whitespace
62
- sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
63
-
64
- for sentence in sentences:
65
- if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
66
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
67
- else:
68
- if current_chunk:
69
- chunks.append(current_chunk.strip())
70
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
71
-
72
- if current_chunk:
73
- chunks.append(current_chunk.strip())
74
-
75
- return chunks
76
-
77
-
78
- # load vocoder
79
- def load_vocoder(is_local=False, local_path="", device=device):
80
- if is_local:
81
- print(f"Load vocos from local path {local_path}")
82
- vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
83
- state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device)
84
- vocos.load_state_dict(state_dict)
85
- vocos.eval()
86
- else:
87
- print("Download Vocos from huggingface charactr/vocos-mel-24khz")
88
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
89
- return vocos
90
-
91
-
92
- # load asr pipeline
93
-
94
- asr_pipe = None
95
-
96
-
97
- def initialize_asr_pipeline(device=device):
98
- global asr_pipe
99
- asr_pipe = pipeline(
100
- "automatic-speech-recognition",
101
- model="openai/whisper-large-v3-turbo",
102
- torch_dtype=torch.float16,
103
- device=device,
104
- )
105
-
106
-
107
- # load model for inference
108
-
109
-
110
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
111
- if vocab_file == "":
112
- vocab_file = "Emilia_ZH_EN"
113
- tokenizer = "pinyin"
114
- else:
115
- tokenizer = "custom"
116
-
117
- print("\nvocab : ", vocab_file)
118
- print("tokenizer : ", tokenizer)
119
- print("model : ", ckpt_path, "\n")
120
-
121
- vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
122
- model = CFM(
123
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
124
- mel_spec_kwargs=dict(
125
- target_sample_rate=target_sample_rate,
126
- n_mel_channels=n_mel_channels,
127
- hop_length=hop_length,
128
- ),
129
- odeint_kwargs=dict(
130
- method=ode_method,
131
- ),
132
- vocab_char_map=vocab_char_map,
133
- ).to(device)
134
-
135
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
136
-
137
- return model
138
-
139
-
140
- # preprocess reference audio and text
141
-
142
-
143
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
144
- show_info("Converting audio...")
145
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
146
- aseg = AudioSegment.from_file(ref_audio_orig)
147
-
148
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
149
- non_silent_wave = AudioSegment.silent(duration=0)
150
- for non_silent_seg in non_silent_segs:
151
- non_silent_wave += non_silent_seg
152
- aseg = non_silent_wave
153
-
154
- audio_duration = len(aseg)
155
- if audio_duration > 15000:
156
- show_info("Audio is over 15s, clipping to only first 15s.")
157
- aseg = aseg[:15000]
158
- aseg.export(f.name, format="wav")
159
- ref_audio = f.name
160
-
161
- if not ref_text.strip():
162
- global asr_pipe
163
- if asr_pipe is None:
164
- initialize_asr_pipeline(device=device)
165
- show_info("No reference text provided, transcribing reference audio...")
166
- ref_text = asr_pipe(
167
- ref_audio,
168
- chunk_length_s=30,
169
- batch_size=128,
170
- generate_kwargs={"task": "transcribe"},
171
- return_timestamps=False,
172
- )["text"].strip()
173
- show_info("Finished transcription")
174
- else:
175
- show_info("Using custom reference text...")
176
-
177
- # Add the functionality to ensure it ends with ". "
178
- if not ref_text.endswith(". ") and not ref_text.endswith("。"):
179
- if ref_text.endswith("."):
180
- ref_text += " "
181
- else:
182
- ref_text += ". "
183
-
184
- return ref_audio, ref_text
185
-
186
-
187
- # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
188
-
189
-
190
- def infer_process(
191
- ref_audio,
192
- ref_text,
193
- gen_text,
194
- model_obj,
195
- show_info=print,
196
- progress=tqdm,
197
- target_rms=target_rms,
198
- cross_fade_duration=cross_fade_duration,
199
- nfe_step=nfe_step,
200
- cfg_strength=cfg_strength,
201
- sway_sampling_coef=sway_sampling_coef,
202
- speed=speed,
203
- fix_duration=fix_duration,
204
- device=device,
205
- ):
206
- # Split the input text into batches
207
- audio, sr = torchaudio.load(ref_audio)
208
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
209
- gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
210
- for i, gen_text in enumerate(gen_text_batches):
211
- print(f"gen_text {i}", gen_text)
212
-
213
- show_info(f"Generating audio in {len(gen_text_batches)} batches...")
214
- return infer_batch_process(
215
- (audio, sr),
216
- ref_text,
217
- gen_text_batches,
218
- model_obj,
219
- progress=progress,
220
- target_rms=target_rms,
221
- cross_fade_duration=cross_fade_duration,
222
- nfe_step=nfe_step,
223
- cfg_strength=cfg_strength,
224
- sway_sampling_coef=sway_sampling_coef,
225
- speed=speed,
226
- fix_duration=fix_duration,
227
- device=device,
228
- )
229
-
230
-
231
- # infer batches
232
-
233
-
234
- def infer_batch_process(
235
- ref_audio,
236
- ref_text,
237
- gen_text_batches,
238
- model_obj,
239
- progress=tqdm,
240
- target_rms=0.1,
241
- cross_fade_duration=0.15,
242
- nfe_step=32,
243
- cfg_strength=2.0,
244
- sway_sampling_coef=-1,
245
- speed=1,
246
- fix_duration=None,
247
- device=None,
248
- ):
249
- audio, sr = ref_audio
250
- if audio.shape[0] > 1:
251
- audio = torch.mean(audio, dim=0, keepdim=True)
252
-
253
- rms = torch.sqrt(torch.mean(torch.square(audio)))
254
- if rms < target_rms:
255
- audio = audio * target_rms / rms
256
- if sr != target_sample_rate:
257
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
258
- audio = resampler(audio)
259
- audio = audio.to(device)
260
-
261
- generated_waves = []
262
- spectrograms = []
263
-
264
- if len(ref_text[-1].encode("utf-8")) == 1:
265
- ref_text = ref_text + " "
266
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
267
- # Prepare the text
268
- text_list = [ref_text + gen_text]
269
- final_text_list = convert_char_to_pinyin(text_list)
270
-
271
- ref_audio_len = audio.shape[-1] // hop_length
272
- if fix_duration is not None:
273
- duration = int(fix_duration * target_sample_rate / hop_length)
274
- else:
275
- # Calculate duration
276
- ref_text_len = len(ref_text.encode("utf-8"))
277
- gen_text_len = len(gen_text.encode("utf-8"))
278
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
279
-
280
- # inference
281
- with torch.inference_mode():
282
- generated, _ = model_obj.sample(
283
- cond=audio,
284
- text=final_text_list,
285
- duration=duration,
286
- steps=nfe_step,
287
- cfg_strength=cfg_strength,
288
- sway_sampling_coef=sway_sampling_coef,
289
- )
290
-
291
- generated = generated.to(torch.float32)
292
- generated = generated[:, ref_audio_len:, :]
293
- generated_mel_spec = generated.permute(0, 2, 1)
294
- generated_wave = vocos.decode(generated_mel_spec.cpu())
295
- if rms < target_rms:
296
- generated_wave = generated_wave * rms / target_rms
297
-
298
- # wav -> numpy
299
- generated_wave = generated_wave.squeeze().cpu().numpy()
300
-
301
- generated_waves.append(generated_wave)
302
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
303
-
304
- # Combine all generated waves with cross-fading
305
- if cross_fade_duration <= 0:
306
- # Simply concatenate
307
- final_wave = np.concatenate(generated_waves)
308
- else:
309
- final_wave = generated_waves[0]
310
- for i in range(1, len(generated_waves)):
311
- prev_wave = final_wave
312
- next_wave = generated_waves[i]
313
-
314
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
315
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
316
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
317
-
318
- if cross_fade_samples <= 0:
319
- # No overlap possible, concatenate
320
- final_wave = np.concatenate([prev_wave, next_wave])
321
- continue
322
-
323
- # Overlapping parts
324
- prev_overlap = prev_wave[-cross_fade_samples:]
325
- next_overlap = next_wave[:cross_fade_samples]
326
-
327
- # Fade out and fade in
328
- fade_out = np.linspace(1, 0, cross_fade_samples)
329
- fade_in = np.linspace(0, 1, cross_fade_samples)
330
-
331
- # Cross-faded overlap
332
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
333
-
334
- # Combine
335
- new_wave = np.concatenate(
336
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
337
- )
338
-
339
- final_wave = new_wave
340
-
341
- # Create a combined spectrogram
342
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
343
-
344
- return final_wave, target_sample_rate, combined_spectrogram
345
-
346
-
347
- # remove silence from generated wav
348
-
349
-
350
- def remove_silence_for_generated_wav(filename):
351
- aseg = AudioSegment.from_file(filename)
352
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
353
- non_silent_wave = AudioSegment.silent(duration=0)
354
- for non_silent_seg in non_silent_segs:
355
- non_silent_wave += non_silent_seg
356
- aseg = non_silent_wave
357
- aseg.export(filename, format="wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pyproject.toml DELETED
@@ -1,62 +0,0 @@
1
- [build-system]
2
- requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "f5-tts"
7
- version = "0.3.1"
8
- description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
- readme = "README.md"
10
- license = {text = "MIT License"}
11
- classifiers = [
12
- "License :: OSI Approved :: MIT License",
13
- "Operating System :: OS Independent",
14
- "Programming Language :: Python :: 3",
15
- ]
16
- dependencies = [
17
- "accelerate>=0.33.0",
18
- "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
19
- "cached_path",
20
- "click",
21
- "datasets",
22
- "ema_pytorch>=0.5.2",
23
- "gradio>=3.45.2",
24
- "hydra-core>=1.3.0",
25
- "jieba",
26
- "librosa",
27
- "matplotlib",
28
- "numpy<=1.26.4",
29
- "pydub",
30
- "pypinyin",
31
- "safetensors",
32
- "soundfile",
33
- "tomli",
34
- "torch>=2.0.0",
35
- "torchaudio>=2.0.0",
36
- "torchdiffeq",
37
- "tqdm>=4.65.0",
38
- "transformers",
39
- "transformers_stream_generator",
40
- "vocos",
41
- "wandb",
42
- "x_transformers>=1.31.14",
43
- ]
44
-
45
- [project.optional-dependencies]
46
- eval = [
47
- "faster_whisper==0.10.1",
48
- "funasr",
49
- "jiwer",
50
- "modelscope",
51
- "zhconv",
52
- "zhon",
53
- ]
54
-
55
- [project.urls]
56
- Homepage = "https://github.com/SWivid/F5-TTS"
57
-
58
- [project.scripts]
59
- "f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
60
- "f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
61
- "f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
62
- "f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,26 +1,26 @@
1
- torch
2
- torchaudio
3
  accelerate>=0.33.0
4
- bitsandbytes>0.37.0
5
- cached_path
6
- click
7
  datasets
 
 
8
  ema_pytorch>=0.5.2
9
- gradio
 
10
  jieba
 
11
  librosa
12
  matplotlib
13
- numpy<=1.26.4
14
- pydub
15
  pypinyin
16
- safetensors
17
- soundfile
18
- tomli
19
  torchdiffeq
20
  tqdm>=4.65.0
21
  transformers
22
  vocos
23
  wandb
24
  x_transformers>=1.31.14
25
- f5_tts @ git+https://huggingface.co/spaces/mrfakename/E2-F5-TTS
26
- detoxify @ git+https://github.com/unitaryai/detoxify
 
 
 
 
 
 
 
1
  accelerate>=0.33.0
 
 
 
2
  datasets
3
+ einops>=0.8.0
4
+ einx>=0.3.0
5
  ema_pytorch>=0.5.2
6
+ faster_whisper
7
+ funasr
8
  jieba
9
+ jiwer
10
  librosa
11
  matplotlib
 
 
12
  pypinyin
13
+ torch>=2.0
14
+ torchaudio>=2.3.0
 
15
  torchdiffeq
16
  tqdm>=4.65.0
17
  transformers
18
  vocos
19
  wandb
20
  x_transformers>=1.31.14
21
+ zhconv
22
+ zhon
23
+ cached_path
24
+ pydub
25
+ txtsplit
26
+ detoxify
requirements_eval.txt DELETED
@@ -1,5 +0,0 @@
1
- faster_whisper
2
- funasr
3
- jiwer
4
- zhconv
5
- zhon
 
 
 
 
 
 
ruff.toml DELETED
@@ -1,10 +0,0 @@
1
- line-length = 120
2
- target-version = "py310"
3
-
4
- [lint]
5
- # Only ignore variables with names starting with "_".
6
- dummy-variable-rgx = "^_.*$"
7
-
8
- [lint.isort]
9
- force-single-line = true
10
- lines-after-imports = 2
 
 
 
 
 
 
 
 
 
 
 
samples/country.flac DELETED
Binary file (180 kB)
 
samples/main.flac DELETED
Binary file (279 kB)
 
samples/story.toml DELETED
@@ -1,19 +0,0 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
- ref_audio = "samples/main.flac"
4
- # If an empty "", transcribes the reference audio automatically.
5
- ref_text = ""
6
- gen_text = ""
7
- # File with text to generate. Ignores the text above.
8
- gen_file = "samples/story.txt"
9
- remove_silence = true
10
- output_dir = "samples"
11
-
12
- [voices.town]
13
- ref_audio = "samples/town.flac"
14
- ref_text = ""
15
-
16
- [voices.country]
17
- ref_audio = "samples/country.flac"
18
- ref_text = ""
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
samples/story.txt DELETED
@@ -1 +0,0 @@
1
- A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
 
 
samples/town.flac DELETED
Binary file (229 kB)
 
scripts/count_max_epoch.py CHANGED
@@ -1,7 +1,6 @@
1
- """ADAPTIVE BATCH SIZE"""
2
-
3
- print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4
- print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5
 
6
  # data
7
  total_hours = 95282
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
 
4
 
5
  # data
6
  total_hours = 95282
scripts/count_params_gflops.py CHANGED
@@ -1,15 +1,13 @@
1
- import sys
2
- import os
3
-
4
  sys.path.append(os.getcwd())
5
 
6
- from model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
10
 
11
 
12
- """ ~155M """
13
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -17,11 +15,11 @@ import thop
17
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
 
20
- """ ~335M """
21
  # FLOPs: 622.1 G, Params: 333.2 M
22
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
  # FLOPs: 363.4 G, Params: 335.8 M
24
- transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
 
26
 
27
  model = M2_TTS(transformer=transformer)
@@ -32,8 +30,6 @@ duration = 20
32
  frame_length = int(duration * target_sample_rate / hop_length)
33
  text_length = 150
34
 
35
- flops, params = thop.profile(
36
- model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
- )
38
  print(f"FLOPs: {flops / 1e9} G")
39
  print(f"Params: {params / 1e6} M")
 
1
+ import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
 
6
  import torch
7
  import thop
8
 
9
 
10
+ ''' ~155M '''
11
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
 
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
 
18
+ ''' ~335M '''
19
  # FLOPs: 622.1 G, Params: 333.2 M
20
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
  # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
 
24
 
25
  model = M2_TTS(transformer=transformer)
 
30
  frame_length = int(duration * target_sample_rate / hop_length)
31
  text_length = 150
32
 
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
 
 
34
  print(f"FLOPs: {flops / 1e9} G")
35
  print(f"Params: {params / 1e6} M")