FocusOnDepth / app.py
ybelkada's picture
add first files
7708d0d
raw
history blame
1.67 kB
import torch
import gradio as gr
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from torchvision import transforms
from transformers import AutoConfig, AutoModel
from transformers import AutoModel
from focusondepth.model_config import FocusOnDepthConfig
from focusondepth.model_definition import FocusOnDepth
AutoConfig.register("focusondepth", FocusOnDepthConfig)
AutoModel.register(FocusOnDepthConfig, FocusOnDepth)
original_image_cache = {}
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True)
model.load_state_dict(torch.load('./focusondepth/FocusOnDepth_vit_base_patch16_384.p', map_location=torch.device('cpu'))['model_state_dict'])
@torch.no_grad()
def inference(input_image):
global model, transform
model.eval()
input_image = Image.fromarray(input_image)
original_size = input_image.size
tensor_image = transform(input_image)
depth, segmentation = model(tensor_image.unsqueeze(0))
depth = 1-depth
depth = transforms.ToPILImage()(depth[0, :])
segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float())
return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)]
iface = gr.Interface(
fn=inference,
inputs=gr.inputs.Image(label="Input Image"),
outputs = [
gr.outputs.Image(label="Depth Map:"),
gr.outputs.Image(label="Segmentation Map:"),
],
)
iface.launch()