from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline import torch from PIL import Image import gradio as gr import aiohttp import asyncio from io import BytesIO device = "cuda:0" if torch.cuda.is_available() else "cpu" dtype = torch.float16 nsfw_pipe = pipeline("image-classification", model=AutoModelForImageClassification.from_pretrained( "carbon225/vit-base-patch16-224-hentai"), feature_extractor=AutoFeatureExtractor.from_pretrained( "carbon225/vit-base-patch16-224-hentai"), device=device, torch_dtype=dtype) style_pipe = pipeline("image-classification", model=AutoModelForImageClassification.from_pretrained( "cafeai/cafe_style"), feature_extractor=AutoFeatureExtractor.from_pretrained( "cafeai/cafe_style"), device=device, torch_dtype=dtype) aesthetic_pipe = pipeline("image-classification", model=AutoModelForImageClassification.from_pretrained( "cafeai/cafe_aesthetic"), feature_extractor=AutoFeatureExtractor.from_pretrained( "cafeai/cafe_aesthetic"), device=device, torch_dtype=dtype) async def fetch_image(session, image_url): print(f"fetching image {image_url}") async with session.get(image_url) as response: if response.status == 200 and response.headers['content-type'].startswith('image'): pil_image = Image.open(BytesIO(await response.read())).convert('RGB') # resize image proportional # image = ImageOps.fit(image, (400, 400), Image.LANCZOS) return pil_image return None async def fetch_images(image_urls): async with aiohttp.ClientSession() as session: tasks = [asyncio.ensure_future(fetch_image( session, image_url)) for image_url in image_urls] return await asyncio.gather(*tasks) async def predict(json=None, enable_gallery=True, image=None, files=None): print(json) if image or files: if image is not None: images_paths = [image] elif files is not None: images_paths = list(map(lambda x: x.name, files)) pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths] elif json is not None: pil_images = await fetch_images(json["urls"]) style = style_pipe(pil_images) aesthetic = aesthetic_pipe(pil_images) nsfw = nsfw_pipe(pil_images) results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)] label_data = {} if image is not None: label_data = {row["label"]: row["score"] for row in results[0]} return results, label_data, pil_images if enable_gallery else None with gr.Blocks() as blocks: with gr.Row(): with gr.Column(): image = gr.Image(label="Image to test", type="filepath") files = gr.File(label="Multipls Images", file_types=[ "image"], file_count="multiple") enable_gallery = gr.Checkbox(label="Enable Gallery", value=True) json = gr.JSON(label="Results", value={"urls": [ 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg', 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg', 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']}) with gr.Column(): label = gr.Label(label="style") results = gr.JSON(label="Results") gallery = gr.Gallery().style(grid=[2], height="auto") btn = gr.Button("Run") btn.click(fn=predict, inputs=[json, enable_gallery, image, files], outputs=[results, label, gallery], api_name="inference") blocks.queue() blocks.launch(debug=True, inline=True)