radames's picture
add json input
ca95568
raw
history blame contribute delete
No virus
4.2 kB
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image
import gradio as gr
import aiohttp
import asyncio
from io import BytesIO
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"carbon225/vit-base-patch16-224-hentai"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"carbon225/vit-base-patch16-224-hentai"),
device=device,
torch_dtype=dtype)
style_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"cafeai/cafe_style"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"cafeai/cafe_style"),
device=device,
torch_dtype=dtype)
aesthetic_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"cafeai/cafe_aesthetic"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"cafeai/cafe_aesthetic"),
device=device,
torch_dtype=dtype)
async def fetch_image(session, image_url):
print(f"fetching image {image_url}")
async with session.get(image_url) as response:
if response.status == 200 and response.headers['content-type'].startswith('image'):
pil_image = Image.open(BytesIO(await response.read())).convert('RGB')
# resize image proportional
# image = ImageOps.fit(image, (400, 400), Image.LANCZOS)
return pil_image
return None
async def fetch_images(image_urls):
async with aiohttp.ClientSession() as session:
tasks = [asyncio.ensure_future(fetch_image(
session, image_url)) for image_url in image_urls]
return await asyncio.gather(*tasks)
async def predict(json=None, enable_gallery=True, image=None, files=None):
print(json)
if image or files:
if image is not None:
images_paths = [image]
elif files is not None:
images_paths = list(map(lambda x: x.name, files))
pil_images = [Image.open(image_path).convert("RGB")
for image_path in images_paths]
elif json is not None:
pil_images = await fetch_images(json["urls"])
style = style_pipe(pil_images)
aesthetic = aesthetic_pipe(pil_images)
nsfw = nsfw_pipe(pil_images)
results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)]
label_data = {}
if image is not None:
label_data = {row["label"]: row["score"] for row in results[0]}
return results, label_data, pil_images if enable_gallery else None
with gr.Blocks() as blocks:
with gr.Row():
with gr.Column():
image = gr.Image(label="Image to test", type="filepath")
files = gr.File(label="Multipls Images", file_types=[
"image"], file_count="multiple")
enable_gallery = gr.Checkbox(label="Enable Gallery", value=True)
json = gr.JSON(label="Results", value={"urls": [
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg',
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg',
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']})
with gr.Column():
label = gr.Label(label="style")
results = gr.JSON(label="Results")
gallery = gr.Gallery().style(grid=[2], height="auto")
btn = gr.Button("Run")
btn.click(fn=predict, inputs=[json, enable_gallery, image, files],
outputs=[results, label, gallery], api_name="inference")
blocks.queue()
blocks.launch(debug=True, inline=True)