''' @author: Zhigang Jiang @time: 2022/05/23 @description: ''' import gradio as gr import numpy as np import os import torch os.system('pip install --upgrade --no-cache-dir gdown') from PIL import Image from utils.logger import get_logger from config.defaults import get_config from inference import preprocess, run_one_inference from models.build import build_model from argparse import Namespace import gdown def down_ckpt(model_cfg, ckpt_dir): model_ids = [ ['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'], ['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'], ['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'], ['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'], ['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha'] ] for model_id in model_ids: if model_id[0] != model_cfg: continue path = os.path.join(ckpt_dir, 'best.pkl') if not os.path.exists(path): logger.info(f"Downloading {model_id}") os.makedirs(ckpt_dir, exist_ok=True) gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False) def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution): args.pre_processing = pre_processing args.post_processing = post_processing if weight_name == 'mp3d': model = mp3d_model elif weight_name == 'zind': model = zind_model else: logger.error("unknown pre-trained weight name") raise NotImplementedError img_name = os.path.basename(img_path).split('.')[0] img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] vp_cache_path = 'src/demo/default_vp.txt' if args.pre_processing: vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt') logger.info("pre-processing ...") img, vp = preprocess(img, vp_cache_path=vp_cache_path) img = (img / 255.0).astype(np.float32) run_one_inference(img, model, args, img_name, logger=logger, show=False, show_depth='depth-normal-gradient' in visualization, show_floorplan='2d-floorplan' in visualization, mesh_format=mesh_format, mesh_resolution=int(mesh_resolution)) return [os.path.join(args.output_dir, f"{img_name}_pred.png"), os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), vp_cache_path, os.path.join(args.output_dir, f"{img_name}_pred.json")] def get_model(args): config = get_config(args) down_ckpt(args.cfg, config.CKPT.DIR) if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): logger.info(f'The {args.device} is not available, will use cpu...') config.defrost() args.device = "cpu" config.TRAIN.DEVICE = "cpu" config.freeze() model, _, _, _ = build_model(config, logger) return model if __name__ == '__main__': logger = get_logger() args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True) os.makedirs(args.output_dir, exist_ok=True) args.cfg = 'src/config/mp3d.yaml' mp3d_model = get_model(args) args.cfg = 'src/config/zind.yaml' zind_model = get_model(args) description = "This demo of the github project " \ "LGT-Net. If this project helped you, please add a star to the github project." \ "It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama." demo = gr.Interface(fn=greet, inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'), gr.Checkbox(label='pre-processing', value=True), gr.Radio(['mp3d', 'zind'], label='pre-trained weight', value='mp3d'), gr.Radio(['manhattan', 'atalanta', 'original'], label='post-processing method', value='manhattan'), gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'], label='2d-visualization', value=['depth-normal-gradient', '2d-floorplan']), gr.Radio(['.gltf', '.obj', '.glb'], label='output format of 3d mesh', value='.gltf'), gr.Radio(['128', '256', '512', '1024'], label='output resolution of 3d mesh', value='256'), ], outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'), gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]), gr.File(label='3d mesh file'), gr.File(label='vanishing point information'), gr.File(label='layout json')], examples=[ ['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], ], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description) demo.launch(debug=True, enable_queue=False)