File size: 884 Bytes
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import torch

from modules.utils.SeedContext import SeedContext

from modules import models, config


@torch.inference_mode()
def refine_text(
    text: str,
    prompt="[oral_2][laugh_0][break_6]",
    seed=-1,
    top_P=0.7,
    top_K=20,
    temperature=0.7,
    repetition_penalty=1.0,
    max_new_token=384,
) -> str:
    chat_tts = models.load_chat_tts()

    with SeedContext(seed):
        refined_text = chat_tts.refiner_prompt(
            text,
            {
                "prompt": prompt,
                "top_K": top_K,
                "top_P": top_P,
                "temperature": temperature,
                "repetition_penalty": repetition_penalty,
                "max_new_token": max_new_token,
                "disable_tqdm": config.disable_tqdm,
            },
            do_text_normalization=False,
        )
        return refined_text