Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import numpy as np | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from torchvision import transforms | |
from transformers import AutoConfig, AutoModel | |
from transformers import AutoModel | |
from focusondepth.model_config import FocusOnDepthConfig | |
from focusondepth.model_definition import FocusOnDepth | |
AutoConfig.register("focusondepth", FocusOnDepthConfig) | |
AutoModel.register(FocusOnDepthConfig, FocusOnDepth) | |
original_image_cache = {} | |
transform = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True) | |
model.load_state_dict(torch.load('./focusondepth/FocusOnDepth_vit_base_patch16_384.p', map_location=torch.device('cpu'))['model_state_dict']) | |
def inference(input_image): | |
global model, transform | |
model.eval() | |
input_image = Image.fromarray(input_image) | |
original_size = input_image.size | |
tensor_image = transform(input_image) | |
depth, segmentation = model(tensor_image.unsqueeze(0)) | |
depth = 1-depth | |
depth = transforms.ToPILImage()(depth[0, :]) | |
segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float()) | |
return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)] | |
iface = gr.Interface( | |
fn=inference, | |
inputs=gr.inputs.Image(label="Input Image"), | |
outputs = [ | |
gr.outputs.Image(label="Depth Map:"), | |
gr.outputs.Image(label="Segmentation Map:"), | |
], | |
) | |
iface.launch() |