Spaces:
Sleeping
Sleeping
import streamlit as st | |
from openai import OpenAI | |
import requests | |
from PIL import Image | |
import io | |
import os | |
from datetime import datetime | |
def convert_to_png(image_file): | |
"""Convert uploaded image to PNG format and ensure it's under 4MB""" | |
# Open the image using PIL | |
image = Image.open(image_file) | |
# Convert to RGBA if not already | |
if image.mode in ('RGBA', 'RGB'): | |
image = image.convert('RGBA') | |
else: | |
image = image.convert('RGB').convert('RGBA') | |
# Save to bytes with PNG format | |
byte_arr = io.BytesIO() | |
image.save(byte_arr, format='PNG', optimize=True) | |
byte_arr.seek(0) | |
# Check if size is under 4MB | |
if byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: | |
# If image is too large, resize it while maintaining aspect ratio | |
while byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: | |
width, height = image.size | |
new_width = int(width * 0.9) # Reduce by 10% | |
new_height = int(height * 0.9) | |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
byte_arr = io.BytesIO() | |
image.save(byte_arr, format='PNG', optimize=True) | |
byte_arr.seek(0) | |
return byte_arr | |
def validate_image(uploaded_file): | |
"""Validate image size and format""" | |
if uploaded_file.size > 4 * 1024 * 1024: # 4MB in bytes | |
return False, "File size must be less than 4MB" | |
return True, "OK" | |
def save_uploaded_file(uploaded_file, folder="uploads"): | |
"""Save uploaded file to a temporary folder and return the path""" | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
# Generate a timestamp for unique filename | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Convert to PNG and validate size | |
image_bytes = convert_to_png(uploaded_file) | |
# Save the converted PNG | |
file_path = os.path.join(folder, f"{timestamp}.png") | |
with open(file_path, "wb") as f: | |
f.write(image_bytes.getvalue()) | |
return file_path | |
def download_image(url, folder="generated"): | |
"""Download image from URL and save it""" | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
response = requests.get(url) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
file_path = os.path.join(folder, f"generated_image_{timestamp}.png") | |
if response.status_code == 200: | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
return file_path | |
return None | |
def main(): | |
st.title("π¨ Image Editor with DALL-E 2") | |
# Sidebar for API key | |
with st.sidebar: | |
st.header("Configuration") | |
api_key = st.text_input("Enter your OpenAI API key", type="password") | |
st.markdown(""" | |
### How to get an API key | |
1. Go to [OpenAI API Keys](https://platform.openai.com/api-keys) | |
2. Create a new secret key | |
3. Copy and paste it here | |
""") | |
# Size selection | |
size_option = st.selectbox( | |
"Select image size:", | |
["1024x1024", "512x512", "256x256"] | |
) | |
# Number of images | |
num_images = st.slider("Number of images to generate", 1, 4, 1) | |
# Main content | |
st.markdown(""" | |
### Requirements: | |
- Original image and mask must be less than 4MB | |
- Images will be automatically converted to PNG format | |
- Images larger than 4MB will be automatically resized | |
""") | |
# File uploaders | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Original Image") | |
original_image = st.file_uploader( | |
"Upload original image", | |
type=["png", "jpg", "jpeg"] | |
) | |
if original_image: | |
# Validate image | |
valid, message = validate_image(original_image) | |
if not valid: | |
st.warning(f"Original image: {message}. Image will be automatically resized.") | |
try: | |
# Display image preview | |
image = Image.open(original_image) | |
st.image(image, caption="Original Image", use_container_width=True) | |
st.caption(f"Original size: {original_image.size/1024/1024:.2f}MB") | |
except Exception as e: | |
st.error(f"Error loading image: {str(e)}") | |
with col2: | |
st.subheader("Mask Image") | |
mask_image = st.file_uploader( | |
"Upload mask image", | |
type=["png", "jpg", "jpeg"] | |
) | |
if mask_image: | |
# Validate image | |
valid, message = validate_image(mask_image) | |
if not valid: | |
st.warning(f"Mask image: {message}. Image will be automatically resized.") | |
try: | |
# Display image preview | |
image = Image.open(mask_image) | |
st.image(image, caption="Mask Image", use_column_width=True) | |
st.caption(f"Original size: {mask_image.size/1024/1024:.2f}MB") | |
except Exception as e: | |
st.error(f"Error loading mask: {str(e)}") | |
# Prompt input | |
prompt = st.text_area( | |
"Enter your prompt:", | |
placeholder="Describe the changes you want to make to the image...", | |
help="Be specific about what you want to add or modify in the masked area" | |
) | |
# Generate button | |
if st.button("Generate Edited Image"): | |
if not api_key: | |
st.error("Please enter your OpenAI API key in the sidebar.") | |
return | |
if not original_image or not mask_image: | |
st.error("Please upload both an original image and a mask image.") | |
return | |
if not prompt: | |
st.error("Please enter a prompt describing the desired changes.") | |
return | |
try: | |
with st.spinner("Processing images and generating edited version..."): | |
# Save and convert uploaded files | |
original_path = save_uploaded_file(original_image) | |
mask_path = save_uploaded_file(mask_image) | |
# Initialize OpenAI client | |
client = OpenAI(api_key=api_key) | |
# Make the API call | |
response = client.images.edit( | |
model="dall-e-2", | |
image=open(original_path, "rb"), | |
mask=open(mask_path, "rb"), | |
prompt=prompt, | |
n=num_images, | |
size=size_option | |
) | |
# Display results | |
st.subheader("Generated Images") | |
cols = st.columns(num_images) | |
for idx, image_data in enumerate(response.data): | |
# Download and save the generated image | |
saved_image_path = download_image(image_data.url) | |
if saved_image_path: | |
with cols[idx]: | |
st.image(saved_image_path, caption=f"Generated Image {idx+1}") | |
# Add download button for each image | |
with open(saved_image_path, "rb") as file: | |
st.download_button( | |
label=f"Download Image {idx+1}", | |
data=file, | |
file_name=f"edited_image_{idx+1}.png", | |
mime="image/png" | |
) | |
# Cleanup temporary files | |
for path in [original_path, mask_path]: | |
if os.path.exists(path): | |
os.remove(path) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
finally: | |
# Cleanup temporary files | |
for path in [original_path, mask_path]: | |
if 'path' in locals() and os.path.exists(path): | |
try: | |
os.remove(path) | |
except Exception: | |
pass | |
if __name__ == "__main__": | |
main() |