testpbr / app.py
ascarlettvfx's picture
Update app.py
2addbd2 verified
import gradio as gr
from gradio_client import Client, handle_file
from PIL import Image
import numpy as np
import io
import tempfile
def process_image(image):
client = Client("prs-eth/marigold")
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Save the PIL Image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpeg") as tmp:
image.save(tmp, format='JPEG')
tmp_path = tmp.name # Get the file path
# Call the API with necessary parameters
result = client.predict(
handle_file(tmp_path), # filepath for 'Input Image' Image component
20, # Ensemble size
10, # Number of denoising steps
"0", # Processing resolution
handle_file(tmp_path), # Placeholder for 'Predicted depth (16-bit)'
handle_file(tmp_path), # Placeholder for 'Predicted depth (32-bit)'
handle_file(tmp_path), # Placeholder for 'Predicted depth (red-near, blue-far)'
0, # Relative position of the near plane
0, # Relative position of the far plane
0, # Embossing level
1, # Size of the smoothing filter
-100, # Frame's near plane offset
api_name="/submit_depth_fn"
)
# Handle the returned file path for the depth image
if result and 'depth_outputs' in result:
depth_image_path = result['depth_outputs'][0]
depth_image = Image.open(depth_image_path)
depth_image.load() # Ensure the image is loaded completely
return depth_image
else:
return "No depth output received or error in processing"
# Define the Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type='numpy'), # Ensure input is received as a numpy array
outputs=gr.Image(),
title="16-bit Depth Output using Marigold API",
description="Upload an image to retrieve its 16-bit depth output using the Marigold API."
)
# Run the Gradio app
if __name__ == "__main__":
iface.launch()