Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import re | |
import uuid | |
import tempfile | |
import json | |
import time | |
import shutil | |
from pathlib import Path | |
from argparse import ArgumentParser | |
from threading import Thread | |
from queue import Queue | |
import torch | |
import torchaudio | |
import gradio as gr | |
import whisper | |
from transformers import ( | |
WhisperFeatureExtractor, | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForCausalLM | |
) | |
from transformers.generation.streamers import BaseStreamer | |
from speech_tokenizer.modeling_whisper import WhisperVQEncoder | |
from speech_tokenizer.utils import extract_speech_token | |
# Add local paths | |
sys.path.insert(0, "./cosyvoice") | |
sys.path.insert(0, "./third_party/Matcha-TTS") | |
from flow_inference import AudioDecoder | |
# RAG imports | |
from langchain_community.document_loaders import * | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores.faiss import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from tqdm import tqdm | |
import joblib | |
import spaces | |
# File loader mapping | |
LOADER_MAPPING = { | |
'.pdf': PyPDFLoader, | |
'.txt': TextLoader, | |
'.md': UnstructuredMarkdownLoader, | |
'.csv': CSVLoader, | |
'.jpg': UnstructuredImageLoader, | |
'.jpeg': UnstructuredImageLoader, | |
'.png': UnstructuredImageLoader, | |
'.json': JSONLoader, | |
'.html': BSHTMLLoader, | |
'.htm': BSHTMLLoader | |
} | |
class SessionManager: | |
def __init__(self, base_path="./sessions"): | |
self.base_path = Path(base_path) | |
self.base_path.mkdir(exist_ok=True) | |
def create_session(self): | |
session_id = str(uuid.uuid4()) | |
session_path = self.base_path / session_id | |
session_path.mkdir(exist_ok=True) | |
return session_id | |
def get_session_path(self, session_id): | |
return self.base_path / session_id | |
def cleanup_old_sessions(self, max_age_hours=24): | |
current_time = time.time() | |
for session_dir in self.base_path.iterdir(): | |
if session_dir.is_dir(): | |
dir_stats = os.stat(session_dir) | |
age_hours = (current_time - dir_stats.st_mtime) / 3600 | |
if age_hours > max_age_hours: | |
shutil.rmtree(session_dir) | |
class VectorStoreManager: | |
def __init__(self, session_manager, embedding_model): | |
self.session_manager = session_manager | |
self.embedding_model = embedding_model | |
self.stores = {} | |
def get_store_path(self, session_id): | |
session_path = self.session_manager.get_session_path(session_id) | |
return session_path / "vector_store.faiss" | |
def create_store(self, session_id, files): | |
if not files: | |
return None | |
store_path = self.get_store_path(session_id) | |
file_paths = [f.name for f in files] | |
pages = load_files(file_paths) | |
if not pages: | |
return None | |
docs = split_text(pages) | |
if not docs: | |
return None | |
vector_store = FAISS.from_documents(docs, self.embedding_model) | |
vector_store.save_local(str(store_path)) | |
save_file_paths(str(store_path.parent), file_paths) | |
self.stores[session_id] = vector_store | |
return vector_store | |
def get_store(self, session_id): | |
if session_id in self.stores: | |
return self.stores[session_id] | |
store_path = self.get_store_path(session_id) | |
if store_path.exists(): | |
vector_store = FAISS.load_local(str(store_path), self.embedding_model) | |
self.stores[session_id] = vector_store | |
return vector_store | |
return None | |
class TokenStreamer(BaseStreamer): | |
def __init__(self, skip_prompt: bool = False, timeout=None): | |
self.skip_prompt = skip_prompt | |
self.token_queue = Queue() | |
self.stop_signal = None | |
self.next_tokens_are_prompt = True | |
self.timeout = timeout | |
def put(self, value): | |
if len(value.shape) > 1 and value.shape[0] > 1: | |
raise ValueError("TextStreamer only supports batch size 1") | |
elif len(value.shape) > 1: | |
value = value[0] | |
if self.skip_prompt and self.next_tokens_are_prompt: | |
self.next_tokens_are_prompt = False | |
return | |
for token in value.tolist(): | |
self.token_queue.put(token) | |
def end(self): | |
self.token_queue.put(self.stop_signal) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
value = self.token_queue.get(timeout=self.timeout) | |
if value == self.stop_signal: | |
raise StopIteration() | |
else: | |
return value | |
class ModelWorker: | |
def __init__(self, model_path, device='cuda'): | |
self.device = device | |
self.glm_model = AutoModel.from_pretrained( | |
model_path, | |
trust_remote_code=True, | |
device=device | |
).to(device).eval() | |
self.glm_tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code=True | |
) | |
def generate_stream(self, params): | |
prompt = params["prompt"] | |
temperature = float(params.get("temperature", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_new_tokens", 256)) | |
inputs = self.glm_tokenizer([prompt], return_tensors="pt") | |
inputs = inputs.to(self.device) | |
streamer = TokenStreamer(skip_prompt=True) | |
thread = Thread( | |
target=self.glm_model.generate, | |
kwargs=dict( | |
**inputs, | |
max_new_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
streamer=streamer | |
) | |
) | |
thread.start() | |
for token_id in streamer: | |
yield token_id | |
def generate_stream_gate(self, params): | |
try: | |
for x in self.generate_stream(params): | |
yield x | |
except Exception as e: | |
print("Caught Unknown Error", e) | |
ret = "Server Error" | |
yield ret | |
def load_single_file(file_path): | |
_, ext = os.path.splitext(file_path) | |
ext = ext.lower() | |
loader_class = LOADER_MAPPING.get(ext) | |
if not loader_class: | |
print(f"Unsupported file type: {ext}") | |
return None | |
loader = loader_class(file_path) | |
docs = list(loader.lazy_load()) | |
return docs | |
def load_files(file_paths: list): | |
if not file_paths: | |
return [] | |
docs = [] | |
for file_path in tqdm(file_paths): | |
print("Loading docs:", file_path) | |
loaded_docs = load_single_file(file_path) | |
if loaded_docs: | |
docs.extend(loaded_docs) | |
return docs | |
def split_text(txt, chunk_size=200, overlap=20): | |
if not txt: | |
return [] | |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap) | |
docs = splitter.split_documents(txt) | |
return docs | |
def create_embedding_model(model_file): | |
embedding = HuggingFaceEmbeddings(model_name=model_file, model_kwargs={'trust_remote_code': True}) | |
return embedding | |
def save_file_paths(store_path, file_paths): | |
joblib.dump(file_paths, f'{store_path}/file_paths.pkl') | |
def create_vector_store(docs, store_file, embeddings): | |
if not docs: | |
raise ValueError("No documents provided for creating vector store") | |
vector_store = FAISS.from_documents(docs, embeddings) | |
vector_store.save_local(store_file) | |
return vector_store | |
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8): | |
retriever = vector_store.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": relevance_threshold, "k": k} | |
) | |
similar_docs = retriever.invoke(query) | |
context = [doc.page_content for doc in similar_docs] | |
return context | |
def initialize_fn(): | |
global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer | |
global session_manager, vector_store_manager, whisper_transcribe_model, model_worker | |
if audio_decoder is not None: | |
return | |
print("Initializing models and managers...") | |
# Initialize session manager first | |
session_manager = SessionManager() | |
model_worker = ModelWorker(args.model_path, device) | |
glm_tokenizer = model_worker.glm_tokenizer | |
audio_decoder = AudioDecoder( | |
config_path=flow_config, | |
flow_ckpt_path=flow_checkpoint, | |
hift_ckpt_path=hift_checkpoint, | |
device=device | |
) | |
whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device) | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path) | |
embedding_model = create_embedding_model(Embedding_Model) | |
vector_store_manager = VectorStoreManager(session_manager, embedding_model) | |
whisper_transcribe_model = whisper.load_model("base") | |
print("Initialization complete.") | |
def clear_fn(): | |
return [], [], '', '', '', None, None | |
def reinitialize_database(files, session_id, progress=gr.Progress()): | |
if not files: | |
return "No files uploaded. Please upload files first." | |
progress(0.5, desc="Processing documents and creating vector store...") | |
vector_store = vector_store_manager.create_store(session_id, files) | |
if vector_store is None: | |
return "Failed to create vector store. Please check your documents." | |
return "Database initialized successfully!" | |
def inference_fn( | |
temperature: float, | |
top_p: float, | |
max_new_token: int, | |
input_mode, | |
audio_path: str | None, | |
input_text: str | None, | |
history: list[dict], | |
session_id: str, | |
): | |
vector_store = vector_store_manager.get_store(session_id) | |
using_context = False | |
context = None | |
if input_mode == "audio": | |
assert audio_path is not None | |
history.append({"role": "user", "content": {"path": audio_path}}) | |
audio_tokens = extract_speech_token( | |
whisper_model, feature_extractor, [audio_path] | |
)[0] | |
if len(audio_tokens) == 0: | |
raise gr.Error("No audio tokens extracted") | |
audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens]) | |
audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>" | |
user_input = audio_tokens | |
system_prompt = "User will provide you with a speech instruction. Do it step by step." | |
if vector_store: | |
whisper_result = whisper_transcribe_model.transcribe(audio_path) | |
transcribed_text = whisper_result['text'] | |
context = query_vector_store(vector_store, transcribed_text, 4, 0.7) | |
else: | |
assert input_text is not None | |
history.append({"role": "user", "content": input_text}) | |
user_input = input_text | |
system_prompt = "User will provide you with a text instruction. Do it step by step." | |
if vector_store: | |
context = query_vector_store(vector_store, input_text, 4, 0.7) | |
if context: | |
using_context = True | |
inputs = "" | |
if "<|system|>" not in inputs: | |
inputs += f"<|system|>\n{system_prompt}" | |
if ("<|context|>" not in inputs) and (using_context == True): | |
inputs += f"<|context|> According to the following content: {context}, Please answer the question" | |
if "<|context|>" not in inputs and context is not None: | |
inputs += f"<|context|>\n{context}" | |
inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n" | |
with torch.no_grad(): | |
text_tokens, audio_tokens = [], [] | |
audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>') | |
end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>') | |
complete_tokens = [] | |
prompt_speech_feat = torch.zeros(1, 0, 80).to(device) | |
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device) | |
this_uuid = str(uuid.uuid4()) | |
tts_speechs = [] | |
tts_mels = [] | |
prev_mel = None | |
is_finalize = False | |
block_size = 10 | |
for token_id in model_worker.generate_stream_gate({ | |
"prompt": inputs, | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_new_tokens": max_new_token, | |
}): | |
if isinstance(token_id, str): | |
yield history, inputs, '', token_id, None, None | |
return | |
if token_id == end_token_id: | |
is_finalize = True | |
if len(audio_tokens) >= block_size or (is_finalize and audio_tokens): | |
block_size = 20 | |
tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0) | |
if prev_mel is not None: | |
prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2) | |
tts_speech, tts_mel = audio_decoder.token2wav( | |
tts_token, | |
uuid=this_uuid, | |
prompt_token=flow_prompt_speech_token.to(device), | |
prompt_feat=prompt_speech_feat.to(device), | |
finalize=is_finalize | |
) | |
prev_mel = tts_mel | |
tts_speechs.append(tts_speech.squeeze()) | |
tts_mels.append(tts_mel) | |
yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None | |
flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1) | |
audio_tokens = [] | |
if not is_finalize: | |
complete_tokens.append(token_id) | |
if token_id >= audio_offset: | |
audio_tokens.append(token_id - audio_offset) | |
else: | |
text_tokens.append(token_id) | |
# Generate final audio and save | |
tts_speech = torch.cat(tts_speechs, dim=-1).cpu() | |
complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav") | |
history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}}) | |
history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)}) | |
yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy()) | |
def update_input_interface(input_mode): | |
if input_mode == "audio": | |
return [gr.update(visible=True), gr.update(visible=False)] | |
else: | |
return [gr.update(visible=False), gr.update(visible=True)] | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default="7860") | |
parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder") | |
parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b") | |
parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer") | |
parser.add_argument("--share", action='store_true') | |
args = parser.parse_args() | |
# Define model configurations | |
flow_config = os.path.join(args.flow_path, "config.yaml") | |
flow_checkpoint = os.path.join(args.flow_path, 'flow.pt') | |
hift_checkpoint = os.path.join(args.flow_path, 'hift.pt') | |
device = "cuda" | |
# Global variables | |
audio_decoder = None | |
whisper_model = None | |
feature_extractor = None | |
glm_model = None | |
glm_tokenizer = None | |
session_manager = None | |
vector_store_manager = None | |
whisper_transcribe_model = None | |
model_worker = None | |
# Configuration | |
Embedding_Model = 'intfloat/multilingual-e5-large-instruct' | |
# Initialize models first | |
initialize_fn() | |
# Create Gradio interface | |
with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo: | |
# Now session_manager is initialized | |
session_id = gr.State(session_manager.create_session()) | |
with gr.Row(): | |
# Left column for chat interface | |
with gr.Column(scale=2): | |
gr.Markdown("## Chat Interface") | |
with gr.Row(): | |
temperature = gr.Number(label="Temperature", value=0.2, minimum=0, maximum=1) | |
top_p = gr.Number(label="Top p", value=0.8, minimum=0, maximum=1) | |
max_new_token = gr.Number(label="Max new tokens", value=2000, minimum=1) | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
bubble_full_width=False, | |
type="messages", | |
scale=1, | |
height=500 | |
) | |
with gr.Row(): | |
input_mode = gr.Radio( | |
["audio", "text"], | |
label="Input Mode", | |
value="audio" | |
) | |
with gr.Row(): | |
audio = gr.Audio( | |
label="Input audio", | |
type='filepath', | |
show_download_button=True, | |
visible=True | |
) | |
text_input = gr.Textbox( | |
label="Input text", | |
placeholder="Enter your text here...", | |
lines=2, | |
visible=False | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
reset_btn = gr.Button("Clear") | |
output_audio = gr.Audio( | |
label="Play", | |
streaming=True, | |
autoplay=True, | |
show_download_button=False | |
) | |
complete_audio = gr.Audio( | |
label="Last Output Audio (If Any)", | |
show_download_button=True | |
) | |
# Right column for database management | |
with gr.Column(scale=1): | |
gr.Markdown("## Database Management") | |
file_upload = gr.Files( | |
label="Upload Database Files", | |
file_types=[".txt", ".pdf", ".md", ".csv", ".json", ".html", ".htm"], | |
file_count="multiple" | |
) | |
reinit_btn = gr.Button("Initialize Database", variant="secondary") | |
status_text = gr.Textbox(label="Status", interactive=False) | |
history_state = gr.State([]) | |
# Setup interaction handlers | |
respond = submit_btn.click( | |
inference_fn, | |
inputs=[ | |
temperature, | |
top_p, | |
max_new_token, | |
input_mode, | |
audio, | |
text_input, | |
history_state, | |
session_id, | |
], | |
outputs=[ | |
history_state, | |
output_audio, | |
complete_audio | |
] | |
) | |
respond.then(lambda s: s, [history_state], chatbot) | |
reset_btn.click( | |
clear_fn, | |
outputs=[ | |
chatbot, | |
history_state, | |
output_audio, | |
complete_audio | |
] | |
) | |
input_mode.change( | |
update_input_interface, | |
inputs=[input_mode], | |
outputs=[audio, text_input] | |
) | |
# Database initialization handler | |
reinit_btn.click( | |
reinitialize_database, | |
inputs=[file_upload, session_id], | |
outputs=[status_text] | |
) | |
# Periodic cleanup of old sessions (optional) | |
if session_manager: | |
session_manager.cleanup_old_sessions() | |
# Initialize models and launch interface | |
initialize_fn() | |
demo.launch( | |
server_port=args.port, | |
server_name=args.host, | |
share=args.share | |
) |