File size: 2,186 Bytes
9c5645d
ce01037
 
066ff91
ce01037
fcc93c7
ce01037
 
057549c
066ff91
 
9c5645d
 
33a927c
c951d7c
066ff91
 
 
057549c
9c5645d
ce01037
066ff91
 
057549c
e86895d
ce01037
9c5645d
066ff91
 
c951d7c
057549c
066ff91
2a28e7a
066ff91
2a28e7a
066ff91
 
ce01037
066ff91
e86895d
066ff91
e86895d
 
 
057549c
fcc93c7
 
 
9c5645d
ce01037
 
 
 
 
c951d7c
ce01037
2a28e7a
a3df7e0
fcc93c7
9c5645d
 
fcc93c7
ce01037
9c5645d
fcc93c7
9c5645d
ce01037
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from gpt4all import GPT4All
from huggingface_hub import hf_hub_download
from diarizationlm import utils

title = "DiarizationLM GGUF inference on CPU"

description = """
A demo of the DiarizationLM model finetuned from Llama 3. In this demo, we run a 4-bit quantized GGUF model on CPU.

To learn more about DiarizationLM, check our paper: https://arxiv.org/abs/2401.03506
"""

model_path = "models"
model_name = "q4_k_m.gguf"
prompt_suffix = " --> "
completion_suffix = " [eod]"

hf_hub_download(repo_id="google/DiarizationLM-8b-Fisher-v2", filename=model_name, local_dir=model_path)

print("Start the model init process")
model = GPT4All(model_name=model_name,
                model_path=model_path,
                allow_download = False,
                device="cpu")
print("Finish the model init process")

def generater(message, history):
    prompt = message + prompt_suffix
    max_new_tokens = round(len(prompt) / 3.0 * 1.2)
    outputs = []
    for token in model.generate(prompt=prompt,
                                temp=0.1,
                                top_k=50,
                                top_p=0.5,
                                max_tokens=max_new_tokens,
                                streaming=True):
        outputs.append(token)
        completion = "".join(outputs)
        yield completion
        if completion.endswith(" [eod]"):
            break
    transferred_completion = utils.transfer_llm_completion(completion, message)
    yield transferred_completion

print("Create chatbot")
chatbot = gr.Chatbot()
print("Created chatbot")

iface = gr.ChatInterface(
    fn = generater,
    title=title,
    description = description,
    chatbot=chatbot,
    additional_inputs=[],
    examples=[
        ["<speaker:1> Hello, my name is Tom. May I speak to Laura <speaker:2> please? Hello, this is Laura. <speaker:1> Hi Laura, how are you? This is <speaker:2> Tom. Hi Tom, I haven't seen you for a <speaker:1> while."],
        ["<speaker:1> This demo looks really <speaker:2> good! Thanks, I am glad to hear that."],
   ]
)

with gr.Blocks() as demo:
    iface.render()


if __name__ == "__main__":
    demo.queue(max_size=3).launch()