|
import gradio as gr |
|
from together import Together |
|
from helper import get_together_api_key |
|
from guardrail import is_safe |
|
|
|
|
|
client = Together(api_key=get_together_api_key()) |
|
|
|
|
|
def generate_response(message, history): |
|
system_prompt = """You are an AI assistant specialized in financial discussions. Please answer questions only related to finance. If the question is unrelated, respond with: 'I am sorry, I can only answer financial-related questions.'""" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
] |
|
|
|
for action in history: |
|
if isinstance(action, tuple) and len(action) == 2: |
|
messages.append({"role": "user", "content": action[0]}) |
|
messages.append({"role": "assistant", "content": action[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
model_output = client.chat.completions.create( |
|
model="meta-llama/Llama-3-70b-chat-hf", |
|
messages=messages, |
|
) |
|
|
|
return model_output.choices[0].message.content |
|
|
|
|
|
def main_loop(message, history): |
|
|
|
if not is_safe(message): |
|
return "Your input violates safety guidelines. Please rephrase your question.", history |
|
|
|
response = generate_response(message, history) |
|
|
|
|
|
if not is_safe(response): |
|
return "The generated response violates safety guidelines. Please try a different question.", history |
|
|
|
|
|
history.append((message, response)) |
|
return response, history |
|
|
|
|
|
demo = gr.ChatInterface( |
|
main_loop, |
|
chatbot=gr.Chatbot( |
|
height=450, |
|
placeholder="Type your financial question here...", |
|
type="messages", |
|
), |
|
textbox=gr.Textbox( |
|
placeholder="Ask about finance (e.g., investments, savings, etc.)", |
|
container=False, |
|
scale=7, |
|
), |
|
title="Financial Chatbot", |
|
theme="Monochrome", |
|
examples=["What are mutual funds?", "How can I save for retirement?"], |
|
cache_examples=False, |
|
) |
|
|
|
|
|
demo.launch(share=True, server_name="0.0.0.0") |
|
|