Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import openai | |
from openai import OpenAI | |
from PIL import Image, ImageEnhance | |
import cv2 | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
import requests | |
from io import BytesIO | |
# Set OpenAI API Key | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Load CLIP model and processor | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
# Expanded object labels | |
object_labels = [ | |
"cat", "dog", "house", "tree", "car", "mountain", "flower", "bird", "person", "robot", | |
"a digital artwork", "a portrait", "a landscape", "a futuristic cityscape", "horse", | |
"lion", "tiger", "elephant", "giraffe", "airplane", "train", "ship", "book", "laptop", | |
"keyboard", "pen", "clock", "cup", "bottle", "backpack", "chair", "table", "sofa", | |
"bed", "building", "street", "forest", "desert", "waterfall", "sunset", "beach", | |
"bridge", "castle", "statue", "3D model" | |
] | |
# Example image for contrast check | |
EXAMPLE_IMAGE_URL = "https://www.watercoloraffair.com/wp-content/uploads/2023/04/monet-houses-of-parliament-low-key.jpg" # Square example image | |
example_image = Image.open(BytesIO(requests.get(EXAMPLE_IMAGE_URL).content)) | |
# Initialize the OpenAI client | |
client = OpenAI() | |
def process_chat(user_text): | |
if not user_text.strip(): | |
yield "β οΈ Please enter a valid question." | |
return | |
try: | |
# Use the OpenAI client for creating a chat completion | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant named Diane specializing in digital art advice. Don't use text styling (i.e., bold, italics."}, | |
{"role": "user", "content": user_text}, | |
], | |
stream=True # Enable streaming | |
) | |
response_text = "" | |
for chunk in response: | |
# Extract the content correctly | |
delta = chunk.choices[0].delta # Get the delta object | |
token = getattr(delta, "content", None) # Safely get the "content" field | |
if token: # Only process non-None tokens | |
response_text += token | |
yield response_text | |
except Exception as e: | |
yield f"β An error occurred: {str(e)}" | |
# Function to analyze image contrast | |
def analyze_contrast_opencv(image_path): | |
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) | |
contrast = img.std() | |
return contrast | |
# Function to identify objects using CLIP | |
def identify_objects_with_clip(image_path): | |
image = Image.open(image_path).convert("RGB") | |
inputs = clip_processor(text=object_labels, images=image, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1).numpy().flatten() | |
best_match_label = object_labels[probs.argmax()] | |
return best_match_label | |
# Function to enhance image contrast | |
def enhance_contrast(image): | |
enhancer = ImageEnhance.Contrast(image) | |
enhanced_image = enhancer.enhance(1.5) | |
enhanced_path = "enhanced_image.png" | |
enhanced_image.save(enhanced_path) | |
return enhanced_path | |
def provide_suggestions_streaming(object_identified): | |
if not object_identified: | |
yield "β οΈ Sorry, I couldn't identify an object in your artwork. Try uploading a different image." | |
return | |
try: | |
# Use the OpenAI client for suggestions | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You are an expert digital art advisor."}, | |
{"role": "user", "content": f"Suggest ways to improve a digital artwork featuring a {object_identified}."}, | |
], | |
stream=True # Enable streaming | |
) | |
response_text = "" | |
for chunk in response: | |
# Extract the content safely | |
delta = chunk.choices[0].delta # Get the delta object | |
token = getattr(delta, "content", None) # Safely access the "content" field | |
if token: # Only process non-None tokens | |
response_text += token | |
yield response_text | |
except Exception as e: | |
yield f"β An error occurred while providing suggestions: {str(e)}" | |
# Main image processing function | |
def process_image(image): | |
if not image: | |
return "β οΈ Please upload an image.", None, None | |
image.save("uploaded_image.png") | |
contrast = analyze_contrast_opencv("uploaded_image.png") | |
object_identified = identify_objects_with_clip("uploaded_image.png") | |
if contrast < 25: | |
enhanced_image_path = enhance_contrast(Image.open("uploaded_image.png")) | |
return ( | |
f"Hey, great artwork of {object_identified}! However, it looks like the contrast is a little low. I've improved the contrast for you. β¨", | |
enhanced_image_path, | |
object_identified | |
) | |
return ( | |
f"Hey, great artwork of {object_identified}! Looks like the color contrast is great. Be proud of yourself! π", | |
None, | |
object_identified | |
) | |
# Gradio Blocks Interface | |
demo = gr.Blocks(css=""" | |
#upload-image, #example-image { | |
height: 300px !important; | |
} | |
.button { | |
height: 50px; | |
font-size: 16px; | |
} | |
""") | |
with demo: | |
gr.Markdown("## π¨ DIANE (Digital Imaging and Art Neural Enhancer)") | |
gr.Markdown("DIANE is here to assist you in refining your digital art. She can answer questions about digital art, analyze your images, and provide creative suggestions to enhance your work.") | |
# Chatbot Section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### π¬ Ask me about digital art") | |
user_text = gr.Textbox(label="Enter your question", placeholder="What is the best tool for a beginner?...") | |
chat_output = gr.Textbox(label="Answer", interactive=False) | |
chat_button = gr.Button("Ask", elem_classes="button") | |
chat_button.click(process_chat, inputs=user_text, outputs=chat_output) | |
# Image Analysis Section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### πΌοΈ Upload an image to check its contrast levels") | |
with gr.Row(equal_height=True): | |
# Left: Image upload field | |
with gr.Column(): | |
image_input = gr.Image(label="Upload an image", type="pil", elem_id="upload-image") | |
image_button = gr.Button("Check", elem_classes="button") | |
# Right: Example image field | |
with gr.Column(): | |
gr.Image(value=example_image, label="Example Image", interactive=False, elem_id="example-image") | |
example_button = gr.Button("Use Example Image", elem_classes="button") | |
image_output_text = gr.Textbox(label="Analysis", interactive=False) | |
image_output_image = gr.Image(label="Improved Image", interactive=False) | |
suggestion_button = gr.Button("I want to improve this artwork. Any suggestions?", visible=False) | |
suggestions_output = gr.Textbox(label="Suggestions", interactive=True) | |
state_object = gr.State() # To store identified object | |
# Load example image into the input | |
def use_example_image(): | |
return example_image | |
example_button.click( | |
use_example_image, | |
inputs=None, | |
outputs=image_input | |
) | |
# Analyze button | |
def update_suggestions_visibility(analysis, enhanced_image, identified_object): | |
return gr.update(visible=True), analysis, enhanced_image | |
image_button.click( | |
process_image, | |
inputs=image_input, | |
outputs=[ | |
image_output_text, | |
image_output_image, | |
state_object | |
] | |
) | |
# Automatically enable suggestions after image processing | |
image_button.click( | |
update_suggestions_visibility, | |
inputs=[image_output_text, image_output_image, state_object], | |
outputs=[suggestion_button, image_output_text, image_output_image] | |
) | |
# Suggestion button functionality with streaming | |
suggestion_button.click( | |
provide_suggestions_streaming, | |
inputs=state_object, | |
outputs=suggestions_output | |
) | |
demo.launch(share=True) | |