Spaces:
Sleeping
Sleeping
from typing import Generator | |
import numpy as np | |
import torch | |
from modules import config, models | |
from modules.utils.SeedContext import SeedContext | |
def refine_text( | |
text: str, | |
prompt="[oral_2][laugh_0][break_6]", | |
seed=-1, | |
top_P=0.7, | |
top_K=20, | |
temperature=0.7, | |
repetition_penalty=1.0, | |
max_new_token=384, | |
) -> str: | |
chat_tts = models.load_chat_tts() | |
with SeedContext(seed): | |
refined_text = chat_tts.refiner_prompt( | |
text, | |
{ | |
"prompt": prompt, | |
"top_K": top_K, | |
"top_P": top_P, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"max_new_token": max_new_token, | |
"disable_tqdm": config.runtime_env_vars.off_tqdm, | |
}, | |
) | |
if isinstance(refined_text, Generator): | |
raise NotImplementedError( | |
"Refiner is not yet implemented for generator output" | |
) | |
if isinstance(refined_text, list): | |
refined_text = "\n".join(refined_text) | |
return refined_text | |