import streamlit as st from openai import OpenAI import requests from PIL import Image import io import os from datetime import datetime def convert_to_png(image_file): """Convert uploaded image to PNG format and ensure it's under 4MB""" # Open the image using PIL image = Image.open(image_file) # Convert to RGBA if not already if image.mode in ('RGBA', 'RGB'): image = image.convert('RGBA') else: image = image.convert('RGB').convert('RGBA') # Save to bytes with PNG format byte_arr = io.BytesIO() image.save(byte_arr, format='PNG', optimize=True) byte_arr.seek(0) # Check if size is under 4MB if byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: # If image is too large, resize it while maintaining aspect ratio while byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: width, height = image.size new_width = int(width * 0.9) # Reduce by 10% new_height = int(height * 0.9) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) byte_arr = io.BytesIO() image.save(byte_arr, format='PNG', optimize=True) byte_arr.seek(0) return byte_arr def validate_image(uploaded_file): """Validate image size and format""" if uploaded_file.size > 4 * 1024 * 1024: # 4MB in bytes return False, "File size must be less than 4MB" return True, "OK" def save_uploaded_file(uploaded_file, folder="uploads"): """Save uploaded file to a temporary folder and return the path""" if not os.path.exists(folder): os.makedirs(folder) # Generate a timestamp for unique filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Convert to PNG and validate size image_bytes = convert_to_png(uploaded_file) # Save the converted PNG file_path = os.path.join(folder, f"{timestamp}.png") with open(file_path, "wb") as f: f.write(image_bytes.getvalue()) return file_path def download_image(url, folder="generated"): """Download image from URL and save it""" if not os.path.exists(folder): os.makedirs(folder) response = requests.get(url) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") file_path = os.path.join(folder, f"generated_image_{timestamp}.png") if response.status_code == 200: with open(file_path, "wb") as f: f.write(response.content) return file_path return None def main(): st.title("🎨 Image Editor with DALL-E 2") # Sidebar for API key with st.sidebar: st.header("Configuration") api_key = st.text_input("Enter your OpenAI API key", type="password") st.markdown(""" ### How to get an API key 1. Go to [OpenAI API Keys](https://platform.openai.com/api-keys) 2. Create a new secret key 3. Copy and paste it here """) # Size selection size_option = st.selectbox( "Select image size:", ["1024x1024", "512x512", "256x256"] ) # Number of images num_images = st.slider("Number of images to generate", 1, 4, 1) # Main content st.markdown(""" ### Requirements: - Original image and mask must be less than 4MB - Images will be automatically converted to PNG format - Images larger than 4MB will be automatically resized """) # File uploaders col1, col2 = st.columns(2) with col1: st.subheader("Original Image") original_image = st.file_uploader( "Upload original image", type=["png", "jpg", "jpeg"] ) if original_image: # Validate image valid, message = validate_image(original_image) if not valid: st.warning(f"Original image: {message}. Image will be automatically resized.") try: # Display image preview image = Image.open(original_image) st.image(image, caption="Original Image", use_container_width=True) st.caption(f"Original size: {original_image.size/1024/1024:.2f}MB") except Exception as e: st.error(f"Error loading image: {str(e)}") with col2: st.subheader("Mask Image") mask_image = st.file_uploader( "Upload mask image", type=["png", "jpg", "jpeg"] ) if mask_image: # Validate image valid, message = validate_image(mask_image) if not valid: st.warning(f"Mask image: {message}. Image will be automatically resized.") try: # Display image preview image = Image.open(mask_image) st.image(image, caption="Mask Image", use_column_width=True) st.caption(f"Original size: {mask_image.size/1024/1024:.2f}MB") except Exception as e: st.error(f"Error loading mask: {str(e)}") # Prompt input prompt = st.text_area( "Enter your prompt:", placeholder="Describe the changes you want to make to the image...", help="Be specific about what you want to add or modify in the masked area" ) # Generate button if st.button("Generate Edited Image"): if not api_key: st.error("Please enter your OpenAI API key in the sidebar.") return if not original_image or not mask_image: st.error("Please upload both an original image and a mask image.") return if not prompt: st.error("Please enter a prompt describing the desired changes.") return try: with st.spinner("Processing images and generating edited version..."): # Save and convert uploaded files original_path = save_uploaded_file(original_image) mask_path = save_uploaded_file(mask_image) # Initialize OpenAI client client = OpenAI(api_key=api_key) # Make the API call response = client.images.edit( model="dall-e-2", image=open(original_path, "rb"), mask=open(mask_path, "rb"), prompt=prompt, n=num_images, size=size_option ) # Display results st.subheader("Generated Images") cols = st.columns(num_images) for idx, image_data in enumerate(response.data): # Download and save the generated image saved_image_path = download_image(image_data.url) if saved_image_path: with cols[idx]: st.image(saved_image_path, caption=f"Generated Image {idx+1}") # Add download button for each image with open(saved_image_path, "rb") as file: st.download_button( label=f"Download Image {idx+1}", data=file, file_name=f"edited_image_{idx+1}.png", mime="image/png" ) # Cleanup temporary files for path in [original_path, mask_path]: if os.path.exists(path): os.remove(path) except Exception as e: st.error(f"An error occurred: {str(e)}") finally: # Cleanup temporary files for path in [original_path, mask_path]: if 'path' in locals() and os.path.exists(path): try: os.remove(path) except Exception: pass if __name__ == "__main__": main()