import torch import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM # Function to load model and tokenizer def load_model(): try: # Define the model path (Falcon 7B with quantization applied) base_model = "aman-augurs/llama3.2-30-fine-tuned" # Load the tokenizer (works for CPU) tokenizer = AutoTokenizer.from_pretrained(base_model) # Load model for CPU with standard settings model = AutoModelForCausalLM.from_pretrained( base_model, device_map='cpu', # Ensure it's set to 'cpu' torch_dtype=torch.float32, # Ensure using float32 for CPU low_cpu_mem_usage=True # Optimize memory usage for CPU ) # Check model's device to confirm it's running on CPU print(f"Model loaded successfully on device: {model.device}") return tokenizer, model except Exception as e: print(f"Model loading error: {e}") return None, None # Load the model globally tokenizer, model = load_model() # Function to generate a response from the model def generate_response(prompt): try: if model is None or tokenizer is None: return "Model is not loaded correctly." # Encode the input prompt using the tokenizer inputs = tokenizer(prompt, return_tensors="pt") # Generate output using the model with torch.no_grad(): outputs = model.generate( inputs['input_ids'], max_length=200, num_return_sequences=1, temperature=0.7, # A typical value for creativity top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3 ) # Decode and return the generated text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text except Exception as e: print(f"Response generation error: {e}") return "An error occurred during response generation." # Streamlit UI def main(): st.title("Text Generation with Falcon 7B Model") # Input text area for user prompt prompt = st.text_area("Enter your prompt:", "Once upon a time") # Button to generate response if st.button("Generate Response"): if prompt: st.write("Generating response...") response = generate_response(prompt) st.subheader("Generated Text:") st.write(response) else: st.error("Please enter a prompt.") if __name__ == '__main__': main()