import cv2 import numpy as np import os import torch import onnxruntime as ort import time from functools import wraps import argparse from PIL import Image from io import BytesIO import streamlit as st # Parse command-line arguments #parser = argparse.ArgumentParser() #parser.add_argument("--mosaic", help="Enable mosaic processing mode", action="store_true") #args = parser.parse_args() #mosaic = args.mosaic # Set this based on your command line argument # For streamlit use let's just set mosaic to "true", but I'm leavind the command-line arg here for anyone to use mosaic = True def center_crop(img, new_height, new_width): height, width, _ = img.shape start_x = width//2 - new_width//2 start_y = height//2 - new_height//2 return img[start_y:start_y+new_height, start_x:start_x+new_width] def mosaic_crop(img, size): height, width, _ = img.shape padding_height = (size - height % size) % size padding_width = (size - width % size) % size padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0]) tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)] return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width def stitch_tiles(tiles, rows, cols, size): return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0) def timing_decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() duration = end_time - start_time print(f"Function '{func.__name__}' took {duration:.6f} seconds") return result return wrapper @timing_decorator def process_image(session, img, colors, mosaic=False): if not mosaic: # Crop the center of the image to 416x416 pixels img = center_crop(img, 416, 416) blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False) # Perform inference output = session.run(None, {session.get_inputs()[0].name: blob}) # Assuming the output is a probability map where higher values indicate higher probability of a class output_img = output[0].squeeze(0).transpose(1, 2, 0) output_img = (output_img * 122).clip(0, 255).astype(np.uint8) output_mask = output_img.max(axis=2) output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8) # Assign specific colors to the classes in the mask for class_idx in np.unique(output_mask): if class_idx in colors: output_mask_color[output_mask == class_idx] = colors[class_idx] # Mask for the transparent class transparent_mask = (output_mask == 122) # Convert the mask to a 3-channel image transparent_mask = np.stack([transparent_mask]*3, axis=-1) # Where the mask is True, set the output color image to the input image output_mask_color[transparent_mask] = img[transparent_mask] # Make the colorful mask semi-transparent overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0) return overlay st.title("OpenLander ONNX app") st.write("Upload an image to process with the ONNX OpenLander model!") st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.") models = { "Embedded model better trained: DeeplabV3+, MobilenetV2, 416px resolution": "20230608_onnx_416_mbnv2_dl3/end2end.onnx", "test model training system V2: DV3+, 40k, 416px": "20230613_40k_test_v2.onnx" } # Create a Streamlit radio button to select the desired model selected_model = st.radio("Select a model", list(models.keys())) # set cuda = true if you have an NVIDIA GPU cuda = torch.cuda.is_available() if cuda: print("We have a GPU!") providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider'] # Get the selected model's path model_path = models[selected_model] session = ort.InferenceSession(model_path, providers=providers) # Define colors for classes 0, 122 and 244 colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} # Red, Black, Yellow def load_image(uploaded_file): try: image = Image.open(uploaded_file) return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) except Exception as e: st.write("Could not load image: ", e) return None uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"]) if uploaded_file is not None: img = load_image(uploaded_file) if img.shape[2] == 4: img = img[:, :, :3] # Drop the alpha channel if it exists img_processed = None if st.button('Process'): with st.spinner('Processing...'): start = time.time() if mosaic: tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416) processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles] overlay = stitch_tiles(processed_tiles, rows, cols, 416) # Crop the padding back out overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width] img_processed = overlay else: img_processed = process_image(session, img, colors) end = time.time() st.write(f"Processing time: {end - start} seconds") st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True) if img_processed is not None: st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True) st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ")