Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Set up the device to use CPU only | |
device = torch.device("cpu") | |
# Load model and tokenizer, then move the model to the appropriate device | |
model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device) | |
tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA") | |
# Function to generate a response from the chatbot | |
def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str: | |
# Apply the chat template and convert to PyTorch tensors | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": message} | |
] | |
input_ids = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, return_tensors="pt" | |
).to(device) | |
# Generate the response | |
output = model.generate( | |
input_ids, | |
max_length=512, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
do_sample=True | |
) | |
# Decode the generated output | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Streamlit app layout | |
st.title("Menstrual QA Chatbot") | |
st.write("Ask any question related to menstrual health.") | |
# User input | |
user_input = st.text_input("You:", "") | |
if st.button("Send"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
response = generate_response(user_input) | |
st.write(f"Chatbot: {response}") | |
else: | |
st.write("Please enter a question.") | |