import streamlit as st
from streamlit_pannellum import streamlit_pannellum
from diffusers import StableDiffusionLDM3DPipeline
from PIL import Image
from typing import Optional
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Conv2d
from torch.nn.modules.utils import _pair

# Function to override _conv_forward method
def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
    paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
    working = F.pad(input, paddingX, mode='circular')
    working = F.pad(working, paddingY, mode='constant')
    return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)

# Load the pipeline
pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-pano")"cuda")

# Patch the Conv2d layers
targets = [pipe.vae, pipe.text_encoder, pipe.unet]
for target in targets:
    for module in target.modules():
        if isinstance(module, Conv2d):
            module._conv_forward = asymmetricConv2DConvForward.__get__(module, Conv2d)

# Function to generate panoramic images
def generate_panoramic_image(prompt, name):
    output = pipe(prompt, width=1024, height=512, guidance_scale=7.0, num_inference_steps=50)
    rgb_image, depth_image = output.rgb, output.depth
    rgb_image[0].save(name + "_ldm3d_rgb.jpg")
    depth_image[0].save(name + "_ldd3d_depth.png")
    return name + "_ldm3d_rgb.jpg", name + "_ldd3d_depth.png"

# Streamlit Interface
st.title("Pannellum Streamlit plugin")
st.markdown("This space is a showcase of the [streamlit_pannellum]( lib.")

prompt = st.text_input("Enter a prompt for the panoramic image",
                       "360, Ben Erdt, Ognjen Sporin, Raphael Lacoste. A garden of oversized flowers...")

generate_button = st.button("Generate Panoramic Image")

if generate_button:
    name = "generated_image"  # This can be dynamic
    rgb_image_path, _ = generate_panoramic_image(prompt, name)

    # Display the generated panoramic image in Pannellum viewer
            "default": {
                "firstScene": "generated",
                "autoLoad": True
            "scenes": {
                "generated": {
                    "title": "Generated Panoramic Image",
                    "type": "equirectangular",
                    "panorama": rgb_image_path,
                    "autoLoad": True,