Spaces:
Runtime error
Runtime error
import gradio as gr | |
from color_matcher import ColorMatcher | |
from color_matcher.normalizer import Normalizer | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
# Function to apply color correction | |
def color_match(source_img, reference_img): | |
# Convert PIL images to OpenCV format (numpy arrays) | |
img_src = np.array(source_img) | |
img_ref = np.array(reference_img) | |
# Ensure images are in RGB format (3 channels) | |
if img_src.shape[2] == 4: | |
img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB) | |
if img_ref.shape[2] == 4: | |
img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB) | |
# Apply color matching | |
cm = ColorMatcher() | |
img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl') | |
# Normalize the result | |
img_res = Normalizer(img_res).uint8_norm() | |
# Convert back to PIL for displaying in Gradio | |
img_res_pil = Image.fromarray(img_res) | |
return img_res_pil | |
# Gradio Interface | |
def gradio_interface(): | |
# Define input and output components | |
inputs = [ | |
gr.Image(type="pil", label="Source Image"), | |
gr.Image(type="pil", label="Reference Image") | |
] | |
outputs = gr.Image(type="pil", label="Resulting Image") | |
# Launch Gradio app | |
gr.Interface(fn=color_match, inputs=inputs, outputs=outputs, title="Color Matching Tool").launch() | |
# Run the Gradio Interface | |
if __name__ == "__main__": | |
gradio_interface() | |