ad-image-generation / app_v1.py
tobiaspires's picture
Enviando nova versão
ca1f90f
import gradio as gr
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageOps
import PIL
# cuda cpu
device_name = 'cpu'
device = torch.device(device_name)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting").to(device)
def numpy_to_pil(images):
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def get_mask(text, image):
inputs = processor(
text=[text], images=[image], padding="max_length", return_tensors="pt"
).to(device)
outputs = model(**inputs)
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
mask_pil = numpy_to_pil(mask)[0].resize(image.size)
#mask_pil.show()
return mask_pil
def predict(prompt, negative_prompt, image, obj2mask):
mask = get_mask(obj2mask, image)
image = image.convert("RGB").resize((512, 512))
mask_image = mask.convert("RGB").resize((512, 512))
mask_image = ImageOps.invert(mask_image)
images = inpainting_pipeline(prompt=prompt, negative_prompt=negative_prompt, image=image,
mask_image=mask_image).images
mask = mask_image.convert('L')
PIL.Image.composite(images[0], image, mask)
return (images[0])
def inference(prompt, negative_prompt, obj2mask, image_numpy):
generator = torch.Generator()
generator.manual_seed(int(52362))
image = numpy_to_pil(image_numpy)[0].convert("RGB").resize((512, 512))
img = predict(prompt, negative_prompt, image, obj2mask)
return img
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="cinematic, landscape, sharpe focus")
negative_prompt = gr.Textbox(label="Negative Prompt", value="illustration, 3d render")
mask = gr.Textbox(label="Mask", value="shoe")
intput_img = gr.Image()
run = gr.Button(value="Generate")
with gr.Column():
output_img = gr.Image()
run.click(
inference,
inputs=[prompt, negative_prompt, mask, intput_img
],
outputs=output_img,
)
demo.queue(concurrency_count=1)
demo.launch()