|
import os |
|
import gradio as gr |
|
from langchain_groq import ChatGroq |
|
from langchain import LLMChain, PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
|
|
GROQ_API_KEY=os.getenv('GROQ_API_KEY') |
|
|
|
template = """You are a helpful assistant to answer all user queries. |
|
{chat_history} |
|
User: {user_message} |
|
Chatbot:""" |
|
|
|
import os |
|
from groq import Groq |
|
import gradio as gr |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
api_key = os.environ.get("GROQ_API_KEY") |
|
if not api_key: |
|
logger.error("GROQ_API_KEY environment variable is not set.") |
|
raise ValueError("GROQ_API_KEY environment variable is required.") |
|
client = Groq(api_key=api_key) |
|
|
|
|
|
MODEL_NAME = os.environ.get("MODEL_NAME", "llama3-8b-8192") |
|
|
|
def get_completion(user_input): |
|
""" |
|
Generate a chat completion response using the Groq client. |
|
|
|
Args: |
|
user_input (str): The user's input query. |
|
|
|
Returns: |
|
str: The generated response or an error message. |
|
""" |
|
try: |
|
completion = client.chat.completions.create( |
|
model=MODEL_NAME, |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": user_input} |
|
], |
|
temperature=1, |
|
max_tokens=1024, |
|
top_p=1, |
|
stream=False, |
|
) |
|
|
|
|
|
response = completion.choices[0].message.content.strip() |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error during completion: {e}") |
|
return "Sorry, I encountered an error while processing your request." |
|
|
|
def launch_interface(): |
|
""" |
|
Launch the Gradio interface for the chatbot. |
|
""" |
|
demo = gr.Interface( |
|
fn=get_completion, |
|
inputs=gr.Textbox( |
|
label="Enter your query:", |
|
placeholder="Ask me anything...", |
|
lines=2, |
|
max_lines=5 |
|
), |
|
outputs=gr.Textbox( |
|
label="Response:", |
|
lines=6, |
|
max_lines=10 |
|
), |
|
title="Mr AI", |
|
description=""" |
|
<style> |
|
.gr-box { |
|
border-radius: 10px; |
|
border: 2px solid #007BFF; |
|
padding: 15px; |
|
background-color: #F8F9FA; |
|
} |
|
.gr-input { |
|
font-size: 1.2em; |
|
padding: 10px; |
|
border-radius: 5px; |
|
border: 2px solid #007BFF; |
|
margin-bottom: 10px; |
|
} |
|
.gr-output { |
|
font-size: 1.2em; |
|
padding: 10px; |
|
border-radius: 5px; |
|
border: 2px solid #28A745; |
|
background-color: #E9F7EF; |
|
} |
|
.gr-interface-title { |
|
font-size: 2em; |
|
font-weight: bold; |
|
color: #007BFF; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.gr-interface-description { |
|
font-size: 1.2em; |
|
color: #6C757D; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.gr-button { |
|
background-color: #007BFF; |
|
color: white; |
|
border-radius: 5px; |
|
padding: 10px 20px; |
|
font-size: 1em; |
|
border: none; |
|
cursor: pointer; |
|
margin-top: 10px; |
|
} |
|
.gr-button:hover { |
|
background-color: #0056b3; |
|
} |
|
</style> |
|
<div class="gr-interface-title">Welcome to Mr AI</div> |
|
<div class="gr-interface-description">Ask anything and get a helpful response.</div> |
|
""", |
|
allow_flagging="never", |
|
live=True |
|
) |
|
|
|
logger.info("Starting Gradio interface") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
launch_interface() |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["chat_history", "user_message"], template=template |
|
) |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
|
|
llm_chain = LLMChain( |
|
llm=ChatOpenAI(temperature='0.5', model_name="gpt-3.5-turbo"), |
|
prompt=prompt, |
|
verbose=True, |
|
memory=memory, |
|
) |
|
|
|
def get_text_response(user_message,history): |
|
response = llm_chain.predict(user_message = user_message) |
|
return response |
|
|
|
demo = gr.ChatInterface(get_text_response) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|