Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Function to load model and tokenizer | |
def load_model(): | |
try: | |
# Define the model path (Falcon 7B with quantization applied) | |
base_model = "aman-augurs/falcon7b-final-merged-quantized-fine-tuned" | |
# Load the tokenizer (works for CPU) | |
tokenizer = AutoTokenizer.from_pretrained(base_model) | |
# Load model for CPU with standard settings | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model, | |
device_map='cpu', # Ensure it's set to 'cpu' | |
torch_dtype=torch.float32, # Ensure using float32 for CPU | |
low_cpu_mem_usage=True # Optimize memory usage for CPU | |
) | |
# Check model's device to confirm it's running on CPU | |
print(f"Model loaded successfully on device: {model.device}") | |
return tokenizer, model | |
except Exception as e: | |
print(f"Model loading error: {e}") | |
return None, None | |
# Load the model globally | |
tokenizer, model = load_model() | |
# Function to generate a response from the model | |
def generate_response(prompt): | |
try: | |
if model is None or tokenizer is None: | |
return "Model is not loaded correctly." | |
# Encode the input prompt using the tokenizer | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Generate output using the model | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs['input_ids'], | |
max_length=200, | |
num_return_sequences=1, | |
temperature=0.7, # A typical value for creativity | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3 | |
) | |
# Decode and return the generated text | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
except Exception as e: | |
print(f"Response generation error: {e}") | |
return "An error occurred during response generation." | |
# Streamlit UI | |
def main(): | |
st.title("Text Generation with Falcon 7B Model") | |
# Input text area for user prompt | |
prompt = st.text_area("Enter your prompt:", "Once upon a time") | |
# Button to generate response | |
if st.button("Generate Response"): | |
if prompt: | |
st.write("Generating response...") | |
response = generate_response(prompt) | |
st.subheader("Generated Text:") | |
st.write(response) | |
else: | |
st.error("Please enter a prompt.") | |
if __name__ == '__main__': | |
main() | |