import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from diffusers.utils import load_image
from PIL import Image as im
import requests
import io
import gradio as gr

API_URL = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
headers = {"Authorization": "Bearer HF_TOKEN"}

model_id = "instruction-tuning-sd/cartoonizer"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, use_auth_token=True
  ).to(device)

def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.content

def cartoonizer(input_img,bg_prompt):
  if input_img is not None:
    data = im.fromarray(input_img)
    data = data.resize((300,300))
    org_image = load_image(data)
    cart_image = pipeline("Cartoonize the following image", image=org_image).images[0]
    if len(bg_prompt) !=0:
      image_bytes = query({
        "inputs": bg_prompt,
      })
    else:
      image_bytes = query({
        "inputs": "orange background image",
      })
    bg_image = im.open(io.BytesIO(image_bytes))

    return [cart_image,bg_image]
  else:
    gr.Warning("Please upload an Input Image!")
    return [input_img,input_img]


with gr.Blocks(theme = gr.themes.Citrus()) as cart:
  gr.HTML("""<h1 align="center">Cartoonize your Image with best backgrounds!</h1>""")
  with gr.Tab("Cartoonize"):
    with gr.Row():
      image_input = gr.Image()
      image_output = gr.Image()
      text_img_output = gr.Image()

    txt_label = gr.Label("Enter your photo frame description:")
    txt_input = gr.Textbox()
    image_btn = gr.Button("Convert")

  image_btn.click(cartoonizer,inputs = [image_input,txt_input],outputs=[image_output,text_img_output])


cart.launch()