|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
from skimage.transform import resize |
|
import tensorflow as tf |
|
from tensorflow.keras.models import load_model |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
import gradio as gr |
|
import os |
|
|
|
REPO_ID = "amosfang/segmentation_u_net" |
|
|
|
def pil_image_as_numpy_array(pilimg): |
|
img_array = tf.keras.utils.img_to_array(pilimg) |
|
return img_array |
|
|
|
def resize_image(image, input_shape=(224, 224, 3)): |
|
|
|
image_array = pil_image_as_numpy_array(image) |
|
image = image_array.astype(np.float32) / 255. |
|
|
|
|
|
image_resized = resize(image, input_shape, anti_aliasing=True) |
|
|
|
return image_resized |
|
|
|
def load_model_file(filename): |
|
model_dir = snapshot_download(REPO_ID) |
|
saved_model_dir = os.path.join(download_dir, filename) |
|
unet_model = load_model(model_dir) |
|
return unet_model |
|
|
|
def ensemble_predict(X_array): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
X_array = np.expand_dims(X_array, axis=0) |
|
|
|
unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5') |
|
vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5') |
|
resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5') |
|
|
|
pred_y_unet = unet_model.predict(X_array) |
|
pred_y_vgg16 = vgg16_model.predict(X_array) |
|
pred_y_resnet50 = resnet50_model.predict(X_array) |
|
|
|
return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3 |
|
|
|
def get_predictions(y_prediction_encoded): |
|
|
|
|
|
predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1 |
|
|
|
return predicted_label_indices |
|
|
|
def predict(image): |
|
sample_image_resized = resize_image(image) |
|
y_pred = ensemble_predict(sample_image_resized) |
|
y_pred = get_predictions(y_pred).squeeze() |
|
|
|
|
|
fig, ax = plt.subplots() |
|
cax = ax.imshow(y_pred, cmap='viridis', vmin=1, vmax=7) |
|
|
|
|
|
image_buffer = io.BytesIO() |
|
plt.savefig(image_buffer, format='png') |
|
image_buffer.seek(0) |
|
image_pil = Image.open(image_buffer) |
|
|
|
|
|
plt.close(fig) |
|
|
|
return image_pil |
|
|
|
|
|
sample_images = [['989953_sat.jpg'], ['999380_sat.jpg'], ['988205_sat.jpg']] |
|
|
|
|
|
gr.Interface( |
|
predict, |
|
title='Land Cover Segmentation', |
|
inputs=[gr.Image()], |
|
outputs=[gr.Image()], |
|
examples=sample_images |
|
).launch(debug=True, share=True) |
|
|
|
|
|
iface.launch(share=True) |