llm / app.py
aman-augurs's picture
Update app.py
fe0fe5b verified
raw
history blame
2.71 kB
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
# Function to load model and tokenizer
def load_model():
try:
# Define the model path (Falcon 7B with quantization applied)
base_model = "aman-augurs/falcon7b-final-merged-quantized-fine-tuned"
# Load the tokenizer (works for CPU)
tokenizer = AutoTokenizer.from_pretrained(base_model)
# Load model for CPU with standard settings
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map='cpu', # Ensure it's set to 'cpu'
torch_dtype=torch.float32, # Ensure using float32 for CPU
low_cpu_mem_usage=True # Optimize memory usage for CPU
)
# Check model's device to confirm it's running on CPU
print(f"Model loaded successfully on device: {model.device}")
return tokenizer, model
except Exception as e:
print(f"Model loading error: {e}")
return None, None
# Load the model globally
tokenizer, model = load_model()
# Function to generate a response from the model
def generate_response(prompt):
try:
if model is None or tokenizer is None:
return "Model is not loaded correctly."
# Encode the input prompt using the tokenizer
inputs = tokenizer(prompt, return_tensors="pt")
# Generate output using the model
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
max_length=200,
num_return_sequences=1,
temperature=0.7, # A typical value for creativity
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3
)
# Decode and return the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
print(f"Response generation error: {e}")
return "An error occurred during response generation."
# Streamlit UI
def main():
st.title("Text Generation with Falcon 7B Model")
# Input text area for user prompt
prompt = st.text_area("Enter your prompt:", "Once upon a time")
# Button to generate response
if st.button("Generate Response"):
if prompt:
st.write("Generating response...")
response = generate_response(prompt)
st.subheader("Generated Text:")
st.write(response)
else:
st.error("Please enter a prompt.")
if __name__ == '__main__':
main()