import gradio as gr from color_matcher import ColorMatcher from color_matcher.normalizer import Normalizer import numpy as np import cv2 from PIL import Image # Function to apply color correction def color_match(source_img, reference_img): # Convert PIL images to OpenCV format (numpy arrays) img_src = np.array(source_img) img_ref = np.array(reference_img) # Ensure images are in RGB format (3 channels) if img_src.shape[2] == 4: img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB) if img_ref.shape[2] == 4: img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB) # Apply color matching cm = ColorMatcher() img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl') # Normalize the result img_res = Normalizer(img_res).uint8_norm() # Convert back to PIL for displaying in Gradio img_res_pil = Image.fromarray(img_res) return img_res_pil # Gradio Interface def gradio_interface(): # Define input and output components inputs = [ gr.Image(type="pil", label="Source Image"), gr.Image(type="pil", label="Reference Image") ] outputs = gr.Image(type="pil", label="Resulting Image") # Launch Gradio app gr.Interface(fn=color_match, inputs=inputs, outputs=outputs, title="Color Matching Tool").launch() # Run the Gradio Interface if __name__ == "__main__": gradio_interface()