Spaces:
Running
on
T4
Running
on
T4
yuta0306
commited on
Commit
•
565faca
0
Parent(s):
first commit
Browse files- .gitattributes +35 -0
- .gitignore +4 -0
- Dockerfile +66 -0
- README.md +11 -0
- fam/__init__.py +0 -0
- fam/llm/.fast_inference.py.swp +0 -0
- fam/llm/__init__.py +0 -0
- fam/llm/adapters/__init__.py +2 -0
- fam/llm/adapters/base.py +5 -0
- fam/llm/adapters/flattened_encodec.py +38 -0
- fam/llm/adapters/tilted_encodec.py +45 -0
- fam/llm/decoders.py +103 -0
- fam/llm/enhancers.py +102 -0
- fam/llm/fast_inference.py +151 -0
- fam/llm/fast_inference_utils.py +453 -0
- fam/llm/fast_model.py +261 -0
- fam/llm/inference.py +710 -0
- fam/llm/layers/__init__.py +3 -0
- fam/llm/layers/attn.py +185 -0
- fam/llm/layers/combined.py +52 -0
- fam/llm/layers/layers.py +72 -0
- fam/llm/mixins/__init__.py +2 -0
- fam/llm/mixins/causal.py +511 -0
- fam/llm/mixins/non_causal.py +67 -0
- fam/llm/model.py +524 -0
- fam/llm/sample.py +731 -0
- fam/llm/serving.py +197 -0
- fam/llm/utils.py +91 -0
- fam/py.typed +0 -0
- fam/quantiser/__init__.py +0 -0
- fam/quantiser/audio/__init__.py +0 -0
- fam/quantiser/audio/speaker_encoder/__init__.py +0 -0
- fam/quantiser/audio/speaker_encoder/audio.py +22 -0
- fam/quantiser/audio/speaker_encoder/ckpt/.gitattributes +1 -0
- fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt +3 -0
- fam/quantiser/audio/speaker_encoder/model.py +123 -0
- fam/quantiser/text/tokenise.py +32 -0
- fam/ui/app.py +201 -0
- requirements.txt +138 -0
- setup.py +6 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/.venv/
|
2 |
+
/models/
|
3 |
+
/**/__pycache__/
|
4 |
+
/**/*.pyc
|
Dockerfile
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
RUN apt-get update && \
|
5 |
+
apt-get upgrade -y && \
|
6 |
+
apt-get install -y --no-install-recommends \
|
7 |
+
libgirepository1.0-dev \
|
8 |
+
git \
|
9 |
+
git-lfs \
|
10 |
+
wget \
|
11 |
+
curl \
|
12 |
+
# python build dependencies \
|
13 |
+
build-essential \
|
14 |
+
libssl-dev \
|
15 |
+
zlib1g-dev \
|
16 |
+
libbz2-dev \
|
17 |
+
libreadline-dev \
|
18 |
+
libsqlite3-dev \
|
19 |
+
libncursesw5-dev \
|
20 |
+
xz-utils \
|
21 |
+
tk-dev \
|
22 |
+
libxml2-dev \
|
23 |
+
libxmlsec1-dev \
|
24 |
+
libffi-dev \
|
25 |
+
liblzma-dev \
|
26 |
+
# nightly dependencies \
|
27 |
+
libdbus-glib-1-dev \
|
28 |
+
libpng-dev \
|
29 |
+
libjpeg-dev \
|
30 |
+
libcairo2-dev \
|
31 |
+
# gradio dependencies \
|
32 |
+
ffmpeg \
|
33 |
+
# fairseq2 dependencies \
|
34 |
+
libsndfile-dev && \
|
35 |
+
apt-get clean && \
|
36 |
+
rm -rf /var/lib/apt/lists/*
|
37 |
+
|
38 |
+
RUN useradd -m -u 1000 user
|
39 |
+
USER user
|
40 |
+
ENV HOME=/home/user \
|
41 |
+
PATH=/home/user/.local/bin:${PATH}
|
42 |
+
WORKDIR ${HOME}/app
|
43 |
+
|
44 |
+
COPY --chown=user . $HOME/app
|
45 |
+
RUN curl https://pyenv.run | bash
|
46 |
+
ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
|
47 |
+
ARG PYTHON_VERSION=3.10.13
|
48 |
+
RUN pyenv install ${PYTHON_VERSION} && \
|
49 |
+
pyenv global ${PYTHON_VERSION} && \
|
50 |
+
pyenv rehash && \
|
51 |
+
pip install --no-cache-dir -U pip setuptools wheel
|
52 |
+
RUN pip install packaging && \
|
53 |
+
pip install -r ${HOME}/app/requirements.txt && \
|
54 |
+
pip install -U flash-attn gradio spacy transformers fastapi tyro julius audiocraft tiktoken hf-transfer && \
|
55 |
+
pip install -U --pre torch torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
|
56 |
+
ENV PYTHONPATH=${HOME}/app \
|
57 |
+
PYTHONUNBUFFERED=1 \
|
58 |
+
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
59 |
+
GRADIO_ALLOW_FLAGGING=never \
|
60 |
+
GRADIO_NUM_PORTS=1 \
|
61 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
62 |
+
GRADIO_THEME=huggingface \
|
63 |
+
TQDM_POSITION=-1 \
|
64 |
+
TQDM_MININTERVAL=1 \
|
65 |
+
SYSTEM=spaces
|
66 |
+
CMD python fam/llm/serving.py --huggingface_repo_id kotoba-tech/kotoba-speech-v0.1 & python fam/ui/app.py
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Kotoba Voice Testing
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: green
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
fam/__init__.py
ADDED
File without changes
|
fam/llm/.fast_inference.py.swp
ADDED
Binary file (16.4 kB). View file
|
|
fam/llm/__init__.py
ADDED
File without changes
|
fam/llm/adapters/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
|
2 |
+
from fam.llm.adapters.tilted_encodec import TiltedEncodec
|
fam/llm/adapters/base.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
|
4 |
+
class BaseDataAdapter(ABC):
|
5 |
+
pass
|
fam/llm/adapters/flattened_encodec.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fam.llm.adapters.base import BaseDataAdapter
|
2 |
+
|
3 |
+
|
4 |
+
class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter):
|
5 |
+
def __init__(self, end_of_audio_token):
|
6 |
+
self._end_of_audio_token = end_of_audio_token
|
7 |
+
|
8 |
+
def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
|
9 |
+
assert len(tokens) == 1
|
10 |
+
tokens = tokens[0]
|
11 |
+
|
12 |
+
text_ids = []
|
13 |
+
extracted_audio_ids = [[], []]
|
14 |
+
|
15 |
+
for t in tokens:
|
16 |
+
if t < self._end_of_audio_token:
|
17 |
+
extracted_audio_ids[0].append(t)
|
18 |
+
elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token:
|
19 |
+
extracted_audio_ids[1].append(t - self._end_of_audio_token)
|
20 |
+
# We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token
|
21 |
+
elif t > 2 * self._end_of_audio_token:
|
22 |
+
text_ids.append(t)
|
23 |
+
|
24 |
+
if len(set([len(x) for x in extracted_audio_ids])) != 1:
|
25 |
+
min_len = min([len(x) for x in extracted_audio_ids])
|
26 |
+
max_len = max([len(x) for x in extracted_audio_ids])
|
27 |
+
print("WARNING: Number of tokens at each hierarchy must be of the same length!")
|
28 |
+
print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
|
29 |
+
print([len(x) for x in extracted_audio_ids])
|
30 |
+
extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
|
31 |
+
|
32 |
+
return text_ids[:-1], extracted_audio_ids
|
33 |
+
|
34 |
+
def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
|
35 |
+
"""
|
36 |
+
Performs the required combination and padding as needed.
|
37 |
+
"""
|
38 |
+
raise NotImplementedError
|
fam/llm/adapters/tilted_encodec.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fam.llm.adapters.base import BaseDataAdapter
|
2 |
+
|
3 |
+
|
4 |
+
class TiltedEncodec(BaseDataAdapter):
|
5 |
+
def __init__(self, end_of_audio_token):
|
6 |
+
self._end_of_audio_token = end_of_audio_token
|
7 |
+
|
8 |
+
def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
|
9 |
+
assert len(tokens) > 1
|
10 |
+
|
11 |
+
text_ids = []
|
12 |
+
extracted_audio_ids = []
|
13 |
+
|
14 |
+
extracted_audio_ids.append([])
|
15 |
+
# Handle first hierarchy as special case as it contains text tokens as well
|
16 |
+
# TODO: maybe it doesn't need special case, and can be handled on it's own :)
|
17 |
+
for t in tokens[0]:
|
18 |
+
if t > self._end_of_audio_token:
|
19 |
+
text_ids.append(t)
|
20 |
+
elif t < self._end_of_audio_token:
|
21 |
+
extracted_audio_ids[0].append(t)
|
22 |
+
|
23 |
+
# Handle the rest of the hierarchies
|
24 |
+
for i in range(1, len(tokens)):
|
25 |
+
token_hierarchy_ids = tokens[i]
|
26 |
+
extracted_audio_ids.append([])
|
27 |
+
for t in token_hierarchy_ids:
|
28 |
+
if t < self._end_of_audio_token:
|
29 |
+
extracted_audio_ids[i].append(t)
|
30 |
+
|
31 |
+
if len(set([len(x) for x in extracted_audio_ids])) != 1:
|
32 |
+
min_len = min([len(x) for x in extracted_audio_ids])
|
33 |
+
max_len = max([len(x) for x in extracted_audio_ids])
|
34 |
+
print("WARNING: Number of tokens at each hierarchy must be of the same length!")
|
35 |
+
print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
|
36 |
+
print([len(x) for x in extracted_audio_ids])
|
37 |
+
extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
|
38 |
+
|
39 |
+
return text_ids[:-1], extracted_audio_ids
|
40 |
+
|
41 |
+
def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
|
42 |
+
"""
|
43 |
+
Performs the required combination and padding as needed.
|
44 |
+
"""
|
45 |
+
raise NotImplementedError
|
fam/llm/decoders.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
import uuid
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from typing import Callable, Optional, Union
|
6 |
+
|
7 |
+
import julius
|
8 |
+
import torch
|
9 |
+
from audiocraft.data.audio import audio_read, audio_write
|
10 |
+
from audiocraft.models import MultiBandDiffusion # type: ignore
|
11 |
+
|
12 |
+
from IPython import embed
|
13 |
+
|
14 |
+
class Decoder(ABC):
|
15 |
+
@abstractmethod
|
16 |
+
def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
|
20 |
+
class EncodecDecoder(Decoder):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
tokeniser_decode_fn: Callable[[list[int]], str],
|
24 |
+
data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
|
25 |
+
output_dir: str,
|
26 |
+
):
|
27 |
+
self._mbd_bandwidth = 6 # 1.5
|
28 |
+
self._mbd_sample_rate = 24_000
|
29 |
+
self._end_of_audio_token = 1024
|
30 |
+
self._num_codebooks = 8
|
31 |
+
self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth)
|
32 |
+
|
33 |
+
self.tokeniser_decode_fn = tokeniser_decode_fn
|
34 |
+
self._data_adapter_fn = data_adapter_fn
|
35 |
+
|
36 |
+
self.output_dir = pathlib.Path(output_dir).resolve()
|
37 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
38 |
+
|
39 |
+
def _save_audio(self, name: str, wav: torch.Tensor):
|
40 |
+
audio_write(
|
41 |
+
name,
|
42 |
+
wav.squeeze(0).cpu(),
|
43 |
+
self._mbd_sample_rate,
|
44 |
+
strategy="loudness",
|
45 |
+
loudness_compressor=True,
|
46 |
+
)
|
47 |
+
|
48 |
+
def get_tokens(self, audio_path: str) -> list[list[int]]:
|
49 |
+
"""
|
50 |
+
Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
|
51 |
+
limited codebook reconstruction or sampling from second stage model only).
|
52 |
+
"""
|
53 |
+
pass
|
54 |
+
wav, sr = audio_read(audio_path)
|
55 |
+
if sr != self._mbd_sample_rate:
|
56 |
+
wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
|
57 |
+
if wav.ndim == 2:
|
58 |
+
wav = wav.unsqueeze(1)
|
59 |
+
wav = wav.to("cuda")
|
60 |
+
tokens = self.mbd.codec_model.encode(wav)
|
61 |
+
tokens = tokens[0][0]
|
62 |
+
# embed()
|
63 |
+
return tokens.tolist()
|
64 |
+
|
65 |
+
def decode(
|
66 |
+
self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
|
67 |
+
) -> Union[str, torch.Tensor]:
|
68 |
+
# TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
|
69 |
+
text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
|
70 |
+
text = self.tokeniser_decode_fn(text_ids)
|
71 |
+
print(f"Text: {text}")
|
72 |
+
|
73 |
+
tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
|
74 |
+
|
75 |
+
if tokens.shape[1] < self._num_codebooks:
|
76 |
+
tokens = torch.cat(
|
77 |
+
[tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
|
78 |
+
)
|
79 |
+
|
80 |
+
if causal:
|
81 |
+
return tokens
|
82 |
+
else:
|
83 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
|
84 |
+
# embed()
|
85 |
+
wav = self.mbd.tokens_to_wav(tokens)
|
86 |
+
# NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
|
87 |
+
|
88 |
+
if wav.shape[-1] < 9600:
|
89 |
+
# this causes problem for the code below, and is also odd :)
|
90 |
+
# first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
|
91 |
+
raise Exception("wav predicted is shorter than 400ms!")
|
92 |
+
|
93 |
+
try:
|
94 |
+
wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
|
95 |
+
self._save_audio(wav_file_name, wav)
|
96 |
+
print(f"\nSaved audio to {wav_file_name}.wav")
|
97 |
+
return wav_file_name
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Failed to save audio! Reason: {e}")
|
100 |
+
wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
|
101 |
+
self._save_audio(wav_file_name, wav)
|
102 |
+
print(f"\nSaved audio to {wav_file_name}.wav")
|
103 |
+
return wav_file_name
|
fam/llm/enhancers.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC
|
3 |
+
from typing import Literal, Optional
|
4 |
+
|
5 |
+
from df.enhance import enhance, init_df, load_audio, save_audio
|
6 |
+
from pydub import AudioSegment
|
7 |
+
|
8 |
+
|
9 |
+
def convert_to_wav(input_file: str, output_file: str):
|
10 |
+
"""Convert an audio file to WAV format
|
11 |
+
|
12 |
+
Args:
|
13 |
+
input_file (str): path to input audio file
|
14 |
+
output_file (str): path to output WAV file
|
15 |
+
|
16 |
+
"""
|
17 |
+
# Detect the format of the input file
|
18 |
+
format = input_file.split(".")[-1].lower()
|
19 |
+
|
20 |
+
# Read the audio file
|
21 |
+
audio = AudioSegment.from_file(input_file, format=format)
|
22 |
+
|
23 |
+
# Export as WAV
|
24 |
+
audio.export(output_file, format="wav")
|
25 |
+
|
26 |
+
|
27 |
+
def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
|
28 |
+
"""Generate the output file path
|
29 |
+
|
30 |
+
Args:
|
31 |
+
audio_file (str): path to input audio file
|
32 |
+
tag (str): tag to append to the output file name
|
33 |
+
ext (str, optional): extension of the output file. Defaults to None.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
str: path to output file
|
37 |
+
"""
|
38 |
+
|
39 |
+
directory = "./enhanced"
|
40 |
+
# Get the name of the input file
|
41 |
+
filename = os.path.basename(audio_file)
|
42 |
+
|
43 |
+
# Get the name of the input file without the extension
|
44 |
+
filename_without_extension = os.path.splitext(filename)[0]
|
45 |
+
|
46 |
+
# Get the extension of the input file
|
47 |
+
extension = ext or os.path.splitext(filename)[1]
|
48 |
+
|
49 |
+
# Generate the output file path
|
50 |
+
output_file = os.path.join(directory, filename_without_extension + tag + extension)
|
51 |
+
|
52 |
+
return output_file
|
53 |
+
|
54 |
+
|
55 |
+
class BaseEnhancer(ABC):
|
56 |
+
"""Base class for audio enhancers"""
|
57 |
+
|
58 |
+
def __init__(self, *args, **kwargs):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
|
65 |
+
output_file = make_output_file_path(audio_file, tag, ext=ext)
|
66 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
67 |
+
return output_file
|
68 |
+
|
69 |
+
|
70 |
+
class DFEnhancer(BaseEnhancer):
|
71 |
+
def __init__(self, *args, **kwargs):
|
72 |
+
self.model, self.df_state, _ = init_df()
|
73 |
+
|
74 |
+
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
|
75 |
+
output_file = output_file or self.get_output_file(audio_file, "_df")
|
76 |
+
|
77 |
+
audio, _ = load_audio(audio_file, sr=self.df_state.sr())
|
78 |
+
|
79 |
+
enhanced = enhance(self.model, self.df_state, audio)
|
80 |
+
|
81 |
+
save_audio(output_file, enhanced, self.df_state.sr())
|
82 |
+
|
83 |
+
return output_file
|
84 |
+
|
85 |
+
|
86 |
+
def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
|
87 |
+
"""Get an audio enhancer
|
88 |
+
|
89 |
+
Args:
|
90 |
+
enhancer_name (Literal["df"]): name of the audio enhancer
|
91 |
+
|
92 |
+
Raises:
|
93 |
+
ValueError: if the enhancer name is not recognised
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
BaseEnhancer: audio enhancer
|
97 |
+
"""
|
98 |
+
|
99 |
+
if enhancer_name == "df":
|
100 |
+
return DFEnhancer()
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
|
fam/llm/fast_inference.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import tempfile
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
|
11 |
+
from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
|
12 |
+
from fam.llm.decoders import EncodecDecoder
|
13 |
+
from fam.llm.fast_inference_utils import build_model, main
|
14 |
+
from fam.llm.inference import (
|
15 |
+
EncodecDecoder,
|
16 |
+
InferenceConfig,
|
17 |
+
Model,
|
18 |
+
TiltedEncodec,
|
19 |
+
TrainedBPETokeniser,
|
20 |
+
get_cached_embedding,
|
21 |
+
get_cached_file,
|
22 |
+
get_enhancer,
|
23 |
+
)
|
24 |
+
from fam.llm.utils import (
|
25 |
+
check_audio_file,
|
26 |
+
get_default_dtype,
|
27 |
+
get_device,
|
28 |
+
normalize_text,
|
29 |
+
)
|
30 |
+
import argparse
|
31 |
+
|
32 |
+
|
33 |
+
class TTS:
|
34 |
+
def __init__(
|
35 |
+
self, model_name: str = "kotoba-tech/kotoba-speech-v0.1", *, seed: int = 1337, output_dir: str = "outputs", first_model_path: str = None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/kotoba-tech/)
|
39 |
+
"""
|
40 |
+
|
41 |
+
# NOTE: this needs to come first so that we don't change global state when we want to use
|
42 |
+
# the torch.compiled-model.
|
43 |
+
self._dtype = get_default_dtype()
|
44 |
+
self._device = get_device()
|
45 |
+
self._model_dir = snapshot_download(repo_id=model_name)
|
46 |
+
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
|
47 |
+
self.output_dir = output_dir
|
48 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
49 |
+
|
50 |
+
second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
|
51 |
+
config_second_stage = InferenceConfig(
|
52 |
+
ckpt_path=second_stage_ckpt_path,
|
53 |
+
num_samples=1,
|
54 |
+
seed=seed,
|
55 |
+
device=self._device,
|
56 |
+
dtype=self._dtype,
|
57 |
+
compile=False,
|
58 |
+
init_from="resume",
|
59 |
+
output_dir=self.output_dir,
|
60 |
+
)
|
61 |
+
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
|
62 |
+
self.llm_second_stage = Model(
|
63 |
+
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
|
64 |
+
)
|
65 |
+
self.enhancer = get_enhancer("df")
|
66 |
+
|
67 |
+
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
|
68 |
+
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
|
69 |
+
precision=self.precision,
|
70 |
+
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
|
71 |
+
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
|
72 |
+
device=self._device,
|
73 |
+
compile=True,
|
74 |
+
compile_prefill=True,
|
75 |
+
first_model_path=first_model_path,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
|
80 |
+
"""
|
81 |
+
text: Text to speak
|
82 |
+
spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
|
83 |
+
top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
|
84 |
+
guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
|
85 |
+
temperature: Temperature for sampling applied to both LLMs (first & second stage)
|
86 |
+
|
87 |
+
returns: path to speech .wav file
|
88 |
+
"""
|
89 |
+
text = normalize_text(text)
|
90 |
+
spk_ref_path = get_cached_file(spk_ref_path)
|
91 |
+
check_audio_file(spk_ref_path)
|
92 |
+
spk_emb = get_cached_embedding(
|
93 |
+
spk_ref_path,
|
94 |
+
self.smodel,
|
95 |
+
).to(device=self._device, dtype=self.precision)
|
96 |
+
|
97 |
+
start = time.time()
|
98 |
+
# first stage LLM
|
99 |
+
tokens = main(
|
100 |
+
model=self.model,
|
101 |
+
tokenizer=self.tokenizer,
|
102 |
+
model_size=self.model_size,
|
103 |
+
prompt=text,
|
104 |
+
spk_emb=spk_emb,
|
105 |
+
top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
|
106 |
+
guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
|
107 |
+
temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
|
108 |
+
)
|
109 |
+
text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
|
110 |
+
|
111 |
+
b_speaker_embs = spk_emb.unsqueeze(0)
|
112 |
+
|
113 |
+
# second stage LLM + multi-band diffusion model
|
114 |
+
wav_files = self.llm_second_stage(
|
115 |
+
texts=[text],
|
116 |
+
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
|
117 |
+
speaker_embs=b_speaker_embs,
|
118 |
+
batch_size=1,
|
119 |
+
guidance_scale=None,
|
120 |
+
top_p=None,
|
121 |
+
top_k=200,
|
122 |
+
temperature=1.0,
|
123 |
+
max_new_tokens=None,
|
124 |
+
)
|
125 |
+
|
126 |
+
# enhance using deepfilternet
|
127 |
+
wav_file = wav_files[0]
|
128 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
|
129 |
+
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
|
130 |
+
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
|
131 |
+
print(f"\nSaved audio to {wav_file}.wav")
|
132 |
+
|
133 |
+
# calculating real-time factor (RTF)
|
134 |
+
time_to_synth_s = time.time() - start
|
135 |
+
audio, sr = librosa.load(str(wav_file) + ".wav")
|
136 |
+
duration_s = librosa.get_duration(y=audio, sr=sr)
|
137 |
+
print(f"\nTotal time to synth (s): {time_to_synth_s}")
|
138 |
+
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
|
139 |
+
|
140 |
+
return str(wav_file) + ".wav"
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser(description="Split list into chunks")
|
145 |
+
parser.add_argument("--in_file", default="/home/data/reazon_large-v2_denoise/large.jsonl", help="Name of the file")
|
146 |
+
parser.add_argument("--text", type=str, default="This is a demo for text-to-speech.", help="input text")
|
147 |
+
parser.add_argument("--spk_ref_path", type=str, default="assets/bria.mp3", help="speaker path")
|
148 |
+
parser.add_argument("--first_model_path", type=str, default=None, help="path to the first model")
|
149 |
+
args = parser.parse_args()
|
150 |
+
tts = TTS(first_model_path=args.first_model_path)
|
151 |
+
tts.synthesise(text=args.text, spk_ref_path=args.spk_ref_path)
|
fam/llm/fast_inference_utils.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kotoba Technologies, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
5 |
+
# provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# 1. Redistributions of source code must retain the above copyright notice, this list of
|
8 |
+
# conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice, this
|
11 |
+
# list of conditions and the following disclaimer in the documentation and/or other
|
12 |
+
# materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# 3. Neither the name of the copyright holder nor the names of its contributors
|
15 |
+
# may be used to endorse or promote products derived from this software without
|
16 |
+
# specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
|
19 |
+
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
20 |
+
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
21 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
22 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
23 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
24 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import itertools
|
27 |
+
import gc
|
28 |
+
import time
|
29 |
+
from pathlib import Path
|
30 |
+
from typing import Optional, Tuple
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch._dynamo.config
|
34 |
+
import torch._inductor.config
|
35 |
+
import tqdm
|
36 |
+
|
37 |
+
|
38 |
+
def device_sync(device):
|
39 |
+
if "cuda" in device:
|
40 |
+
torch.cuda.synchronize()
|
41 |
+
elif "cpu" in device:
|
42 |
+
pass
|
43 |
+
else:
|
44 |
+
print(f"device={device} is not yet suppported")
|
45 |
+
|
46 |
+
|
47 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
48 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
49 |
+
torch._inductor.config.fx_graph_cache = (
|
50 |
+
True # Experimental feature to reduce compilation times, will be on by default in future
|
51 |
+
)
|
52 |
+
|
53 |
+
# imports need to happen after setting above flags
|
54 |
+
from fam.llm.fast_model import Transformer
|
55 |
+
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
|
56 |
+
from fam.quantiser.text.tokenise import TrainedBPETokeniser
|
57 |
+
|
58 |
+
|
59 |
+
def multinomial_sample_one_no_sync(
|
60 |
+
probs_sort,
|
61 |
+
): # Does multinomial sampling without a cuda synchronization
|
62 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
63 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
64 |
+
|
65 |
+
|
66 |
+
def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor):
|
67 |
+
# ref: huggingface/transformers
|
68 |
+
|
69 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
70 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
71 |
+
|
72 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
73 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
74 |
+
# Keep at least min_tokens_to_keep
|
75 |
+
sorted_indices_to_remove[-1:] = 0
|
76 |
+
|
77 |
+
# scatter sorted tensors to original indexing
|
78 |
+
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
79 |
+
scores = logits.masked_fill(indices_to_remove, -float("Inf"))
|
80 |
+
return scores
|
81 |
+
|
82 |
+
|
83 |
+
def logits_to_probs(
|
84 |
+
logits,
|
85 |
+
*,
|
86 |
+
temperature: torch.Tensor,
|
87 |
+
top_p: Optional[torch.Tensor] = None,
|
88 |
+
top_k: Optional[torch.Tensor] = None,
|
89 |
+
):
|
90 |
+
logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature))
|
91 |
+
|
92 |
+
if top_k is not None:
|
93 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
94 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
95 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
96 |
+
|
97 |
+
if top_p is not None:
|
98 |
+
logits = top_p_sample(logits, top_p)
|
99 |
+
|
100 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
101 |
+
|
102 |
+
return probs
|
103 |
+
|
104 |
+
|
105 |
+
def sample(
|
106 |
+
logits,
|
107 |
+
guidance_scale: torch.Tensor,
|
108 |
+
temperature: torch.Tensor,
|
109 |
+
top_p: Optional[torch.Tensor] = None,
|
110 |
+
top_k: Optional[torch.Tensor] = None,
|
111 |
+
):
|
112 |
+
# (b, t, vocab_size)
|
113 |
+
logits = logits[:, -1]
|
114 |
+
logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0)
|
115 |
+
logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb
|
116 |
+
probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k)
|
117 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
118 |
+
return idx_next, probs
|
119 |
+
|
120 |
+
|
121 |
+
def prefill(
|
122 |
+
model: Transformer,
|
123 |
+
x: torch.Tensor,
|
124 |
+
spk_emb: torch.Tensor,
|
125 |
+
input_pos: torch.Tensor,
|
126 |
+
**sampling_kwargs,
|
127 |
+
) -> torch.Tensor:
|
128 |
+
# input_pos: [B, S]
|
129 |
+
logits = model(x, spk_emb, input_pos)
|
130 |
+
return sample(logits, **sampling_kwargs)[0]
|
131 |
+
|
132 |
+
|
133 |
+
def decode_one_token(
|
134 |
+
model: Transformer,
|
135 |
+
x: torch.Tensor,
|
136 |
+
spk_emb: torch.Tensor,
|
137 |
+
input_pos: torch.Tensor,
|
138 |
+
**sampling_kwargs,
|
139 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
140 |
+
# input_pos: [B, 1]
|
141 |
+
assert input_pos.shape[-1] == 1
|
142 |
+
logits = model(x, spk_emb, input_pos)
|
143 |
+
return sample(logits, **sampling_kwargs)
|
144 |
+
|
145 |
+
|
146 |
+
def decode_n_tokens(
|
147 |
+
model: Transformer,
|
148 |
+
cur_token: torch.Tensor,
|
149 |
+
spk_emb: torch.Tensor,
|
150 |
+
input_pos: torch.Tensor,
|
151 |
+
num_new_tokens: int,
|
152 |
+
callback=lambda _: _,
|
153 |
+
return_probs: bool = False,
|
154 |
+
end_of_audio_token: int = 2048,
|
155 |
+
**sampling_kwargs,
|
156 |
+
):
|
157 |
+
new_tokens, new_probs = [], []
|
158 |
+
for i in tqdm.tqdm(range(num_new_tokens)):
|
159 |
+
if (cur_token == end_of_audio_token).any():
|
160 |
+
break
|
161 |
+
with torch.backends.cuda.sdp_kernel(
|
162 |
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
163 |
+
): # Actually better for Inductor to codegen attention here
|
164 |
+
next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs)
|
165 |
+
input_pos += 1
|
166 |
+
new_tokens.append(next_token.clone())
|
167 |
+
callback(new_tokens[-1])
|
168 |
+
if return_probs:
|
169 |
+
new_probs.append(next_prob.clone())
|
170 |
+
cur_token = next_token.view(1, -1).repeat(2, 1)
|
171 |
+
|
172 |
+
return new_tokens, new_probs
|
173 |
+
|
174 |
+
|
175 |
+
def model_forward(model, x, spk_emb, input_pos):
|
176 |
+
return model(x, spk_emb, input_pos)
|
177 |
+
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def generate(
|
181 |
+
model: Transformer,
|
182 |
+
prompt: torch.Tensor,
|
183 |
+
spk_emb: torch.Tensor,
|
184 |
+
*,
|
185 |
+
max_new_tokens: Optional[int] = None,
|
186 |
+
callback=lambda x: x,
|
187 |
+
end_of_audio_token: int = 2048,
|
188 |
+
**sampling_kwargs,
|
189 |
+
) -> torch.Tensor:
|
190 |
+
"""
|
191 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
192 |
+
"""
|
193 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
194 |
+
T = prompt.size(0)
|
195 |
+
if max_new_tokens is None:
|
196 |
+
max_seq_length = model.config.block_size
|
197 |
+
else:
|
198 |
+
max_seq_length = T + max_new_tokens
|
199 |
+
max_seq_length = min(max_seq_length, model.config.block_size)
|
200 |
+
max_new_tokens = max_seq_length - T
|
201 |
+
if max_new_tokens <= 0:
|
202 |
+
raise ValueError("Prompt is too long to generate more tokens")
|
203 |
+
|
204 |
+
device, dtype = prompt.device, prompt.dtype
|
205 |
+
|
206 |
+
seq = torch.clone(prompt)
|
207 |
+
input_pos = torch.arange(0, T, device=device)
|
208 |
+
|
209 |
+
next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
|
210 |
+
seq = torch.cat([seq, next_token.view(1)])
|
211 |
+
|
212 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
213 |
+
|
214 |
+
generated_tokens, _ = decode_n_tokens(
|
215 |
+
model,
|
216 |
+
next_token.view(1, -1).repeat(2, 1),
|
217 |
+
spk_emb,
|
218 |
+
input_pos,
|
219 |
+
max_new_tokens - 1,
|
220 |
+
callback=callback,
|
221 |
+
end_of_audio_token=end_of_audio_token,
|
222 |
+
**sampling_kwargs,
|
223 |
+
)
|
224 |
+
seq = torch.cat([seq, torch.cat(generated_tokens)])
|
225 |
+
|
226 |
+
return seq
|
227 |
+
|
228 |
+
|
229 |
+
def encode_tokens(tokenizer, string, device="cuda"):
|
230 |
+
tokens = tokenizer.encode(string)
|
231 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
232 |
+
|
233 |
+
|
234 |
+
def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path=None, unwanted_prefix="_orig_mod."):
|
235 |
+
##### MODEL
|
236 |
+
with torch.device("meta"):
|
237 |
+
model = Transformer.from_name("kotoba-speech-v0.1")
|
238 |
+
|
239 |
+
# TODO(quantization): enable
|
240 |
+
# if "int8" in str(checkpoint_path):
|
241 |
+
# print("Using int8 weight-only quantization!")
|
242 |
+
# from quantize import WeightOnlyInt8QuantHandler
|
243 |
+
# simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
244 |
+
# model = simple_quantizer.convert_for_runtime()
|
245 |
+
# from quantize import WeightOnlyInt8QuantHandler
|
246 |
+
|
247 |
+
# if "int4" in str(checkpoint_path):
|
248 |
+
# print("Using int4 quantization!")
|
249 |
+
# path_comps = checkpoint_path.name.split(".")
|
250 |
+
# assert path_comps[-2].startswith("g")
|
251 |
+
# groupsize = int(path_comps[-2][1:])
|
252 |
+
# from quantize import WeightOnlyInt4QuantHandler
|
253 |
+
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
254 |
+
# model = simple_quantizer.convert_for_runtime()
|
255 |
+
|
256 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
|
257 |
+
|
258 |
+
###### TOKENIZER
|
259 |
+
tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
|
260 |
+
tokenizer = TrainedBPETokeniser(**tokenizer_info)
|
261 |
+
|
262 |
+
if first_model_path is not None:
|
263 |
+
trained_ckpt = torch.load(str(first_model_path), mmap=True, weights_only=False)
|
264 |
+
state_dict = trained_ckpt["state_dict"]
|
265 |
+
del checkpoint
|
266 |
+
gc.collect()
|
267 |
+
torch.cuda.empty_cache()
|
268 |
+
else:
|
269 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
|
270 |
+
if "state_dict" in checkpoint.keys():
|
271 |
+
state_dict = checkpoint["state_dict"]
|
272 |
+
else:
|
273 |
+
state_dict = checkpoint["model"]
|
274 |
+
# convert Kotoba-Speech model weights naming to gptfast naming
|
275 |
+
for k, v in list(state_dict.items()):
|
276 |
+
if k.startswith(unwanted_prefix):
|
277 |
+
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
278 |
+
state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight")
|
279 |
+
state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight")
|
280 |
+
state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight")
|
281 |
+
state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight")
|
282 |
+
for k, v in list(state_dict.items()):
|
283 |
+
if k.startswith("transformer.h."):
|
284 |
+
state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k)
|
285 |
+
k = k.replace("transformer.h.", "layers.")
|
286 |
+
if ".attn.c_attn." in k:
|
287 |
+
state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k)
|
288 |
+
k = k.replace(".attn.c_attn.", ".attention.wqkv.")
|
289 |
+
if ".attn.c_proj." in k:
|
290 |
+
state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k)
|
291 |
+
k = k.replace(".attn.c_proj.", ".attention.wo.")
|
292 |
+
if ".mlp.swiglu.w1." in k:
|
293 |
+
state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k)
|
294 |
+
k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")
|
295 |
+
if ".mlp.swiglu.w3." in k:
|
296 |
+
state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k)
|
297 |
+
k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")
|
298 |
+
if ".ln_1." in k:
|
299 |
+
state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k)
|
300 |
+
k = k.replace(".ln_1.", ".attention_norm.")
|
301 |
+
if ".ln_2." in k:
|
302 |
+
state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k)
|
303 |
+
k = k.replace(".ln_2.", ".ffn_norm.")
|
304 |
+
if ".mlp.c_proj." in k:
|
305 |
+
state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k)
|
306 |
+
k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")
|
307 |
+
|
308 |
+
model.load_state_dict(state_dict, assign=True)
|
309 |
+
# simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
310 |
+
# quantized_state_dict = simple_quantizer.create_quantized_state_dict()
|
311 |
+
# model = simple_quantizer.convert_for_runtime()
|
312 |
+
# model.load_state_dict(quantized_state_dict, assign=True)
|
313 |
+
model = model.to(device=device, dtype=precision)
|
314 |
+
|
315 |
+
###### SPEAKER EMBEDDER
|
316 |
+
# TODO: fix!
|
317 |
+
smodel = SpeakerEncoder(
|
318 |
+
weights_fpath=spk_emb_ckpt_path,
|
319 |
+
device=device,
|
320 |
+
eval=True,
|
321 |
+
verbose=False,
|
322 |
+
)
|
323 |
+
return model.eval(), tokenizer, smodel
|
324 |
+
|
325 |
+
|
326 |
+
def build_model(
|
327 |
+
*,
|
328 |
+
precision: torch.dtype,
|
329 |
+
checkpoint_path: Path = Path(""),
|
330 |
+
spk_emb_ckpt_path: Path = Path(""),
|
331 |
+
compile_prefill: bool = False,
|
332 |
+
compile: bool = True,
|
333 |
+
device: str = "cuda",
|
334 |
+
first_model_path: str = None,
|
335 |
+
):
|
336 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
337 |
+
|
338 |
+
print(f"Using device={device}")
|
339 |
+
|
340 |
+
print("Loading model ...")
|
341 |
+
t0 = time.time()
|
342 |
+
if first_model_path is None:
|
343 |
+
# model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
|
344 |
+
model, tokenizer, smodel = _load_model(
|
345 |
+
checkpoint_path, spk_emb_ckpt_path, device, precision, unwanted_prefix="first_stage_model_transformer."
|
346 |
+
)
|
347 |
+
|
348 |
+
else:
|
349 |
+
model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path, unwanted_prefix="first_stage_model_transformer.")
|
350 |
+
|
351 |
+
|
352 |
+
device_sync(device=device) # MKG
|
353 |
+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
|
354 |
+
|
355 |
+
torch.manual_seed(1234)
|
356 |
+
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
|
357 |
+
|
358 |
+
with torch.device(device):
|
359 |
+
model.setup_spk_cond_mask()
|
360 |
+
model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size)
|
361 |
+
|
362 |
+
if compile:
|
363 |
+
print("Compiling...Can take up to 2 mins.")
|
364 |
+
global decode_one_token, prefill
|
365 |
+
decode_one_token = torch.compile(
|
366 |
+
decode_one_token,
|
367 |
+
mode="max-autotune",
|
368 |
+
fullgraph=True,
|
369 |
+
)
|
370 |
+
|
371 |
+
if compile_prefill:
|
372 |
+
prefill = torch.compile(
|
373 |
+
prefill,
|
374 |
+
fullgraph=True,
|
375 |
+
dynamic=True,
|
376 |
+
)
|
377 |
+
|
378 |
+
encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
|
379 |
+
spk_emb = torch.randn((1, 256), device=device, dtype=precision)
|
380 |
+
|
381 |
+
device_sync(device=device) # MKG
|
382 |
+
t0 = time.perf_counter()
|
383 |
+
y = generate(
|
384 |
+
model,
|
385 |
+
encoded,
|
386 |
+
spk_emb,
|
387 |
+
max_new_tokens=200,
|
388 |
+
callback=lambda x: x,
|
389 |
+
temperature=torch.tensor(1.0, device=device, dtype=precision),
|
390 |
+
top_k=None,
|
391 |
+
top_p=torch.tensor(0.95, device=device, dtype=precision),
|
392 |
+
guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
|
393 |
+
end_of_audio_token=9999, # don't end early for compilation stage.
|
394 |
+
)
|
395 |
+
|
396 |
+
device_sync(device=device) # MKG
|
397 |
+
|
398 |
+
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
399 |
+
|
400 |
+
return model, tokenizer, smodel, model_size
|
401 |
+
|
402 |
+
|
403 |
+
def main(
|
404 |
+
*,
|
405 |
+
model,
|
406 |
+
tokenizer,
|
407 |
+
model_size,
|
408 |
+
prompt: str,
|
409 |
+
guidance_scale: torch.Tensor,
|
410 |
+
temperature: torch.Tensor,
|
411 |
+
spk_emb: torch.Tensor,
|
412 |
+
top_k: Optional[torch.Tensor] = None,
|
413 |
+
top_p: Optional[torch.Tensor] = None,
|
414 |
+
device: str = "cuda",
|
415 |
+
) -> list:
|
416 |
+
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
|
417 |
+
|
418 |
+
encoded = encode_tokens(tokenizer, prompt, device=device)
|
419 |
+
prompt_length = encoded.size(0)
|
420 |
+
|
421 |
+
aggregate_metrics: dict = {
|
422 |
+
"tokens_per_sec": [],
|
423 |
+
}
|
424 |
+
|
425 |
+
device_sync(device=device) # MKG
|
426 |
+
|
427 |
+
if True:
|
428 |
+
callback = lambda x: x
|
429 |
+
t0 = time.perf_counter()
|
430 |
+
|
431 |
+
y = generate(
|
432 |
+
model,
|
433 |
+
encoded,
|
434 |
+
spk_emb,
|
435 |
+
callback=callback,
|
436 |
+
temperature=temperature,
|
437 |
+
top_k=top_k,
|
438 |
+
top_p=top_p,
|
439 |
+
guidance_scale=guidance_scale,
|
440 |
+
)
|
441 |
+
|
442 |
+
device_sync(device=device) # MKG
|
443 |
+
t = time.perf_counter() - t0
|
444 |
+
|
445 |
+
tokens_generated = y.size(0) - prompt_length
|
446 |
+
tokens_sec = tokens_generated / t
|
447 |
+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
|
448 |
+
print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
|
449 |
+
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
450 |
+
# print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
|
451 |
+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n")
|
452 |
+
|
453 |
+
return y.tolist()
|
fam/llm/fast_model.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kotoba Technologies, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
5 |
+
# provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# 1. Redistributions of source code must retain the above copyright notice, this list of
|
8 |
+
# conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice, this
|
11 |
+
# list of conditions and the following disclaimer in the documentation and/or other
|
12 |
+
# materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# 3. Neither the name of the copyright holder nor the names of its contributors
|
15 |
+
# may be used to endorse or promote products derived from this software without
|
16 |
+
# specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
|
19 |
+
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
20 |
+
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
21 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
22 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
23 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
24 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
from dataclasses import dataclass
|
27 |
+
from functools import reduce
|
28 |
+
from math import gcd
|
29 |
+
from typing import Optional, Tuple
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
from torch import Tensor
|
34 |
+
from torch.nn import functional as F
|
35 |
+
|
36 |
+
from fam.llm.utils import get_default_dtype
|
37 |
+
|
38 |
+
import logging
|
39 |
+
|
40 |
+
# Adjust the logging level
|
41 |
+
logger = logging.getLogger("torch")
|
42 |
+
logger.setLevel(logging.ERROR)
|
43 |
+
|
44 |
+
|
45 |
+
def find_multiple(n: int, *args: Tuple[int]) -> int:
|
46 |
+
k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
|
47 |
+
if n % k == 0:
|
48 |
+
return n
|
49 |
+
return n + k - (n % k)
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class ModelArgs:
|
54 |
+
block_size: int = 2048
|
55 |
+
vocab_size: int = 32000
|
56 |
+
n_layer: int = 32
|
57 |
+
n_head: int = 32
|
58 |
+
dim: int = 4096
|
59 |
+
speaker_emb_dim: int = 256
|
60 |
+
intermediate_size: int = None
|
61 |
+
n_local_heads: int = -1
|
62 |
+
head_dim: int = 64
|
63 |
+
norm_eps: float = 1e-5
|
64 |
+
dtype: torch.dtype = torch.bfloat16
|
65 |
+
|
66 |
+
def __post_init__(self):
|
67 |
+
if self.n_local_heads == -1:
|
68 |
+
self.n_local_heads = self.n_head
|
69 |
+
if self.intermediate_size is None:
|
70 |
+
hidden_dim = 4 * self.dim
|
71 |
+
n_hidden = int(2 * hidden_dim / 3)
|
72 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
73 |
+
self.head_dim = self.dim // self.n_head
|
74 |
+
|
75 |
+
self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()]
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def from_name(cls, name: str):
|
79 |
+
if name in transformer_configs:
|
80 |
+
return cls(**transformer_configs[name])
|
81 |
+
# fuzzy search
|
82 |
+
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
|
83 |
+
assert len(config) == 1, name
|
84 |
+
return cls(**transformer_configs[config[0]])
|
85 |
+
|
86 |
+
|
87 |
+
transformer_configs = {
|
88 |
+
"kotoba-speech-v0.1": dict(
|
89 |
+
n_layer=24,
|
90 |
+
n_head=16,
|
91 |
+
dim=2048,
|
92 |
+
vocab_size=2562,
|
93 |
+
),
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
class KVCache(nn.Module):
|
98 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
|
99 |
+
super().__init__()
|
100 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
101 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
102 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
103 |
+
|
104 |
+
def update(self, input_pos, k_val, v_val):
|
105 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
106 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
107 |
+
|
108 |
+
k_out = self.k_cache
|
109 |
+
v_out = self.v_cache
|
110 |
+
k_out[:, :, input_pos] = k_val
|
111 |
+
v_out[:, :, input_pos] = v_val
|
112 |
+
|
113 |
+
return k_out, v_out
|
114 |
+
|
115 |
+
|
116 |
+
class Transformer(nn.Module):
|
117 |
+
def __init__(self, config: ModelArgs) -> None:
|
118 |
+
super().__init__()
|
119 |
+
self.config = config
|
120 |
+
|
121 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
122 |
+
self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
|
123 |
+
self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
|
124 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
125 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
126 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
127 |
+
|
128 |
+
self.mask_cache: Optional[Tensor] = None
|
129 |
+
self.max_batch_size = -1
|
130 |
+
self.max_seq_length = -1
|
131 |
+
|
132 |
+
def setup_spk_cond_mask(self):
|
133 |
+
self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
|
134 |
+
self.spk_cond_mask[0] = 1
|
135 |
+
|
136 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
137 |
+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
138 |
+
return
|
139 |
+
head_dim = self.config.dim // self.config.n_head
|
140 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
141 |
+
self.max_seq_length = max_seq_length
|
142 |
+
self.max_batch_size = max_batch_size
|
143 |
+
for b in self.layers:
|
144 |
+
b.attention.kv_cache = KVCache(
|
145 |
+
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype
|
146 |
+
)
|
147 |
+
|
148 |
+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
|
149 |
+
|
150 |
+
def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
|
151 |
+
mask = self.causal_mask[None, None, input_pos]
|
152 |
+
x = (
|
153 |
+
self.tok_embeddings(idx)
|
154 |
+
+ self.pos_embeddings(input_pos)
|
155 |
+
# masking for speaker condition free guidance
|
156 |
+
+ self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
|
157 |
+
)
|
158 |
+
|
159 |
+
for i, layer in enumerate(self.layers):
|
160 |
+
x = layer(x, input_pos, mask)
|
161 |
+
x = self.norm(x)
|
162 |
+
logits = self.output(x)
|
163 |
+
return logits
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def from_name(cls, name: str):
|
167 |
+
return cls(ModelArgs.from_name(name))
|
168 |
+
|
169 |
+
|
170 |
+
class TransformerBlock(nn.Module):
|
171 |
+
def __init__(self, config: ModelArgs) -> None:
|
172 |
+
super().__init__()
|
173 |
+
self.attention = Attention(config)
|
174 |
+
self.feed_forward = FeedForward(config)
|
175 |
+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
176 |
+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
177 |
+
|
178 |
+
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
|
179 |
+
h = x + self.attention(self.attention_norm(x), mask, input_pos)
|
180 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
class Attention(nn.Module):
|
185 |
+
def __init__(self, config: ModelArgs):
|
186 |
+
super().__init__()
|
187 |
+
assert config.dim % config.n_head == 0
|
188 |
+
|
189 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
190 |
+
# key, query, value projections for all heads, but in a batch
|
191 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
192 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
193 |
+
self.kv_cache = None
|
194 |
+
|
195 |
+
self.n_head = config.n_head
|
196 |
+
self.head_dim = config.head_dim
|
197 |
+
self.n_local_heads = config.n_local_heads
|
198 |
+
self.dim = config.dim
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
x: Tensor,
|
203 |
+
mask: Tensor,
|
204 |
+
input_pos: Optional[Tensor] = None,
|
205 |
+
) -> Tensor:
|
206 |
+
bsz, seqlen, _ = x.shape
|
207 |
+
|
208 |
+
kv_size = self.n_local_heads * self.head_dim
|
209 |
+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
210 |
+
|
211 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
212 |
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
213 |
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
214 |
+
|
215 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
216 |
+
|
217 |
+
if self.kv_cache is not None:
|
218 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
219 |
+
|
220 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
221 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
222 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
223 |
+
|
224 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
225 |
+
|
226 |
+
y = self.wo(y)
|
227 |
+
return y
|
228 |
+
|
229 |
+
|
230 |
+
class SwiGLU(nn.Module):
|
231 |
+
def __init__(self, config: ModelArgs) -> None:
|
232 |
+
super().__init__()
|
233 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
234 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
235 |
+
|
236 |
+
def forward(self, x: Tensor) -> Tensor:
|
237 |
+
return F.silu(self.w1(x)) * self.w3(x)
|
238 |
+
|
239 |
+
|
240 |
+
class FeedForward(nn.Module):
|
241 |
+
def __init__(self, config: ModelArgs) -> None:
|
242 |
+
super().__init__()
|
243 |
+
self.swiglu = SwiGLU(config)
|
244 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
245 |
+
|
246 |
+
def forward(self, x: Tensor) -> Tensor:
|
247 |
+
return self.w2(self.swiglu(x))
|
248 |
+
|
249 |
+
|
250 |
+
class RMSNorm(nn.Module):
|
251 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
252 |
+
super().__init__()
|
253 |
+
self.eps = eps
|
254 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
255 |
+
|
256 |
+
def _norm(self, x):
|
257 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
258 |
+
|
259 |
+
def forward(self, x: Tensor) -> Tensor:
|
260 |
+
output = self._norm(x.float()).type_as(x)
|
261 |
+
return output * self.weight
|
fam/llm/inference.py
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
import time
|
10 |
+
from contextlib import nullcontext
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import List, Literal, Optional, Tuple, Type, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import tqdm
|
16 |
+
import tqdm.contrib.concurrent
|
17 |
+
import tyro
|
18 |
+
from huggingface_hub import snapshot_download
|
19 |
+
|
20 |
+
from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
|
21 |
+
from fam.llm.decoders import Decoder, EncodecDecoder
|
22 |
+
from fam.llm.enhancers import BaseEnhancer, get_enhancer
|
23 |
+
from fam.llm.model import GPT, GPTConfig
|
24 |
+
from fam.llm.utils import check_audio_file, get_default_dtype, normalize_text
|
25 |
+
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
|
26 |
+
from fam.quantiser.text.tokenise import TrainedBPETokeniser
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class InferenceConfig:
|
31 |
+
ckpt_path: str # path to checkpoint
|
32 |
+
output_dir: str
|
33 |
+
num_samples: int = 10 # number of samples to draw
|
34 |
+
seed: int = 1337 # random seed
|
35 |
+
device: str = "cuda"
|
36 |
+
dtype: str = "bfloat16"
|
37 |
+
compile: bool = False
|
38 |
+
init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
39 |
+
|
40 |
+
def __str__(self):
|
41 |
+
field_strs = []
|
42 |
+
for field in dataclasses.fields(self):
|
43 |
+
value = getattr(self, field.name)
|
44 |
+
field_strs.append(f" {field.name}: {value}")
|
45 |
+
|
46 |
+
return "InferenceConfig:\n" + "\n".join(field_strs)
|
47 |
+
|
48 |
+
|
49 |
+
class Model:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
config: InferenceConfig,
|
53 |
+
tokenizer_cls: Type[TrainedBPETokeniser],
|
54 |
+
decoder_cls: Type[Decoder],
|
55 |
+
data_adapter_fn,
|
56 |
+
use_kv_cache: Optional[Literal["vanilla"]] = None,
|
57 |
+
):
|
58 |
+
# TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
|
59 |
+
# TODO: remove magic number
|
60 |
+
self._encodec_codes_pad_token = 1024
|
61 |
+
self._num_encodec_codebooks = 8
|
62 |
+
self.config = config
|
63 |
+
self.use_kv_cache = use_kv_cache
|
64 |
+
|
65 |
+
torch.manual_seed(config.seed)
|
66 |
+
torch.cuda.manual_seed(config.seed)
|
67 |
+
torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
|
68 |
+
torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
|
69 |
+
device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
|
70 |
+
self.ptdtype = {
|
71 |
+
"float32": torch.float32,
|
72 |
+
"tfloat32": torch.float32,
|
73 |
+
"bfloat16": torch.bfloat16,
|
74 |
+
"float16": torch.float16,
|
75 |
+
}[config.dtype]
|
76 |
+
self._ctx = (
|
77 |
+
nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.use_bpe_tokenizer = False
|
81 |
+
self.load_meta = None
|
82 |
+
self.speaker_cond = None
|
83 |
+
self.meta = None
|
84 |
+
self.model = None
|
85 |
+
self.checkpoint_config = None
|
86 |
+
self.vocab_sizes = None
|
87 |
+
self.smodel = None
|
88 |
+
|
89 |
+
self._init_model()
|
90 |
+
|
91 |
+
self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
|
92 |
+
self.decoder = decoder_cls(
|
93 |
+
tokeniser_decode_fn=self.tokenizer.decode,
|
94 |
+
output_dir=self.config.output_dir,
|
95 |
+
data_adapter_fn=data_adapter_fn,
|
96 |
+
)
|
97 |
+
|
98 |
+
def _init_model(self):
|
99 |
+
if self.config.init_from == "resume":
|
100 |
+
# init from a model saved in a specific directory
|
101 |
+
checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
|
102 |
+
self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
|
103 |
+
|
104 |
+
self.load_meta = False
|
105 |
+
self.speaker_cond = False
|
106 |
+
|
107 |
+
if "config" in checkpoint:
|
108 |
+
self.checkpoint_config = checkpoint["config"]
|
109 |
+
|
110 |
+
self.meta = checkpoint["meta"]
|
111 |
+
load_meta = True
|
112 |
+
|
113 |
+
if load_meta:
|
114 |
+
self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
|
115 |
+
self.speaker_cond = self.meta.get("speaker_cond")
|
116 |
+
|
117 |
+
if self.speaker_cond:
|
118 |
+
speaker_emb_size = self.meta["speaker_emb_size"]
|
119 |
+
|
120 |
+
model_args = checkpoint["model_args"]
|
121 |
+
if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
|
122 |
+
self._encodec_ctx_window = model_args["block_size"]
|
123 |
+
|
124 |
+
gptconf = GPTConfig(**model_args)
|
125 |
+
|
126 |
+
# TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
|
127 |
+
self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
|
128 |
+
state_dict = checkpoint["model"]
|
129 |
+
unwanted_prefix = "_orig_mod."
|
130 |
+
for k, v in list(state_dict.items()):
|
131 |
+
if k.startswith(unwanted_prefix):
|
132 |
+
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
133 |
+
self.model.load_state_dict(state_dict)
|
134 |
+
|
135 |
+
# model
|
136 |
+
self.model.eval()
|
137 |
+
self.model.to(self.config.device)
|
138 |
+
|
139 |
+
if self.config.compile:
|
140 |
+
from einops._torch_specific import allow_ops_in_compiled_graph
|
141 |
+
|
142 |
+
allow_ops_in_compiled_graph()
|
143 |
+
self.model = torch.compile(self.model) # type: ignore
|
144 |
+
|
145 |
+
if self.use_kv_cache is not None:
|
146 |
+
if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
|
147 |
+
raise Exception("kv_cache not supported for non-causal models!")
|
148 |
+
|
149 |
+
if self.use_kv_cache == "vanilla":
|
150 |
+
self.model.enable_kv_cache()
|
151 |
+
else:
|
152 |
+
raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
|
153 |
+
|
154 |
+
def causal_sample(
|
155 |
+
self,
|
156 |
+
*,
|
157 |
+
texts: list[str],
|
158 |
+
batch_size: int,
|
159 |
+
max_new_tokens: int,
|
160 |
+
temperature: Optional[float],
|
161 |
+
top_k: Optional[int],
|
162 |
+
top_p: Optional[float],
|
163 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
164 |
+
guidance_scale: Optional[float] = None,
|
165 |
+
) -> list[torch.Tensor]:
|
166 |
+
"""
|
167 |
+
Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
|
168 |
+
Any flattening / inteleaving / tilting gets reversed before the output is returned.
|
169 |
+
"""
|
170 |
+
if speaker_embs is not None:
|
171 |
+
assert len(texts) == len(speaker_embs)
|
172 |
+
|
173 |
+
encoded_texts = [self.tokenizer.encode(text) for text in texts]
|
174 |
+
|
175 |
+
## create multiple hierarchies and get seq_lens
|
176 |
+
seq_lens = []
|
177 |
+
xs = []
|
178 |
+
for i, encoded_text in enumerate(encoded_texts):
|
179 |
+
encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
|
180 |
+
# TODO: remove magic number
|
181 |
+
xs.append(
|
182 |
+
torch.cat(
|
183 |
+
# [1st hierarchy of text, *remaining hierarchies of padded tokens]
|
184 |
+
# TODO: self.vocab_sizes should be from the model config?
|
185 |
+
[encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
|
186 |
+
dim=0,
|
187 |
+
).unsqueeze(0)
|
188 |
+
) # b x [(b=1, c, t)]
|
189 |
+
seq_lens.append(xs[-1].shape[-1])
|
190 |
+
max_len = max(seq_lens)
|
191 |
+
assert len(xs) == len(seq_lens)
|
192 |
+
|
193 |
+
## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
|
194 |
+
x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
|
195 |
+
for i, _xs in enumerate(xs):
|
196 |
+
assert _xs.shape[-1] == seq_lens[i]
|
197 |
+
x[i, :, : seq_lens[i]] = _xs
|
198 |
+
|
199 |
+
## check that the input is correct
|
200 |
+
for i in range(x.shape[0]):
|
201 |
+
assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
|
202 |
+
|
203 |
+
# TODO: remove magic number
|
204 |
+
if x.shape[1] > 1:
|
205 |
+
assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
|
206 |
+
|
207 |
+
assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
|
208 |
+
|
209 |
+
if self.speaker_cond is False:
|
210 |
+
speaker_embs = None
|
211 |
+
|
212 |
+
# run sampling loop
|
213 |
+
with torch.no_grad():
|
214 |
+
with self._ctx: # type: ignore
|
215 |
+
to_return = []
|
216 |
+
for k in range(self.config.num_samples):
|
217 |
+
assert seq_lens is not None
|
218 |
+
assert batch_size is not None
|
219 |
+
|
220 |
+
if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
|
221 |
+
raise Exception(
|
222 |
+
f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
|
223 |
+
)
|
224 |
+
|
225 |
+
y = self.model.generate(
|
226 |
+
x,
|
227 |
+
max_new_tokens,
|
228 |
+
seq_lens=seq_lens,
|
229 |
+
temperature=temperature,
|
230 |
+
top_k=top_k,
|
231 |
+
top_p=top_p,
|
232 |
+
speaker_embs=speaker_embs,
|
233 |
+
batch_size=batch_size,
|
234 |
+
guidance_scale=guidance_scale,
|
235 |
+
dtype=self.ptdtype,
|
236 |
+
end_of_audio_token=self.tokenizer.offset - 1,
|
237 |
+
end_of_text_token=self.tokenizer.eot_token,
|
238 |
+
)
|
239 |
+
for i in range(len(y)):
|
240 |
+
to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
|
241 |
+
|
242 |
+
return to_return
|
243 |
+
|
244 |
+
def non_causal_sample(
|
245 |
+
self,
|
246 |
+
*,
|
247 |
+
texts: list[str],
|
248 |
+
encodec_tokens: list[torch.Tensor],
|
249 |
+
batch_size: int,
|
250 |
+
top_k: Optional[int],
|
251 |
+
temperature: Optional[float],
|
252 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
253 |
+
) -> list[str]:
|
254 |
+
"""
|
255 |
+
Returns paths to saved audio files.
|
256 |
+
"""
|
257 |
+
if speaker_embs is not None:
|
258 |
+
assert len(texts) == len(speaker_embs)
|
259 |
+
|
260 |
+
encoded_texts = [self.tokenizer.encode(text) for text in texts]
|
261 |
+
|
262 |
+
# setup input
|
263 |
+
# TODO: same code is used during data prep. refactor
|
264 |
+
padded_hierarchies_inputs = []
|
265 |
+
for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
|
266 |
+
x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
|
267 |
+
None, None, ...
|
268 |
+
] # (b=1, c=1, t)
|
269 |
+
|
270 |
+
# TODO: should only happen if decoder is encodecdeocder?
|
271 |
+
assert encodec_token.shape[0] == 1
|
272 |
+
encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
|
273 |
+
assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
|
274 |
+
|
275 |
+
## setup hierarchies of tokens
|
276 |
+
# TODO: refactor and merge with code in processing.py
|
277 |
+
text_tokens = encoded_text # (t,)
|
278 |
+
|
279 |
+
hierarchies_in = []
|
280 |
+
hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
|
281 |
+
hierarchies_in.append(
|
282 |
+
[self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
|
283 |
+
)
|
284 |
+
|
285 |
+
## adding padding / cutting to the right size as needed
|
286 |
+
# TODO: refactor and merge with code in processing.py
|
287 |
+
padded_hierarchies_input = []
|
288 |
+
for _, t_hierarchy in enumerate(hierarchies_in):
|
289 |
+
assert len(t_hierarchy) == len(hierarchies_in[0])
|
290 |
+
if len(t_hierarchy) < self._encodec_ctx_window:
|
291 |
+
padded_hierarchies_input.append(
|
292 |
+
t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
|
293 |
+
)
|
294 |
+
elif len(t_hierarchy) > self._encodec_ctx_window:
|
295 |
+
padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
|
296 |
+
else:
|
297 |
+
padded_hierarchies_input.append(t_hierarchy)
|
298 |
+
|
299 |
+
padded_hierarchies_inputs.append(padded_hierarchies_input)
|
300 |
+
|
301 |
+
## check that the input is correct
|
302 |
+
in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
|
303 |
+
assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
|
304 |
+
|
305 |
+
if self.speaker_cond is False:
|
306 |
+
speaker_embs = None
|
307 |
+
|
308 |
+
# run sampling loop
|
309 |
+
with torch.no_grad():
|
310 |
+
with self._ctx: # type: ignore
|
311 |
+
to_return = []
|
312 |
+
for k in range(self.config.num_samples):
|
313 |
+
y = self.model.generate(
|
314 |
+
in_x,
|
315 |
+
None,
|
316 |
+
temperature=temperature,
|
317 |
+
top_k=top_k,
|
318 |
+
# TODO: handle separate top_p for this model explicitly
|
319 |
+
top_p=None,
|
320 |
+
speaker_embs=speaker_embs,
|
321 |
+
batch_size=batch_size,
|
322 |
+
guidance_scale=None,
|
323 |
+
)
|
324 |
+
|
325 |
+
b_tokens = torch.cat([in_x, y], dim=1)
|
326 |
+
for tokens in b_tokens:
|
327 |
+
try:
|
328 |
+
to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
|
329 |
+
except Exception as e:
|
330 |
+
print("failed to run MBD.")
|
331 |
+
print(f"reason: {str(e)}")
|
332 |
+
to_return.append(None)
|
333 |
+
|
334 |
+
return to_return
|
335 |
+
|
336 |
+
def __call__(
|
337 |
+
self,
|
338 |
+
*,
|
339 |
+
texts: list[str],
|
340 |
+
batch_size: int,
|
341 |
+
max_new_tokens: Optional[int],
|
342 |
+
top_k: Optional[int],
|
343 |
+
top_p: Optional[float],
|
344 |
+
temperature: Optional[float],
|
345 |
+
encodec_tokens: Optional[list[torch.Tensor]] = None,
|
346 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
347 |
+
guidance_scale: Optional[float] = None,
|
348 |
+
):
|
349 |
+
if self.checkpoint_config.get("causal", True):
|
350 |
+
return self.causal_sample(
|
351 |
+
texts=texts,
|
352 |
+
batch_size=batch_size,
|
353 |
+
speaker_embs=speaker_embs,
|
354 |
+
guidance_scale=guidance_scale,
|
355 |
+
max_new_tokens=max_new_tokens,
|
356 |
+
top_k=top_k,
|
357 |
+
top_p=top_p,
|
358 |
+
temperature=temperature,
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
assert encodec_tokens is not None
|
362 |
+
assert guidance_scale is None
|
363 |
+
assert max_new_tokens is None
|
364 |
+
assert top_p is None
|
365 |
+
|
366 |
+
return self.non_causal_sample(
|
367 |
+
texts=texts,
|
368 |
+
encodec_tokens=encodec_tokens,
|
369 |
+
batch_size=batch_size,
|
370 |
+
speaker_embs=speaker_embs,
|
371 |
+
top_k=top_k,
|
372 |
+
temperature=temperature,
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
|
377 |
+
if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
|
378 |
+
return
|
379 |
+
json.dump(
|
380 |
+
{
|
381 |
+
"speaker": ref_path,
|
382 |
+
"text": text,
|
383 |
+
},
|
384 |
+
pathlib.Path(str(wav_path) + ".json").open("w"),
|
385 |
+
)
|
386 |
+
|
387 |
+
|
388 |
+
def get_cached_file(file_or_uri: str):
|
389 |
+
"""
|
390 |
+
If it's an s3 file, download it to a local temporary file and return that path.
|
391 |
+
Otherwise return the path as is.
|
392 |
+
"""
|
393 |
+
is_uri = file_or_uri.startswith("http")
|
394 |
+
|
395 |
+
cache_path = None
|
396 |
+
if is_uri:
|
397 |
+
ext = pathlib.Path(file_or_uri).suffix
|
398 |
+
# hash the file path to get the cache name
|
399 |
+
_cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
|
400 |
+
|
401 |
+
os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
|
402 |
+
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
|
403 |
+
|
404 |
+
if not os.path.exists(cache_path):
|
405 |
+
command = f"curl -o {cache_path} {file_or_uri}"
|
406 |
+
subprocess.run(command, shell=True, check=True)
|
407 |
+
else:
|
408 |
+
if os.path.exists(file_or_uri):
|
409 |
+
cache_path = file_or_uri
|
410 |
+
else:
|
411 |
+
raise FileNotFoundError(f"File {file_or_uri} not found!")
|
412 |
+
return cache_path
|
413 |
+
|
414 |
+
|
415 |
+
def get_cached_embedding(local_file_path: str, spkemb_model):
|
416 |
+
if not os.path.exists(local_file_path):
|
417 |
+
raise FileNotFoundError(f"File {local_file_path} not found!")
|
418 |
+
|
419 |
+
# hash the file path to get the cache name
|
420 |
+
_cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
|
421 |
+
|
422 |
+
os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
|
423 |
+
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
|
424 |
+
|
425 |
+
if not os.path.exists(cache_path):
|
426 |
+
spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
|
427 |
+
torch.save(spk_emb, cache_path)
|
428 |
+
else:
|
429 |
+
spk_emb = torch.load(cache_path)
|
430 |
+
|
431 |
+
return spk_emb
|
432 |
+
|
433 |
+
|
434 |
+
def _sample_utterance_batch(
|
435 |
+
texts: list[str],
|
436 |
+
spk_cond_paths: list[Optional[str]],
|
437 |
+
spkemb_model,
|
438 |
+
first_stage_model,
|
439 |
+
second_stage_model,
|
440 |
+
enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
|
441 |
+
first_stage_ckpt_path: str,
|
442 |
+
second_stage_ckpt_path: str,
|
443 |
+
guidance_scale: Optional[Tuple[float, float]],
|
444 |
+
max_new_tokens: int,
|
445 |
+
top_k: Optional[int],
|
446 |
+
top_p: Optional[float],
|
447 |
+
temperature: Optional[float],
|
448 |
+
batch_size: int = 128,
|
449 |
+
) -> List[str]:
|
450 |
+
|
451 |
+
speaker_embs = []
|
452 |
+
refs = spk_cond_paths.copy()
|
453 |
+
|
454 |
+
# multithreaded loop to cache all the files
|
455 |
+
spk_cond_paths = tqdm.contrib.concurrent.thread_map(
|
456 |
+
get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
|
457 |
+
)
|
458 |
+
|
459 |
+
for i, (text, spk_cond_path) in tqdm.tqdm(
|
460 |
+
enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
|
461 |
+
):
|
462 |
+
texts[i] = normalize_text(text)
|
463 |
+
speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
|
464 |
+
|
465 |
+
b_speaker_embs = torch.cat(speaker_embs, dim=0)
|
466 |
+
|
467 |
+
start = time.time()
|
468 |
+
b_tokens = first_stage_model(
|
469 |
+
texts=texts,
|
470 |
+
speaker_embs=b_speaker_embs,
|
471 |
+
batch_size=batch_size,
|
472 |
+
guidance_scale=guidance_scale,
|
473 |
+
top_p=top_p,
|
474 |
+
top_k=top_k,
|
475 |
+
temperature=temperature,
|
476 |
+
max_new_tokens=max_new_tokens,
|
477 |
+
)
|
478 |
+
|
479 |
+
# TODO: set batch size for second stage model!
|
480 |
+
wav_files = second_stage_model(
|
481 |
+
texts=texts,
|
482 |
+
encodec_tokens=b_tokens,
|
483 |
+
speaker_embs=b_speaker_embs,
|
484 |
+
batch_size=batch_size,
|
485 |
+
guidance_scale=None,
|
486 |
+
top_p=None,
|
487 |
+
top_k=top_k,
|
488 |
+
temperature=temperature,
|
489 |
+
max_new_tokens=None,
|
490 |
+
)
|
491 |
+
|
492 |
+
for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
|
493 |
+
if wav_file is None:
|
494 |
+
continue
|
495 |
+
|
496 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
|
497 |
+
if enhancer is not None:
|
498 |
+
enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
|
499 |
+
enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
|
500 |
+
# copy enhanced_tmp.name back to wav_file
|
501 |
+
print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
|
502 |
+
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
|
503 |
+
|
504 |
+
save_result_metadata(
|
505 |
+
wav_file,
|
506 |
+
ref_name,
|
507 |
+
text,
|
508 |
+
first_stage_ckpt_path,
|
509 |
+
second_stage_ckpt_path,
|
510 |
+
)
|
511 |
+
|
512 |
+
print(f"time_to_synth_s: {time.time() - start}")
|
513 |
+
return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
|
514 |
+
|
515 |
+
|
516 |
+
def sample_utterance(
|
517 |
+
text: str,
|
518 |
+
spk_cond_path: Optional[str],
|
519 |
+
spkemb_model,
|
520 |
+
first_stage_model,
|
521 |
+
second_stage_model,
|
522 |
+
enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
|
523 |
+
first_stage_ckpt_path: str,
|
524 |
+
second_stage_ckpt_path: str,
|
525 |
+
guidance_scale: Optional[Tuple[float, float]],
|
526 |
+
max_new_tokens: int,
|
527 |
+
top_k: Optional[int],
|
528 |
+
top_p: Optional[float],
|
529 |
+
temperature: Optional[float],
|
530 |
+
) -> str:
|
531 |
+
# NOTE: supports max. 220 characters atm.
|
532 |
+
# Long form synthesis coming soon...
|
533 |
+
MAX_CHARS = 220
|
534 |
+
if len(text) > MAX_CHARS:
|
535 |
+
print(
|
536 |
+
f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
|
537 |
+
)
|
538 |
+
|
539 |
+
return _sample_utterance_batch(
|
540 |
+
texts=[text],
|
541 |
+
spk_cond_paths=[spk_cond_path],
|
542 |
+
spkemb_model=spkemb_model,
|
543 |
+
first_stage_model=first_stage_model,
|
544 |
+
second_stage_model=second_stage_model,
|
545 |
+
enhancer=enhancer,
|
546 |
+
first_stage_ckpt_path=first_stage_ckpt_path,
|
547 |
+
second_stage_ckpt_path=second_stage_ckpt_path,
|
548 |
+
batch_size=1,
|
549 |
+
guidance_scale=guidance_scale,
|
550 |
+
max_new_tokens=max_new_tokens,
|
551 |
+
top_k=top_k,
|
552 |
+
top_p=top_p,
|
553 |
+
temperature=temperature,
|
554 |
+
)[0]
|
555 |
+
|
556 |
+
|
557 |
+
def build_models(config_first_stage, config_second_stage, model_dir, device, use_kv_cache):
|
558 |
+
smodel = SpeakerEncoder(
|
559 |
+
weights_fpath=os.path.join(model_dir, "speaker_encoder.pt"), device=device, eval=True, verbose=False
|
560 |
+
)
|
561 |
+
data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
|
562 |
+
llm_first_stage = Model(
|
563 |
+
config_first_stage,
|
564 |
+
TrainedBPETokeniser,
|
565 |
+
EncodecDecoder,
|
566 |
+
data_adapter_fn=data_adapter.decode,
|
567 |
+
use_kv_cache=use_kv_cache,
|
568 |
+
)
|
569 |
+
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
|
570 |
+
llm_second_stage = Model(
|
571 |
+
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
|
572 |
+
)
|
573 |
+
return smodel, llm_first_stage, llm_second_stage
|
574 |
+
|
575 |
+
|
576 |
+
def get_first_stage_path(model_dir: str):
|
577 |
+
"""Absolute path to checkpoint for the first stage model."""
|
578 |
+
return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
|
579 |
+
|
580 |
+
|
581 |
+
def get_second_stage_path(model_dir: str):
|
582 |
+
"""Absolute path to checkpoint for the second stage model."""
|
583 |
+
return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
|
584 |
+
|
585 |
+
|
586 |
+
@dataclass
|
587 |
+
class SamplingControllerConfig:
|
588 |
+
"""
|
589 |
+
Sample from a trained model.
|
590 |
+
"""
|
591 |
+
|
592 |
+
spk_cond_path: str
|
593 |
+
"""Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
|
594 |
+
|
595 |
+
huggingface_repo_id: str = "kotoba-tech/kotoba-speech-v0.1"
|
596 |
+
"""Absolute path to the model directory."""
|
597 |
+
|
598 |
+
text: str = (
|
599 |
+
"This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice."
|
600 |
+
)
|
601 |
+
"""Text to synthesise."""
|
602 |
+
|
603 |
+
num_samples: int = 1
|
604 |
+
"""Number of samples to generate from each model."""
|
605 |
+
|
606 |
+
max_new_tokens: int = 864
|
607 |
+
"""Maximum number of new tokens to generate from the first stage model."""
|
608 |
+
|
609 |
+
temperature: float = 1.0
|
610 |
+
"""Temperature for sampling applied to both models."""
|
611 |
+
|
612 |
+
top_k: Optional[int] = None
|
613 |
+
"""Top k for sampling applied to both models."""
|
614 |
+
|
615 |
+
top_p: Optional[float] = 0.95
|
616 |
+
"""Top p for sampling applied to first-stage model."""
|
617 |
+
|
618 |
+
seed: int = 1337
|
619 |
+
"""Random seed for sampling."""
|
620 |
+
|
621 |
+
device: Literal["cuda", "cpu"] = "cuda"
|
622 |
+
"""Device to use for sampling."""
|
623 |
+
|
624 |
+
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
|
625 |
+
"""Data type to use for sampling."""
|
626 |
+
|
627 |
+
compile: bool = False
|
628 |
+
"""Whether to compile the model using PyTorch 2.0."""
|
629 |
+
|
630 |
+
enhancer: Optional[Literal["df"]] = "df"
|
631 |
+
"""Enhancer to use for post-processing."""
|
632 |
+
|
633 |
+
init_from: str = "resume"
|
634 |
+
"""Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
|
635 |
+
|
636 |
+
use_kv_cache: Optional[Literal["vanilla"]] = "vanilla"
|
637 |
+
"""Type of kv caching to use for inference: 1) [none] no kv caching, 2) [vanilla] use torch attention with hand implemented kv-cache."""
|
638 |
+
|
639 |
+
output_dir: str = "samples/"
|
640 |
+
"""Relative path to output directory"""
|
641 |
+
|
642 |
+
guidance_scale: Optional[Tuple[float, float]] = (3.0, 1.0)
|
643 |
+
"""Guidance scale for sampling: (speaker conditioning guidance_scale, prompt conditioning guidance scale)."""
|
644 |
+
|
645 |
+
batch_size: int = 128
|
646 |
+
"""Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
|
647 |
+
1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
|
648 |
+
64 (w/ guidance) works well."""
|
649 |
+
|
650 |
+
|
651 |
+
if __name__ == "__main__":
|
652 |
+
# TODO: add support for batch sampling via CLI. Function has been implemented above.
|
653 |
+
sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)
|
654 |
+
|
655 |
+
check_audio_file(sampling_config.spk_cond_path)
|
656 |
+
|
657 |
+
model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
|
658 |
+
first_stage_ckpt_path = get_first_stage_path(model_dir)
|
659 |
+
second_stage_ckpt_path = get_second_stage_path(model_dir)
|
660 |
+
|
661 |
+
config_first_stage = InferenceConfig(
|
662 |
+
ckpt_path=first_stage_ckpt_path,
|
663 |
+
num_samples=sampling_config.num_samples,
|
664 |
+
seed=sampling_config.seed,
|
665 |
+
device=sampling_config.device,
|
666 |
+
dtype=sampling_config.dtype,
|
667 |
+
compile=sampling_config.compile,
|
668 |
+
init_from=sampling_config.init_from,
|
669 |
+
output_dir=sampling_config.output_dir,
|
670 |
+
)
|
671 |
+
|
672 |
+
config_second_stage = InferenceConfig(
|
673 |
+
ckpt_path=second_stage_ckpt_path,
|
674 |
+
num_samples=sampling_config.num_samples,
|
675 |
+
seed=sampling_config.seed,
|
676 |
+
device=sampling_config.device,
|
677 |
+
dtype=sampling_config.dtype,
|
678 |
+
compile=sampling_config.compile,
|
679 |
+
init_from=sampling_config.init_from,
|
680 |
+
output_dir=sampling_config.output_dir,
|
681 |
+
)
|
682 |
+
|
683 |
+
sampling_config.max_new_tokens *= (
|
684 |
+
2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
|
685 |
+
)
|
686 |
+
|
687 |
+
# define models
|
688 |
+
smodel, llm_first_stage, llm_second_stage = build_models(
|
689 |
+
config_first_stage,
|
690 |
+
config_second_stage,
|
691 |
+
model_dir=model_dir,
|
692 |
+
device=sampling_config.device,
|
693 |
+
use_kv_cache=sampling_config.use_kv_cache,
|
694 |
+
)
|
695 |
+
|
696 |
+
sample_utterance(
|
697 |
+
sampling_config.text,
|
698 |
+
os.path.expanduser(sampling_config.spk_cond_path),
|
699 |
+
smodel,
|
700 |
+
llm_first_stage,
|
701 |
+
llm_second_stage,
|
702 |
+
sampling_config.enhancer,
|
703 |
+
first_stage_ckpt_path,
|
704 |
+
second_stage_ckpt_path,
|
705 |
+
sampling_config.guidance_scale,
|
706 |
+
max_new_tokens=sampling_config.max_new_tokens,
|
707 |
+
top_k=sampling_config.top_k,
|
708 |
+
top_p=sampling_config.top_p,
|
709 |
+
temperature=sampling_config.temperature,
|
710 |
+
)
|
fam/llm/layers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from fam.llm.layers.attn import SelfAttention
|
2 |
+
from fam.llm.layers.combined import Block
|
3 |
+
from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU
|
fam/llm/layers/attn.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class SelfAttention(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
"""
|
11 |
+
Initializes the SelfAttention module.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
config: An object containing the configuration parameters for the SelfAttention module.
|
15 |
+
"""
|
16 |
+
super().__init__()
|
17 |
+
self._validate_config(config)
|
18 |
+
self._initialize_parameters(config)
|
19 |
+
|
20 |
+
def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
|
21 |
+
"""
|
22 |
+
Empties the key-value cache.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
batch_size: The batch size.
|
26 |
+
kv_cache_maxlen: The maximum length of the key-value cache.
|
27 |
+
dtype: The data type of the cache.
|
28 |
+
|
29 |
+
Raises:
|
30 |
+
Exception: If trying to empty the KV cache when it is disabled.
|
31 |
+
"""
|
32 |
+
if self.kv_cache_enabled is False:
|
33 |
+
raise Exception("Trying to empty KV cache when it is disabled")
|
34 |
+
|
35 |
+
# register so that the cache moves devices along with the module
|
36 |
+
# TODO: get rid of re-allocation.
|
37 |
+
self.register_buffer(
|
38 |
+
"kv_cache",
|
39 |
+
torch.zeros(
|
40 |
+
2,
|
41 |
+
batch_size,
|
42 |
+
kv_cache_maxlen,
|
43 |
+
self.n_head,
|
44 |
+
self.n_embd // self.n_head,
|
45 |
+
dtype=dtype,
|
46 |
+
device=self.c_attn.weight.device,
|
47 |
+
),
|
48 |
+
persistent=False,
|
49 |
+
)
|
50 |
+
|
51 |
+
self.kv_cache_first_empty_index = 0
|
52 |
+
|
53 |
+
def _initialize_parameters(self, config):
|
54 |
+
"""
|
55 |
+
Initializes the parameters of the SelfAttention module.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
config: An object containing the configuration parameters for the SelfAttention module.
|
59 |
+
"""
|
60 |
+
# key, query, value projections for all heads, but in a batch
|
61 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
62 |
+
|
63 |
+
# output projection
|
64 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
65 |
+
|
66 |
+
# regularization
|
67 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
68 |
+
self.n_head = config.n_head
|
69 |
+
self.n_embd = config.n_embd
|
70 |
+
self.dropout = config.dropout
|
71 |
+
self.causal = config.causal
|
72 |
+
self.attn_kernel_type = config.attn_kernel_type
|
73 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
74 |
+
|
75 |
+
self.kv_cache_enabled = False
|
76 |
+
|
77 |
+
def _validate_config(self, config):
|
78 |
+
"""
|
79 |
+
Validates the configuration parameters.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
config: An object containing the configuration parameters for the SelfAttention module.
|
83 |
+
|
84 |
+
Raises:
|
85 |
+
AssertionError: If the embedding dimension is not divisible by the number of heads.
|
86 |
+
"""
|
87 |
+
assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
|
88 |
+
|
89 |
+
def _update_kv_cache(self, q, k, v):
|
90 |
+
"""
|
91 |
+
Updates the key-value cache.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
q: The query tensor.
|
95 |
+
k: The key tensor.
|
96 |
+
v: The value tensor.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
The updated key and value tensors.
|
100 |
+
|
101 |
+
Raises:
|
102 |
+
AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
|
103 |
+
"""
|
104 |
+
q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
|
105 |
+
|
106 |
+
if self.kv_cache_first_empty_index == 0:
|
107 |
+
assert q_time == k_time and q_time == v_time
|
108 |
+
else:
|
109 |
+
assert (
|
110 |
+
q_time == 1
|
111 |
+
), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
|
112 |
+
|
113 |
+
self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
|
114 |
+
self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
|
115 |
+
self.kv_cache_first_empty_index += q_time
|
116 |
+
|
117 |
+
k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
|
118 |
+
v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
|
119 |
+
|
120 |
+
return k, v
|
121 |
+
|
122 |
+
def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
|
123 |
+
"""
|
124 |
+
Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
c_x: The input tensor.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
The output tensor.
|
131 |
+
"""
|
132 |
+
q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
|
133 |
+
q = q.squeeze(2) # (B, T, nh, hs)
|
134 |
+
k = k.squeeze(2) # (B, T, nh, hs)
|
135 |
+
v = v.squeeze(2) # (B, T, nh, hs)
|
136 |
+
|
137 |
+
# if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
|
138 |
+
# use no mask for the "one time step" parts.
|
139 |
+
# calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
|
140 |
+
is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
|
141 |
+
|
142 |
+
if self.kv_cache_enabled:
|
143 |
+
k, v = self._update_kv_cache(q, k, v)
|
144 |
+
|
145 |
+
q = q.transpose(1, 2) # (B, nh, T, hs)
|
146 |
+
k = k.transpose(1, 2) # (B, nh, T, hs)
|
147 |
+
v = v.transpose(1, 2) # (B, nh, T, hs)
|
148 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
149 |
+
q,
|
150 |
+
k,
|
151 |
+
v,
|
152 |
+
attn_mask=None,
|
153 |
+
dropout_p=self.dropout if self.training else 0,
|
154 |
+
is_causal=is_causal_attn_mask,
|
155 |
+
).transpose(
|
156 |
+
1, 2
|
157 |
+
) # (B, nh, T, hs) -> (B, T, nh, hs)
|
158 |
+
|
159 |
+
return y
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
"""
|
163 |
+
Performs the forward pass of the SelfAttention module.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
x: The input tensor.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
The output tensor.
|
170 |
+
"""
|
171 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
172 |
+
|
173 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
174 |
+
c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
|
175 |
+
|
176 |
+
# causal self-attention;
|
177 |
+
if self.attn_kernel_type == "torch_attn":
|
178 |
+
y = self._torch_attn(c_x)
|
179 |
+
else:
|
180 |
+
raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
|
181 |
+
|
182 |
+
y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
|
183 |
+
# output projection
|
184 |
+
y = self.resid_dropout(self.c_proj(y))
|
185 |
+
return y
|
fam/llm/layers/combined.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from fam.llm.layers.attn import SelfAttention
|
4 |
+
from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
|
5 |
+
|
6 |
+
|
7 |
+
class Block(nn.Module):
|
8 |
+
"""
|
9 |
+
Block class represents a single block in the model.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
config (object): Configuration object containing parameters for the block.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
ln_1 (object): Layer normalization for the attention layer.
|
16 |
+
ln_2 (object): Layer normalization for the feed-forward layer.
|
17 |
+
attn (object): Self-attention layer.
|
18 |
+
mlp (object): Multi-layer perceptron layer.
|
19 |
+
|
20 |
+
Methods:
|
21 |
+
forward(x): Performs forward pass through the block.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, config):
|
25 |
+
super().__init__()
|
26 |
+
if config.norm_type == "rmsnorm":
|
27 |
+
if config.rmsnorm_eps is None:
|
28 |
+
raise Exception("RMSNorm requires rmsnorm_eps to be set")
|
29 |
+
self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
|
30 |
+
self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
|
31 |
+
elif config.norm_type == "layernorm":
|
32 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
|
33 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
|
34 |
+
else:
|
35 |
+
raise Exception(f"Unknown norm type: {config.norm_type}")
|
36 |
+
self.attn = SelfAttention(config)
|
37 |
+
|
38 |
+
self.mlp = MLP(config)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
"""
|
42 |
+
Performs forward pass through the block.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x (tensor): Input tensor.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tensor: Output tensor after passing through the block.
|
49 |
+
"""
|
50 |
+
x = x + self.attn(self.ln_1(x))
|
51 |
+
x = x + self.mlp(self.ln_2(x))
|
52 |
+
return x
|
fam/llm/layers/layers.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class LayerNorm(nn.Module):
|
9 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
10 |
+
|
11 |
+
def __init__(self, ndim, bias):
|
12 |
+
super().__init__()
|
13 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
14 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
15 |
+
|
16 |
+
def forward(self, input):
|
17 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
18 |
+
|
19 |
+
|
20 |
+
class RMSNorm(torch.nn.Module):
|
21 |
+
def __init__(self, ndim: int, eps: float):
|
22 |
+
super().__init__()
|
23 |
+
self.eps = eps
|
24 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
25 |
+
|
26 |
+
def _norm(self, x):
|
27 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return self._norm(x) * self.weight
|
31 |
+
|
32 |
+
|
33 |
+
class SwiGLU(nn.Module):
|
34 |
+
def __init__(self, in_dim, out_dim, bias) -> None:
|
35 |
+
super().__init__()
|
36 |
+
self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
|
37 |
+
self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return F.silu(self.w1(x)) * self.w3(x)
|
41 |
+
|
42 |
+
|
43 |
+
class MLP(nn.Module):
|
44 |
+
def __init__(self, config):
|
45 |
+
super().__init__()
|
46 |
+
self.non_linearity = config.nonlinearity_type
|
47 |
+
hidden_dim = 4 * config.n_embd
|
48 |
+
if config.nonlinearity_type == "gelu":
|
49 |
+
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
|
50 |
+
self.gelu = nn.GELU()
|
51 |
+
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
|
52 |
+
elif config.nonlinearity_type == "swiglu":
|
53 |
+
if config.swiglu_multiple_of is None:
|
54 |
+
raise Exception("SwiGLU requires swiglu_multiple_of to be set")
|
55 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
56 |
+
hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
|
57 |
+
# set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
|
58 |
+
self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
|
59 |
+
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
|
60 |
+
else:
|
61 |
+
raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
|
62 |
+
self.dropout = nn.Dropout(config.dropout)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
if self.non_linearity == "gelu":
|
66 |
+
x = self.c_fc(x)
|
67 |
+
x = self.gelu(x)
|
68 |
+
elif self.non_linearity == "swiglu":
|
69 |
+
x = self.swiglu(x)
|
70 |
+
x = self.c_proj(x)
|
71 |
+
x = self.dropout(x)
|
72 |
+
return x
|
fam/llm/mixins/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from fam.llm.mixins.causal import CausalInferenceMixin
|
2 |
+
from fam.llm.mixins.non_causal import NonCausalInferenceMixin
|
fam/llm/mixins/causal.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import tqdm
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from IPython import embed
|
8 |
+
|
9 |
+
def top_p_sample(prob_dist: torch.Tensor, top_p: float):
|
10 |
+
sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1)
|
11 |
+
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size)
|
12 |
+
|
13 |
+
sorted_indices_to_remove = cum_sum_probs > top_p
|
14 |
+
|
15 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
16 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
17 |
+
sorted_indices_to_remove[:, 0] = 0
|
18 |
+
sorted_indices_to_remove = sorted_indices_to_remove.bool()
|
19 |
+
|
20 |
+
# replace probs to be removed with 0 in the sorted_probs
|
21 |
+
sorted_probs[sorted_indices_to_remove] = 0
|
22 |
+
|
23 |
+
# reverse the sorting process
|
24 |
+
reversed_indices = torch.argsort(sorted_indices)
|
25 |
+
prob_dist = torch.gather(sorted_probs, -1, reversed_indices)
|
26 |
+
|
27 |
+
# normalize
|
28 |
+
prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True)
|
29 |
+
|
30 |
+
return prob_dist
|
31 |
+
|
32 |
+
|
33 |
+
class CausalInferenceMixin:
|
34 |
+
"""
|
35 |
+
Mixin class for performing inference in a causal language model.
|
36 |
+
|
37 |
+
This mixin provides methods for predicting the next token in a sequence, sampling from the model,
|
38 |
+
and applying token prediction masks.
|
39 |
+
|
40 |
+
Attributes:
|
41 |
+
None
|
42 |
+
|
43 |
+
Methods:
|
44 |
+
_sample_next_token: Predicts the next token in the sequence.
|
45 |
+
_create_token_pred_mask: Creates a token prediction mask based on sequence lengths.
|
46 |
+
_apply_token_pred_mask: Applies a token prediction mask to the next token predictions.
|
47 |
+
_sample_batch: Samples a batch of tokens from the model.
|
48 |
+
_sort_for_batching: Sorts the input sequences for efficient batching.
|
49 |
+
_causal_sample: Generates a sequence of tokens using causal sampling.
|
50 |
+
|
51 |
+
"""
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def _sample_next_token(
|
55 |
+
self,
|
56 |
+
*,
|
57 |
+
idx: torch.Tensor,
|
58 |
+
speaker_embs: Optional[torch.Tensor],
|
59 |
+
temperature: float,
|
60 |
+
top_k: Optional[int],
|
61 |
+
top_p: Optional[float],
|
62 |
+
guidance_scale: Optional[float],
|
63 |
+
) -> torch.Tensor:
|
64 |
+
"""
|
65 |
+
Predict the next token in the sequence.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
|
69 |
+
speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
|
70 |
+
temperature (float): Sampling temperature.
|
71 |
+
top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
|
72 |
+
top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
|
73 |
+
guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies).
|
77 |
+
"""
|
78 |
+
if top_k is not None and top_p is not None:
|
79 |
+
raise ValueError("Only one of top_k and top_p can be set")
|
80 |
+
|
81 |
+
# if the sequence context is growing too long we must crop it at block_size
|
82 |
+
idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :]
|
83 |
+
|
84 |
+
# forward the model to get the logits for the index in the sequence
|
85 |
+
list_logits, _ = self(
|
86 |
+
idx_cond, speaker_embs=speaker_embs
|
87 |
+
) # list with len num_hierarchies of (b,1,vocab_size) tensors
|
88 |
+
# print(f'{list_logits[0].shape=}, {len(list_logits)=}')
|
89 |
+
# print(f'{list_logits[0][:,:,:10]}')
|
90 |
+
|
91 |
+
if guidance_scale is not None:
|
92 |
+
assert idx_cond.shape[0] % 2 == 0
|
93 |
+
assert list_logits[0].shape[0] % 2 == 0
|
94 |
+
|
95 |
+
for i, logits in enumerate(list_logits):
|
96 |
+
logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0)
|
97 |
+
list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond
|
98 |
+
|
99 |
+
assert list_logits[0].shape[0] == idx_cond.shape[0] // 2
|
100 |
+
|
101 |
+
# pluck the logits at the final step and scale by desired temperature
|
102 |
+
list_logits = [
|
103 |
+
logits[:, -1, :] / temperature for logits in list_logits
|
104 |
+
] # list with len num_hierarchies of (b,vocab_size) tensors
|
105 |
+
|
106 |
+
# optionally crop the logits to only the top k options
|
107 |
+
if top_k is not None:
|
108 |
+
for i in range(len(list_logits)):
|
109 |
+
logits = list_logits[i]
|
110 |
+
v, _ = torch.topk(
|
111 |
+
logits, min(top_k, logits.size(-1))
|
112 |
+
) # returns a descending sorted list of values and indices of top_k values
|
113 |
+
logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf
|
114 |
+
list_logits[i] = logits
|
115 |
+
|
116 |
+
# apply softmax to convert logits to (normalized) probabilities
|
117 |
+
# embed()
|
118 |
+
probs = [
|
119 |
+
F.softmax(logits, dim=-1) for logits in list_logits
|
120 |
+
] # list of len num_hierarchies of (b,vocab_size) tensors
|
121 |
+
# print(f'{probs[0].shape=}')
|
122 |
+
# print(f'{probs[0][:,:,:10]}')
|
123 |
+
if top_p is not None:
|
124 |
+
for i in range(len(probs)):
|
125 |
+
probs[i] = top_p_sample(probs[i], top_p)
|
126 |
+
|
127 |
+
# sample from the distribution
|
128 |
+
idx_next = [
|
129 |
+
torch.multinomial(prob, num_samples=1) for prob in probs
|
130 |
+
] # list of len num_hierarchies of (b,1) tensors
|
131 |
+
idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor
|
132 |
+
|
133 |
+
return idx_next # (b, num_hierarchies) tensor
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor:
|
137 |
+
"""
|
138 |
+
Creates a token prediction mask based on sequence lengths.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
|
142 |
+
seq_lens (list[int]): List of sequence lengths for each sequence in idx.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: Token prediction mask of shape (batch, time).
|
146 |
+
"""
|
147 |
+
token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device)
|
148 |
+
for i in range(len(seq_lens)):
|
149 |
+
token_pred_mask[i, : seq_lens[i]] = True
|
150 |
+
|
151 |
+
assert (token_pred_mask[:, : min(seq_lens)] == 1).all()
|
152 |
+
|
153 |
+
return token_pred_mask
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def _apply_token_pred_mask(
|
157 |
+
self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor
|
158 |
+
) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Applies a token prediction mask to the next token predictions.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies).
|
164 |
+
orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies).
|
165 |
+
token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1).
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
torch.Tensor: Updated next token predictions after applying the token prediction mask.
|
169 |
+
"""
|
170 |
+
idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t
|
171 |
+
|
172 |
+
return idx_next
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def _sample_batch(
|
176 |
+
self,
|
177 |
+
*,
|
178 |
+
idx: torch.Tensor,
|
179 |
+
max_new_tokens: int,
|
180 |
+
seq_lens: list[int],
|
181 |
+
temperature: float,
|
182 |
+
top_k: Optional[int],
|
183 |
+
top_p: Optional[float],
|
184 |
+
speaker_embs: Optional[torch.Tensor],
|
185 |
+
guidance_scale: Optional[float],
|
186 |
+
):
|
187 |
+
"""
|
188 |
+
Samples a batch of tokens from the model.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
|
192 |
+
max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
|
193 |
+
seq_lens (list[int]): List of sequence lengths for each sequence in idx.
|
194 |
+
temperature (float): Sampling temperature.
|
195 |
+
top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
|
196 |
+
top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
|
197 |
+
speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
|
198 |
+
guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
|
202 |
+
"""
|
203 |
+
assert max(seq_lens) <= idx.shape[-1]
|
204 |
+
token_pred_mask = self._create_token_pred_mask(idx, seq_lens)
|
205 |
+
input = torch.clone(idx)
|
206 |
+
|
207 |
+
min_seq_lens = min(seq_lens)
|
208 |
+
idx = idx[:, :, :min_seq_lens]
|
209 |
+
|
210 |
+
if guidance_scale is not None:
|
211 |
+
if speaker_embs is None:
|
212 |
+
raise Exception("Guidance is only supported for conditional models")
|
213 |
+
|
214 |
+
# create speaker embeddings equivalent to the batch size, filling with None
|
215 |
+
# for second half to do unconditional generation.
|
216 |
+
speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0])
|
217 |
+
|
218 |
+
for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "):
|
219 |
+
if (self.kv_cache_enabled is True) and (timestep > min_seq_lens):
|
220 |
+
idx_input = idx[:, :, -1:]
|
221 |
+
else:
|
222 |
+
idx_input = idx
|
223 |
+
|
224 |
+
if guidance_scale is not None:
|
225 |
+
# TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size.
|
226 |
+
if timestep == min_seq_lens:
|
227 |
+
print("[hack!!!!] Guidance is on, so we're doubling batch size!")
|
228 |
+
|
229 |
+
# replicate idx in the batch dimension
|
230 |
+
idx_input = (
|
231 |
+
idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2])
|
232 |
+
)
|
233 |
+
|
234 |
+
# sanity checks
|
235 |
+
assert idx_input.shape[0] % 2 == 0
|
236 |
+
|
237 |
+
idx_next = self._sample_next_token(
|
238 |
+
idx=idx_input,
|
239 |
+
speaker_embs=speaker_embs,
|
240 |
+
temperature=temperature,
|
241 |
+
top_k=top_k,
|
242 |
+
top_p=top_p,
|
243 |
+
guidance_scale=guidance_scale,
|
244 |
+
) # (b, num_hierarchies)
|
245 |
+
|
246 |
+
assert idx_next.shape[0] == idx.shape[0]
|
247 |
+
|
248 |
+
if timestep < token_pred_mask.shape[-1]:
|
249 |
+
idx_next = self._apply_token_pred_mask(
|
250 |
+
idx_next=idx_next,
|
251 |
+
orig_input_at_t=input[:, :, timestep],
|
252 |
+
token_pred_mask_at_t=token_pred_mask[:, [timestep]],
|
253 |
+
)
|
254 |
+
|
255 |
+
idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor
|
256 |
+
# append sampled index to the running sequence and continue
|
257 |
+
idx = torch.cat((idx, idx_next), dim=2)
|
258 |
+
|
259 |
+
return idx
|
260 |
+
|
261 |
+
@torch.no_grad()
|
262 |
+
def _sort_for_batching(
|
263 |
+
self,
|
264 |
+
*,
|
265 |
+
idx: torch.Tensor,
|
266 |
+
seq_lens: list[int],
|
267 |
+
speaker_embs: Optional[torch.Tensor],
|
268 |
+
batch_size: int,
|
269 |
+
max_new_tokens: int,
|
270 |
+
) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
|
271 |
+
"""
|
272 |
+
Sorts the input sequences for efficient batching.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
|
276 |
+
seq_lens (list[int]): List of sequence lengths for each sequence in idx.
|
277 |
+
speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
|
278 |
+
batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
|
279 |
+
max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
|
283 |
+
- sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order.
|
284 |
+
- invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order.
|
285 |
+
- idx (torch.Tensor): Input sequence indices in sorted order.
|
286 |
+
- seq_lens (list[int]): Sequence lengths in sorted order.
|
287 |
+
- speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order.
|
288 |
+
- max_token_len (int): Effective maximum number of tokens to generate.
|
289 |
+
"""
|
290 |
+
assert len(seq_lens) == idx.shape[0]
|
291 |
+
assert max(seq_lens) <= idx.shape[-1]
|
292 |
+
|
293 |
+
sorted_indices = np.argsort(seq_lens)
|
294 |
+
inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32)
|
295 |
+
inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32)
|
296 |
+
|
297 |
+
idx = idx[sorted_indices]
|
298 |
+
seq_lens = [seq_lens[i] for i in sorted_indices]
|
299 |
+
speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None
|
300 |
+
max_token_len = 0
|
301 |
+
|
302 |
+
# figure out effective max_tokens to generate
|
303 |
+
for start_index in range(0, len(seq_lens), batch_size):
|
304 |
+
end_index = min(start_index + batch_size, len(seq_lens))
|
305 |
+
batch_seq_lens = seq_lens[start_index:end_index]
|
306 |
+
# random heuristic...
|
307 |
+
# # TODO: fix!
|
308 |
+
max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens)
|
309 |
+
|
310 |
+
return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def _causal_sample(
|
314 |
+
self,
|
315 |
+
*,
|
316 |
+
idx: torch.Tensor,
|
317 |
+
max_new_tokens: int,
|
318 |
+
seq_lens: list[int],
|
319 |
+
temperature: float,
|
320 |
+
top_k: Optional[int],
|
321 |
+
top_p: Optional[float],
|
322 |
+
speaker_embs: Optional[torch.Tensor],
|
323 |
+
batch_size: int,
|
324 |
+
guidance_scale: Optional[float] = None,
|
325 |
+
) -> torch.Tensor:
|
326 |
+
"""
|
327 |
+
Generates a sequence of tokens using causal sampling.
|
328 |
+
|
329 |
+
Args:
|
330 |
+
idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
|
331 |
+
max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
|
332 |
+
seq_lens (list[int]): List of sequence lengths for each sequence in idx.
|
333 |
+
temperature (float): Sampling temperature.
|
334 |
+
top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
|
335 |
+
top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
|
336 |
+
speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
|
337 |
+
batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
|
338 |
+
guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
|
342 |
+
"""
|
343 |
+
(
|
344 |
+
_,
|
345 |
+
invert_sorted_indices,
|
346 |
+
idx,
|
347 |
+
seq_lens,
|
348 |
+
speaker_embs,
|
349 |
+
max_token_len,
|
350 |
+
) = self._sort_for_batching(
|
351 |
+
idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens
|
352 |
+
)
|
353 |
+
|
354 |
+
return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device)
|
355 |
+
|
356 |
+
for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "):
|
357 |
+
end_index = min(start_index + batch_size, len(seq_lens))
|
358 |
+
|
359 |
+
kv_batch_size = end_index - start_index
|
360 |
+
if guidance_scale is not None:
|
361 |
+
kv_batch_size = 2 * kv_batch_size
|
362 |
+
|
363 |
+
if self.kv_cache_enabled:
|
364 |
+
print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
|
365 |
+
self.empty_kv_cache(
|
366 |
+
batch_size=kv_batch_size,
|
367 |
+
kv_cache_maxlen=self.config.block_size,
|
368 |
+
dtype=torch.bfloat16,
|
369 |
+
)
|
370 |
+
|
371 |
+
batch_seq_lens = seq_lens[start_index:end_index]
|
372 |
+
batch_max_new_tokens = max_token_len - min(batch_seq_lens)
|
373 |
+
|
374 |
+
batch_idx = idx[start_index:end_index]
|
375 |
+
batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None
|
376 |
+
|
377 |
+
batch_idx = self._sample_batch(
|
378 |
+
idx=batch_idx,
|
379 |
+
max_new_tokens=batch_max_new_tokens,
|
380 |
+
seq_lens=batch_seq_lens,
|
381 |
+
temperature=temperature,
|
382 |
+
top_k=top_k,
|
383 |
+
top_p=top_p,
|
384 |
+
speaker_embs=batch_speaker_embs,
|
385 |
+
guidance_scale=guidance_scale,
|
386 |
+
)
|
387 |
+
return_idx[start_index:end_index] = batch_idx
|
388 |
+
|
389 |
+
return return_idx[invert_sorted_indices]
|
390 |
+
|
391 |
+
def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
|
392 |
+
"""
|
393 |
+
Empties key-value (KV) cache for causal attention.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
batch_size (int): The batch size.
|
397 |
+
kv_cache_maxlen (int): The maximum length of the KV cache.
|
398 |
+
dtype (torch.dtype): The data type of the KV cache.
|
399 |
+
|
400 |
+
Raises:
|
401 |
+
Exception: If KV cache is enabled for non-causal attention.
|
402 |
+
|
403 |
+
"""
|
404 |
+
if self.kv_cache_enabled is False:
|
405 |
+
raise Exception("KV cache is not enabled")
|
406 |
+
if self.config.causal is False:
|
407 |
+
raise Exception("KV cache is not supported for non-causal attention")
|
408 |
+
|
409 |
+
self.kv_pos = 0
|
410 |
+
for block in self.transformer.h:
|
411 |
+
block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype)
|
412 |
+
|
413 |
+
def enable_kv_cache(self):
|
414 |
+
"""
|
415 |
+
Enables key-value (KV) cache for causal attention.
|
416 |
+
|
417 |
+
Raises:
|
418 |
+
Exception: If KV cache is enabled for non-causal attention.
|
419 |
+
|
420 |
+
"""
|
421 |
+
if self.config.causal is False:
|
422 |
+
raise Exception("KV cache is not supported for non-causal attention")
|
423 |
+
|
424 |
+
self.kv_cache_enabled = True
|
425 |
+
for block in self.transformer.h:
|
426 |
+
block.attn.kv_cache_enabled = True
|
427 |
+
|
428 |
+
def disable_kv_cache(self):
|
429 |
+
"""
|
430 |
+
Disables the key-value cache for the transformer and all its blocks.
|
431 |
+
"""
|
432 |
+
self.kv_cache_enabled = False
|
433 |
+
for block in self.transformer.h:
|
434 |
+
block.attn.kv_cache_enabled = False
|
435 |
+
block.attn.kv_cache = None
|
436 |
+
block.attn.kv_cache_first_empty_index = 0
|
437 |
+
|
438 |
+
@torch.no_grad()
|
439 |
+
def _slow_causal_sampling_loop(
|
440 |
+
self,
|
441 |
+
idx: torch.Tensor,
|
442 |
+
max_new_tokens: int,
|
443 |
+
temperature: float = 1.0,
|
444 |
+
top_k: Optional[int] = None,
|
445 |
+
top_p: Optional[float] = None,
|
446 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
447 |
+
guidance_scale: Optional[float] = None,
|
448 |
+
):
|
449 |
+
"""
|
450 |
+
Old non-batched version of causal sampling. Kept for testing / reference.
|
451 |
+
|
452 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete
|
453 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
454 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
455 |
+
"""
|
456 |
+
assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
|
457 |
+
assert idx.size(0) == 1, "can only do one sequence at a time for now"
|
458 |
+
assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop"
|
459 |
+
|
460 |
+
if self.config.causal is not True:
|
461 |
+
raise Exception("Causal sampling is only supported for causal models")
|
462 |
+
|
463 |
+
if self.kv_cache_enabled:
|
464 |
+
print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
|
465 |
+
self.empty_kv_cache(
|
466 |
+
batch_size=1,
|
467 |
+
kv_cache_maxlen=self.config.block_size,
|
468 |
+
dtype=torch.bfloat16,
|
469 |
+
)
|
470 |
+
|
471 |
+
for i in range(max_new_tokens):
|
472 |
+
# if the sequence context is growing too long we must crop it at block_size
|
473 |
+
idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :]
|
474 |
+
|
475 |
+
if self.kv_cache_enabled:
|
476 |
+
if i > 0:
|
477 |
+
idx_cond = idx_cond[:, :, -1:]
|
478 |
+
|
479 |
+
# forward the model to get the logits for the index in the sequence
|
480 |
+
list_logits, _ = self(idx_cond, speaker_embs=speaker_embs)
|
481 |
+
|
482 |
+
if guidance_scale is not None:
|
483 |
+
# we've already checked that kv-caching is not switched on
|
484 |
+
# so this should be ok.
|
485 |
+
list_logits_uncond, _ = self(idx_cond, speaker_embs=None)
|
486 |
+
list_logits = [
|
487 |
+
(guidance_scale) * logits + (1 - guidance_scale) * logits_uncond
|
488 |
+
for logits, logits_uncond in zip(list_logits, list_logits_uncond)
|
489 |
+
]
|
490 |
+
|
491 |
+
# pluck the logits at the final step and scale by desired temperature
|
492 |
+
list_logits = [logits[:, -1, :] / temperature for logits in list_logits]
|
493 |
+
|
494 |
+
# optionally crop the logits to only the top k options
|
495 |
+
if top_k is not None:
|
496 |
+
for i in range(len(list_logits)):
|
497 |
+
logits = list_logits[i]
|
498 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
499 |
+
logits[logits < v[:, [-1]]] = -float("Inf")
|
500 |
+
list_logits[i] = logits
|
501 |
+
|
502 |
+
# apply softmax to convert logits to (normalized) probabilities
|
503 |
+
probs = [F.softmax(logits, dim=-1) for logits in list_logits]
|
504 |
+
# sample from the distribution
|
505 |
+
idx_next = torch.tensor(
|
506 |
+
[torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device
|
507 |
+
) # (c, 1)
|
508 |
+
# append sampled index to the running sequence and continue
|
509 |
+
idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2)
|
510 |
+
|
511 |
+
return idx
|
fam/llm/mixins/non_causal.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class NonCausalInferenceMixin:
|
8 |
+
"""
|
9 |
+
Mixin class for non-causal inference in a language model.
|
10 |
+
|
11 |
+
This class provides methods for performing non-causal sampling using a language model.
|
12 |
+
"""
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def _non_causal_sample(
|
16 |
+
self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
Perform non-causal sampling.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
|
23 |
+
speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
|
24 |
+
temperature (float): Temperature parameter for scaling the logits.
|
25 |
+
top_k (int): Number of top options to consider.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
|
29 |
+
"""
|
30 |
+
b, c, t = idx.size()
|
31 |
+
assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
|
32 |
+
# forward the model to get the logits for the index in the sequence
|
33 |
+
list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
|
34 |
+
|
35 |
+
# scale by desired temperature
|
36 |
+
list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
|
37 |
+
|
38 |
+
# optionally crop the logits to only the top k options
|
39 |
+
if top_k is not None:
|
40 |
+
for i in range(len(list_logits)):
|
41 |
+
logits = list_logits[i] # (b, t, vocab_size)
|
42 |
+
|
43 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
|
44 |
+
logits[logits < v[:, :, [-1]]] = -float("Inf")
|
45 |
+
list_logits[i] = logits # (b, t, vocab_size)
|
46 |
+
assert logits.shape[0] == b and logits.shape[1] == t
|
47 |
+
|
48 |
+
# apply softmax to convert logits to (normalized) probabilities
|
49 |
+
# TODO: check shapes here!
|
50 |
+
probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
|
51 |
+
assert probs[0].shape[0] == b and probs[0].shape[1] == t
|
52 |
+
|
53 |
+
# TODO: output shape is as expected
|
54 |
+
outs = []
|
55 |
+
for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
|
56 |
+
out = [
|
57 |
+
torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
|
58 |
+
] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
|
59 |
+
assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
|
60 |
+
out = torch.cat(out, dim=0) # (b, 1, t)
|
61 |
+
assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
|
62 |
+
outs.append(out)
|
63 |
+
|
64 |
+
out = torch.cat(outs, dim=1) # (b, c, t)
|
65 |
+
assert out.shape[0] == b and out.shape[2] == t
|
66 |
+
|
67 |
+
return out
|
fam/llm/model.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Literal, Optional, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import tqdm
|
9 |
+
from einops import rearrange
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from fam.llm.layers import Block, LayerNorm, RMSNorm
|
13 |
+
from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin
|
14 |
+
|
15 |
+
from IPython import embed
|
16 |
+
END_OF_TEXT_TOKEN = 1537
|
17 |
+
|
18 |
+
|
19 |
+
def _select_spkemb(spkemb, mask):
|
20 |
+
_, examples, _ = spkemb.shape
|
21 |
+
mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples)
|
22 |
+
spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex
|
23 |
+
mask = mask.transpose(1, 2) # b t ex -> b ex t
|
24 |
+
return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class GPTConfig:
|
29 |
+
block_size: int = 1024
|
30 |
+
vocab_sizes: list = field(default_factory=list)
|
31 |
+
target_vocab_sizes: Optional[list] = None
|
32 |
+
n_layer: int = 12
|
33 |
+
n_head: int = 12
|
34 |
+
n_embd: int = 768
|
35 |
+
dropout: float = 0.0
|
36 |
+
spkemb_dropout: float = 0.0
|
37 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
38 |
+
causal: bool = (
|
39 |
+
True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
|
40 |
+
)
|
41 |
+
spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
|
42 |
+
norm_type: str = "layernorm" # "rmsnorm" or "layernorm
|
43 |
+
rmsnorm_eps: Optional[float] = None # only used for rmsnorm
|
44 |
+
nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
|
45 |
+
swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
|
46 |
+
attn_kernel_type: Literal["torch_attn"] = "torch_attn"
|
47 |
+
#Literal["fa2", "torch_attn", "hand"] = "fa2"
|
48 |
+
kv_cache_enabled: bool = False # whether to use key-value cache for attention
|
49 |
+
|
50 |
+
|
51 |
+
def _check_speaker_emb_dims(
|
52 |
+
speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
|
53 |
+
) -> Union[torch.Tensor, list]:
|
54 |
+
"""
|
55 |
+
Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
|
56 |
+
"""
|
57 |
+
if type(speaker_embs) == list:
|
58 |
+
b_se = len(speaker_embs)
|
59 |
+
for i, s in enumerate(speaker_embs):
|
60 |
+
if s is not None:
|
61 |
+
emb_dim = s.shape[-1]
|
62 |
+
if s.ndim == 1:
|
63 |
+
speaker_embs[i] = speaker_embs[i].unsqueeze(0)
|
64 |
+
else:
|
65 |
+
if speaker_embs.ndim == 2:
|
66 |
+
# if we have a single speaker embedding for the whole sequence,
|
67 |
+
# add a dummy dimension for backwards compatibility
|
68 |
+
speaker_embs = speaker_embs[:, None, :]
|
69 |
+
|
70 |
+
# num_examples is the number of utterances packed into this sequence
|
71 |
+
b_se, num_examples, emb_dim = speaker_embs.size()
|
72 |
+
|
73 |
+
assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
|
74 |
+
assert (
|
75 |
+
emb_dim == expected_speaker_emb_dim
|
76 |
+
), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"
|
77 |
+
|
78 |
+
return speaker_embs
|
79 |
+
|
80 |
+
|
81 |
+
class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
|
82 |
+
def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
|
83 |
+
"""
|
84 |
+
Initialize the GPT model.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
config (GPTConfig): Configuration object for the model.
|
88 |
+
speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
|
89 |
+
"""
|
90 |
+
super().__init__()
|
91 |
+
assert config.vocab_sizes is not None
|
92 |
+
assert config.block_size is not None
|
93 |
+
self.config = config
|
94 |
+
|
95 |
+
self.kv_cache_enabled = False # disabled by default
|
96 |
+
self.kv_pos = 0
|
97 |
+
|
98 |
+
self.speaker_emb_dim = speaker_emb_dim
|
99 |
+
self.spk_emb_on_text = config.spk_emb_on_text
|
100 |
+
if self.config.causal is True and self.spk_emb_on_text is False:
|
101 |
+
print("!!!!!!!!!!!!!!!!!!")
|
102 |
+
print(
|
103 |
+
f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
|
104 |
+
)
|
105 |
+
print("!!!!!!!!!!!!!!!!!!")
|
106 |
+
if self.config.causal is False and self.spk_emb_on_text is False:
|
107 |
+
raise Exception(
|
108 |
+
"Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
|
109 |
+
)
|
110 |
+
|
111 |
+
if config.norm_type == "rmsnorm":
|
112 |
+
if config.rmsnorm_eps is None:
|
113 |
+
raise Exception("RMSNorm requires rmsnorm_eps to be set")
|
114 |
+
ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
|
115 |
+
elif config.norm_type == "layernorm":
|
116 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias)
|
117 |
+
else:
|
118 |
+
raise Exception(f"Unknown norm type: {config.norm_type}")
|
119 |
+
|
120 |
+
self.transformer = nn.ModuleDict(
|
121 |
+
dict(
|
122 |
+
wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd,) for vsize in config.vocab_sizes]),
|
123 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
124 |
+
drop=nn.Dropout(config.dropout),
|
125 |
+
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
126 |
+
ln_f=ln_f,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
if speaker_emb_dim is not None:
|
130 |
+
self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) # ここで256->2048
|
131 |
+
|
132 |
+
self.lm_heads = nn.ModuleList()
|
133 |
+
if config.target_vocab_sizes is not None:
|
134 |
+
assert config.causal is False
|
135 |
+
else:
|
136 |
+
assert config.causal is True
|
137 |
+
|
138 |
+
for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
|
139 |
+
self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))
|
140 |
+
|
141 |
+
if config.target_vocab_sizes is None:
|
142 |
+
for i in range(len(config.vocab_sizes)):
|
143 |
+
# TODO: do we not need to take the transpose here?
|
144 |
+
# https://paperswithcode.com/method/weight-tying
|
145 |
+
self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore
|
146 |
+
assert len(self.lm_heads) == len(
|
147 |
+
self.transformer.wtes # type: ignore
|
148 |
+
), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore
|
149 |
+
# - causal
|
150 |
+
# GPT(
|
151 |
+
# (transformer): ModuleDict(
|
152 |
+
# (wtes): ModuleList(
|
153 |
+
# (0): Embedding(2562, 2048)
|
154 |
+
# )
|
155 |
+
# (wpe): Embedding(2048, 2048)
|
156 |
+
# (drop): Dropout(p=0.0, inplace=False)
|
157 |
+
# (h): ModuleList(
|
158 |
+
# (0-23): 24 x Block(
|
159 |
+
# (ln_1): RMSNorm()
|
160 |
+
# (ln_2): RMSNorm()
|
161 |
+
# (attn): SelfAttention(
|
162 |
+
# (c_attn): Linear(in_features=2048, out_features=6144, bias=False)
|
163 |
+
# (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
|
164 |
+
# (resid_dropout): Dropout(p=0.0, inplace=False)
|
165 |
+
# )
|
166 |
+
# (mlp): MLP(
|
167 |
+
# (swiglu): SwiGLU(
|
168 |
+
# (w1): Linear(in_features=2048, out_features=5632, bias=False)
|
169 |
+
# (w3): Linear(in_features=2048, out_features=5632, bias=False)
|
170 |
+
# )
|
171 |
+
# (c_proj): Linear(in_features=5632, out_features=2048, bias=False)
|
172 |
+
# (dropout): Dropout(p=0.0, inplace=False)
|
173 |
+
# )
|
174 |
+
# )
|
175 |
+
# )
|
176 |
+
# (ln_f): RMSNorm()
|
177 |
+
# )
|
178 |
+
# (speaker_cond_pos): Linear(in_features=256, out_features=2048, bias=False)
|
179 |
+
# (lm_heads): ModuleList(
|
180 |
+
# (0): Linear(in_features=2048, out_features=2562, bias=False)
|
181 |
+
# )
|
182 |
+
# )
|
183 |
+
# GPTConfig(block_size=2048, vocab_sizes=[2562], target_vocab_sizes=None, n_layer=24, n_head=16, n_embd=2048, dropout=0.0, spkemb_dropout=0.1, bias=False, causal=True, spk_emb_on_text=True, norm_type='rmsnorm', rmsnorm_eps=1e-05, nonlinearity_type='swiglu', swiglu_multiple_of=256, attn_kernel_type='torch_attn', kv_cache_enabled=False)
|
184 |
+
#
|
185 |
+
# - non causal
|
186 |
+
# GPT(
|
187 |
+
# (transformer): ModuleDict(
|
188 |
+
# (wtes): ModuleList(
|
189 |
+
# (0): Embedding(1538, 384)
|
190 |
+
# (1): Embedding(1025, 384)
|
191 |
+
# )
|
192 |
+
# (wpe): Embedding(1024, 384)
|
193 |
+
# (drop): Dropout(p=0.0, inplace=False)
|
194 |
+
# (h): ModuleList(
|
195 |
+
# (0-5): 6 x Block(
|
196 |
+
# (ln_1): LayerNorm()
|
197 |
+
# (ln_2): LayerNorm()
|
198 |
+
# (attn): SelfAttention(
|
199 |
+
# (c_attn): Linear(in_features=384, out_features=1152, bias=False)
|
200 |
+
# (c_proj): Linear(in_features=384, out_features=384, bias=False)
|
201 |
+
# (resid_dropout): Dropout(p=0.0, inplace=False)
|
202 |
+
# )
|
203 |
+
# (mlp): MLP(
|
204 |
+
# (c_fc): Linear(in_features=384, out_features=1536, bias=False)
|
205 |
+
# (gelu): GELU(approximate='none')
|
206 |
+
# (c_proj): Linear(in_features=1536, out_features=384, bias=False)
|
207 |
+
# (dropout): Dropout(p=0.0, inplace=False)
|
208 |
+
# )
|
209 |
+
# )
|
210 |
+
# )
|
211 |
+
# (ln_f): LayerNorm()
|
212 |
+
# )
|
213 |
+
# (speaker_cond_pos): Linear(in_features=256, out_features=384, bias=False)
|
214 |
+
# (lm_heads): ModuleList(
|
215 |
+
# (0-5): 6 x Linear(in_features=384, out_features=1025, bias=False)
|
216 |
+
# )
|
217 |
+
# )
|
218 |
+
# GPTConfig(block_size=1024, vocab_sizes=[1538, 1025], target_vocab_sizes=[1025, 1025, 1025, 1025, 1025, 1025], n_layer=6, n_head=6, n_embd=384, dropout=0.0, spkemb_dropout=0.0, bias=False, causal=False, spk_emb_on_text=True, norm_type='layernorm', rmsnorm_eps=None, nonlinearity_type='gelu', swiglu_multiple_of=None, attn_kernel_type='fa2', kv_cache_enabled=False)
|
219 |
+
# if config.causal is False:
|
220 |
+
# embed()
|
221 |
+
# init all weights
|
222 |
+
self.apply(self._init_weights)
|
223 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
224 |
+
for pn, p in self.named_parameters():
|
225 |
+
if pn.endswith("c_proj.weight"):
|
226 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
|
227 |
+
|
228 |
+
# report number of parameters
|
229 |
+
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
|
230 |
+
|
231 |
+
def get_num_params(self, non_embedding=True):
|
232 |
+
"""
|
233 |
+
Return the number of parameters in the model.
|
234 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
235 |
+
The token embeddings would too, except due to the parameter sharing these
|
236 |
+
params are actually used as weights in the final layer, so we include them.
|
237 |
+
"""
|
238 |
+
n_params = sum(p.numel() for p in self.parameters())
|
239 |
+
if non_embedding:
|
240 |
+
n_params -= self.transformer.wpe.weight.numel()
|
241 |
+
return n_params
|
242 |
+
|
243 |
+
def _init_weights(self, module):
|
244 |
+
if isinstance(module, nn.Linear):
|
245 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
246 |
+
if module.bias is not None:
|
247 |
+
torch.nn.init.zeros_(module.bias)
|
248 |
+
elif isinstance(module, nn.Embedding):
|
249 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
250 |
+
|
251 |
+
def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
|
252 |
+
"""
|
253 |
+
This is in a separate function so we can test it easily.
|
254 |
+
"""
|
255 |
+
# find index of end of text token in each sequence, then generate a binary mask
|
256 |
+
# of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
|
257 |
+
# Note: this does NOT mask the <end_of_text_token> token. This is important so that the first audio token predicted
|
258 |
+
# has speaker information to use.
|
259 |
+
|
260 |
+
# Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
|
261 |
+
is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
|
262 |
+
# use > 0, in case end_of_text_token is repeated for any reason.
|
263 |
+
mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
|
264 |
+
spk_emb = spk_emb * mask[:, :, None]
|
265 |
+
|
266 |
+
return spk_emb
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
idx,
|
271 |
+
targets=None,
|
272 |
+
speaker_embs=None,
|
273 |
+
embedding=None,
|
274 |
+
speaker_emb_mask=None,
|
275 |
+
loss_reduce: Literal["mean", "none"] = "mean",
|
276 |
+
):
|
277 |
+
# print(f'{idx.shape}')
|
278 |
+
device = idx.device
|
279 |
+
b, num_hierarchies, t = idx.size()
|
280 |
+
|
281 |
+
if speaker_embs is not None:
|
282 |
+
speaker_embs = _check_speaker_emb_dims(
|
283 |
+
speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
|
284 |
+
)
|
285 |
+
|
286 |
+
assert (
|
287 |
+
t <= self.config.block_size
|
288 |
+
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
289 |
+
|
290 |
+
if self.kv_cache_enabled:
|
291 |
+
if self.kv_pos == 0:
|
292 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device)
|
293 |
+
self.kv_pos += t
|
294 |
+
else:
|
295 |
+
assert t == 1, "KV cache is only supported for single token inputs"
|
296 |
+
pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1)
|
297 |
+
self.kv_pos += 1
|
298 |
+
else:
|
299 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
300 |
+
|
301 |
+
# print("here1")
|
302 |
+
# forward the GPT model itself
|
303 |
+
# assert num_hierarchies == len(
|
304 |
+
# self.transformer.wtes
|
305 |
+
# ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings."
|
306 |
+
|
307 |
+
# embed the tokens, positional encoding, and speaker embedding
|
308 |
+
tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
|
309 |
+
# ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
|
310 |
+
# print(f'{len(self.transformer.wtes)=}')
|
311 |
+
# print(f'{self.transformer.wtes[0]=}')
|
312 |
+
wte = self.transformer.wtes[0]
|
313 |
+
#for i in range(num_hierarchies):
|
314 |
+
for i, wte in enumerate(self.transformer.wtes):
|
315 |
+
# print(f'{idx[:,i,:].shape=}')
|
316 |
+
# print(f'{idx[:,i,:]=}')
|
317 |
+
# print(f'{torch.max(idx[:,i,:])}')
|
318 |
+
# print(f'{torch.min(idx[:,i,:])}')
|
319 |
+
# print(f'{wte(idx[:,i,:]).shape=}')
|
320 |
+
# print(f'{tok_emb.shape=}')
|
321 |
+
mask_pad = idx[:, i, :] == -1 # 要素が-1であるindexを取得
|
322 |
+
masked_idx = idx[:, i, :].clone()
|
323 |
+
masked_idx[mask_pad] = 0
|
324 |
+
# embed(header='a')
|
325 |
+
embedded_idx = wte(masked_idx)
|
326 |
+
# embed(header='b')
|
327 |
+
# embedding_dim = embedded_idx.shape[-1]
|
328 |
+
# mask_expanded = mask_pad.unsqueeze(-1).expand(-1, -1, embedding_dim)
|
329 |
+
embedded_idx[mask_pad] = 0
|
330 |
+
# embedded_idx = wte(idx[:, i, :])
|
331 |
+
# print(embedded_idx[:,:,:10])
|
332 |
+
# embed(header='c')
|
333 |
+
# embed()
|
334 |
+
# masked_embedded_idx =
|
335 |
+
tok_emb += embedded_idx
|
336 |
+
# tok_emb += wte(idx[:, i, :])
|
337 |
+
# embed()
|
338 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
339 |
+
|
340 |
+
spk_emb = 0.0
|
341 |
+
if speaker_embs is not None:
|
342 |
+
if type(speaker_embs) == list:
|
343 |
+
assert speaker_emb_mask is None
|
344 |
+
assert self.training is False
|
345 |
+
assert self.spk_emb_on_text is True
|
346 |
+
# print(f'{self.config.n_embd=}')
|
347 |
+
spk_emb = []
|
348 |
+
for speaker_emb_row in speaker_embs:
|
349 |
+
if speaker_emb_row is not None:
|
350 |
+
spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
|
351 |
+
assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
|
352 |
+
else:
|
353 |
+
spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
|
354 |
+
# print(f'{len(spk_emb)}, {[v.shape for v in spk_emb]=}')
|
355 |
+
spk_emb = torch.cat(spk_emb, dim=0)
|
356 |
+
|
357 |
+
assert (
|
358 |
+
spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
|
359 |
+
), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
|
360 |
+
else:
|
361 |
+
speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c)
|
362 |
+
|
363 |
+
if speaker_emb_mask is not None:
|
364 |
+
spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
|
365 |
+
assert spk_emb.shape == (b, t, self.config.n_embd)
|
366 |
+
else:
|
367 |
+
spk_emb = speakers_embedded
|
368 |
+
# if we don't have a mask, we assume that the speaker embedding is the same for all tokens
|
369 |
+
# then num_examples dimension just becomes the time dimension
|
370 |
+
assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1
|
371 |
+
|
372 |
+
if self.training and self.config.spkemb_dropout > 0.0:
|
373 |
+
# Remove speaker conditioning at random.
|
374 |
+
dropout = torch.ones_like(speakers_embedded) * (
|
375 |
+
torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
|
376 |
+
)
|
377 |
+
spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
|
378 |
+
|
379 |
+
if self.spk_emb_on_text is False:
|
380 |
+
assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
|
381 |
+
spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
|
382 |
+
elif embedding is not None:
|
383 |
+
# spk_emb = embedding
|
384 |
+
# spk_emb = torch.zeros((b, t, self.config.n_embd), device=device)
|
385 |
+
# for i, wte in enumerate(self.transformer.wtes):
|
386 |
+
# print(f'{embedding[:, i, :].shape=}, {embedding.shape=}')
|
387 |
+
# print(f'{wte(embedding[:, i, :]).shape=}')
|
388 |
+
# spk_emb += wte(embedding[:, i, :])
|
389 |
+
spk_emb = self.speaker_cond_pos(embedding)
|
390 |
+
# TODO: implement causal attnetion mask here
|
391 |
+
# memo:
|
392 |
+
# b, t, d=2048のとき, tok_emb=(b,t,d), pos_emb=(t,d), spk_emp=(b,1,d)
|
393 |
+
# train: tok_emb.shape=torch.Size([128, 187, 2048]), pos_emb.shape=torch.Size([187, 2048]), spk_emb.shape=torch.Size([128, 1, 1, 187])<- spk_embは(b,1,2048)になってほしい?
|
394 |
+
# sample: tok_emb.shape=torch.Size([2, 369, 2048]), pos_emb.shape=torch.Size([369, 2048]), spk_emb.shape=torch.Size([2, 1, 2048])
|
395 |
+
# print(f'{tok_emb.shape=}, {pos_emb.shape=}, {spk_emb.shape=}')
|
396 |
+
x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
|
397 |
+
for block in self.transformer.h:
|
398 |
+
x = block(x)
|
399 |
+
x = self.transformer.ln_f(x)
|
400 |
+
|
401 |
+
if targets is not None:
|
402 |
+
# if we are given some desired targets also calculate the loss
|
403 |
+
list_logits = [lm_head(x) for lm_head in self.lm_heads]
|
404 |
+
# print(f'{len(list_logits)=}, {list_logits[0].shape=}')
|
405 |
+
# embed(header='cc')
|
406 |
+
|
407 |
+
losses = [
|
408 |
+
F.cross_entropy(
|
409 |
+
logits.view(-1, logits.size(-1)),
|
410 |
+
targets[:, i, :].contiguous().view(-1),
|
411 |
+
ignore_index=-1,
|
412 |
+
reduction=loss_reduce,
|
413 |
+
)
|
414 |
+
for i, logits in enumerate(list_logits)
|
415 |
+
]
|
416 |
+
# TODO: should we do this better without stack somehow?
|
417 |
+
# embed(header='bb')
|
418 |
+
losses = torch.stack(losses)
|
419 |
+
if loss_reduce == "mean":
|
420 |
+
# embed(header='aa')
|
421 |
+
losses = losses.mean()
|
422 |
+
else:
|
423 |
+
losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
|
424 |
+
else:
|
425 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
426 |
+
if self.config.causal:
|
427 |
+
list_logits = [
|
428 |
+
lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
|
429 |
+
] # note: using list [-1] to preserve the time dim
|
430 |
+
# print(f'{len(list_logits)=}, {list_logits[0].shape=}')
|
431 |
+
else:
|
432 |
+
list_logits = [lm_head(x) for lm_head in self.lm_heads]
|
433 |
+
losses = None
|
434 |
+
|
435 |
+
return list_logits, losses
|
436 |
+
|
437 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
438 |
+
# start with all of the candidate parameters
|
439 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
440 |
+
# filter out those that do not require grad
|
441 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
442 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
443 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
444 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
445 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
446 |
+
optim_groups = [
|
447 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
448 |
+
{"params": nodecay_params, "weight_decay": 0.0},
|
449 |
+
]
|
450 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
451 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
452 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
453 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
454 |
+
# Create AdamW optimizer and use the fused version if it is available
|
455 |
+
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
|
456 |
+
use_fused = fused_available and device_type == "cuda"
|
457 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
458 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
459 |
+
print(f"using fused AdamW: {use_fused}")
|
460 |
+
|
461 |
+
return optimizer
|
462 |
+
|
463 |
+
@torch.no_grad()
|
464 |
+
def generate(
|
465 |
+
self,
|
466 |
+
idx: torch.Tensor,
|
467 |
+
max_new_tokens: int,
|
468 |
+
seq_lens: Optional[list] = None,
|
469 |
+
temperature: float = 1.0,
|
470 |
+
top_k: Optional[int] = None,
|
471 |
+
top_p: Optional[float] = None,
|
472 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
473 |
+
batch_size: Optional[int] = None,
|
474 |
+
guidance_scale: Optional[float] = None,
|
475 |
+
):
|
476 |
+
"""
|
477 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
|
478 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
479 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
480 |
+
"""
|
481 |
+
assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
|
482 |
+
|
483 |
+
if self.config.causal:
|
484 |
+
if seq_lens is None or batch_size is None:
|
485 |
+
raise Exception("seq_lens and batch_size must be provided for causal sampling")
|
486 |
+
|
487 |
+
return self._causal_sample(
|
488 |
+
idx=idx,
|
489 |
+
max_new_tokens=max_new_tokens,
|
490 |
+
seq_lens=seq_lens,
|
491 |
+
temperature=temperature,
|
492 |
+
top_k=top_k,
|
493 |
+
top_p=top_p,
|
494 |
+
speaker_embs=speaker_embs,
|
495 |
+
batch_size=batch_size,
|
496 |
+
guidance_scale=guidance_scale,
|
497 |
+
)
|
498 |
+
|
499 |
+
else:
|
500 |
+
if seq_lens is not None:
|
501 |
+
raise Exception("seq_lens is not supported yet for non-causal sampling")
|
502 |
+
|
503 |
+
if batch_size is None:
|
504 |
+
raise Exception("batch_size must be provided for non-causal sampling")
|
505 |
+
|
506 |
+
if guidance_scale is not None:
|
507 |
+
raise Exception("guidance_scale is not supported for non-causal sampling")
|
508 |
+
|
509 |
+
if top_p is not None:
|
510 |
+
raise Exception("top_p is not supported for non-causal sampling")
|
511 |
+
|
512 |
+
out = []
|
513 |
+
for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"):
|
514 |
+
end_index = min(start_index + batch_size, idx.shape[0])
|
515 |
+
out.append(
|
516 |
+
self._non_causal_sample(
|
517 |
+
idx=idx[start_index:end_index],
|
518 |
+
speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
|
519 |
+
temperature=temperature,
|
520 |
+
top_k=top_k,
|
521 |
+
)
|
522 |
+
)
|
523 |
+
return torch.cat(out, dim=0)
|
524 |
+
return torch.cat(out, dim=0)
|
fam/llm/sample.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
from contextlib import nullcontext
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import List, Literal, Optional, Type, Union
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import torch
|
15 |
+
import tqdm
|
16 |
+
import tqdm.contrib.concurrent
|
17 |
+
import tyro
|
18 |
+
from huggingface_hub import snapshot_download
|
19 |
+
|
20 |
+
from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
|
21 |
+
from fam.llm.decoders import Decoder, EncodecDecoder
|
22 |
+
from fam.llm.enhancers import BaseEnhancer, get_enhancer
|
23 |
+
from fam.llm.model import GPT, GPTConfig
|
24 |
+
from fam.llm.utils import normalize_text
|
25 |
+
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
|
26 |
+
from fam.quantiser.text.tokenise import TrainedBPETokeniser
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class InferenceConfig:
|
30 |
+
ckpt_path: str # path to checkpoint
|
31 |
+
output_dir: str
|
32 |
+
num_samples: int = 10 # number of samples to draw
|
33 |
+
seed: int = 1337 # random seed
|
34 |
+
device: str = "cuda"
|
35 |
+
dtype: str = "bfloat16"
|
36 |
+
compile: bool = False
|
37 |
+
init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
38 |
+
|
39 |
+
def __str__(self):
|
40 |
+
field_strs = []
|
41 |
+
for field in dataclasses.fields(self):
|
42 |
+
value = getattr(self, field.name)
|
43 |
+
field_strs.append(f" {field.name}: {value}")
|
44 |
+
|
45 |
+
return "InferenceConfig:\n" + "\n".join(field_strs)
|
46 |
+
|
47 |
+
|
48 |
+
class Model:
|
49 |
+
"""
|
50 |
+
Class to sample from a trained model.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
config: InferenceConfig,
|
56 |
+
tokenizer_cls: Type[TrainedBPETokeniser],
|
57 |
+
decoder_cls: Type[Decoder],
|
58 |
+
data_adapter_fn,
|
59 |
+
use_kv_cache: Optional[Literal["none", "flash_decoding", "vanilla"]] = None,
|
60 |
+
first_model_path = None
|
61 |
+
):
|
62 |
+
# TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
|
63 |
+
# TODO: remove magic number
|
64 |
+
self._encodec_codes_pad_token = 1024
|
65 |
+
self._num_encodec_codebooks = 8
|
66 |
+
self.config = config
|
67 |
+
self.use_kv_cache = use_kv_cache
|
68 |
+
|
69 |
+
torch.manual_seed(config.seed)
|
70 |
+
torch.cuda.manual_seed(config.seed)
|
71 |
+
torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
|
72 |
+
torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
|
73 |
+
device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
|
74 |
+
ptdtype = {
|
75 |
+
"float32": torch.float32,
|
76 |
+
"tfloat32": torch.float32,
|
77 |
+
"bfloat16": torch.bfloat16,
|
78 |
+
"float16": torch.float16,
|
79 |
+
}[config.dtype]
|
80 |
+
self._ctx = (
|
81 |
+
nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.use_bpe_tokenizer = False
|
85 |
+
self.load_meta = None
|
86 |
+
self.speaker_cond = None
|
87 |
+
self.meta = None
|
88 |
+
self.model = None
|
89 |
+
self.checkpoint_config = None
|
90 |
+
self.vocab_sizes = None
|
91 |
+
self.smodel = None
|
92 |
+
self.first_model_path = first_model_path
|
93 |
+
|
94 |
+
self._init_model()
|
95 |
+
|
96 |
+
self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
|
97 |
+
self.decoder = decoder_cls(
|
98 |
+
tokeniser_decode_fn=self.tokenizer.decode,
|
99 |
+
output_dir=self.config.output_dir,
|
100 |
+
data_adapter_fn=data_adapter_fn,
|
101 |
+
)
|
102 |
+
|
103 |
+
def _init_model(self):
|
104 |
+
if self.config.init_from == "resume":
|
105 |
+
# init from a model saved in a specific directory
|
106 |
+
checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
|
107 |
+
self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
|
108 |
+
|
109 |
+
self.load_meta = False
|
110 |
+
self.speaker_cond = False
|
111 |
+
|
112 |
+
if "config" in checkpoint:
|
113 |
+
self.checkpoint_config = checkpoint["config"]
|
114 |
+
|
115 |
+
self.meta = checkpoint["meta"]
|
116 |
+
load_meta = True
|
117 |
+
|
118 |
+
if load_meta:
|
119 |
+
self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
|
120 |
+
self.speaker_cond = self.meta.get("speaker_cond")
|
121 |
+
|
122 |
+
if self.speaker_cond:
|
123 |
+
speaker_emb_size = self.meta["speaker_emb_size"]
|
124 |
+
|
125 |
+
model_args = checkpoint["model_args"]
|
126 |
+
if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
|
127 |
+
self._encodec_ctx_window = model_args["block_size"]
|
128 |
+
|
129 |
+
gptconf = GPTConfig(**model_args)
|
130 |
+
|
131 |
+
# TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
|
132 |
+
self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
|
133 |
+
if not getattr(self.config, 'train_from_scratch', False):
|
134 |
+
state_dict = checkpoint["model"]
|
135 |
+
unwanted_prefix = "_orig_mod."
|
136 |
+
for k, v in list(state_dict.items()):
|
137 |
+
if k.startswith(unwanted_prefix):
|
138 |
+
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
139 |
+
self.model.load_state_dict(state_dict)
|
140 |
+
else:
|
141 |
+
print("training a model from scratch!!")
|
142 |
+
|
143 |
+
if self.first_model_path is not None:
|
144 |
+
if self.checkpoint_config.get("causal", True):
|
145 |
+
new_dict = {}
|
146 |
+
state_dict = torch.load(self.first_model_path)['state_dict']
|
147 |
+
for key, val in state_dict.items():
|
148 |
+
if 'first_stage_model_transformer' in key:
|
149 |
+
new_dict[key.replace('first_stage_model_transformer.', '')] = val
|
150 |
+
self.model.load_state_dict(new_dict)
|
151 |
+
|
152 |
+
# model
|
153 |
+
self.model.eval()
|
154 |
+
self.model.to(self.config.device)
|
155 |
+
|
156 |
+
if self.config.compile:
|
157 |
+
from einops._torch_specific import allow_ops_in_compiled_graph
|
158 |
+
|
159 |
+
allow_ops_in_compiled_graph()
|
160 |
+
self.model = torch.compile(self.model) # type: ignore
|
161 |
+
|
162 |
+
if self.use_kv_cache is not None:
|
163 |
+
if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
|
164 |
+
raise Exception("kv_cache not supported for non-causal models!")
|
165 |
+
|
166 |
+
if self.use_kv_cache == "flash_decoding":
|
167 |
+
self.model.enable_kv_cache()
|
168 |
+
for block in self.model.transformer.h:
|
169 |
+
block.attn.attn_kernel_type = "fd"
|
170 |
+
elif self.use_kv_cache == "vanilla":
|
171 |
+
for block in self.model.transformer.h:
|
172 |
+
if block.attn.attn_kernel_type != "fa2":
|
173 |
+
raise Exception(
|
174 |
+
f"kv_cache only supported for flash attention 2 but found {block.attn.attn_kernel_type} inside model!"
|
175 |
+
)
|
176 |
+
self.model.enable_kv_cache()
|
177 |
+
else:
|
178 |
+
raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
|
179 |
+
|
180 |
+
def causal_sample(
|
181 |
+
self,
|
182 |
+
*,
|
183 |
+
texts: list[str],
|
184 |
+
batch_size: int,
|
185 |
+
max_new_tokens: int,
|
186 |
+
temperature: Optional[float],
|
187 |
+
top_k: Optional[int],
|
188 |
+
top_p: Optional[float],
|
189 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
190 |
+
guidance_scale: Optional[float] = None,
|
191 |
+
) -> list[torch.Tensor]:
|
192 |
+
"""
|
193 |
+
Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
|
194 |
+
Any flattening / inteleaving / tilting gets reversed before the output is returned.
|
195 |
+
"""
|
196 |
+
if speaker_embs is not None:
|
197 |
+
assert len(texts) == len(speaker_embs)
|
198 |
+
|
199 |
+
encoded_texts = [self.tokenizer.encode(text) for text in texts]
|
200 |
+
|
201 |
+
## create multiple hierarchies and get seq_lens
|
202 |
+
seq_lens = []
|
203 |
+
xs = []
|
204 |
+
for i, encoded_text in enumerate(encoded_texts):
|
205 |
+
encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
|
206 |
+
# TODO: remove magic number
|
207 |
+
xs.append(
|
208 |
+
torch.cat(
|
209 |
+
# [1st hierarchy of text, *remaining hierarchies of padded tokens]
|
210 |
+
# TODO: self.vocab_sizes should be from the model config?
|
211 |
+
[encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
|
212 |
+
dim=0,
|
213 |
+
).unsqueeze(0)
|
214 |
+
) # b x [(b=1, c, t)]
|
215 |
+
seq_lens.append(xs[-1].shape[-1])
|
216 |
+
max_len = max(seq_lens)
|
217 |
+
assert len(xs) == len(seq_lens)
|
218 |
+
|
219 |
+
## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
|
220 |
+
x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
|
221 |
+
for i, _xs in enumerate(xs):
|
222 |
+
assert _xs.shape[-1] == seq_lens[i]
|
223 |
+
x[i, :, : seq_lens[i]] = _xs
|
224 |
+
|
225 |
+
## check that the input is correct
|
226 |
+
for i in range(x.shape[0]):
|
227 |
+
assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
|
228 |
+
|
229 |
+
# TODO: remove magic number
|
230 |
+
if x.shape[1] > 1:
|
231 |
+
assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
|
232 |
+
|
233 |
+
assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
|
234 |
+
|
235 |
+
if self.speaker_cond is False:
|
236 |
+
speaker_embs = None
|
237 |
+
|
238 |
+
# run sampling loop
|
239 |
+
with torch.no_grad():
|
240 |
+
with self._ctx: # type: ignore
|
241 |
+
to_return = []
|
242 |
+
for k in range(self.config.num_samples):
|
243 |
+
assert seq_lens is not None
|
244 |
+
assert batch_size is not None
|
245 |
+
|
246 |
+
if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
|
247 |
+
raise Exception(
|
248 |
+
f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
|
249 |
+
)
|
250 |
+
|
251 |
+
y = self.model.generate(
|
252 |
+
x,
|
253 |
+
max_new_tokens,
|
254 |
+
seq_lens=seq_lens,
|
255 |
+
temperature=temperature,
|
256 |
+
top_k=top_k,
|
257 |
+
top_p=top_p,
|
258 |
+
speaker_embs=speaker_embs,
|
259 |
+
batch_size=batch_size,
|
260 |
+
guidance_scale=guidance_scale,
|
261 |
+
)
|
262 |
+
for i in range(len(y)):
|
263 |
+
to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
|
264 |
+
|
265 |
+
return to_return
|
266 |
+
|
267 |
+
def non_causal_sample(
|
268 |
+
self,
|
269 |
+
*,
|
270 |
+
texts: list[str],
|
271 |
+
encodec_tokens: list[torch.Tensor],
|
272 |
+
batch_size: int,
|
273 |
+
top_k: Optional[int],
|
274 |
+
temperature: Optional[float],
|
275 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
276 |
+
) -> list[str]:
|
277 |
+
"""
|
278 |
+
Returns paths to saved audio files.
|
279 |
+
"""
|
280 |
+
if speaker_embs is not None:
|
281 |
+
assert len(texts) == len(speaker_embs)
|
282 |
+
|
283 |
+
encoded_texts = [self.tokenizer.encode(text) for text in texts]
|
284 |
+
|
285 |
+
# setup input
|
286 |
+
# TODO: same code is used during data prep. refactor
|
287 |
+
padded_hierarchies_inputs = []
|
288 |
+
for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
|
289 |
+
x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
|
290 |
+
None, None, ...
|
291 |
+
] # (b=1, c=1, t)
|
292 |
+
|
293 |
+
# TODO: should only happen if decoder is encodecdeocder?
|
294 |
+
assert encodec_token.shape[0] == 1
|
295 |
+
encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
|
296 |
+
assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
|
297 |
+
|
298 |
+
## setup hierarchies of tokens
|
299 |
+
# TODO: refactor and merge with code in processing.py
|
300 |
+
text_tokens = encoded_text # (t,)
|
301 |
+
|
302 |
+
# print(f'{encodec_tokens[0].shape=}, {len(encodec_tokens)}')
|
303 |
+
hierarchies_in = []
|
304 |
+
hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
|
305 |
+
hierarchies_in.append(
|
306 |
+
[self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
|
307 |
+
)
|
308 |
+
|
309 |
+
## adding padding / cutting to the right size as needed
|
310 |
+
# TODO: refactor and merge with code in processing.py
|
311 |
+
padded_hierarchies_input = []
|
312 |
+
for _, t_hierarchy in enumerate(hierarchies_in):
|
313 |
+
assert len(t_hierarchy) == len(hierarchies_in[0])
|
314 |
+
if len(t_hierarchy) < self._encodec_ctx_window:
|
315 |
+
padded_hierarchies_input.append(
|
316 |
+
t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
|
317 |
+
)
|
318 |
+
elif len(t_hierarchy) > self._encodec_ctx_window:
|
319 |
+
padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
|
320 |
+
else:
|
321 |
+
padded_hierarchies_input.append(t_hierarchy)
|
322 |
+
|
323 |
+
padded_hierarchies_inputs.append(padded_hierarchies_input)
|
324 |
+
|
325 |
+
## check that the input is correct
|
326 |
+
in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
|
327 |
+
assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
|
328 |
+
|
329 |
+
if self.speaker_cond is False:
|
330 |
+
speaker_embs = None
|
331 |
+
|
332 |
+
# run sampling loop
|
333 |
+
with torch.no_grad():
|
334 |
+
with self._ctx: # type: ignore
|
335 |
+
to_return = []
|
336 |
+
for k in range(self.config.num_samples):
|
337 |
+
y = self.model.generate(
|
338 |
+
in_x,
|
339 |
+
None,
|
340 |
+
temperature=temperature,
|
341 |
+
top_k=top_k,
|
342 |
+
# TODO: handle separate top_p for this model explicitly
|
343 |
+
top_p=None,
|
344 |
+
speaker_embs=speaker_embs,
|
345 |
+
batch_size=batch_size,
|
346 |
+
guidance_scale=None,
|
347 |
+
)
|
348 |
+
|
349 |
+
b_tokens = torch.cat([in_x, y], dim=1)
|
350 |
+
for tokens in b_tokens:
|
351 |
+
try:
|
352 |
+
to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
|
353 |
+
except Exception as e:
|
354 |
+
print("failed to run MBD.")
|
355 |
+
print(f"reason: {str(e)}")
|
356 |
+
to_return.append(None)
|
357 |
+
|
358 |
+
return to_return
|
359 |
+
|
360 |
+
def __call__(
|
361 |
+
self,
|
362 |
+
*,
|
363 |
+
texts: list[str],
|
364 |
+
batch_size: int,
|
365 |
+
max_new_tokens: Optional[int],
|
366 |
+
top_k: Optional[int],
|
367 |
+
top_p: Optional[float],
|
368 |
+
temperature: Optional[float],
|
369 |
+
encodec_tokens: Optional[list[torch.Tensor]] = None,
|
370 |
+
speaker_embs: Optional[torch.Tensor] = None,
|
371 |
+
guidance_scale: Optional[float] = None,
|
372 |
+
):
|
373 |
+
if self.checkpoint_config.get("causal", True):
|
374 |
+
return self.causal_sample(
|
375 |
+
texts=texts,
|
376 |
+
batch_size=batch_size,
|
377 |
+
speaker_embs=speaker_embs,
|
378 |
+
guidance_scale=guidance_scale,
|
379 |
+
max_new_tokens=max_new_tokens,
|
380 |
+
top_k=top_k,
|
381 |
+
top_p=top_p,
|
382 |
+
temperature=temperature,
|
383 |
+
)
|
384 |
+
else:
|
385 |
+
assert encodec_tokens is not None
|
386 |
+
assert guidance_scale is None
|
387 |
+
assert max_new_tokens is None
|
388 |
+
assert top_p is None
|
389 |
+
|
390 |
+
return self.non_causal_sample(
|
391 |
+
texts=texts,
|
392 |
+
encodec_tokens=encodec_tokens,
|
393 |
+
batch_size=batch_size,
|
394 |
+
speaker_embs=speaker_embs,
|
395 |
+
top_k=top_k,
|
396 |
+
temperature=temperature,
|
397 |
+
)
|
398 |
+
|
399 |
+
|
400 |
+
def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
|
401 |
+
if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
|
402 |
+
return
|
403 |
+
json.dump(
|
404 |
+
{
|
405 |
+
"speaker": ref_path,
|
406 |
+
"text": text,
|
407 |
+
},
|
408 |
+
pathlib.Path(str(wav_path) + ".json").open("w"),
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
+
def get_cached_file(file_or_uri: str):
|
413 |
+
"""
|
414 |
+
If it's an s3 file, download it to a local temporary file and return that path.
|
415 |
+
Otherwise return the path as is.
|
416 |
+
"""
|
417 |
+
is_uri = file_or_uri.startswith("http")
|
418 |
+
|
419 |
+
cache_path = None
|
420 |
+
if is_uri:
|
421 |
+
ext = pathlib.Path(file_or_uri).suffix
|
422 |
+
# hash the file path to get the cache name
|
423 |
+
_cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
|
424 |
+
|
425 |
+
os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
|
426 |
+
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
|
427 |
+
|
428 |
+
if not os.path.exists(cache_path):
|
429 |
+
command = f"curl -o {cache_path} {file_or_uri}"
|
430 |
+
subprocess.run(command, shell=True, check=True)
|
431 |
+
else:
|
432 |
+
if os.path.exists(file_or_uri):
|
433 |
+
cache_path = file_or_uri
|
434 |
+
else:
|
435 |
+
raise FileNotFoundError(f"File {file_or_uri} not found!")
|
436 |
+
|
437 |
+
# check audio file is at min. 30s in length
|
438 |
+
audio, sr = librosa.load(cache_path)
|
439 |
+
#assert librosa.get_duration(y=audio, sr=sr) >= 30, "Speaker reference audio file needs to be >= 30s in duration."
|
440 |
+
|
441 |
+
return cache_path
|
442 |
+
|
443 |
+
|
444 |
+
def get_cached_embedding(local_file_path: str, spkemb_model):
|
445 |
+
if not os.path.exists(local_file_path):
|
446 |
+
raise FileNotFoundError(f"File {local_file_path} not found!")
|
447 |
+
|
448 |
+
# hash the file path to get the cache name
|
449 |
+
_cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
|
450 |
+
|
451 |
+
os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
|
452 |
+
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
|
453 |
+
|
454 |
+
if not os.path.exists(cache_path):
|
455 |
+
spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
|
456 |
+
torch.save(spk_emb, cache_path)
|
457 |
+
else:
|
458 |
+
spk_emb = torch.load(cache_path)
|
459 |
+
|
460 |
+
return spk_emb
|
461 |
+
|
462 |
+
|
463 |
+
def _sample_utterance_batch(
|
464 |
+
texts: list[str],
|
465 |
+
spk_cond_paths: list[Optional[str]],
|
466 |
+
spkemb_model,
|
467 |
+
first_stage_model,
|
468 |
+
second_stage_model,
|
469 |
+
enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
|
470 |
+
first_stage_ckpt_path: str,
|
471 |
+
second_stage_ckpt_path: str,
|
472 |
+
guidance_scale: Optional[float],
|
473 |
+
max_new_tokens: int,
|
474 |
+
top_k: Optional[int],
|
475 |
+
top_p: Optional[float],
|
476 |
+
temperature: Optional[float],
|
477 |
+
batch_size: int = 128,
|
478 |
+
) -> List[str]:
|
479 |
+
|
480 |
+
speaker_embs = []
|
481 |
+
refs = spk_cond_paths.copy()
|
482 |
+
|
483 |
+
# multithreaded loop to cache all the files
|
484 |
+
spk_cond_paths = tqdm.contrib.concurrent.thread_map(
|
485 |
+
get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
|
486 |
+
)
|
487 |
+
|
488 |
+
for i, (text, spk_cond_path) in tqdm.tqdm(
|
489 |
+
enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
|
490 |
+
):
|
491 |
+
texts[i] = normalize_text(text)
|
492 |
+
speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
|
493 |
+
|
494 |
+
b_speaker_embs = torch.cat(speaker_embs, dim=0)
|
495 |
+
b_tokens = first_stage_model(
|
496 |
+
texts=texts,
|
497 |
+
speaker_embs=b_speaker_embs,
|
498 |
+
batch_size=batch_size,
|
499 |
+
guidance_scale=guidance_scale,
|
500 |
+
top_p=top_p,
|
501 |
+
top_k=top_k,
|
502 |
+
temperature=temperature,
|
503 |
+
max_new_tokens=max_new_tokens,
|
504 |
+
)
|
505 |
+
|
506 |
+
# TODO: set batch size for second stage model!
|
507 |
+
wav_files = second_stage_model(
|
508 |
+
texts=texts,
|
509 |
+
encodec_tokens=b_tokens,
|
510 |
+
speaker_embs=b_speaker_embs,
|
511 |
+
batch_size=batch_size,
|
512 |
+
guidance_scale=None,
|
513 |
+
top_p=None,
|
514 |
+
top_k=top_k,
|
515 |
+
temperature=temperature,
|
516 |
+
max_new_tokens=None,
|
517 |
+
)
|
518 |
+
|
519 |
+
for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
|
520 |
+
if wav_file is None:
|
521 |
+
continue
|
522 |
+
|
523 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
|
524 |
+
if enhancer is not None:
|
525 |
+
enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
|
526 |
+
enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
|
527 |
+
# copy enhanced_tmp.name back to wav_file
|
528 |
+
print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
|
529 |
+
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
|
530 |
+
|
531 |
+
save_result_metadata(
|
532 |
+
wav_file,
|
533 |
+
ref_name,
|
534 |
+
text,
|
535 |
+
first_stage_ckpt_path,
|
536 |
+
second_stage_ckpt_path,
|
537 |
+
)
|
538 |
+
return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
|
539 |
+
|
540 |
+
|
541 |
+
def sample_utterance(
|
542 |
+
text: str,
|
543 |
+
spk_cond_path: Optional[str],
|
544 |
+
spkemb_model,
|
545 |
+
first_stage_model,
|
546 |
+
second_stage_model,
|
547 |
+
enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
|
548 |
+
first_stage_ckpt_path: str,
|
549 |
+
second_stage_ckpt_path: str,
|
550 |
+
guidance_scale: Optional[float],
|
551 |
+
max_new_tokens: int,
|
552 |
+
top_k: Optional[int],
|
553 |
+
top_p: Optional[float],
|
554 |
+
temperature: Optional[float],
|
555 |
+
) -> str:
|
556 |
+
# NOTE: supports max. 220 characters atm.
|
557 |
+
# Long form synthesis coming soon...
|
558 |
+
MAX_CHARS = 220
|
559 |
+
if len(text) > MAX_CHARS:
|
560 |
+
print(
|
561 |
+
f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
|
562 |
+
)
|
563 |
+
|
564 |
+
return _sample_utterance_batch(
|
565 |
+
texts=[text],
|
566 |
+
spk_cond_paths=[spk_cond_path],
|
567 |
+
spkemb_model=spkemb_model,
|
568 |
+
first_stage_model=first_stage_model,
|
569 |
+
second_stage_model=second_stage_model,
|
570 |
+
enhancer=enhancer,
|
571 |
+
first_stage_ckpt_path=first_stage_ckpt_path,
|
572 |
+
second_stage_ckpt_path=second_stage_ckpt_path,
|
573 |
+
batch_size=1,
|
574 |
+
guidance_scale=guidance_scale,
|
575 |
+
max_new_tokens=max_new_tokens,
|
576 |
+
top_k=top_k,
|
577 |
+
top_p=top_p,
|
578 |
+
temperature=temperature,
|
579 |
+
)[0]
|
580 |
+
|
581 |
+
|
582 |
+
def build_models(config_first_stage, config_second_stage, device, use_kv_cache, first_model_path=None):
|
583 |
+
smodel = SpeakerEncoder(device=device, eval=True, verbose=False)
|
584 |
+
data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
|
585 |
+
llm_first_stage = Model(
|
586 |
+
config_first_stage,
|
587 |
+
TrainedBPETokeniser,
|
588 |
+
EncodecDecoder,
|
589 |
+
data_adapter_fn=data_adapter.decode,
|
590 |
+
use_kv_cache=use_kv_cache,
|
591 |
+
first_model_path=first_model_path,
|
592 |
+
)
|
593 |
+
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
|
594 |
+
llm_second_stage = Model(
|
595 |
+
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
|
596 |
+
)
|
597 |
+
return smodel, llm_first_stage, llm_second_stage
|
598 |
+
|
599 |
+
|
600 |
+
def get_first_stage_path(model_dir: str):
|
601 |
+
"""Absolute path to checkpoint for the first stage model."""
|
602 |
+
return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
|
603 |
+
|
604 |
+
|
605 |
+
def get_second_stage_path(model_dir: str):
|
606 |
+
"""Absolute path to checkpoint for the second stage model."""
|
607 |
+
return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
|
608 |
+
|
609 |
+
|
610 |
+
@dataclass
|
611 |
+
class SamplingControllerConfig:
|
612 |
+
|
613 |
+
huggingface_repo_id: str
|
614 |
+
"""Absolute path to the model directory."""
|
615 |
+
|
616 |
+
spk_cond_path: str
|
617 |
+
"""Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
|
618 |
+
|
619 |
+
text: str = "コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。"
|
620 |
+
"""Text to synthesise."""
|
621 |
+
|
622 |
+
num_samples: int = 1
|
623 |
+
"""Number of samples to generate from each model."""
|
624 |
+
|
625 |
+
max_new_tokens: int = 864
|
626 |
+
"""Maximum number of new tokens to generate from the first stage model."""
|
627 |
+
|
628 |
+
temperature: float = 1.0
|
629 |
+
"""Temperature for sampling applied to both models."""
|
630 |
+
|
631 |
+
top_k: Optional[int] = None
|
632 |
+
"""Top k for sampling applied to both models."""
|
633 |
+
|
634 |
+
top_p: Optional[float] = 0.95
|
635 |
+
"""Top p for sampling applied to first-stage model."""
|
636 |
+
|
637 |
+
seed: int = 1337
|
638 |
+
"""Random seed for sampling."""
|
639 |
+
|
640 |
+
device: Literal["cuda", "cpu"] = "cuda"
|
641 |
+
"""Device to use for sampling."""
|
642 |
+
|
643 |
+
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
|
644 |
+
"""Data type to use for sampling."""
|
645 |
+
|
646 |
+
compile: bool = False
|
647 |
+
"""Whether to compile the model using PyTorch 2.0."""
|
648 |
+
|
649 |
+
enhancer: Optional[Literal["df"]] = "df"
|
650 |
+
"""Enhancer to use for post-processing."""
|
651 |
+
|
652 |
+
init_from: str = "resume"
|
653 |
+
"""Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
|
654 |
+
|
655 |
+
use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None
|
656 |
+
"""Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the
|
657 |
+
flash decoding kernel, 3) [vanilla] use flash attention 2 with hand implemented kv-cache."""
|
658 |
+
|
659 |
+
output_dir: str = "samples/"
|
660 |
+
"""Relative path to output directory"""
|
661 |
+
|
662 |
+
guidance_scale: Optional[float] = 3.0
|
663 |
+
"""Guidance scale for sampling."""
|
664 |
+
|
665 |
+
batch_size: int = 128
|
666 |
+
"""Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
|
667 |
+
1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
|
668 |
+
64 (w/ guidance) works well."""
|
669 |
+
"""
|
670 |
+
Sample from a trained model.
|
671 |
+
"""
|
672 |
+
first_model_path: str = None
|
673 |
+
"""first model path"""
|
674 |
+
|
675 |
+
|
676 |
+
|
677 |
+
if __name__ == "__main__":
|
678 |
+
# TODO: add support for batch sampling via CLI. Function has been implemented above.
|
679 |
+
sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)
|
680 |
+
|
681 |
+
model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
|
682 |
+
first_stage_ckpt_path = get_first_stage_path(model_dir)
|
683 |
+
second_stage_ckpt_path = get_second_stage_path(model_dir)
|
684 |
+
|
685 |
+
config_first_stage = InferenceConfig(
|
686 |
+
ckpt_path=first_stage_ckpt_path,
|
687 |
+
num_samples=sampling_config.num_samples,
|
688 |
+
seed=sampling_config.seed,
|
689 |
+
device=sampling_config.device,
|
690 |
+
dtype=sampling_config.dtype,
|
691 |
+
compile=sampling_config.compile,
|
692 |
+
init_from=sampling_config.init_from,
|
693 |
+
output_dir=sampling_config.output_dir,
|
694 |
+
)
|
695 |
+
|
696 |
+
config_second_stage = InferenceConfig(
|
697 |
+
ckpt_path=second_stage_ckpt_path,
|
698 |
+
num_samples=sampling_config.num_samples,
|
699 |
+
seed=sampling_config.seed,
|
700 |
+
device=sampling_config.device,
|
701 |
+
dtype=sampling_config.dtype,
|
702 |
+
compile=sampling_config.compile,
|
703 |
+
init_from=sampling_config.init_from,
|
704 |
+
output_dir=sampling_config.output_dir,
|
705 |
+
)
|
706 |
+
|
707 |
+
sampling_config.max_new_tokens *= (
|
708 |
+
2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
|
709 |
+
)
|
710 |
+
|
711 |
+
# define models
|
712 |
+
smodel, llm_first_stage, llm_second_stage = build_models(
|
713 |
+
config_first_stage, config_second_stage, sampling_config.device, sampling_config.use_kv_cache, sampling_config.first_model_path,
|
714 |
+
)
|
715 |
+
|
716 |
+
print(f"Synthesising utterance...")
|
717 |
+
sample_utterance(
|
718 |
+
sampling_config.text,
|
719 |
+
os.path.expanduser(sampling_config.spk_cond_path),
|
720 |
+
smodel,
|
721 |
+
llm_first_stage,
|
722 |
+
llm_second_stage,
|
723 |
+
sampling_config.enhancer,
|
724 |
+
first_stage_ckpt_path,
|
725 |
+
second_stage_ckpt_path,
|
726 |
+
sampling_config.guidance_scale,
|
727 |
+
max_new_tokens=sampling_config.max_new_tokens,
|
728 |
+
top_k=sampling_config.top_k,
|
729 |
+
top_p=sampling_config.top_p,
|
730 |
+
temperature=sampling_config.temperature,
|
731 |
+
)
|
fam/llm/serving.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import shlex
|
5 |
+
import subprocess
|
6 |
+
import tempfile
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Literal, Optional
|
9 |
+
|
10 |
+
import fastapi
|
11 |
+
import fastapi.middleware.cors
|
12 |
+
import torch
|
13 |
+
import tyro
|
14 |
+
import uvicorn
|
15 |
+
from attr import dataclass
|
16 |
+
from fastapi import Request
|
17 |
+
from fastapi.responses import Response
|
18 |
+
from huggingface_hub import snapshot_download
|
19 |
+
|
20 |
+
from fam.llm.sample import (
|
21 |
+
InferenceConfig,
|
22 |
+
Model,
|
23 |
+
build_models,
|
24 |
+
get_first_stage_path,
|
25 |
+
get_second_stage_path,
|
26 |
+
# sample_utterance,
|
27 |
+
)
|
28 |
+
from fam.llm.fast_inference import TTS
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
## Setup FastAPI server.
|
34 |
+
app = fastapi.FastAPI()
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class ServingConfig:
|
39 |
+
huggingface_repo_id: str
|
40 |
+
"""Absolute path to the model directory."""
|
41 |
+
|
42 |
+
max_new_tokens: int = 864 * 2
|
43 |
+
"""Maximum number of new tokens to generate from the first stage model."""
|
44 |
+
|
45 |
+
temperature: float = 1.0
|
46 |
+
"""Temperature for sampling applied to both models."""
|
47 |
+
|
48 |
+
top_k: int = 200
|
49 |
+
"""Top k for sampling applied to both models."""
|
50 |
+
|
51 |
+
seed: int = 1337
|
52 |
+
"""Random seed for sampling."""
|
53 |
+
|
54 |
+
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
|
55 |
+
"""Data type to use for sampling."""
|
56 |
+
|
57 |
+
enhancer: Optional[Literal["df"]] = "df"
|
58 |
+
"""Enhancer to use for post-processing."""
|
59 |
+
|
60 |
+
port: int = 58003
|
61 |
+
|
62 |
+
|
63 |
+
# Singleton
|
64 |
+
class _GlobalState:
|
65 |
+
config: ServingConfig
|
66 |
+
tts: TTS
|
67 |
+
|
68 |
+
|
69 |
+
GlobalState = _GlobalState()
|
70 |
+
|
71 |
+
@dataclass(frozen=True)
|
72 |
+
class TTSRequest:
|
73 |
+
text: str
|
74 |
+
guidance: Optional[float] = 3.0
|
75 |
+
top_p: Optional[float] = 0.95
|
76 |
+
speaker_ref_path: Optional[str] = None
|
77 |
+
top_k: Optional[int] = None
|
78 |
+
|
79 |
+
|
80 |
+
def sample_utterance(
|
81 |
+
text: str,
|
82 |
+
spk_cond_path: str | None,
|
83 |
+
guidance_scale,
|
84 |
+
max_new_tokens,
|
85 |
+
top_k,
|
86 |
+
top_p,
|
87 |
+
temperature,
|
88 |
+
) -> str:
|
89 |
+
return GlobalState.tts.synthesise(
|
90 |
+
text,
|
91 |
+
spk_cond_path,
|
92 |
+
top_p=top_p,
|
93 |
+
guidance_scale=guidance_scale,
|
94 |
+
temperature=temperature,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
@app.post("/tts", response_class=Response)
|
99 |
+
async def text_to_speech(req: Request):
|
100 |
+
audiodata = await req.body()
|
101 |
+
payload = None
|
102 |
+
wav_out_path = None
|
103 |
+
|
104 |
+
try:
|
105 |
+
headers = req.headers
|
106 |
+
payload = headers["X-Payload"]
|
107 |
+
payload = json.loads(payload)
|
108 |
+
tts_req = TTSRequest(**payload)
|
109 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
|
110 |
+
if tts_req.speaker_ref_path is None:
|
111 |
+
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
|
112 |
+
else:
|
113 |
+
wav_path = tts_req.speaker_ref_path
|
114 |
+
wav_out_path = sample_utterance(
|
115 |
+
tts_req.text,
|
116 |
+
wav_path,
|
117 |
+
guidance_scale=tts_req.guidance,
|
118 |
+
max_new_tokens=GlobalState.config.max_new_tokens,
|
119 |
+
temperature=GlobalState.config.temperature,
|
120 |
+
top_k=tts_req.top_k,
|
121 |
+
top_p=tts_req.top_p,
|
122 |
+
)
|
123 |
+
with open(wav_out_path, "rb") as f:
|
124 |
+
return Response(content=f.read(), media_type="audio/wav")
|
125 |
+
except Exception as e:
|
126 |
+
# traceback_str = "".join(traceback.format_tb(e.__traceback__))
|
127 |
+
logger.exception(f"Error processing request {payload}")
|
128 |
+
return Response(
|
129 |
+
content="Something went wrong. Please try again in a few mins or contact us on Discord",
|
130 |
+
status_code=500,
|
131 |
+
)
|
132 |
+
finally:
|
133 |
+
if wav_out_path is not None:
|
134 |
+
Path(wav_out_path).unlink(missing_ok=True)
|
135 |
+
|
136 |
+
|
137 |
+
def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
|
138 |
+
with tempfile.NamedTemporaryFile() as unknown_format_tmp:
|
139 |
+
assert unknown_format_tmp.write(audiodata) > 0
|
140 |
+
unknown_format_tmp.flush()
|
141 |
+
|
142 |
+
subprocess.check_output(
|
143 |
+
# arbitrary 2 minute cutoff
|
144 |
+
shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}")
|
145 |
+
)
|
146 |
+
|
147 |
+
return wav_tmp.name
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
# This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib
|
152 |
+
from fam.llm.enhancers import get_enhancer
|
153 |
+
|
154 |
+
for name in logging.root.manager.loggerDict:
|
155 |
+
logger = logging.getLogger(name)
|
156 |
+
logger.setLevel(logging.INFO)
|
157 |
+
logging.root.setLevel(logging.INFO)
|
158 |
+
|
159 |
+
GlobalState.config = tyro.cli(ServingConfig)
|
160 |
+
app.add_middleware(
|
161 |
+
fastapi.middleware.cors.CORSMiddleware,
|
162 |
+
allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"],
|
163 |
+
allow_credentials=True,
|
164 |
+
allow_methods=["*"],
|
165 |
+
allow_headers=["*"],
|
166 |
+
)
|
167 |
+
|
168 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
169 |
+
common_config = dict(
|
170 |
+
num_samples=1,
|
171 |
+
seed=1337,
|
172 |
+
device=device,
|
173 |
+
dtype=GlobalState.config.dtype,
|
174 |
+
compile=False,
|
175 |
+
init_from="resume",
|
176 |
+
output_dir=tempfile.mkdtemp(),
|
177 |
+
)
|
178 |
+
model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id)
|
179 |
+
config1 = InferenceConfig(
|
180 |
+
ckpt_path=get_first_stage_path(model_dir),
|
181 |
+
**common_config,
|
182 |
+
)
|
183 |
+
|
184 |
+
config2 = InferenceConfig(
|
185 |
+
ckpt_path=get_second_stage_path(model_dir),
|
186 |
+
**common_config,
|
187 |
+
)
|
188 |
+
|
189 |
+
GlobalState.tts = TTS()
|
190 |
+
|
191 |
+
# start server
|
192 |
+
uvicorn.run(
|
193 |
+
app,
|
194 |
+
host="127.0.0.1",
|
195 |
+
port=GlobalState.config.port,
|
196 |
+
log_level="info",
|
197 |
+
)
|
fam/llm/utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import subprocess
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def normalize_text(text: str) -> str:
|
11 |
+
unicode_conversion = {
|
12 |
+
8175: "'",
|
13 |
+
8189: "'",
|
14 |
+
8190: "'",
|
15 |
+
8208: "-",
|
16 |
+
8209: "-",
|
17 |
+
8210: "-",
|
18 |
+
8211: "-",
|
19 |
+
8212: "-",
|
20 |
+
8213: "-",
|
21 |
+
8214: "||",
|
22 |
+
8216: "'",
|
23 |
+
8217: "'",
|
24 |
+
8218: ",",
|
25 |
+
8219: "`",
|
26 |
+
8220: '"',
|
27 |
+
8221: '"',
|
28 |
+
8222: ",,",
|
29 |
+
8223: '"',
|
30 |
+
8228: ".",
|
31 |
+
8229: "..",
|
32 |
+
8230: "...",
|
33 |
+
8242: "'",
|
34 |
+
8243: '"',
|
35 |
+
8245: "'",
|
36 |
+
8246: '"',
|
37 |
+
180: "'",
|
38 |
+
2122: "TM", # Trademark
|
39 |
+
}
|
40 |
+
|
41 |
+
text = text.translate(unicode_conversion)
|
42 |
+
|
43 |
+
non_bpe_chars = set([c for c in list(text) if ord(c) >= 256])
|
44 |
+
#if len(non_bpe_chars) > 0:
|
45 |
+
# non_bpe_points = [(c, ord(c)) for c in non_bpe_chars]
|
46 |
+
# raise ValueError(f"Non-BPE single token characters found: {non_bpe_points}")
|
47 |
+
|
48 |
+
text = text.replace("\t", " ")
|
49 |
+
text = text.replace("\n", " ")
|
50 |
+
text = text.replace("*", " ")
|
51 |
+
text = text.strip()
|
52 |
+
text = re.sub("\s\s+", " ", text) # remove multiple spaces
|
53 |
+
return text
|
54 |
+
|
55 |
+
def check_audio_file(path_or_uri, threshold_s=10): # default 30
|
56 |
+
if "http" in path_or_uri:
|
57 |
+
temp_fd, filepath = tempfile.mkstemp()
|
58 |
+
os.close(temp_fd) # Close the file descriptor, curl will create a new connection
|
59 |
+
curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
|
60 |
+
subprocess.run(curl_command, check=True)
|
61 |
+
|
62 |
+
else:
|
63 |
+
filepath = path_or_uri
|
64 |
+
|
65 |
+
audio, sr = librosa.load(filepath)
|
66 |
+
duration_s = librosa.get_duration(y=audio, sr=sr)
|
67 |
+
if duration_s < threshold_s:
|
68 |
+
raise Exception(
|
69 |
+
f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
|
70 |
+
)
|
71 |
+
|
72 |
+
# Clean up the temporary file if it was created
|
73 |
+
if "http" in path_or_uri:
|
74 |
+
os.remove(filepath)
|
75 |
+
|
76 |
+
|
77 |
+
def get_default_dtype() -> str:
|
78 |
+
"""Compute default 'dtype' based on GPU architecture"""
|
79 |
+
if torch.cuda.is_available():
|
80 |
+
for i in range(torch.cuda.device_count()):
|
81 |
+
device_properties = torch.cuda.get_device_properties(i)
|
82 |
+
dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures
|
83 |
+
else:
|
84 |
+
dtype = "float16"
|
85 |
+
|
86 |
+
print(f"using dtype={dtype}")
|
87 |
+
return dtype
|
88 |
+
|
89 |
+
|
90 |
+
def get_device() -> str:
|
91 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
fam/py.typed
ADDED
File without changes
|
fam/quantiser/__init__.py
ADDED
File without changes
|
fam/quantiser/audio/__init__.py
ADDED
File without changes
|
fam/quantiser/audio/speaker_encoder/__init__.py
ADDED
File without changes
|
fam/quantiser/audio/speaker_encoder/audio.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
mel_window_length = 25
|
5 |
+
mel_window_step = 10
|
6 |
+
mel_n_channels = 40
|
7 |
+
sampling_rate = 16000
|
8 |
+
|
9 |
+
|
10 |
+
def wav_to_mel_spectrogram(wav):
|
11 |
+
"""
|
12 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
13 |
+
Note: this not a log-mel spectrogram.
|
14 |
+
"""
|
15 |
+
frames = librosa.feature.melspectrogram(
|
16 |
+
y=wav,
|
17 |
+
sr=sampling_rate,
|
18 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
19 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
20 |
+
n_mels=mel_n_channels,
|
21 |
+
)
|
22 |
+
return frames.astype(np.float32).T
|
fam/quantiser/audio/speaker_encoder/ckpt/.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
|
3 |
+
size 17090379
|
fam/quantiser/audio/speaker_encoder/model.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from time import perf_counter as timer
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from fam.quantiser.audio.speaker_encoder import audio
|
11 |
+
|
12 |
+
DEFAULT_SPKENC_CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/ckpt.pt")
|
13 |
+
|
14 |
+
mel_window_step = 10
|
15 |
+
mel_n_channels = 40
|
16 |
+
sampling_rate = 16000
|
17 |
+
partials_n_frames = 160
|
18 |
+
model_hidden_size = 256
|
19 |
+
model_embedding_size = 256
|
20 |
+
model_num_layers = 3
|
21 |
+
|
22 |
+
|
23 |
+
class SpeakerEncoder(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
weights_fpath: Optional[str] = None,
|
27 |
+
device: Optional[Union[str, torch.device]] = None,
|
28 |
+
verbose: bool = True,
|
29 |
+
eval: bool = False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# Define the network
|
34 |
+
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
35 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
36 |
+
self.relu = nn.ReLU()
|
37 |
+
|
38 |
+
# Get the target device
|
39 |
+
if device is None:
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
elif isinstance(device, str):
|
42 |
+
device = torch.device(device)
|
43 |
+
self.device = device
|
44 |
+
|
45 |
+
start = timer()
|
46 |
+
if eval and weights_fpath is None:
|
47 |
+
weights_fpath = DEFAULT_SPKENC_CKPT_PATH
|
48 |
+
|
49 |
+
if weights_fpath is not None:
|
50 |
+
checkpoint = torch.load(weights_fpath, map_location="cpu")
|
51 |
+
|
52 |
+
self.load_state_dict(checkpoint["model_state"], strict=False)
|
53 |
+
self.to(device)
|
54 |
+
|
55 |
+
if eval:
|
56 |
+
self.eval()
|
57 |
+
|
58 |
+
if verbose:
|
59 |
+
print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start))
|
60 |
+
|
61 |
+
def forward(self, mels: torch.FloatTensor):
|
62 |
+
_, (hidden, _) = self.lstm(mels)
|
63 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
64 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
68 |
+
# Compute how many frames separate two partial utterances
|
69 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
70 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
71 |
+
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
72 |
+
|
73 |
+
# Compute the slices
|
74 |
+
wav_slices, mel_slices = [], []
|
75 |
+
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
76 |
+
for i in range(0, steps, frame_step):
|
77 |
+
mel_range = np.array([i, i + partials_n_frames])
|
78 |
+
wav_range = mel_range * samples_per_frame
|
79 |
+
mel_slices.append(slice(*mel_range))
|
80 |
+
wav_slices.append(slice(*wav_range))
|
81 |
+
|
82 |
+
# Evaluate whether extra padding is warranted or not
|
83 |
+
last_wav_range = wav_slices[-1]
|
84 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
85 |
+
if coverage < min_coverage and len(mel_slices) > 1:
|
86 |
+
mel_slices = mel_slices[:-1]
|
87 |
+
wav_slices = wav_slices[:-1]
|
88 |
+
|
89 |
+
return wav_slices, mel_slices
|
90 |
+
|
91 |
+
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True):
|
92 |
+
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
93 |
+
max_wave_length = wav_slices[-1].stop
|
94 |
+
if max_wave_length >= len(wav):
|
95 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
96 |
+
|
97 |
+
mel = audio.wav_to_mel_spectrogram(wav)
|
98 |
+
mels = np.array([mel[s] for s in mel_slices])
|
99 |
+
with torch.no_grad():
|
100 |
+
mels = torch.from_numpy(mels).to(self.device) # type: ignore
|
101 |
+
partial_embeds = self(mels)
|
102 |
+
|
103 |
+
if numpy:
|
104 |
+
partial_embeds = partial_embeds.cpu().numpy()
|
105 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
106 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
107 |
+
else:
|
108 |
+
raw_embed = partial_embeds.mean(dim=0)
|
109 |
+
embed = raw_embed / torch.linalg.norm(raw_embed, 2)
|
110 |
+
|
111 |
+
if return_partials:
|
112 |
+
return embed, partial_embeds, wav_slices
|
113 |
+
return embed
|
114 |
+
|
115 |
+
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
116 |
+
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
|
117 |
+
return raw_embed / np.linalg.norm(raw_embed, 2)
|
118 |
+
|
119 |
+
def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor:
|
120 |
+
wav_tgt, _ = librosa.load(fpath, sr=16000)
|
121 |
+
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
122 |
+
embedding = self.embed_utterance(wav_tgt, numpy=numpy)
|
123 |
+
return embedding
|
fam/quantiser/text/tokenise.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
|
3 |
+
|
4 |
+
class TrainedBPETokeniser:
|
5 |
+
def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None:
|
6 |
+
self.tokenizer = tiktoken.Encoding(
|
7 |
+
name=name,
|
8 |
+
pat_str=pat_str,
|
9 |
+
mergeable_ranks=mergeable_ranks,
|
10 |
+
special_tokens=special_tokens,
|
11 |
+
)
|
12 |
+
self.offset = offset
|
13 |
+
|
14 |
+
def encode(self, text: str) -> list[int]:
|
15 |
+
# note: we add a end of text token!
|
16 |
+
tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token]
|
17 |
+
if self.offset is not None:
|
18 |
+
tokens = [x + self.offset for x in tokens]
|
19 |
+
|
20 |
+
return tokens
|
21 |
+
|
22 |
+
def decode(self, tokens: list[int]):
|
23 |
+
if self.offset is not None:
|
24 |
+
tokens = [x - self.offset for x in tokens]
|
25 |
+
return self.tokenizer.decode(tokens)
|
26 |
+
|
27 |
+
@property
|
28 |
+
def eot_token(self):
|
29 |
+
if self.offset is not None:
|
30 |
+
return self.tokenizer.eot_token + self.offset
|
31 |
+
else:
|
32 |
+
return self.tokenizer.eot_token
|
fam/ui/app.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import requests
|
7 |
+
import soundfile as sf
|
8 |
+
|
9 |
+
API_SERVER_URL = "http://127.0.0.1:58003/tts"
|
10 |
+
RADIO_CHOICES = ["Preset voices", "Upload target voice", "Record your voice"]
|
11 |
+
MAX_CHARS = 220
|
12 |
+
PRESET_VOICES = {
|
13 |
+
# female
|
14 |
+
"Ava": "https://cdn.themetavoice.xyz/speakers/ava.flac",
|
15 |
+
"Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
|
16 |
+
# male
|
17 |
+
"Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
|
18 |
+
"Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def denormalise_top_p(top_p):
|
23 |
+
# returns top_p in the range [0.9, 1.0]
|
24 |
+
return round(0.9 + top_p / 100, 2)
|
25 |
+
|
26 |
+
|
27 |
+
def denormalise_guidance(guidance):
|
28 |
+
# returns guidance in the range [1.0, 3.0]
|
29 |
+
return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)
|
30 |
+
|
31 |
+
|
32 |
+
def _handle_edge_cases(to_say, upload_target):
|
33 |
+
if not to_say:
|
34 |
+
raise gr.Error("Please provide text to synthesise")
|
35 |
+
|
36 |
+
def _check_file_size(path):
|
37 |
+
if not path:
|
38 |
+
return
|
39 |
+
filesize = os.path.getsize(path)
|
40 |
+
filesize_mb = filesize / 1024 / 1024
|
41 |
+
if filesize_mb >= 50:
|
42 |
+
raise gr.Error(
|
43 |
+
f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB"
|
44 |
+
)
|
45 |
+
|
46 |
+
_check_file_size(upload_target)
|
47 |
+
|
48 |
+
|
49 |
+
def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target):
|
50 |
+
d_top_p = denormalise_top_p(top_p)
|
51 |
+
d_guidance = denormalise_guidance(guidance)
|
52 |
+
|
53 |
+
_handle_edge_cases(to_say, upload_target)
|
54 |
+
|
55 |
+
to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
|
56 |
+
|
57 |
+
custom_target_path = None
|
58 |
+
if toggle == RADIO_CHOICES[1]:
|
59 |
+
custom_target_path = upload_target
|
60 |
+
elif toggle == RADIO_CHOICES[2]:
|
61 |
+
custom_target_path = record_target
|
62 |
+
|
63 |
+
config = {
|
64 |
+
"text": to_say,
|
65 |
+
"guidance": d_guidance,
|
66 |
+
"top_p": d_top_p,
|
67 |
+
"speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None,
|
68 |
+
}
|
69 |
+
headers = {"Content-Type": "audio/wav", "X-Payload": json.dumps(config)}
|
70 |
+
if not custom_target_path:
|
71 |
+
response = requests.post(API_SERVER_URL, headers=headers, data=None)
|
72 |
+
else:
|
73 |
+
with open(custom_target_path, "rb") as f:
|
74 |
+
data = f.read()
|
75 |
+
response = requests.post(API_SERVER_URL, headers=headers, data=data)
|
76 |
+
|
77 |
+
wav, sr = None, None
|
78 |
+
if response.status_code == 200:
|
79 |
+
audio_buffer = io.BytesIO(response.content)
|
80 |
+
audio_buffer.seek(0)
|
81 |
+
wav, sr = sf.read(audio_buffer, dtype="float32")
|
82 |
+
else:
|
83 |
+
print(f"Something went wrong. response status code: {response.status_code}")
|
84 |
+
|
85 |
+
return sr, wav
|
86 |
+
|
87 |
+
|
88 |
+
def change_voice_selection_layout(choice):
|
89 |
+
index = RADIO_CHOICES.index(choice)
|
90 |
+
return [
|
91 |
+
gr.update(visible=True)
|
92 |
+
if i == index else gr.update(visible=False)
|
93 |
+
for i in range(len(RADIO_CHOICES))
|
94 |
+
]
|
95 |
+
|
96 |
+
|
97 |
+
title = "# TTS by Kotoba-Speech"
|
98 |
+
|
99 |
+
description = """
|
100 |
+
<strong>Kotoba-Speech v0.1</strong>は、1.2Bのトランスフォーマーに基づく音声生成モデルです。
|
101 |
+
以下の機能をサポートしています:
|
102 |
+
\n
|
103 |
+
* 日本語における滑らかなテキスト読み上げ生成
|
104 |
+
* スピーチプロンプトを通じたOne-shot音声クローニング
|
105 |
+
|
106 |
+
Kotoba Technologiesは、公開されたモデルを商用可能なApache 2.0ライセンスで公開します。
|
107 |
+
推論およびモデルコードは、Meta-Voiceをベースに作られており、学習コードは弊社のGitHubで近日中に公開する予定です。
|
108 |
+
Kotoba Technologiesは、音声基盤モデルの開発に取り組んでおり、今後もモデルの公開を行なっていきます。是非、[Discord Community](https://discord.gg/qPVFqhGN7Z)に参加してご意見ください!
|
109 |
+
|
110 |
+
<strong>Kotoba-Speech v0.1</strong> is a 1.2B Transformer-based speech generative model. It supports the following properties:
|
111 |
+
\n
|
112 |
+
* Fluent text-to-speech generation in Japanese
|
113 |
+
* One-shot voice cloning through speech prompt
|
114 |
+
|
115 |
+
We are releasing our model under the Apache 2.0 license. Our inference and model code is adapted from Meta-Voice, and we will our training code on our GitHub repository shortly.
|
116 |
+
Kotoba Technologies is committing on developing speech foundation models, and we’ll continue releasing our models. Please join [our discord](https://discord.gg/qPVFqhGN7Z) to contribute to out community.
|
117 |
+
"""
|
118 |
+
|
119 |
+
with gr.Blocks(title="TTS by Kotoba-Speech") as demo:
|
120 |
+
gr.Markdown(title)
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
gr.Markdown(description)
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column():
|
127 |
+
to_say = gr.TextArea(
|
128 |
+
label="What should I say!?",
|
129 |
+
lines=4,
|
130 |
+
value="コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。",
|
131 |
+
)
|
132 |
+
|
133 |
+
with gr.Row(), gr.Column():
|
134 |
+
# voice settings
|
135 |
+
top_p = gr.Slider(
|
136 |
+
value=5.0,
|
137 |
+
minimum=0.0,
|
138 |
+
maximum=10.0,
|
139 |
+
step=1.0,
|
140 |
+
label="Speech Stability - improves text following for a challenging speaker",
|
141 |
+
)
|
142 |
+
guidance = gr.Slider(
|
143 |
+
value=5.0,
|
144 |
+
minimum=1.0,
|
145 |
+
maximum=5.0,
|
146 |
+
step=1.0,
|
147 |
+
label="Speaker similarity - How closely to match speaker identity and speech style.",
|
148 |
+
)
|
149 |
+
|
150 |
+
# voice select
|
151 |
+
toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
|
152 |
+
|
153 |
+
with gr.Row(visible=True) as row_1:
|
154 |
+
preset_dropdown = gr.Dropdown(
|
155 |
+
PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
|
156 |
+
)
|
157 |
+
with gr.Accordion("Preview: Preset voices", open=False):
|
158 |
+
for label, path in PRESET_VOICES.items():
|
159 |
+
gr.Audio(value=path, label=label)
|
160 |
+
|
161 |
+
with gr.Row(visible=False) as row_2:
|
162 |
+
upload_target = gr.Audio(
|
163 |
+
sources=["upload"],
|
164 |
+
type="filepath",
|
165 |
+
label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
|
166 |
+
min_length=10,
|
167 |
+
max_length=90,
|
168 |
+
)
|
169 |
+
|
170 |
+
with gr.Row(visible=False) as row_3:
|
171 |
+
record_target = gr.Audio(
|
172 |
+
sources=["microphone"],
|
173 |
+
type="filepath",
|
174 |
+
label="Record your voice with a microphone to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
|
175 |
+
min_length=10,
|
176 |
+
max_length=90,
|
177 |
+
)
|
178 |
+
|
179 |
+
toggle.change(
|
180 |
+
change_voice_selection_layout,
|
181 |
+
inputs=toggle,
|
182 |
+
outputs=[row_1, row_2, row_3],
|
183 |
+
)
|
184 |
+
|
185 |
+
with gr.Column():
|
186 |
+
speech = gr.Audio(
|
187 |
+
type="numpy",
|
188 |
+
label="Kotoba-Speech says...",
|
189 |
+
)
|
190 |
+
|
191 |
+
submit = gr.Button("Generate Speech")
|
192 |
+
submit.click(
|
193 |
+
fn=tts,
|
194 |
+
inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target],
|
195 |
+
outputs=speech,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
demo.queue(default_concurrency_limit=2)
|
200 |
+
# demo.launch()
|
201 |
+
demo.launch(server_name="0.0.0.0", server_port=3000, share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anyio==4.0.0
|
2 |
+
argon2-cffi==23.1.0
|
3 |
+
argon2-cffi-bindings==21.2.0
|
4 |
+
arrow==1.3.0
|
5 |
+
asttokens==2.4.1
|
6 |
+
async-lru==2.0.4
|
7 |
+
attrs==23.1.0
|
8 |
+
Babel==2.13.1
|
9 |
+
beautifulsoup4==4.12.2
|
10 |
+
bleach==6.1.0
|
11 |
+
blinker==1.4
|
12 |
+
certifi==2022.12.7
|
13 |
+
cffi==1.16.0
|
14 |
+
charset-normalizer==2.1.1
|
15 |
+
comm==0.2.0
|
16 |
+
cryptography==3.4.8
|
17 |
+
dbus-python==1.2.18
|
18 |
+
debugpy==1.8.0
|
19 |
+
decorator==5.1.1
|
20 |
+
deepfilternet
|
21 |
+
defusedxml==0.7.1
|
22 |
+
distro==1.7.0
|
23 |
+
entrypoints==0.4
|
24 |
+
exceptiongroup==1.1.3
|
25 |
+
executing==2.0.1
|
26 |
+
fastjsonschema==2.18.1
|
27 |
+
filelock==3.9.0
|
28 |
+
fqdn==1.5.1
|
29 |
+
fsspec
|
30 |
+
gradio==4.21.0
|
31 |
+
httplib2==0.20.2
|
32 |
+
idna==3.4
|
33 |
+
importlib-metadata==4.6.4
|
34 |
+
ipykernel==6.26.0
|
35 |
+
ipython==8.17.2
|
36 |
+
ipython-genutils==0.2.0
|
37 |
+
ipywidgets==8.1.1
|
38 |
+
isoduration==20.11.0
|
39 |
+
jedi==0.19.1
|
40 |
+
jeepney==0.7.1
|
41 |
+
Jinja2==3.1.2
|
42 |
+
json5==0.9.14
|
43 |
+
jsonpointer==2.4
|
44 |
+
jsonschema==4.19.2
|
45 |
+
jsonschema-specifications==2023.7.1
|
46 |
+
jupyter-archive==3.4.0
|
47 |
+
jupyter-contrib-core==0.4.2
|
48 |
+
jupyter-contrib-nbextensions==0.7.0
|
49 |
+
jupyter-events==0.9.0
|
50 |
+
jupyter-highlight-selected-word==0.2.0
|
51 |
+
jupyter-lsp==2.2.0
|
52 |
+
jupyter-nbextensions-configurator==0.6.3
|
53 |
+
jupyter_client==7.4.9
|
54 |
+
jupyter_core==5.5.0
|
55 |
+
jupyter_server==2.10.0
|
56 |
+
jupyter_server_terminals==0.4.4
|
57 |
+
jupyterlab==4.0.8
|
58 |
+
jupyterlab-pygments==0.2.2
|
59 |
+
jupyterlab-widgets==3.0.9
|
60 |
+
jupyterlab_server==2.25.0
|
61 |
+
keyring==23.5.0
|
62 |
+
launchpadlib==1.10.16
|
63 |
+
lazr.restfulclient==0.14.4
|
64 |
+
lazr.uri==1.0.6
|
65 |
+
librosa
|
66 |
+
lxml==4.9.3
|
67 |
+
Mako==1.1.3
|
68 |
+
Markdown==3.3.6
|
69 |
+
MarkupSafe==2.1.2
|
70 |
+
matplotlib-inline==0.1.6
|
71 |
+
mistune==3.0.2
|
72 |
+
more-itertools==8.10.0
|
73 |
+
mpmath==1.3.0
|
74 |
+
nbclassic==1.0.0
|
75 |
+
nbclient==0.9.0
|
76 |
+
nbconvert==7.11.0
|
77 |
+
nbformat==5.9.2
|
78 |
+
nest-asyncio==1.5.8
|
79 |
+
networkx==3.0
|
80 |
+
notebook==6.5.5
|
81 |
+
notebook_shim==0.2.3
|
82 |
+
numpy==1.24.1
|
83 |
+
oauthlib==3.2.0
|
84 |
+
overrides==7.4.0
|
85 |
+
packaging==23.2
|
86 |
+
pandocfilters==1.5.0
|
87 |
+
parso==0.8.3
|
88 |
+
pexpect==4.8.0
|
89 |
+
Pillow==9.3.0
|
90 |
+
platformdirs==3.11.0
|
91 |
+
prometheus-client==0.18.0
|
92 |
+
prompt-toolkit==3.0.39
|
93 |
+
psutil==5.9.6
|
94 |
+
ptyprocess==0.7.0
|
95 |
+
pure-eval==0.2.2
|
96 |
+
pycparser==2.21
|
97 |
+
Pygments==2.16.1
|
98 |
+
PyGObject==3.42.1
|
99 |
+
PyJWT==2.3.0
|
100 |
+
pyparsing==2.4.7
|
101 |
+
# python-apt==2.4.0
|
102 |
+
python-dateutil==2.8.2
|
103 |
+
python-json-logger==2.0.7
|
104 |
+
PyYAML==6.0.1
|
105 |
+
pyzmq==24.0.1
|
106 |
+
referencing==0.30.2
|
107 |
+
requests==2.31.0
|
108 |
+
rfc3339-validator==0.1.4
|
109 |
+
rfc3986-validator==0.1.1
|
110 |
+
rpds-py==0.12.0
|
111 |
+
SecretStorage==3.3.1
|
112 |
+
Send2Trash==1.8.2
|
113 |
+
six==1.16.0
|
114 |
+
sniffio==1.3.0
|
115 |
+
soundfile==0.12.1
|
116 |
+
soupsieve==2.5
|
117 |
+
stack-data==0.6.3
|
118 |
+
sympy==1.12
|
119 |
+
terminado==0.17.1
|
120 |
+
tinycss2==1.2.1
|
121 |
+
tomli==2.0.1
|
122 |
+
torch
|
123 |
+
torchaudio
|
124 |
+
torchvision
|
125 |
+
tornado==6.3.3
|
126 |
+
traitlets==5.13.0
|
127 |
+
triton==2.1.0
|
128 |
+
types-python-dateutil==2.8.19.14
|
129 |
+
typing_extensions
|
130 |
+
uri-template==1.3.0
|
131 |
+
urllib3==1.26.13
|
132 |
+
wadllib==1.3.6
|
133 |
+
wcwidth==0.2.9
|
134 |
+
webcolors==1.13
|
135 |
+
webencodings==0.5.1
|
136 |
+
websocket-client==1.6.4
|
137 |
+
widgetsnbextension==4.0.9
|
138 |
+
zipp==1.0.0
|
setup.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup # type: ignore
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="fam",
|
5 |
+
packages=find_packages(".", exclude=["tests"]),
|
6 |
+
)
|