Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image, ImageOps | |
import PIL | |
# cuda cpu | |
device_name = 'cpu' | |
device = torch.device(device_name) | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) | |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting").to(device) | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def get_mask(text, image): | |
inputs = processor( | |
text=[text], images=[image], padding="max_length", return_tensors="pt" | |
).to(device) | |
outputs = model(**inputs) | |
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() | |
mask_pil = numpy_to_pil(mask)[0].resize(image.size) | |
#mask_pil.show() | |
return mask_pil | |
def predict(prompt, negative_prompt, image, obj2mask): | |
mask = get_mask(obj2mask, image) | |
image = image.convert("RGB").resize((512, 512)) | |
mask_image = mask.convert("RGB").resize((512, 512)) | |
mask_image = ImageOps.invert(mask_image) | |
images = inpainting_pipeline(prompt=prompt, negative_prompt=negative_prompt, image=image, | |
mask_image=mask_image).images | |
mask = mask_image.convert('L') | |
PIL.Image.composite(images[0], image, mask) | |
return (images[0]) | |
def inference(prompt, negative_prompt, obj2mask, image_numpy): | |
generator = torch.Generator() | |
generator.manual_seed(int(52362)) | |
image = numpy_to_pil(image_numpy)[0].convert("RGB").resize((512, 512)) | |
img = predict(prompt, negative_prompt, image, obj2mask) | |
return img | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="cinematic, landscape, sharpe focus") | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="illustration, 3d render") | |
mask = gr.Textbox(label="Mask", value="shoe") | |
intput_img = gr.Image() | |
run = gr.Button(value="Generate") | |
with gr.Column(): | |
output_img = gr.Image() | |
run.click( | |
inference, | |
inputs=[prompt, negative_prompt, mask, intput_img | |
], | |
outputs=output_img, | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch() | |