Spaces:
Running
Running
File size: 2,232 Bytes
3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 87c57a3 3a6f1f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import hashlib
import os
import sys
from contextlib import redirect_stdout
from pathlib import Path
from typing import Type
import onnxruntime as ort
import pooch
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_simple import SimpleSession
def new_session(model_name: str = "u2net") -> BaseSession:
session_class: Type[BaseSession]
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
session_class = SimpleSession
if model_name == "u2netp":
md5 = "8e83ca70e441ab06c318d82300c84806"
url = (
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
)
session_class = SimpleSession
elif model_name == "u2net_human_seg":
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
session_class = SimpleSession
elif model_name == "u2net_cloth_seg":
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
session_class = ClothSession
elif model_name == "silueta":
md5 = "55e59e0d8062d2f5d013f4725ee84782"
url = (
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
)
session_class = SimpleSession
u2net_home = os.getenv(
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
)
fname = f"{model_name}.onnx"
path = Path(u2net_home).expanduser()
full_path = Path(u2net_home).expanduser() / fname
pooch.retrieve(
url,
f"md5:{md5}",
fname=fname,
path=Path(u2net_home).expanduser(),
progressbar=True,
)
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
return session_class(
model_name,
ort.InferenceSession(
str(full_path),
providers=ort.get_available_providers(),
sess_options=sess_opts,
),
)
|