open-notebooklm / utils.py
gabrielchua's picture
use SUNO
8ddd281
raw
history blame
3.57 kB
"""
utils.py
Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""
import os
import requests
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav
MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"
client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.getenv("FIREWORKS_API_KEY"),
)
# hf_client = Client("mrfakename/MeloTTS")
# download and load all models
preload_models()
def generate_script(system_prompt: str, input_text: str, output_model):
"""Get the dialogue from the LLM."""
# Load as python object
try:
response = call_llm(system_prompt, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
except ValidationError as e:
error_message = f"Failed to parse dialogue JSON: {e}"
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
# Call the LLM again to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
response = call_llm(
system_prompt_with_dialogue, "Please improve the dialogue.", output_model
)
improved_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
return improved_dialogue
def call_llm(system_prompt: str, text: str, dialogue_format):
"""Call the LLM with the given prompt and dialogue format."""
response = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=MODEL_ID,
max_tokens=16_384,
temperature=0.1,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
full_url = f"{JINA_URL}{url}"
response = requests.get(full_url, timeout=60)
return response.text
def generate_audio(text: str, speaker: str, language: str) -> str:
audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")
file_path = f"audio_{language}_{speaker}.mp3"
# save audio to disk
write_wav(file_path, SAMPLE_RATE, audio_array)
return file_path
# """Get the audio from the TTS model from HF Spaces and adjust pitch if necessary."""
# if speaker == "Guest":
# accent = "EN-US" if language == "EN" else language
# speed = 0.9
# else: # host
# accent = "EN-Default" if language == "EN" else language
# speed = 1
# if language != "EN" and speaker != "Guest":
# speed = 1.1
# # Generate audio
# result = hf_client.predict(
# text=text,
# language=language,
# speaker=accent,
# speed=speed,
# api_name="/synthesize",
# )
# return result