import gradio as gr import huggingface_hub import onnxruntime as rt import numpy as np import cv2 from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import Response import io from PIL import Image import imghdr from typing import Optional SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'} def is_valid_image(file_content: bytes) -> Optional[str]: image_format = imghdr.what(None, file_content) if image_format is None: return None return image_format.lower() def process_image_bytes(image_bytes: bytes) -> np.ndarray: try: image = Image.open(io.BytesIO(image_bytes)) if image.mode == 'RGBA': image = image.convert('RGB') img_array = np.array(image) return img_array except Exception as e: raise HTTPException(status_code=400, detail=f"Error: {str(e)}") def get_mask(img, s=1024): img = (img / 255).astype(np.float32) h, w = h0, w0 = img.shape[:-1] h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) ph, pw = s - h, s - w img_input = np.zeros([s, s, 3], dtype=np.float32) img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) img_input = np.transpose(img_input, (2, 0, 1)) img_input = img_input[np.newaxis, :] mask = rmbg_model.run(None, {'img': img_input})[0][0] mask = np.transpose(mask, (1, 2, 0)) mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis] return mask def rmbg_fn(img): mask = get_mask(img) img = (mask * img + 255 * (1 - mask)).astype(np.uint8) mask = (mask * 255).astype(np.uint8) img = np.concatenate([img, mask], axis=2, dtype=np.uint8) mask = mask.repeat(3, axis=2) return mask, img app = FastAPI() gradio_app = gr.Blocks() with gradio_app: gr.Markdown("# Anime Remove Background\n\n" "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n" "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)") with gr.Column(): input_img = gr.Image(label="input image") examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)] examples = gr.Examples(examples=examples_data, inputs=[input_img]) run_btn = gr.Button(variant="primary") with gr.Row(): output_mask = gr.Image(label="mask", format="png") output_img = gr.Image(label="result", image_mode="RGBA", format="png") run_btn.click(rmbg_fn, [input_img], [output_mask, output_img]) @app.post("/remove-bg") async def remove_background(file: UploadFile = File(...)): contents = await file.read() image_format = is_valid_image(contents) if not image_format or image_format not in SUPPORTED_FORMATS: raise HTTPException( status_code=400, detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}" ) try: img = process_image_bytes(contents) mask = get_mask(img) img = (mask * img + 255 * (1 - mask)).astype(np.uint8) mask = (mask * 255).astype(np.uint8) img = np.concatenate([img, mask], axis=2, dtype=np.uint8) pil_image = Image.fromarray(img, 'RGBA') img_byte_arr = io.BytesIO() pil_image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() return Response( content=img_byte_arr, media_type="image/png", headers={ "Content-Disposition": f"attachment; filename={file.filename.split('.')[0]}_nobg.png" } ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") if __name__ == "__main__": providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx") rmbg_model = rt.InferenceSession(model_path, providers=providers) app = gr.mount_gradio_app(app, gradio_app, path="/") import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)