import gradio as gr
from keras.models import load_model
from patchify import patchify, unpatchify
import numpy as np
import cv2
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

# Define colors for classes
class_building = np.array([60, 16, 152])
class_land = np.array([132, 41, 246])
class_road = np.array([110, 193, 228])
class_vegetation = np.array([254, 221, 58])
class_water = np.array([226, 169, 41])
class_unlabeled = np.array([155, 155, 155])

# Number of classes in your segmentation task
total_classes = 6  # Update this with your total number of classes

# Define custom loss functions
def jaccard_coef(y_true, y_pred):
    smooth = 1e-12
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
    jac = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return jac

def dice_loss(y_true, y_pred):
    smooth = 1e-12
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
    return 1.0 - dice

def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
    y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
    ce_loss = -y_true * K.log(y_pred)
    weight = alpha * y_true * K.pow((1 - y_pred), gamma)
    fl_loss = ce_loss * weight
    return K.mean(K.sum(fl_loss, axis=-1))

def total_loss(y_true, y_pred):
    return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))

# Load the pre-trained model
model_path = 'satmodel.h5'  # Replace with your model path
model = load_model(model_path, custom_objects={'total_loss': total_loss, 'jaccard_coef': jaccard_coef, 'dice_loss': dice_loss, 'focal_loss': focal_loss})

# MinMaxScaler for normalization
minmaxscaler = MinMaxScaler()

# Function to predict the full image
def predict_full_image(image, patch_size, model):
    original_shape = image.shape
    print(f"Original image shape: {original_shape}")
    
    # Pad image to make its dimensions divisible by the patch size
    pad_height = (patch_size - image.shape[0] % patch_size) % patch_size
    pad_width = (patch_size - image.shape[1] % patch_size) % patch_size
    image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
    padded_shape = image.shape
    print(f"Padded image shape: {padded_shape}")
    
    # Normalize the image
    image = minmaxscaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape)
    
    # Create patches
    patched_images = patchify(image, (patch_size, patch_size, 3), step=patch_size)
    print(f"Patched image shape: {patched_images.shape}")
    
    predicted_patches = []
    
    # Predict on each patch
    for i in range(patched_images.shape[0]):
        for j in range(patched_images.shape[1]):
            single_patch = patched_images[i, j, 0]
            single_patch = np.expand_dims(single_patch, axis=0)
            prediction = model.predict(single_patch)
            predicted_patches.append(prediction[0])
    
    # Reshape predicted patches
    predicted_patches = np.array(predicted_patches)
    print(f"Predicted patches shape: {predicted_patches.shape}")
    
    predicted_patches = predicted_patches.reshape(patched_images.shape[0], patched_images.shape[1], patch_size, patch_size, total_classes)
    print(f"Reshaped predicted patches shape: {predicted_patches.shape}")
    
    # Unpatchify the image
    reconstructed_image = np.zeros((padded_shape[0], padded_shape[1], total_classes))
    for i in range(patched_images.shape[0]):
        for j in range(patched_images.shape[1]):
            reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size, :] = predicted_patches[i, j]
    print(f"Reconstructed image shape (with padding): {reconstructed_image.shape}")
    
    # Remove padding
    reconstructed_image = reconstructed_image[:original_shape[0], :original_shape[1]]
    print(f"Final reconstructed image shape: {reconstructed_image.shape}")
    
    return reconstructed_image

# Function to process the input image
def process_input_image(input_image):
    image_patch_size = 256
    predicted_full_image = predict_full_image(input_image, image_patch_size, model)
    
    # Convert the predictions to RGB
    predicted_full_image_rgb = np.zeros_like(input_image)
    
    # Map the predicted class labels to RGB colors
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 0] = class_water
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 1] = class_land
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 2] = class_road
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 3] = class_building
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 4] = class_vegetation
    predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 5] = class_unlabeled
    
    return "Image processed", predicted_full_image_rgb

# Gradio application
my_app = gr.Blocks()
with my_app:
    gr.Markdown("Satellite Image Segmentation Application UI with Gradio")
    gr.Markdown("Building: #3C1098,Land (unpaved area): #8429F6,Road: #6EC1E4,Vegetation: #FEDD3A,Water: #E2A929,Unlabeled: #9B9B9B")
    gr.Markdown("Building: Purple,Land (unpaved area): Violet,  Road:Blue,  Vegetation: Gold/yellow,  Water: Copper,  Unlabeled: Gray")
    with gr.Tabs():
        with gr.TabItem("Select your image"):
            with gr.Row():
                with gr.Column():
                    img_source = gr.Image(label="Please select source Image")
                    source_image_loader = gr.Button("Load above Image")
                with gr.Column():
                    output_label = gr.Label(label="Prediction Image Info ")
                    img_output = gr.Image(label="Image Output")
    source_image_loader.click(
        process_input_image,
        inputs=[img_source],
        outputs=[output_label, img_output]
    )

