File size: 6,141 Bytes
38de785
25bcda8
38de785
 
 
436d21b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38de785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os

# Install necessary libraries using os.system
os.system("pip install streamlit llama-cpp-agent huggingface_hub")

try:
    from llama_cpp import Llama
    from llama_cpp_agent.providers import LlamaCppPythonProvider
    from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
    from llama_cpp_agent.chat_history import BasicChatHistory
    from llama_cpp_agent.chat_history.messages import Roles
    from llama_cpp_agent.llm_output_settings import (
        LlmStructuredOutputSettings,
        LlmStructuredOutputType,
    )
    from llama_cpp_agent.tools import WebSearchTool
    from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
    from utils import CitingSources
    from settings import get_context_by_model, get_messages_formatter_type
except ImportError as e:
    raise ImportError(f"Error importing modules: {e}")

import logging
import streamlit as st
from huggingface_hub import hf_hub_download

# Download the models
hf_hub_download(
    repo_id="bartowski/Mistral-7B-Instruct-v0.3-GGUF",
    filename="Mistral-7B-Instruct-v0.3-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",
    filename="Meta-Llama-3-8B-Instruct-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
    filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf",
    local_dir="./models"
)

# Function to respond to user messages
def respond(message, history, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
    chat_template = get_messages_formatter_type(model)
    llm = Llama(
        model_path=f"models/{model}",
        flash_attn=True,
        n_gpu_layers=81,
        n_batch=1024,
        n_ctx=get_context_by_model(model),
    )
    provider = LlamaCppPythonProvider(llm)
    logging.info(f"Loaded chat examples: {chat_template}")
    search_tool = WebSearchTool(
        llm_provider=provider,
        message_formatter_type=chat_template,
        max_tokens_search_results=12000,
        max_tokens_per_summary=2048,
    )

    web_search_agent = LlamaCppAgent(
        provider,
        system_prompt=web_search_system_prompt,
        predefined_messages_formatter_type=chat_template,
        debug_output=True,
    )

    answer_agent = LlamaCppAgent(
        provider,
        system_prompt=research_system_prompt,
        predefined_messages_formatter_type=chat_template,
        debug_output=True,
    )

    settings = provider.get_provider_default_settings()
    settings.stream = False
    settings.temperature = temperature
    settings.top_k = top_k
    settings.top_p = top_p

    settings.max_tokens = max_tokens
    settings.repeat_penalty = repeat_penalty

    output_settings = LlmStructuredOutputSettings.from_functions(
        [search_tool.get_tool()]
    )

    messages = BasicChatHistory()

    for msn in history:
        user = {"role": Roles.user, "content": msn[0]}
        assistant = {"role": Roles.assistant, "content": msn[1]}
        messages.add_message(user)
        messages.add_message(assistant)

    result = web_search_agent.get_chat_response(
        message,
        llm_sampling_settings=settings,
        structured_output_settings=output_settings,
        add_message_to_chat_history=False,
        add_response_to_chat_history=False,
        print_output=False,
    )

    outputs = ""

    settings.stream = True
    response_text = answer_agent.get_chat_response(
        f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
        result[0]["return_value"],
        role=Roles.tool,
        llm_sampling_settings=settings,
        chat_history=messages,
        returns_streaming_generator=True,
        print_output=False,
    )

    for text in response_text:
        outputs += text
        yield outputs

    output_settings = LlmStructuredOutputSettings.from_pydantic_models(
        [CitingSources], LlmStructuredOutputType.object_instance
    )

    citing_sources = answer_agent.get_chat_response(
        "Cite the sources you used in your response.",
        role=Roles.tool,
        llm_sampling_settings=settings,
        chat_history=messages,
        returns_streaming_generator=False,
        structured_output_settings=output_settings,
        print_output=False,
    )
    outputs += "\n\nSources:\n"
    outputs += "\n".join(citing_sources.sources)
    yield outputs

# Streamlit app
st.title("Llama-CPP-Agent Chatbot with Web Search")

# Sidebar for settings
st.sidebar.title("Settings")
model = st.sidebar.selectbox(
    "Model",
    [
        'Mistral-7B-Instruct-v0.3-Q6_K.gguf',
        'mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf',
        'Meta-Llama-3-8B-Instruct-Q6_K.gguf'
    ]
)
system_message = st.sidebar.text_area("System message", value=web_search_system_prompt)
max_tokens = st.sidebar.slider("Max tokens", min_value=1, max_value=4096, value=2048, step=1)
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, value=0.45, step=0.1)
top_p = st.sidebar.slider("Top-p", min_value=0.1, max_value=1.0, value=0.95, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=100, value=40, step=1)
repeat_penalty = st.sidebar.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.1, step=0.1)

# Chat history
if "history" not in st.session_state:
    st.session_state.history = []

# Chat input
message = st.text_input("You:", key="input")

if st.button("Send"):
    history = st.session_state.history
    response = respond(
        message,
        history,
        model,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
        repeat_penalty
    )
    
    for res in response:
        st.session_state.history.append((message, res))
        st.text_area("Chat", value=f"You: {message}\nBot: {res}", height=300)

# Display chat history
for user_msg, bot_msg in st.session_state.history:
    st.text_area("Chat", value=f"You: {user_msg}\nBot: {bot_msg}", height=300)