|
import os |
|
import logging |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
os.system("pip install --upgrade pip") |
|
os.system("pip install llama-cpp-agent huggingface_hub trafilatura beautifulsoup4 requests duckduckgo-search googlesearch-python") |
|
|
|
|
|
try: |
|
from llama_cpp import Llama |
|
from llama_cpp_agent.providers import LlamaCppPythonProvider |
|
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType |
|
from llama_cpp_agent.chat_history import BasicChatHistory |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
from llama_cpp_agent.llm_output_settings import ( |
|
LlmStructuredOutputSettings, |
|
LlmStructuredOutputType, |
|
) |
|
from llama_cpp_agent.tools import WebSearchTool |
|
from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt |
|
from utils import CitingSources |
|
from settings import get_context_by_model, get_messages_formatter_type |
|
except ImportError as e: |
|
raise ImportError(f"Error importing modules: {e}") |
|
|
|
|
|
hf_hub_download( |
|
repo_id="bartowski/Mistral-7B-Instruct-v0.3-GGUF", |
|
filename="Mistral-7B-Instruct-v0.3-Q6_K.gguf", |
|
local_dir="./models" |
|
) |
|
hf_hub_download( |
|
repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF", |
|
filename="Meta-Llama-3-8B-Instruct-Q6_K.gguf", |
|
local_dir="./models" |
|
) |
|
hf_hub_download( |
|
repo_id="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF", |
|
filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf", |
|
local_dir="./models" |
|
) |
|
|
|
|
|
def respond(message, temperature, top_p, top_k, repeat_penalty): |
|
try: |
|
model = "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf" |
|
max_tokens = 3000 |
|
chat_template = get_messages_formatter_type(model) |
|
llm = Llama( |
|
model_path=f"models/{model}", |
|
flash_attn=True, |
|
n_gpu_layers=81, |
|
n_batch=1024, |
|
n_ctx=get_context_by_model(model), |
|
) |
|
provider = LlamaCppPythonProvider(llm) |
|
logging.info(f"Loaded chat examples: {chat_template}") |
|
search_tool = WebSearchTool( |
|
llm_provider=provider, |
|
message_formatter_type=chat_template, |
|
max_tokens_search_results=12000, |
|
max_tokens_per_summary=2048, |
|
) |
|
|
|
web_search_agent = LlamaCppAgent( |
|
provider, |
|
system_prompt=web_search_system_prompt, |
|
predefined_messages_formatter_type=chat_template, |
|
debug_output=True, |
|
) |
|
|
|
answer_agent = LlamaCppAgent( |
|
provider, |
|
system_prompt=research_system_prompt, |
|
predefined_messages_formatter_type=chat_template, |
|
debug_output=True, |
|
) |
|
|
|
settings = provider.get_provider_default_settings() |
|
settings.stream = False |
|
settings.temperature = temperature |
|
settings.top_k = top_k |
|
settings.top_p = top_p |
|
settings.max_tokens = max_tokens |
|
settings.repeat_penalty = repeat_penalty |
|
|
|
output_settings = LlmStructuredOutputSettings.from_functions( |
|
[search_tool.get_tool()] |
|
) |
|
|
|
messages = BasicChatHistory() |
|
|
|
result = web_search_agent.get_chat_response( |
|
message, |
|
llm_sampling_settings=settings, |
|
structured_output_settings=output_settings, |
|
add_message_to_chat_history=False, |
|
add_response_to_chat_history=False, |
|
print_output=False, |
|
) |
|
|
|
outputs = "" |
|
|
|
settings.stream = True |
|
response_text = answer_agent.get_chat_response( |
|
f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" + |
|
result[0]["return_value"], |
|
role=Roles.tool, |
|
llm_sampling_settings=settings, |
|
chat_history=messages, |
|
returns_streaming_generator=True, |
|
print_output=False, |
|
) |
|
|
|
for text in response_text: |
|
outputs += text |
|
|
|
output_settings = LlmStructuredOutputSettings.from_pydantic_models( |
|
[CitingSources], LlmStructuredOutputType.object_instance |
|
) |
|
|
|
citing_sources = answer_agent.get_chat_response( |
|
"Cite the sources you used in your response.", |
|
role=Roles.tool, |
|
llm_sampling_settings=settings, |
|
chat_history=messages, |
|
returns_streaming_generator=False, |
|
structured_output_settings=output_settings, |
|
print_output=False, |
|
) |
|
outputs += "\n\nSources:\n" |
|
outputs += "\n".join(citing_sources.sources) |
|
return outputs |
|
|
|
except Exception as e: |
|
return f"An error occurred: {e}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=respond, |
|
inputs=[ |
|
gr.Textbox(label="Enter your message:"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.45, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), |
|
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k"), |
|
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty") |
|
], |
|
outputs="text", |
|
title="Novav2 Web Engine" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|