import gradio as gr from transformers import pipeline from langchain_community.vectorstores import Chroma from langchain_huggingface import HuggingFaceEmbeddings import os import torch # Load the embedding model embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # Load the pre-existing vector database persist_directory = "db" vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings) # Determine device device = "cuda" if torch.cuda.is_available() else "cpu" # Load the Marco-o1 model pipe = pipeline( "text-generation", model="AIDC-AI/Marco-o1", device=device, torch_dtype=torch.float16 if device == "cuda" else torch.float32, trust_remote_code=True, ) def get_relevant_context(query, k=3): # Search the vector database for relevant documents docs = vectordb.similarity_search(query, k=k) # Combine the relevant documents into a single context string context = "\n".join([doc.page_content for doc in docs]) return context def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): try: # Get relevant context from the vector database context = get_relevant_context(message) # Prepare the messages for the model messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": f"Context:\n{context}" if context else ""}, ] for user_msg, bot_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) # Format the messages for the model input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) response = pipe( input_text, max_length=max_tokens + len(input_text), temperature=temperature, top_p=top_p, num_return_sequences=1 )[0]['generated_text'] # Extract new response new_response = response.split("assistant: ")[-1].strip() yield new_response except Exception as e: yield f"An error occurred: {e}" demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant. Use the provided context to answer questions accurately.", label="System message" ), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], title="Marco-O1 Assistant with Knowledge Base", description="Ask questions about the documents in the knowledge base. The assistant will use the relevant context to provide accurate answers." ) if __name__ == "__main__": demo.launch()