# Launch the app
my_app.launch(share=True)

























# import gradio as gr
# from keras.models import load_model
# from patchify import patchify, unpatchify
# import numpy as np
# import cv2
# from sklearn.preprocessing import MinMaxScaler
# import matplotlib.pyplot as plt

# # Define colors for classes
# class_building = np.array([60, 16, 152])
# class_land = np.array([132, 41, 246])
# class_road = np.array([110, 193, 228])
# class_vegetation = np.array([254, 221, 58])
# class_water = np.array([226, 169, 41])
# class_unlabeled = np.array([155, 155, 155])

# # Number of classes in your segmentation task
# total_classes = 6  # Update this with your total number of classes

# # Define custom loss functions
# def jaccard_coef(y_true, y_pred):
#     smooth = 1e-12
#     intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
#     union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
#     jac = K.mean((intersection + smooth) / (union + smooth), axis=0)
#     return jac

# def dice_loss(y_true, y_pred):
#     smooth = 1e-12
#     intersection = K.sum(y_true * y_pred, axis=[1,2,3])
#     union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
#     dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
#     return 1.0 - dice

# def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
#     y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
#     ce_loss = -y_true * K.log(y_pred)
#     weight = alpha * y_true * K.pow((1 - y_pred), gamma)
#     fl_loss = ce_loss * weight
#     return K.mean(K.sum(fl_loss, axis=-1))

# def total_loss(y_true, y_pred):
#     return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))

# # Load the pre-trained model
# model_path = 'satmodel.h5'  # Replace with your model path
# model = load_model(model_path, custom_objects={'total_loss': total_loss, 'jaccard_coef': jaccard_coef, 'dice_loss': dice_loss, 'focal_loss': focal_loss})

# # MinMaxScaler for normalization
# minmaxscaler = MinMaxScaler()

# # Function to predict the full image
# def predict_full_image(image, patch_size, model):
#     original_shape = image.shape
#     print(f"Original image shape: {original_shape}")
    
#     # Pad image to make its dimensions divisible by the patch size
#     pad_height = (patch_size - image.shape[0] % patch_size) % patch_size
#     pad_width = (patch_size - image.shape[1] % patch_size) % patch_size
#     image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
#     padded_shape = image.shape
#     print(f"Padded image shape: {padded_shape}")
    
#     # Normalize the image
#     image = minmaxscaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape)
    
#     # Create patches
#     patched_images = patchify(image, (patch_size, patch_size, 3), step=patch_size)
#     print(f"Patched image shape: {patched_images.shape}")
    
#     predicted_patches = []
    
#     # Predict on each patch
#     for i in range(patched_images.shape[0]):
#         for j in range(patched_images.shape[1]):
#             single_patch = patched_images[i, j, 0]
#             single_patch = np.expand_dims(single_patch, axis=0)
#             prediction = model.predict(single_patch)
#             predicted_patches.append(prediction[0])
    
#     # Reshape predicted patches
#     predicted_patches = np.array(predicted_patches)
#     print(f"Predicted patches shape: {predicted_patches.shape}")
    
#     predicted_patches = predicted_patches.reshape(patched_images.shape[0], patched_images.shape[1], patch_size, patch_size, total_classes)
#     print(f"Reshaped predicted patches shape: {predicted_patches.shape}")
    
#     # Unpatchify the image
#     reconstructed_image = np.zeros((padded_shape[0], padded_shape[1], total_classes))
#     for i in range(patched_images.shape[0]):
#         for j in range(patched_images.shape[1]):
#             reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size, :] = predicted_patches[i, j]
#     print(f"Reconstructed image shape (with padding): {reconstructed_image.shape}")
    
#     # Remove padding
#     reconstructed_image = reconstructed_image[:original_shape[0], :original_shape[1]]
#     print(f"Final reconstructed image shape: {reconstructed_image.shape}")
    
#     return reconstructed_image

# # Function to process the input image
# def process_input_image(input_image):
#     image_patch_size = 256
#     predicted_full_image = predict_full_image(input_image, image_patch_size, model)
    
