amosfang's picture
Update app.py
7164500 verified
raw
history blame
2.94 kB
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.transform import resize
import tensorflow as tf
from tensorflow.keras.models import load_model
from huggingface_hub import snapshot_download
import gradio as gr
import os
REPO_ID = "amosfang/segmentation_u_net"
def pil_image_as_numpy_array(pilimg):
img_array = tf.keras.utils.img_to_array(pilimg)
return img_array
def resize_image(image, input_shape=(224, 224, 3)):
# Convert to NumPy array and normalize
image_array = pil_image_as_numpy_array(image)
image = image_array.astype(np.float32) / 255.
# Resize the image to 224x224
image_resized = resize(image, input_shape, anti_aliasing=True)
return image_resized
def load_model_file(filename):
model_dir = snapshot_download(REPO_ID)
saved_model_dir = os.path.join(download_dir, filename)
unet_model = load_model(model_dir)
return unet_model
def ensemble_predict(X_array):
#
# Call the predict methods of the unet_model and the vgg16_unet_model
# to retrieve their predictions.
#
# Sum the two predictions together and return their results.
# You can also consider multiplying a different weight on
# one or both of the models to improve performance
X_array = np.expand_dims(X_array, axis=0)
unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
pred_y_unet = unet_model.predict(X_array)
pred_y_vgg16 = vgg16_model.predict(X_array)
pred_y_resnet50 = resnet50_model.predict(X_array)
return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3
def get_predictions(y_prediction_encoded):
# Convert predictions to categorical indices
predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1
return predicted_label_indices
def predict(image):
sample_image_resized = resize_image(image)
y_pred = ensemble_predict(sample_image_resized)
y_pred = get_predictions(y_pred).squeeze()
# Create a figure without saving it to a file
fig, ax = plt.subplots()
cax = ax.imshow(y_pred, cmap='viridis', vmin=1, vmax=7)
# Convert the figure to a PIL Image
image_buffer = io.BytesIO()
plt.savefig(image_buffer, format='png')
image_buffer.seek(0)
image_pil = Image.open(image_buffer)
# Close the figure to release resources
plt.close(fig)
return image_pil
# Specify paths to example images
sample_images = [['989953_sat.jpg'], ['999380_sat.jpg'], ['988205_sat.jpg']]
# Launch Gradio Interface
gr.Interface(
predict,
title='Land Cover Segmentation',
inputs=[gr.Image()],
outputs=[gr.Image()],
examples=sample_images
).launch(debug=True, share=True)
# Launch the interface
iface.launch(share=True)