Spaces:
Sleeping
Sleeping
File size: 5,986 Bytes
35c1cfd f0abd45 35c1cfd f0abd45 35c1cfd f0abd45 35c1cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import random
import torch
from slam_llm.utils.model_utils import get_custom_model_factory
from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
import whisper
import numpy as np
from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
import os
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from typing import Callable
def update_progress(progress_callback: Callable[[str], None] | None, message: str):
if progress_callback:
progress_callback(message)
def pull_model_ckpt():
if not os.path.exists(CKPT_LOCAL_DIR):
os.makedirs(CKPT_LOCAL_DIR)
if os.path.exists(CKPT_PATH):
return
hf_hub_download(
repo_id=CKPT_REPO,
filename=CKPT_NAME,
local_dir=CKPT_LOCAL_DIR,
token=os.getenv("HF_TOKEN"),
)
pull_model_ckpt()
def extract_audio_feature(audio_path, mel_size):
print("Extracting audio features from", audio_path)
audio_raw = whisper.load_audio(audio_path)
audio_raw = whisper.pad_or_trim(audio_raw)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
audio_length = (audio_mel.shape[0] + 1) // 2
audio_length = audio_length // 5
audio_res = audio_mel
return audio_res, audio_length
def get_input_ids(length, special_token_a, special_token_t, vocab_config):
input_ids = []
for i in range(vocab_config.code_layer):
input_ids_item = []
input_ids_item.append(layershift(vocab_config.input_a, i))
input_ids_item += [layershift(vocab_config.pad_a, i)] * length
input_ids_item += [
(layershift(vocab_config.eoa, i)),
layershift(special_token_a, i),
]
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
input_id_T = torch.tensor(
[vocab_config.input_t]
+ [vocab_config.pad_t] * length
+ [vocab_config.eot, special_token_t]
)
input_ids.append(input_id_T.unsqueeze(0))
return input_ids
def generate_from_wav(
wav_path, model, codec_decoder, dataset_config, decode_config, device
):
mel_size = dataset_config.mel_size
prompt = dataset_config.prompt
prompt_template = "USER: {}\n ASSISTANT: "
vocab_config = dataset_config.vocab_config
special_token_a = vocab_config.answer_a
special_token_t = vocab_config.answer_t
code_layer = vocab_config.code_layer
task_type = dataset_config.task_type
audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)
prompt = prompt_template.format(prompt)
prompt_ids = model.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = get_input_ids(
audio_length + prompt_length, special_token_a, special_token_t, vocab_config
)
text_layer = example_ids[code_layer]
text_layer = torch.cat(
(
text_layer[:, : audio_length + 1],
prompt_ids.unsqueeze(0),
text_layer[:, -2:],
),
dim=1,
) # <bos> <audio> <prompt> <eos> <task>
example_ids[code_layer] = text_layer
input_length = audio_length
example_mask = example_ids[0][0].ge(-1)
example_ids = torch.stack(example_ids).squeeze()
input_ids = example_ids.unsqueeze(0).to(device)
attention_mask = example_mask.unsqueeze(0).to(device)
audio_mel = audio_mel.unsqueeze(0).to(device)
input_length = torch.tensor([input_length]).to(device)
audio_length = torch.tensor([audio_length]).to(device)
task_type = [task_type]
modality_mask = torch.zeros_like(attention_mask)
padding_left = 1 # +1 for <bos>
modality_mask[0, padding_left : padding_left + audio_length] = True
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio_mel": audio_mel,
"input_length": input_length,
"audio_length": audio_length,
"modality_mask": modality_mask,
"task_types": task_type,
}
model_outputs = model.generate(**batch, **decode_config)
text_outputs = model_outputs[7]
audio_outputs = model_outputs[:7]
output_text = model.tokenizer.decode(
text_outputs, add_special_tokens=False, skip_special_tokens=True
)
if decode_config.decode_text_only:
return None, output_text
audio_tokens = [audio_outputs[layer] for layer in range(7)]
audiolist = reconscruct_snac(audio_tokens)
audio = reconstruct_tensors(audiolist)
with torch.inference_mode():
audio_hat = codec_decoder.decode(audio)
return audio_hat, output_text
model = None
codec_decoder = None
device = None
def generate(
wav_path: str, progress_callback: Callable[[str], None] | None = None
) -> tuple[np.ndarray, int | float]:
global model, codec_decoder, device
config = OmegaConf.structured(InferenceConfig())
train_config, model_config, dataset_config, decode_config = (
config.train_config,
config.model_config,
config.dataset_config,
config.decode_config,
)
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
if model is None or codec_decoder is None or device is None:
update_progress(progress_callback, "Loading model")
model_factory = get_custom_model_factory(model_config)
model, _ = model_factory(train_config, model_config, CKPT_PATH)
codec_decoder = model.codec_decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
update_progress(progress_callback, "Generating")
output_wav, output_text = generate_from_wav(
wav_path, model, codec_decoder, dataset_config, decode_config, device
)
return output_wav.squeeze().cpu().numpy(), 24000
if __name__ == "__main__":
wav_path = "sample.wav"
generate(wav_path)
|