Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,200 Bytes
6974603 ca95568 6974603 ca95568 6974603 ca95568 6974603 ca95568 6974603 ca95568 6974603 ca95568 caeb1f4 ca95568 caeb1f4 c9267e5 ca95568 6974603 caeb1f4 ca95568 caeb1f4 ca95568 caeb1f4 ca95568 6974603 ca95568 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image
import gradio as gr
import aiohttp
import asyncio
from io import BytesIO
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"carbon225/vit-base-patch16-224-hentai"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"carbon225/vit-base-patch16-224-hentai"),
device=device,
torch_dtype=dtype)
style_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"cafeai/cafe_style"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"cafeai/cafe_style"),
device=device,
torch_dtype=dtype)
aesthetic_pipe = pipeline("image-classification",
model=AutoModelForImageClassification.from_pretrained(
"cafeai/cafe_aesthetic"),
feature_extractor=AutoFeatureExtractor.from_pretrained(
"cafeai/cafe_aesthetic"),
device=device,
torch_dtype=dtype)
async def fetch_image(session, image_url):
print(f"fetching image {image_url}")
async with session.get(image_url) as response:
if response.status == 200 and response.headers['content-type'].startswith('image'):
pil_image = Image.open(BytesIO(await response.read())).convert('RGB')
# resize image proportional
# image = ImageOps.fit(image, (400, 400), Image.LANCZOS)
return pil_image
return None
async def fetch_images(image_urls):
async with aiohttp.ClientSession() as session:
tasks = [asyncio.ensure_future(fetch_image(
session, image_url)) for image_url in image_urls]
return await asyncio.gather(*tasks)
async def predict(json=None, enable_gallery=True, image=None, files=None):
print(json)
if image or files:
if image is not None:
images_paths = [image]
elif files is not None:
images_paths = list(map(lambda x: x.name, files))
pil_images = [Image.open(image_path).convert("RGB")
for image_path in images_paths]
elif json is not None:
pil_images = await fetch_images(json["urls"])
style = style_pipe(pil_images)
aesthetic = aesthetic_pipe(pil_images)
nsfw = nsfw_pipe(pil_images)
results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)]
label_data = {}
if image is not None:
label_data = {row["label"]: row["score"] for row in results[0]}
return results, label_data, pil_images if enable_gallery else None
with gr.Blocks() as blocks:
with gr.Row():
with gr.Column():
image = gr.Image(label="Image to test", type="filepath")
files = gr.File(label="Multipls Images", file_types=[
"image"], file_count="multiple")
enable_gallery = gr.Checkbox(label="Enable Gallery", value=True)
json = gr.JSON(label="Results", value={"urls": [
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg',
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg',
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']})
with gr.Column():
label = gr.Label(label="style")
results = gr.JSON(label="Results")
gallery = gr.Gallery().style(grid=[2], height="auto")
btn = gr.Button("Run")
btn.click(fn=predict, inputs=[json, enable_gallery, image, files],
outputs=[results, label, gallery], api_name="inference")
blocks.queue()
blocks.launch(debug=True, inline=True)
|