File size: 3,332 Bytes
cfb1176
 
7d06188
 
cfb1176
8f5f397
cfb1176
 
 
 
 
 
 
 
8f5f397
 
 
cfb1176
8f5f397
 
 
 
 
 
 
cfb1176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f5f397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfb1176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbed061
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import pipeline
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import os
import torch

# Load the embedding model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# Load the pre-existing vector database
persist_directory = "db"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings)

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the Marco-o1 model
pipe = pipeline(
    "text-generation",
    model="AIDC-AI/Marco-o1",
    device=device,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    trust_remote_code=True,
)

def get_relevant_context(query, k=3):
    # Search the vector database for relevant documents
    docs = vectordb.similarity_search(query, k=k)
    # Combine the relevant documents into a single context string
    context = "\n".join([doc.page_content for doc in docs])
    return context

def respond(

    message,

    history: list[tuple[str, str]],

    system_message,

    max_tokens,

    temperature,

    top_p,

):
    try:
        # Get relevant context from the vector database
        context = get_relevant_context(message)
        
        # Prepare the messages for the model
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": f"Context:\n{context}" if context else ""},
        ]
        for user_msg, bot_msg in history:
            if user_msg:
                messages.append({"role": "user", "content": user_msg})
            if bot_msg:
                messages.append({"role": "assistant", "content": bot_msg})
        messages.append({"role": "user", "content": message})
        
        # Format the messages for the model
        input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
        
        response = pipe(
            input_text,
            max_length=max_tokens + len(input_text),
            temperature=temperature,
            top_p=top_p,
            num_return_sequences=1
        )[0]['generated_text']
        
        # Extract new response
        new_response = response.split("assistant: ")[-1].strip()
        
        yield new_response
    except Exception as e:
        yield f"An error occurred: {e}"

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a helpful AI assistant. Use the provided context to answer questions accurately.",
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    title="Marco-O1 Assistant with Knowledge Base",
    description="Ask questions about the documents in the knowledge base. The assistant will use the relevant context to provide accurate answers."
)

if __name__ == "__main__":
    demo.launch()