FacePoke / loader.py
jbilcke-hf's picture
jbilcke-hf HF staff
working on improvements
e123fec
raw
history blame
6.21 kB
import os
import logging
import torch
import asyncio
import aiohttp
import requests
from huggingface_hub import hf_hub_download
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Configuration
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hugging Face repository information
HF_REPO_ID = "jbilcke-hf/model-cocktail"
# Model files to download
MODEL_FILES = [
"dwpose/dw-ll_ucoco_384.pth",
"face-detector/s3fd-619a316812.pth",
"liveportrait/spade_generator.pth",
"liveportrait/warping_module.pth",
"liveportrait/motion_extractor.pth",
"liveportrait/stitching_retargeting_module.pth",
"liveportrait/appearance_feature_extractor.pth",
"liveportrait/landmark.onnx",
# For animal mode 🐢🐱
# however they say animal mode doesn't support stitching yet?
# https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode
#"liveportrait-animals/warping_module.pth",
#"liveportrait-animals/spade_generator.pth",
#"liveportrait-animals/motion_extractor.pth",
#"liveportrait-animals/appearance_feature_extractor.pth",
#"liveportrait-animals/stitching_retargeting_module.pth",
#"liveportrait-animals/xpose.pth",
# this is a hack, instead we should probably try to
# fix liveportrait/utils/dependencies/insightface/utils/storage.py
"insightface/models/buffalo_l.zip",
"insightface/buffalo_l/det_10g.onnx",
"insightface/buffalo_l/2d106det.onnx",
"sd-vae-ft-mse/diffusion_pytorch_model.bin",
"sd-vae-ft-mse/diffusion_pytorch_model.safetensors",
"sd-vae-ft-mse/config.json",
# we don't use those yet
#"flux-dev/flux-dev-fp8.safetensors",
#"flux-dev/flux_dev_quantization_map.json",
#"pulid-flux/pulid_flux_v0.9.0.safetensors",
#"pulid-flux/pulid_v1.bin"
]
def create_directory(directory):
"""Create a directory if it doesn't exist and log its status."""
if not os.path.exists(directory):
os.makedirs(directory)
logger.info(f" Directory created: {directory}")
else:
logger.info(f" Directory already exists: {directory}")
def print_directory_structure(startpath):
"""Print the directory structure starting from the given path."""
for root, dirs, files in os.walk(startpath):
level = root.replace(startpath, '').count(os.sep)
indent = ' ' * 4 * level
logger.info(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
logger.info(f"{subindent}{f}")
async def download_hf_file(filename: str) -> None:
"""Download a file from Hugging Face to the models directory."""
dest = os.path.join(MODELS_DIR, filename)
os.makedirs(os.path.dirname(dest), exist_ok=True)
if os.path.exists(dest):
# this is really for debugging purposes only
logger.debug(f" βœ… {filename}")
return
logger.info(f" ⏳ Downloading {HF_REPO_ID}/{filename}")
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=MODELS_DIR
)
)
logger.info(f" βœ… Downloaded {filename}")
except Exception as e:
logger.error(f"🚨 Error downloading file from Hugging Face: {e}")
if os.path.exists(dest):
os.remove(dest)
raise
async def download_all_models():
"""Download all required models from the Hugging Face repository."""
logger.info(" πŸ”Ž Looking for models...")
tasks = [download_hf_file(filename) for filename in MODEL_FILES]
await asyncio.gather(*tasks)
logger.info(" βœ… All models are available")
# are you looking to debug the app and verify that models are downloaded properly?
# then un-comment the two following lines:
#logger.info("πŸ’‘ Printing directory structure of models:")
#print_directory_structure(MODELS_DIR)
class ModelLoader:
"""A class responsible for loading and initializing all required models."""
def __init__(self):
self.device = DEVICE
self.models_dir = MODELS_DIR
async def load_live_portrait(self):
"""Load LivePortrait models."""
from liveportrait.config.inference_config import InferenceConfig
from liveportrait.config.crop_config import CropConfig
from liveportrait.live_portrait_pipeline import LivePortraitPipeline
logger.info(" ⏳ Loading LivePortrait models...")
live_portrait_pipeline = await asyncio.to_thread(
LivePortraitPipeline,
inference_cfg=InferenceConfig(
# default values
flag_stitching=True, # we recommend setting it to True!
flag_relative=True, # whether to use relative motion
flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop= True, # whether to crop the source portrait to the face-cropping space
flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True
),
crop_cfg=CropConfig()
)
logger.info(" βœ… LivePortrait models loaded successfully.")
return live_portrait_pipeline
async def initialize_models():
"""Initialize and load all required models."""
logger.info("πŸš€ Starting model initialization...")
# Ensure all required models are downloaded
await download_all_models()
# Initialize the ModelLoader
loader = ModelLoader()
# Load LivePortrait models
live_portrait = await loader.load_live_portrait()
logger.info("βœ… Model initialization completed.")
return live_portrait
# Initial setup
logger.info("πŸš€ Setting up storage directories...")
create_directory(MODELS_DIR)
logger.info("βœ… Storage directories setup completed.")