import gradio as gr from keras.models import load_model from patchify import patchify, unpatchify import numpy as np import cv2 from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt # Define colors for classes class_building = np.array([60, 16, 152]) class_land = np.array([132, 41, 246]) class_road = np.array([110, 193, 228]) class_vegetation = np.array([254, 221, 58]) class_water = np.array([226, 169, 41]) class_unlabeled = np.array([155, 155, 155]) # Number of classes in your segmentation task total_classes = 6 # Update this with your total number of classes # Define custom loss functions def jaccard_coef(y_true, y_pred): smooth = 1e-12 intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3]) union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection jac = K.mean((intersection + smooth) / (union + smooth), axis=0) return jac def dice_loss(y_true, y_pred): smooth = 1e-12 intersection = K.sum(y_true * y_pred, axis=[1,2,3]) union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0) return 1.0 - dice def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0): y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) ce_loss = -y_true * K.log(y_pred) weight = alpha * y_true * K.pow((1 - y_pred), gamma) fl_loss = ce_loss * weight return K.mean(K.sum(fl_loss, axis=-1)) def total_loss(y_true, y_pred): return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred)) # Load the pre-trained model model_path = 'satmodel.h5' # Replace with your model path model = load_model(model_path, custom_objects={'total_loss': total_loss, 'jaccard_coef': jaccard_coef, 'dice_loss': dice_loss, 'focal_loss': focal_loss}) # MinMaxScaler for normalization minmaxscaler = MinMaxScaler() # Function to predict the full image def predict_full_image(image, patch_size, model): original_shape = image.shape print(f"Original image shape: {original_shape}") # Pad image to make its dimensions divisible by the patch size pad_height = (patch_size - image.shape[0] % patch_size) % patch_size pad_width = (patch_size - image.shape[1] % patch_size) % patch_size image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant', constant_values=0) padded_shape = image.shape print(f"Padded image shape: {padded_shape}") # Normalize the image image = minmaxscaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape) # Create patches patched_images = patchify(image, (patch_size, patch_size, 3), step=patch_size) print(f"Patched image shape: {patched_images.shape}") predicted_patches = [] # Predict on each patch for i in range(patched_images.shape[0]): for j in range(patched_images.shape[1]): single_patch = patched_images[i, j, 0] single_patch = np.expand_dims(single_patch, axis=0) prediction = model.predict(single_patch) predicted_patches.append(prediction[0]) # Reshape predicted patches predicted_patches = np.array(predicted_patches) print(f"Predicted patches shape: {predicted_patches.shape}") predicted_patches = predicted_patches.reshape(patched_images.shape[0], patched_images.shape[1], patch_size, patch_size, total_classes) print(f"Reshaped predicted patches shape: {predicted_patches.shape}") # Unpatchify the image reconstructed_image = np.zeros((padded_shape[0], padded_shape[1], total_classes)) for i in range(patched_images.shape[0]): for j in range(patched_images.shape[1]): reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size, :] = predicted_patches[i, j] print(f"Reconstructed image shape (with padding): {reconstructed_image.shape}") # Remove padding reconstructed_image = reconstructed_image[:original_shape[0], :original_shape[1]] print(f"Final reconstructed image shape: {reconstructed_image.shape}") return reconstructed_image # Function to process the input image def process_input_image(input_image): image_patch_size = 256 predicted_full_image = predict_full_image(input_image, image_patch_size, model) # Convert the predictions to RGB predicted_full_image_rgb = np.zeros_like(input_image) # Map the predicted class labels to RGB colors predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 0] = class_water predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 1] = class_land predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 2] = class_road predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 3] = class_building predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 4] = class_vegetation predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 5] = class_unlabeled return "Image processed", predicted_full_image_rgb # Gradio application my_app = gr.Blocks() with my_app: gr.Markdown("Satellite Image Segmentation Application UI with Gradio") gr.Markdown("Building: #3C1098,Land (unpaved area): #8429F6,Road: #6EC1E4,Vegetation: #FEDD3A,Water: #E2A929,Unlabeled: #9B9B9B") gr.Markdown("Building: Purple,Land (unpaved area): Violet, Road:Blue, Vegetation: Gold/yellow, Water: Copper, Unlabeled: Gray") with gr.Tabs(): with gr.TabItem("Select your image"): with gr.Row(): with gr.Column(): img_source = gr.Image(label="Please select source Image") source_image_loader = gr.Button("Load above Image") with gr.Column(): output_label = gr.Label(label="Prediction Image Info ") img_output = gr.Image(label="Image Output") source_image_loader.click( process_input_image, inputs=[img_source], outputs=[output_label, img_output] ) # Launch the app my_app.launch(share=True) # import gradio as gr # from keras.models import load_model # from patchify import patchify, unpatchify # import numpy as np # import cv2 # from sklearn.preprocessing import MinMaxScaler # import matplotlib.pyplot as plt # # Define colors for classes # class_building = np.array([60, 16, 152]) # class_land = np.array([132, 41, 246]) # class_road = np.array([110, 193, 228]) # class_vegetation = np.array([254, 221, 58]) # class_water = np.array([226, 169, 41]) # class_unlabeled = np.array([155, 155, 155]) # # Number of classes in your segmentation task # total_classes = 6 # Update this with your total number of classes # # Define custom loss functions # def jaccard_coef(y_true, y_pred): # smooth = 1e-12 # intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3]) # union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection # jac = K.mean((intersection + smooth) / (union + smooth), axis=0) # return jac # def dice_loss(y_true, y_pred): # smooth = 1e-12 # intersection = K.sum(y_true * y_pred, axis=[1,2,3]) # union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) # dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0) # return 1.0 - dice # def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0): # y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) # ce_loss = -y_true * K.log(y_pred) # weight = alpha * y_true * K.pow((1 - y_pred), gamma) # fl_loss = ce_loss * weight # return K.mean(K.sum(fl_loss, axis=-1)) # def total_loss(y_true, y_pred): # return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred)) # # Load the pre-trained model # model_path = 'satmodel.h5' # Replace with your model path # model = load_model(model_path, custom_objects={'total_loss': total_loss, 'jaccard_coef': jaccard_coef, 'dice_loss': dice_loss, 'focal_loss': focal_loss}) # # MinMaxScaler for normalization # minmaxscaler = MinMaxScaler() # # Function to predict the full image # def predict_full_image(image, patch_size, model): # original_shape = image.shape # print(f"Original image shape: {original_shape}") # # Pad image to make its dimensions divisible by the patch size # pad_height = (patch_size - image.shape[0] % patch_size) % patch_size # pad_width = (patch_size - image.shape[1] % patch_size) % patch_size # image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant', constant_values=0) # padded_shape = image.shape # print(f"Padded image shape: {padded_shape}") # # Normalize the image # image = minmaxscaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape) # # Create patches # patched_images = patchify(image, (patch_size, patch_size, 3), step=patch_size) # print(f"Patched image shape: {patched_images.shape}") # predicted_patches = [] # # Predict on each patch # for i in range(patched_images.shape[0]): # for j in range(patched_images.shape[1]): # single_patch = patched_images[i, j, 0] # single_patch = np.expand_dims(single_patch, axis=0) # prediction = model.predict(single_patch) # predicted_patches.append(prediction[0]) # # Reshape predicted patches # predicted_patches = np.array(predicted_patches) # print(f"Predicted patches shape: {predicted_patches.shape}") # predicted_patches = predicted_patches.reshape(patched_images.shape[0], patched_images.shape[1], patch_size, patch_size, total_classes) # print(f"Reshaped predicted patches shape: {predicted_patches.shape}") # # Unpatchify the image # reconstructed_image = np.zeros((padded_shape[0], padded_shape[1], total_classes)) # for i in range(patched_images.shape[0]): # for j in range(patched_images.shape[1]): # reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size, :] = predicted_patches[i, j] # print(f"Reconstructed image shape (with padding): {reconstructed_image.shape}") # # Remove padding # reconstructed_image = reconstructed_image[:original_shape[0], :original_shape[1]] # print(f"Final reconstructed image shape: {reconstructed_image.shape}") # return reconstructed_image # # Function to process the input image # def process_input_image(input_image): # image_patch_size = 256 # predicted_full_image = predict_full_image(input_image, image_patch_size, model) # # Convert the predictions to RGB # predicted_full_image_rgb = np.zeros_like(input_image) # # Map the predicted class labels to RGB colors # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 0] = class_water # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 1] = class_land # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 2] = class_road # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 3] = class_building # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 4] = class_vegetation # predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 5] = class_unlabeled # return "Image processed", predicted_full_image_rgb # # Gradio application # my_app = gr.Blocks() # with my_app: # gr.Markdown("Satellite Image Segmentation Application UI with Gradio") # with gr.Tabs(): # with gr.TabItem("Select your image"): # with gr.Row(): # with gr.Column(): # img_source = gr.Image(label="Please select source Image") # source_image_loader = gr.Button("Load above Image") # with gr.Column(): # output_label = gr.Label(label="Image Info") # img_output = gr.Image(label="Image Output") # source_image_loader.click( # process_input_image, # inputs=[img_source], # outputs=[output_label, img_output] # ) # # Launch the app # my_app.launch() # import os # import cv2 # from PIL import Image # import numpy as np # from matplotlib import pyplot as plt # import random # import gradio as gr # from keras import backend as K # from keras.models import load_model # def jaccard_coef(y_true, y_pred): # y_true_flatten = K.flatten(y_true) # y_pred_flatten = K.flatten(y_pred) # intersection = K.sum(y_true_flatten * y_pred_flatten) # final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0) # return final_coef_value # # Define Dice Loss # def dice_loss(y_true, y_pred): # smooth = 1e-12 # intersection = K.sum(y_true * y_pred, axis=[1,2,3]) # union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) # dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0) # return 1.0 - dice # # Define Focal Loss # def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0): # y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) # ce_loss = -y_true * K.log(y_pred) # weight = alpha * y_true * K.pow((1 - y_pred), gamma) # fl_loss = ce_loss * weight # return K.mean(K.sum(fl_loss, axis=-1)) # # Define Total Loss # def total_loss(y_true, y_pred): # return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred)) # weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666] # from keras.models import load_model # import numpy as np # from PIL import Image # import matplotlib.pyplot as plt # saved_model=load_model('satmodel.h5', custom_objects={'total_loss': total_loss, 'dice_loss': dice_loss, 'focal_loss': focal_loss, 'jaccard_coef': jaccard_coef}) # # def process_input_image(image_source): # # image = np.expand_dims(image_source, 0) # # prediction = saved_model.predict(image) # # predicted_image = np.argmax(prediction, axis=3) # # predicted_image = predicted_image[0,:,:] # # predicted_image = predicted_image * 50 # # return 'Predicted Masked Image', predicted_image # import matplotlib.pyplot as plt # import matplotlib.colors as mcolors # # # Define the image processing function # # Define the image processing function # def process_input_image(image): # image = Image.fromarray(image) # image = image.convert('RGB') # Convert the image to RGB # image = image.resize((256, 256)) # image = np.array(image) # image = np.expand_dims(image, 0) # prediction = saved_model.predict(image) # predicted_image = np.argmax(prediction, axis=3) # predicted_image = predicted_image[0,:,:] # predicted_image = predicted_image * 50 # # Apply a colormap to the predicted image # cmap = plt.get_cmap('viridis') # You can choose any colormap you prefer # colored_image = cmap(predicted_image / predicted_image.max()) # Normalize to [0, 1] # colored_image = (colored_image[:, :, :3] * 255).astype(np.uint8) # Convert to RGB and scale to [0, 255] # return 'Predicted Masked Image', colored_image # # return 'Predicted Masked Image', predicted_image # my_app = gr.Blocks() # with my_app: # gr.Markdown("Statellite Image Segmentation Application UI with Gradio") # with gr.Tabs(): # with gr.TabItem("Select your image"): # with gr.Row(): # with gr.Column(): # img_source = gr.Image(label="Please select source Image") # source_image_loader = gr.Button("Load above Image") # with gr.Column(): # output_label = gr.Label(label="Image Info") # img_output = gr.Image(label="Image Output") # source_image_loader.click( # process_input_image, # [ # img_source # ], # [ # output_label, # img_output # ] # ) # my_app.launch(debug=True,share=True)