import matplotlib.pyplot as plt import numpy as np from PIL import Image from skimage.transform import resize import tensorflow as tf from tensorflow.keras.models import load_model from huggingface_hub import snapshot_download import gradio as gr import os import io REPO_ID = "amosfang/segmentation_u_net" def pil_image_as_numpy_array(pilimg): img_array = tf.keras.utils.img_to_array(pilimg) return img_array def resize_image(image, input_shape=(224, 224, 3)): # Convert to NumPy array and normalize image_array = pil_image_as_numpy_array(image) image = image_array.astype(np.float32) / 255. # Resize the image to 224x224 image_resized = resize(image, input_shape, anti_aliasing=True) return image_resized def load_model_file(filename): model_dir = snapshot_download(REPO_ID) saved_model_filepath = os.path.join(model_dir, filename) unet_model = load_model(saved_model_filepath) return unet_model def ensemble_predict(X_array): # # Call the predict methods of the unet_model and the vgg16_unet_model # to retrieve their predictions. # # Sum the two predictions together and return their results. # You can also consider multiplying a different weight on # one or both of the models to improve performance X_array = np.expand_dims(X_array, axis=0) unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5') vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5') resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5') pred_y_unet = unet_model.predict(X_array) pred_y_vgg16 = vgg16_model.predict(X_array) pred_y_resnet50 = resnet50_model.predict(X_array) return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3 def get_predictions(y_prediction_encoded): # Convert predictions to categorical indices predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1 return predicted_label_indices def predict(image): sample_image_resized = resize_image(image) y_pred = ensemble_predict(sample_image_resized) y_pred = get_predictions(y_pred).squeeze() # Create a figure without saving it to a file fig, ax = plt.subplots() cax = ax.imshow(y_pred, cmap='viridis', vmin=1, vmax=7) # Convert the figure to a PIL Image image_buffer = io.BytesIO() plt.savefig(image_buffer, format='png') image_buffer.seek(0) image_pil = Image.open(image_buffer) # Close the figure to release resources plt.close(fig) return image_pil # Specify paths to example images sample_images = [['989953_sat.jpg'], ['999380_sat.jpg'], ['988205_sat.jpg']] # Launch Gradio Interface gr.Interface( predict, title='Land Cover Segmentation', inputs=[gr.Image()], outputs=[gr.Image()], examples=sample_images ).launch(debug=True, share=True) # Launch the interface iface.launch(share=True)