diane / app.py
oscurantismo's picture
Update app.py
4aead6f verified
import os
import gradio as gr
import openai
from openai import OpenAI
from PIL import Image, ImageEnhance
import cv2
import torch
from transformers import CLIPProcessor, CLIPModel
import requests
from io import BytesIO
# Set OpenAI API Key
openai.api_key = os.getenv("OPENAI_API_KEY")
# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# Expanded object labels
object_labels = [
"cat", "dog", "house", "tree", "car", "mountain", "flower", "bird", "person", "robot",
"a digital artwork", "a portrait", "a landscape", "a futuristic cityscape", "horse",
"lion", "tiger", "elephant", "giraffe", "airplane", "train", "ship", "book", "laptop",
"keyboard", "pen", "clock", "cup", "bottle", "backpack", "chair", "table", "sofa",
"bed", "building", "street", "forest", "desert", "waterfall", "sunset", "beach",
"bridge", "castle", "statue", "3D model"
]
# Example image for contrast check
EXAMPLE_IMAGE_URL = "https://www.watercoloraffair.com/wp-content/uploads/2023/04/monet-houses-of-parliament-low-key.jpg" # Square example image
example_image = Image.open(BytesIO(requests.get(EXAMPLE_IMAGE_URL).content))
# Initialize the OpenAI client
client = OpenAI()
def process_chat(user_text):
if not user_text.strip():
yield "⚠️ Please enter a valid question."
return
try:
# Use the OpenAI client for creating a chat completion
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant named Diane specializing in digital art advice. Don't use text styling (i.e., bold, italics."},
{"role": "user", "content": user_text},
],
stream=True # Enable streaming
)
response_text = ""
for chunk in response:
# Extract the content correctly
delta = chunk.choices[0].delta # Get the delta object
token = getattr(delta, "content", None) # Safely get the "content" field
if token: # Only process non-None tokens
response_text += token
yield response_text
except Exception as e:
yield f"❌ An error occurred: {str(e)}"
# Function to analyze image contrast
def analyze_contrast_opencv(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
contrast = img.std()
return contrast
# Function to identify objects using CLIP
def identify_objects_with_clip(image_path):
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(text=object_labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).numpy().flatten()
best_match_label = object_labels[probs.argmax()]
return best_match_label
# Function to enhance image contrast
def enhance_contrast(image):
enhancer = ImageEnhance.Contrast(image)
enhanced_image = enhancer.enhance(1.5)
enhanced_path = "enhanced_image.png"
enhanced_image.save(enhanced_path)
return enhanced_path
def provide_suggestions_streaming(object_identified):
if not object_identified:
yield "⚠️ Sorry, I couldn't identify an object in your artwork. Try uploading a different image."
return
try:
# Use the OpenAI client for suggestions
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an expert digital art advisor."},
{"role": "user", "content": f"Suggest ways to improve a digital artwork featuring a {object_identified}."},
],
stream=True # Enable streaming
)
response_text = ""
for chunk in response:
# Extract the content safely
delta = chunk.choices[0].delta # Get the delta object
token = getattr(delta, "content", None) # Safely access the "content" field
if token: # Only process non-None tokens
response_text += token
yield response_text
except Exception as e:
yield f"❌ An error occurred while providing suggestions: {str(e)}"
# Main image processing function
def process_image(image):
if not image:
return "⚠️ Please upload an image.", None, None
image.save("uploaded_image.png")
contrast = analyze_contrast_opencv("uploaded_image.png")
object_identified = identify_objects_with_clip("uploaded_image.png")
if contrast < 25:
enhanced_image_path = enhance_contrast(Image.open("uploaded_image.png"))
return (
f"Hey, great artwork of {object_identified}! However, it looks like the contrast is a little low. I've improved the contrast for you. ✨",
enhanced_image_path,
object_identified
)
return (
f"Hey, great artwork of {object_identified}! Looks like the color contrast is great. Be proud of yourself! 🌟",
None,
object_identified
)
# Gradio Blocks Interface
demo = gr.Blocks(css="""
#upload-image, #example-image {
height: 300px !important;
}
.button {
height: 50px;
font-size: 16px;
}
""")
with demo:
gr.Markdown("## 🎨 DIANE (Digital Imaging and Art Neural Enhancer)")
gr.Markdown("DIANE is here to assist you in refining your digital art. She can answer questions about digital art, analyze your images, and provide creative suggestions to enhance your work.")
# Chatbot Section
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ’¬ Ask me about digital art")
user_text = gr.Textbox(label="Enter your question", placeholder="What is the best tool for a beginner?...")
chat_output = gr.Textbox(label="Answer", interactive=False)
chat_button = gr.Button("Ask", elem_classes="button")
chat_button.click(process_chat, inputs=user_text, outputs=chat_output)
# Image Analysis Section
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ–ΌοΈ Upload an image to check its contrast levels")
with gr.Row(equal_height=True):
# Left: Image upload field
with gr.Column():
image_input = gr.Image(label="Upload an image", type="pil", elem_id="upload-image")
image_button = gr.Button("Check", elem_classes="button")
# Right: Example image field
with gr.Column():
gr.Image(value=example_image, label="Example Image", interactive=False, elem_id="example-image")
example_button = gr.Button("Use Example Image", elem_classes="button")
image_output_text = gr.Textbox(label="Analysis", interactive=False)
image_output_image = gr.Image(label="Improved Image", interactive=False)
suggestion_button = gr.Button("I want to improve this artwork. Any suggestions?", visible=False)
suggestions_output = gr.Textbox(label="Suggestions", interactive=True)
state_object = gr.State() # To store identified object
# Load example image into the input
def use_example_image():
return example_image
example_button.click(
use_example_image,
inputs=None,
outputs=image_input
)
# Analyze button
def update_suggestions_visibility(analysis, enhanced_image, identified_object):
return gr.update(visible=True), analysis, enhanced_image
image_button.click(
process_image,
inputs=image_input,
outputs=[
image_output_text,
image_output_image,
state_object
]
)
# Automatically enable suggestions after image processing
image_button.click(
update_suggestions_visibility,
inputs=[image_output_text, image_output_image, state_object],
outputs=[suggestion_button, image_output_text, image_output_image]
)
# Suggestion button functionality with streaming
suggestion_button.click(
provide_suggestions_streaming,
inputs=state_object,
outputs=suggestions_output
)
demo.launch(share=True)