#     # Convert the predictions to RGB
#     predicted_full_image_rgb = np.zeros_like(input_image)
    
#     # Map the predicted class labels to RGB colors
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 0] = class_water
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 1] = class_land
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 2] = class_road
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 3] = class_building
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 4] = class_vegetation
#     predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 5] = class_unlabeled
    
#     return "Image processed", predicted_full_image_rgb

# # Gradio application
# my_app = gr.Blocks()
# with my_app:
#     gr.Markdown("Satellite Image Segmentation Application UI with Gradio")
#     with gr.Tabs():
#         with gr.TabItem("Select your image"):
#             with gr.Row():
#                 with gr.Column():
#                     img_source = gr.Image(label="Please select source Image")
#                     source_image_loader = gr.Button("Load above Image")
#                 with gr.Column():
#                     output_label = gr.Label(label="Image Info")
#                     img_output = gr.Image(label="Image Output")
#     source_image_loader.click(
#         process_input_image,
#         inputs=[img_source],
#         outputs=[output_label, img_output]
#     )

# # Launch the app
# my_app.launch()













# import os
# import cv2
# from PIL import Image
# import numpy as np
# from matplotlib import pyplot as plt
# import random
# import gradio as gr
# from keras import backend as K
# from keras.models import load_model
# def jaccard_coef(y_true, y_pred):
#   y_true_flatten = K.flatten(y_true)
#   y_pred_flatten = K.flatten(y_pred)
#   intersection = K.sum(y_true_flatten * y_pred_flatten)
#   final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
#   return final_coef_value


# # Define Dice Loss
# def dice_loss(y_true, y_pred):
#     smooth = 1e-12
#     intersection = K.sum(y_true * y_pred, axis=[1,2,3])
#     union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
#     dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
#     return 1.0 - dice

# # Define Focal Loss
# def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
#     y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
#     ce_loss = -y_true * K.log(y_pred)
#     weight = alpha * y_true * K.pow((1 - y_pred), gamma)
#     fl_loss = ce_loss * weight
#     return K.mean(K.sum(fl_loss, axis=-1))

# # Define Total Loss
# def total_loss(y_true, y_pred):
#     return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))

# weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]


# from keras.models import load_model
# import numpy as np
# from PIL import Image
# import matplotlib.pyplot as plt
# saved_model=load_model('satmodel.h5', custom_objects={'total_loss': total_loss, 'dice_loss': dice_loss, 'focal_loss': focal_loss, 'jaccard_coef': jaccard_coef})
# # def process_input_image(image_source):
# #   image = np.expand_dims(image_source, 0)

# #   prediction = saved_model.predict(image)
# #   predicted_image = np.argmax(prediction, axis=3)

# #   predicted_image = predicted_image[0,:,:]
# #   predicted_image = predicted_image * 50
# #   return 'Predicted Masked Image', predicted_image

# import matplotlib.pyplot as plt
# import matplotlib.colors as mcolors

# # # Define the image processing function

# # Define the image processing function
# def process_input_image(image):
#     image = Image.fromarray(image)
#     image = image.convert('RGB')  # Convert the image to RGB
#     image = image.resize((256, 256))
#     image = np.array(image)
#     image = np.expand_dims(image, 0)

#     prediction = saved_model.predict(image)
#     predicted_image = np.argmax(prediction, axis=3)

#     predicted_image = predicted_image[0,:,:]
#     predicted_image = predicted_image * 50


#     # Apply a colormap to the predicted image
#     cmap = plt.get_cmap('viridis')  # You can choose any colormap you prefer
#     colored_image = cmap(predicted_image / predicted_image.max())  # Normalize to [0, 1]
#     colored_image = (colored_image[:, :, :3] * 255).astype(np.uint8)  # Convert to RGB and scale to [0, 255]

#     return 'Predicted Masked Image', colored_image
#     # return 'Predicted Masked Image', predicted_image

# my_app = gr.Blocks()
# with my_app:
#   gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
#   with gr.Tabs():
#     with gr.TabItem("Select your image"):
#       with gr.Row():
#         with gr.Column():
#             img_source = gr.Image(label="Please select source Image")
#             source_image_loader = gr.Button("Load above Image")
#         with gr.Column():
#             output_label = gr.Label(label="Image Info")
#             img_output = gr.Image(label="Image Output")
#     source_image_loader.click(
#         process_input_image,
#         [
#             img_source
#         ],
#         [
#             output_label,
#             img_output
#         ]
#     )
# my_app.launch(debug=True,share=True)