|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
import logging |
|
import sys |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
client = InferenceClient( |
|
"mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Hugging Face client: {str(e)}") |
|
sys.exit(1) |
|
|
|
def format_prompt(message, history): |
|
try: |
|
prompt = "<s>" |
|
if history: |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
prompt += f"[INST] {message} [/INST]" |
|
logger.info(f"Formatted prompt: {prompt}") |
|
return prompt |
|
except Exception as e: |
|
logger.error(f"Error in format_prompt: {str(e)}") |
|
return None |
|
|
|
def generate(message, chat_history, system_prompt, temperature=0.9, max_new_tokens=512, top_p=0.95): |
|
try: |
|
logger.info(f"Received message: {message}") |
|
logger.info(f"System prompt: {system_prompt}") |
|
|
|
|
|
if not chat_history: |
|
full_message = f"{system_prompt}\n\nUser: {message}" |
|
else: |
|
full_message = message |
|
|
|
formatted_prompt = format_prompt(full_message, chat_history) |
|
|
|
if not formatted_prompt: |
|
return "I encountered an error formatting your message. Please try again." |
|
|
|
|
|
generate_kwargs = dict( |
|
temperature=float(temperature), |
|
max_new_tokens=int(max_new_tokens), |
|
top_p=float(top_p), |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
logger.info("Starting generation with parameters: %s", generate_kwargs) |
|
|
|
|
|
response_stream = client.text_generation( |
|
formatted_prompt, |
|
**generate_kwargs, |
|
stream=True, |
|
details=True, |
|
return_full_text=False |
|
) |
|
|
|
partial_message = "" |
|
for response in response_stream: |
|
if response.token.text: |
|
partial_message += response.token.text |
|
yield partial_message |
|
|
|
except Exception as e: |
|
logger.error(f"Error in generate function: {str(e)}") |
|
yield f"I encountered an error: {str(e)}" |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """You are a supportive AI assistant trained to provide emotional support and general guidance. |
|
Remember to: |
|
1. Show empathy and understanding |
|
2. Ask clarifying questions when needed |
|
3. Provide practical coping strategies |
|
4. Encourage professional help when appropriate |
|
5. Maintain boundaries and ethical guidelines""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot(height=500) |
|
msg = gr.Textbox(label="Message", placeholder="Type your message here...") |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
system_prompt = gr.Textbox( |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
label="System Prompt", |
|
lines=3 |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
max_new_tokens = gr.Slider( |
|
minimum=64, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
label="Max Tokens" |
|
) |
|
|
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history, system_prompt, temperature, max_new_tokens): |
|
if not history: |
|
return history |
|
|
|
user_message = history[-1][0] |
|
history[-1][1] = "" |
|
|
|
for chunk in generate( |
|
user_message, |
|
history[:-1], |
|
system_prompt, |
|
temperature, |
|
max_new_tokens |
|
): |
|
history[-1][1] = chunk |
|
yield history |
|
|
|
msg.submit( |
|
user, |
|
[msg, chatbot], |
|
[msg, chatbot], |
|
queue=False |
|
).then( |
|
bot, |
|
[chatbot, system_prompt, temperature, max_new_tokens], |
|
chatbot |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
gr.Markdown(""" |
|
# PsyAssist - ADVANCING MENTAL HEALTH SUPPORT WITH AI-DRIVEN INTERACTION |
|
|
|
**Important Notice**: This is an AI-powered mental health support chatbot. While it can provide emotional support |
|
and general guidance, it is not a replacement for professional mental health services. In case of emergency, |
|
please contact your local mental health crisis hotline. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
try: |
|
demo.queue().launch(show_api=False) |
|
except Exception as e: |
|
logger.error(f"Failed to launch Gradio interface: {str(e)}") |
|
sys.exit(1) |