mini-omni-s2s / s2s.py
xcczach's picture
Update s2s.py
f0abd45 verified
import random
import torch
from slam_llm.utils.model_utils import get_custom_model_factory
from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
import whisper
import numpy as np
from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
import os
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from typing import Callable
def update_progress(progress_callback: Callable[[str], None] | None, message: str):
if progress_callback:
progress_callback(message)
def pull_model_ckpt():
if not os.path.exists(CKPT_LOCAL_DIR):
os.makedirs(CKPT_LOCAL_DIR)
if os.path.exists(CKPT_PATH):
return
hf_hub_download(
repo_id=CKPT_REPO,
filename=CKPT_NAME,
local_dir=CKPT_LOCAL_DIR,
token=os.getenv("HF_TOKEN"),
)
pull_model_ckpt()
def extract_audio_feature(audio_path, mel_size):
print("Extracting audio features from", audio_path)
audio_raw = whisper.load_audio(audio_path)
audio_raw = whisper.pad_or_trim(audio_raw)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
audio_length = (audio_mel.shape[0] + 1) // 2
audio_length = audio_length // 5
audio_res = audio_mel
return audio_res, audio_length
def get_input_ids(length, special_token_a, special_token_t, vocab_config):
input_ids = []
for i in range(vocab_config.code_layer):
input_ids_item = []
input_ids_item.append(layershift(vocab_config.input_a, i))
input_ids_item += [layershift(vocab_config.pad_a, i)] * length
input_ids_item += [
(layershift(vocab_config.eoa, i)),
layershift(special_token_a, i),
]
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
input_id_T = torch.tensor(
[vocab_config.input_t]
+ [vocab_config.pad_t] * length
+ [vocab_config.eot, special_token_t]
)
input_ids.append(input_id_T.unsqueeze(0))
return input_ids
def generate_from_wav(
wav_path, model, codec_decoder, dataset_config, decode_config, device
):
mel_size = dataset_config.mel_size
prompt = dataset_config.prompt
prompt_template = "USER: {}\n ASSISTANT: "
vocab_config = dataset_config.vocab_config
special_token_a = vocab_config.answer_a
special_token_t = vocab_config.answer_t
code_layer = vocab_config.code_layer
task_type = dataset_config.task_type
audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)
prompt = prompt_template.format(prompt)
prompt_ids = model.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = get_input_ids(
audio_length + prompt_length, special_token_a, special_token_t, vocab_config
)
text_layer = example_ids[code_layer]
text_layer = torch.cat(
(
text_layer[:, : audio_length + 1],
prompt_ids.unsqueeze(0),
text_layer[:, -2:],
),
dim=1,
) # <bos> <audio> <prompt> <eos> <task>
example_ids[code_layer] = text_layer
input_length = audio_length
example_mask = example_ids[0][0].ge(-1)
example_ids = torch.stack(example_ids).squeeze()
input_ids = example_ids.unsqueeze(0).to(device)
attention_mask = example_mask.unsqueeze(0).to(device)
audio_mel = audio_mel.unsqueeze(0).to(device)
input_length = torch.tensor([input_length]).to(device)
audio_length = torch.tensor([audio_length]).to(device)
task_type = [task_type]
modality_mask = torch.zeros_like(attention_mask)
padding_left = 1 # +1 for <bos>
modality_mask[0, padding_left : padding_left + audio_length] = True
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio_mel": audio_mel,
"input_length": input_length,
"audio_length": audio_length,
"modality_mask": modality_mask,
"task_types": task_type,
}
model_outputs = model.generate(**batch, **decode_config)
text_outputs = model_outputs[7]
audio_outputs = model_outputs[:7]
output_text = model.tokenizer.decode(
text_outputs, add_special_tokens=False, skip_special_tokens=True
)
if decode_config.decode_text_only:
return None, output_text
audio_tokens = [audio_outputs[layer] for layer in range(7)]
audiolist = reconscruct_snac(audio_tokens)
audio = reconstruct_tensors(audiolist)
with torch.inference_mode():
audio_hat = codec_decoder.decode(audio)
return audio_hat, output_text
model = None
codec_decoder = None
device = None
def generate(
wav_path: str, progress_callback: Callable[[str], None] | None = None
) -> tuple[np.ndarray, int | float]:
global model, codec_decoder, device
config = OmegaConf.structured(InferenceConfig())
train_config, model_config, dataset_config, decode_config = (
config.train_config,
config.model_config,
config.dataset_config,
config.decode_config,
)
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
if model is None or codec_decoder is None or device is None:
update_progress(progress_callback, "Loading model")
model_factory = get_custom_model_factory(model_config)
model, _ = model_factory(train_config, model_config, CKPT_PATH)
codec_decoder = model.codec_decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
update_progress(progress_callback, "Generating")
output_wav, output_text = generate_from_wav(
wav_path, model, codec_decoder, dataset_config, decode_config, device
)
return output_wav.squeeze().cpu().numpy(), 24000
if __name__ == "__main__":
wav_path = "sample.wav"
generate(wav_path)