File size: 2,966 Bytes
662a39e
c80c6bb
662a39e
 
 
c80c6bb
 
 
 
 
 
662a39e
 
c80c6bb
 
 
 
 
 
662a39e
c80c6bb
 
 
 
 
 
 
 
662a39e
c80c6bb
 
 
 
 
 
 
 
 
662a39e
c80c6bb
662a39e
c80c6bb
 
 
 
 
 
 
 
 
 
 
 
 
662a39e
c80c6bb
 
 
 
 
 
 
 
 
 
662a39e
c80c6bb
 
 
 
 
662a39e
c80c6bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch
import re

hf_repo = "khanhdhq/test_finetune_bloom_3b"
config = PeftConfig.from_pretrained(hf_repo)
finetuned_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
finetuned_model = PeftModel.from_pretrained(finetuned_model, hf_repo)


@torch.no_grad()
def infer(text):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    try:
        if torch.backends.mps.is_available():
            device = "mps"
    except:  # noqa: E722
        pass
    inputs = tokenizer(text, add_special_tokens=True, return_tensors="pt").to(device)
    outputs = finetuned_model.generate(**inputs, max_new_tokens=30)
    response = tokenizer.decode(outputs[0])
    
    response = response.split('<bot>:')[-1]
    # print(response)
    response = re.split(r'<human>:|\"codepoints\"', response, re.IGNORECASE)[0].strip()
    def split_string(string):
        pattern = r'[^a-zA-Z0-9\sđđăâàáảạãầấẩậẫằắẳặẵẻẹẽèéẻêệễểỉịĩìíỏọõôồốổộỗơờớởợỡủụũưừứửựữỷỵỹỳýỷỹỵĐđÀÁẢẠĂÃẤẦẤẨẬẪẰẮẲẶẴẺẸẼÈÉẺÊỆỄỂỈỊĨÌÍỎỌÕÔỒỐỔỘỖƠỜỚỞỢỠỦỤŨƯỪỨỬỰỮỶỴỸỲÝỶỸỴ\.\?,<>!:;\'\"\(\)\{\}\[\]]'
        result = re.split(pattern, string, re.IGNORECASE)
        return result[0].strip()
    response = split_string(response)
    return response

import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # OmiCall chatbot
        Chat với tôi nếu bạn có hứng thú với các sản phẩm của OmiCall.
        """)
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Chatbot OmiCall", placeholder="chat ở đây")
    # while not msg.strip():
    #     msg = gr.Textbox(label="Chatbot OmiCall", placeholder="chat ở đây")
    clear = gr.Button("Xóa lịch sử chat")
    def user(user_message, history):
        return gr.update(value="", interactive=False), history + [[user_message, None]]

    def bot(history):
        messages = [] 
        convs = history[-5:-1]
        for h in history[-5:-1]:
            messages.append(f'<human>: {h[0]}')
            messages.append(f'<bot>: {h[1]}')
        messages.append(f'<human>: {history[-1][0]} <bot>:')
        mess = ' '.join(messages)
        history[-1][1] = infer(mess)
        return history

    response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch()