Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Set up the device to use CPU only | |
device = torch.device("cpu") | |
# Load model and tokenizer, then move the model to the appropriate device | |
model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device) | |
tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA") | |
# Function to generate a response from the chatbot | |
def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1, max_input_length: int = 256) -> str: | |
inputs = tokenizer( | |
message, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=max_input_length | |
).to(device) | |
# Generate the response | |
output = model.generate( | |
inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
max_length=512, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the generated output | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Streamlit app layout | |
st.title("Menstrual QA Chatbot") | |
st.write("Ask any question related to menstrual health.") | |
# User input | |
user_input = st.text_input("You:", "") | |
if st.button("Send"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
response = generate_response(user_input) | |
st.write(f"Chatbot: {response}") | |
else: | |
st.write("Please enter a question.") | |