File size: 5,418 Bytes
38de785
6f48855
d0a6d13
 
25bcda8
38de785
846e1c8
d0a6d13
38de785
6f48855
436d21b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a6d13
436d21b
38de785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a6d13
70c3062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a6d13
38de785
70c3062
d0a6d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import logging
import gradio as gr
from huggingface_hub import hf_hub_download

# Install necessary libraries using os.system
os.system("pip install --upgrade pip")
os.system("pip install llama-cpp-agent huggingface_hub trafilatura beautifulsoup4 requests duckduckgo-search googlesearch-python")

# Attempt to import all required modules
try:
    from llama_cpp import Llama
    from llama_cpp_agent.providers import LlamaCppPythonProvider
    from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
    from llama_cpp_agent.chat_history import BasicChatHistory
    from llama_cpp_agent.chat_history.messages import Roles
    from llama_cpp_agent.llm_output_settings import (
        LlmStructuredOutputSettings,
        LlmStructuredOutputType,
    )
    from llama_cpp_agent.tools import WebSearchTool
    from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
    from utils import CitingSources
    from settings import get_context_by_model, get_messages_formatter_type
except ImportError as e:
    raise ImportError(f"Error importing modules: {e}")

# Download the models
hf_hub_download(
    repo_id="bartowski/Mistral-7B-Instruct-v0.3-GGUF",
    filename="Mistral-7B-Instruct-v0.3-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",
    filename="Meta-Llama-3-8B-Instruct-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
    filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf",
    local_dir="./models"
)

# Function to respond to user messages
def respond(message, temperature, top_p, top_k, repeat_penalty):
    try:
        model = "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf"
        max_tokens = 3000
        chat_template = get_messages_formatter_type(model)
        llm = Llama(
            model_path=f"models/{model}",
            flash_attn=True,
            n_gpu_layers=81,
            n_batch=1024,
            n_ctx=get_context_by_model(model),
        )
        provider = LlamaCppPythonProvider(llm)
        logging.info(f"Loaded chat examples: {chat_template}")
        search_tool = WebSearchTool(
            llm_provider=provider,
            message_formatter_type=chat_template,
            max_tokens_search_results=12000,
            max_tokens_per_summary=2048,
        )

        web_search_agent = LlamaCppAgent(
            provider,
            system_prompt=web_search_system_prompt,
            predefined_messages_formatter_type=chat_template,
            debug_output=True,
        )

        answer_agent = LlamaCppAgent(
            provider,
            system_prompt=research_system_prompt,
            predefined_messages_formatter_type=chat_template,
            debug_output=True,
        )

        settings = provider.get_provider_default_settings()
        settings.stream = False
        settings.temperature = temperature
        settings.top_k = top_k
        settings.top_p = top_p
        settings.max_tokens = max_tokens
        settings.repeat_penalty = repeat_penalty

        output_settings = LlmStructuredOutputSettings.from_functions(
            [search_tool.get_tool()]
        )

        messages = BasicChatHistory()

        result = web_search_agent.get_chat_response(
            message,
            llm_sampling_settings=settings,
            structured_output_settings=output_settings,
            add_message_to_chat_history=False,
            add_response_to_chat_history=False,
            print_output=False,
        )

        outputs = ""

        settings.stream = True
        response_text = answer_agent.get_chat_response(
            f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
            result[0]["return_value"],
            role=Roles.tool,
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=True,
            print_output=False,
        )

        for text in response_text:
            outputs += text

        output_settings = LlmStructuredOutputSettings.from_pydantic_models(
            [CitingSources], LlmStructuredOutputType.object_instance
        )

        citing_sources = answer_agent.get_chat_response(
            "Cite the sources you used in your response.",
            role=Roles.tool,
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=False,
            structured_output_settings=output_settings,
            print_output=False,
        )
        outputs += "\n\nSources:\n"
        outputs += "\n".join(citing_sources.sources)
        return outputs

    except Exception as e:
        return f"An error occurred: {e}"

# Gradio interface
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Enter your message:"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.45, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
        gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k"),
        gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
    ],
    outputs="text",
    title="Novav2 Web Engine"
)

if __name__ == "__main__":
    demo.launch()