Spaces:
Running
on
T4
Running
on
T4
import os | |
try: | |
import perspective2d | |
except: | |
os.system(f"pip install git+https://github.com/jinlinyi/PerspectiveFields.git@hf-debug") | |
os.system(f"pip install --upgrade numpy") | |
import gradio as gr | |
import cv2 | |
import copy | |
import numpy as np | |
import os.path as osp | |
from datetime import datetime | |
import torch | |
from PIL import Image, ImageDraw | |
from glob import glob | |
from perspective2d import PerspectiveFields | |
from perspective2d.utils import draw_perspective_fields, draw_from_r_p_f_cx_cy | |
from perspective2d.perspectivefields import model_zoo | |
title = "Perspective Fields Demo" | |
description = """ | |
<p style="text-align: center"> | |
<a href="https://jinlinyi.github.io/PerspectiveFields/" target="_blank">Project Page</a> | | |
<a href="https://arxiv.org/abs/2212.03239" target="_blank">Paper</a> | | |
<a href="https://github.com/jinlinyi/PerspectiveFields" target="_blank">Code</a> | | |
<a href="https://www.youtube.com/watch?v=sN5B_ZvMva8&themeRefresh=1" target="_blank">Video</a> | |
</p> | |
<h2>Gradio Demo</h2> | |
<p>Try our Gradio demo for Perspective Fields for single image camera calibration. You can click on one of the provided examples or upload your own image.</p> | |
<h3>Available Models:</h3> | |
<ol> | |
<li><span style="color:red;">[NEW!!!]</span><strong>Paramnet-360Cities-edina:</strong> Our latest model trained on <a href="https://www.360cities.net/">360cities</a> and <a href="https://github.com/tien-d/EgoDepthNormal/tree/main#egocentric-depth-on-everyday-indoor-activities-edina-dataset">EDINA</a> dataset.</li> | |
<li><strong>PersNet-360Cities:</strong> PerspectiveNet trained on the 360Cities dataset. This model predicts perspective fields and is designed to be robust and generalize well to both indoor and outdoor images.</li> | |
<li><strong>PersNet_Paramnet-GSV-uncentered:</strong> A combination of PerspectiveNet and ParamNet trained on the Google Street View (GSV) dataset. This model predicts camera Roll, Pitch, and Field of View (FoV), as well as the Principal Point location.</li> | |
<li><strong>PersNet_Paramnet-GSV-centered:</strong> PerspectiveNet+ParamNet trained on the GSV dataset. This model assumes the principal point is at the center of the image and predicts camera Roll, Pitch, and FoV.</li> | |
</ol> | |
""" | |
article = """ | |
<p style='text-align: center'><a href='https://arxiv.org/abs/2212.03239' target='_blank'>Perspective Fields for Single Image Camera Calibrations</a> | <a href='https://github.com/jinlinyi/PerspectiveFields' target='_blank'>Github Repo</a></p> | |
""" | |
def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None): | |
height = img.shape[0] | |
width = img.shape[1] | |
if target_height is None: | |
factor = target_width / width | |
elif target_width is None: | |
factor = target_height / height | |
else: | |
factor = max(target_width / width, target_height / height) | |
if factor == target_width / width: | |
target_height = int(height * factor) | |
else: | |
target_width = int(width * factor) | |
img = cv2.resize(img, (target_width, target_height)) | |
for key in field: | |
if key not in ['up', 'lati']: | |
continue | |
tmp = field[key].numpy() | |
transpose = len(tmp.shape) == 3 | |
if transpose: | |
tmp = tmp.transpose(1,2,0) | |
tmp = cv2.resize(tmp, (target_width, target_height)) | |
if transpose: | |
tmp = tmp.transpose(2,0,1) | |
field[key] = torch.tensor(tmp) | |
return img, field | |
def inference(img_rgb, model_type): | |
if model_type is None: | |
return None, "" | |
pf_model = PerspectiveFields(model_type).eval().to(device) | |
pred = pf_model.inference(img_bgr=img_rgb[...,::-1]) | |
img_h = img_rgb.shape[0] | |
field = { | |
'up': pred['pred_gravity_original'].cpu().detach(), | |
'lati': pred['pred_latitude_original'].cpu().detach(), | |
} | |
img_rgb, field = resize_fix_aspect_ratio(img_rgb, field, 640) | |
if not model_zoo[model_type]['param']: | |
pred_vis = draw_perspective_fields( | |
img_rgb, | |
field['up'], | |
torch.deg2rad(field['lati']), | |
color=(0,1,0), | |
) | |
param = "Not Implemented" | |
else: | |
r_p_f_rad = np.radians( | |
[ | |
pred['pred_roll'].cpu().item(), | |
pred['pred_pitch'].cpu().item(), | |
pred['pred_general_vfov'].cpu().item(), | |
] | |
) | |
cx_cy = [ | |
pred['pred_rel_cx'].cpu().item(), | |
pred['pred_rel_cy'].cpu().item(), | |
] | |
param = f"roll {pred['pred_roll'].cpu().item() :.2f}\npitch {pred['pred_pitch'].cpu().item() :.2f}\nvertical fov {pred['pred_general_vfov'].cpu().item() :.2f}\nfocal_length {pred['pred_rel_focal'].cpu().item()*img_h :.2f}\n" | |
param += f"principal point {pred['pred_rel_cx'].cpu().item() :.2f} {pred['pred_rel_cy'].cpu().item() :.2f}" | |
pred_vis = draw_from_r_p_f_cx_cy( | |
img_rgb, | |
*r_p_f_rad, | |
*cx_cy, | |
'rad', | |
up_color=(0,1,0), | |
) | |
print(f"""time {datetime.now().strftime("%H:%M:%S")} | |
img.shape {img_rgb.shape} | |
model_type {model_type} | |
param {param} | |
""" | |
) | |
return Image.fromarray(pred_vis), param | |
examples = [] | |
for img_name in glob('assets/imgs/*.*g'): | |
examples.append([img_name]) | |
print(examples) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
info = """Select model\n""" | |
gr.Interface( | |
fn=inference, | |
inputs=[ | |
"image", | |
gr.Radio( | |
list(model_zoo.keys()), | |
value=list(sorted(model_zoo.keys()))[0], | |
label="Model", | |
info=info, | |
), | |
], | |
outputs=[gr.Image(label='Perspective Fields'), gr.Textbox(label='Pred Camera Parameters')], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
).launch() |