|
import numpy as np |
|
import cv2 |
|
import onnxruntime |
|
import gradio as gr |
|
|
|
|
|
def pre_process(img: np.array) -> np.array: |
|
|
|
img = np.transpose(img[:, :, 0:3], (2, 0, 1)) |
|
|
|
img = np.expand_dims(img, axis=0).astype(np.float32) |
|
return img |
|
|
|
|
|
def post_process(img: np.array) -> np.array: |
|
|
|
img = np.squeeze(img) |
|
|
|
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) |
|
return img |
|
|
|
|
|
def inference(model_path: str, img_array: np.array) -> np.array: |
|
options = onnxruntime.SessionOptions() |
|
options.intra_op_num_threads = 1 |
|
options.inter_op_num_threads = 1 |
|
ort_session = onnxruntime.InferenceSession(model_path, options) |
|
ort_inputs = {ort_session.get_inputs()[0].name: img_array} |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
|
|
return ort_outs[0] |
|
|
|
|
|
def convert_pil_to_cv2(image): |
|
|
|
open_cv_image = np.array(image) |
|
|
|
open_cv_image = open_cv_image[:, :, ::-1].copy() |
|
return open_cv_image |
|
|
|
|
|
def upscale(image, model): |
|
model_path = f"models/{model}.ort" |
|
img = convert_pil_to_cv2(image) |
|
if img.ndim == 2: |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
|
|
if img.shape[2] == 4: |
|
alpha = img[:, :, 3] |
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) |
|
alpha_output = post_process(inference(model_path, pre_process(alpha))) |
|
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) |
|
|
|
img = img[:, :, 0:3] |
|
image_output = post_process(inference(model_path, pre_process(img))) |
|
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) |
|
image_output[:, :, 3] = alpha_output |
|
|
|
elif img.shape[2] == 3: |
|
image_output = post_process(inference(model_path, pre_process(img))) |
|
|
|
return image_output |
|
|
|
|
|
examples = [[f"examples/example_{i+1}.png", "modelx4"] for i in range(5)] |
|
examples += [ |
|
[f"minecraft_examples/minecraft-{j+1}.png", "minecraft_modelx4"] for j in range(5) |
|
] |
|
css = ".output-image, .input-image, .image-preview {height: 480px !important} " |
|
model_choices = ["modelx2", "modelx4", "minecraft_modelx4"] |
|
|
|
gr.Interface( |
|
fn=upscale, |
|
inputs=[ |
|
gr.inputs.Image(type="pil"), |
|
gr.inputs.Radio( |
|
model_choices, |
|
type="value", |
|
default=None, |
|
label="Choose a Model", |
|
optional=False, |
|
), |
|
], |
|
outputs="image", |
|
examples=examples, |
|
examples_per_page=5, |
|
title="Image Upscaling 🦆", |
|
allow_flagging="never", |
|
css=css, |
|
).launch(debug=True) |
|
|