Spaces:
Paused
Paused
import time | |
import logging | |
import hashlib | |
import uuid | |
import os | |
import io | |
import shutil | |
import asyncio | |
import base64 | |
from concurrent.futures import ThreadPoolExecutor | |
from queue import Queue | |
from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple, Union | |
from functools import lru_cache | |
import av | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from liveportrait.config.argument_config import ArgumentConfig | |
from liveportrait.utils.camera import get_rotation_matrix | |
from liveportrait.utils.io import load_image_rgb, load_driving_info, resize_to_limit | |
from liveportrait.utils.crop import prepare_paste_back, paste_back | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Global constants | |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image: | |
""" | |
Convert a base64 data URI to a PIL Image. | |
Args: | |
base64_string (str): The base64 encoded image data. | |
Returns: | |
Image.Image: The decoded PIL Image. | |
""" | |
if ',' in base64_string: | |
base64_string = base64_string.split(',')[1] | |
img_data = base64.b64decode(base64_string) | |
return Image.open(io.BytesIO(img_data)) | |
class Engine: | |
""" | |
The main engine class for FacePoke | |
""" | |
def __init__(self, live_portrait): | |
""" | |
Initialize the FacePoke engine with necessary models and processors. | |
Args: | |
live_portrait (LivePortraitPipeline): The LivePortrait model for video generation. | |
""" | |
self.live_portrait = live_portrait | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# cache for the "modify image" workflow | |
self.image_cache = {} # Stores the original images | |
self.processed_cache = {} # Stores the processed image data | |
logger.info("β FacePoke Engine initialized successfully.") | |
def get_image_hash(self, image: Union[Image.Image, str, bytes]) -> str: | |
""" | |
Compute or retrieve the hash for an image. | |
Args: | |
image (Union[Image.Image, str, bytes]): The input image, either as a PIL Image, | |
base64 string, or bytes. | |
Returns: | |
str: The computed hash of the image. | |
""" | |
if isinstance(image, str): | |
# Assume it's already a hash if it's a string of the right length | |
if len(image) == 32: | |
return image | |
# Otherwise, assume it's a base64 string | |
image = base64_data_uri_to_PIL_Image(image) | |
if isinstance(image, Image.Image): | |
return hashlib.md5(image.tobytes()).hexdigest() | |
elif isinstance(image, bytes): | |
return hashlib.md5(image).hexdigest() | |
else: | |
raise ValueError("Unsupported image type") | |
def _process_image(self, image_hash: str) -> Dict[str, Any]: | |
""" | |
Process the input image and cache the results. | |
Args: | |
image_hash (str): Hash of the input image. | |
Returns: | |
Dict[str, Any]: Processed image data. | |
""" | |
logger.info(f"Processing image with hash: {image_hash}") | |
if image_hash not in self.image_cache: | |
raise ValueError(f"Image with hash {image_hash} not found in cache") | |
image = self.image_cache[image_hash] | |
img_rgb = np.array(image) | |
inference_cfg = self.live_portrait.live_portrait_wrapper.cfg | |
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) | |
crop_info = self.live_portrait.cropper.crop_single_image(img_rgb) | |
img_crop_256x256 = crop_info['img_crop_256x256'] | |
I_s = self.live_portrait.live_portrait_wrapper.prepare_source(img_crop_256x256) | |
x_s_info = self.live_portrait.live_portrait_wrapper.get_kp_info(I_s) | |
f_s = self.live_portrait.live_portrait_wrapper.extract_feature_3d(I_s) | |
x_s = self.live_portrait.live_portrait_wrapper.transform_keypoint(x_s_info) | |
processed_data = { | |
'img_rgb': img_rgb, | |
'crop_info': crop_info, | |
'x_s_info': x_s_info, | |
'f_s': f_s, | |
'x_s': x_s, | |
'inference_cfg': inference_cfg | |
} | |
self.processed_cache[image_hash] = processed_data | |
return processed_data | |
async def modify_image(self, image_or_hash: Union[Image.Image, str, bytes], params: Dict[str, float]) -> str: | |
""" | |
Modify the input image based on the provided parameters, using caching for efficiency | |
and outputting the result as a WebP image. | |
Args: | |
image_or_hash (Union[Image.Image, str, bytes]): Input image as a PIL Image, base64-encoded string, | |
image bytes, or a hash string. | |
params (Dict[str, float]): Parameters for face transformation. | |
Returns: | |
str: Modified image as a base64-encoded WebP data URI. | |
Raises: | |
ValueError: If there's an error modifying the image or WebP is not supported. | |
""" | |
logger.info("Starting image modification") | |
logger.debug(f"Modification parameters: {params}") | |
try: | |
image_hash = self.get_image_hash(image_or_hash) | |
# If we don't have the image in cache yet, add it | |
if image_hash not in self.image_cache: | |
if isinstance(image_or_hash, (Image.Image, bytes)): | |
self.image_cache[image_hash] = image_or_hash | |
elif isinstance(image_or_hash, str) and len(image_or_hash) != 32: | |
# It's a base64 string, not a hash | |
self.image_cache[image_hash] = base64_data_uri_to_PIL_Image(image_or_hash) | |
else: | |
raise ValueError("Image not found in cache and no valid image provided") | |
# Process the image (this will use the cache if available) | |
if image_hash not in self.processed_cache: | |
processed_data = await asyncio.to_thread(self._process_image, image_hash) | |
else: | |
processed_data = self.processed_cache[image_hash] | |
# Apply modifications based on params | |
x_d_new = processed_data['x_s_info']['kp'].clone() | |
await self._apply_facial_modifications(x_d_new, params) | |
# Apply rotation | |
R_new = get_rotation_matrix( | |
processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0), | |
processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0), | |
processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0) | |
) | |
x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t'] | |
# Apply stitching | |
x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new) | |
# Generate the output | |
out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new) | |
I_p = self.live_portrait.live_portrait_wrapper.parse_output(out['out'])[0] | |
# Paste back to full size | |
mask_ori = await asyncio.to_thread( | |
prepare_paste_back, | |
processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'], | |
dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0]) | |
) | |
I_p_to_ori_blend = await asyncio.to_thread( | |
paste_back, | |
I_p, processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori | |
) | |
# Convert the result to a PIL Image | |
result_image = Image.fromarray(I_p_to_ori_blend) | |
# Save as WebP | |
buffered = io.BytesIO() | |
result_image.save(buffered, format="WebP", quality=85) # Adjust quality as needed | |
modified_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
logger.info("Image modification completed successfully") | |
return f"data:image/webp;base64,{modified_image_base64}" | |
except Exception as e: | |
logger.error(f"Error in modify_image: {str(e)}") | |
logger.exception("Full traceback:") | |
raise ValueError(f"Failed to modify image: {str(e)}") | |
async def _apply_facial_modifications(self, x_d_new: torch.Tensor, params: Dict[str, float]) -> None: | |
""" | |
Apply facial modifications to the keypoints based on the provided parameters. | |
Args: | |
x_d_new (torch.Tensor): Tensor of facial keypoints to be modified. | |
params (Dict[str, float]): Parameters for face transformation. | |
""" | |
modifications = [ | |
('smile', [ | |
(0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003), | |
(0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035) | |
]), | |
('aaa', [ | |
(0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001) | |
]), | |
('eee', [ | |
(0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001) | |
]), | |
('woo', [ | |
(0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005) | |
]), | |
('wink', [ | |
(0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003), | |
(0, 17, 1, 0.0003), (0, 3, 1, -0.0003) | |
]), | |
('pupil_x', [ | |
(0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001), | |
(0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007) | |
]), | |
('pupil_y', [ | |
(0, 11, 1, -0.001), (0, 15, 1, -0.001) | |
]), | |
('eyes', [ | |
(0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003), | |
(0, 1, 1, -0.00025), (0, 2, 1, 0.00025) | |
]), | |
('eyebrow', [ | |
(0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003), | |
(0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003), | |
(0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0), | |
(0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0) | |
]) | |
] | |
for param_name, adjustments in modifications: | |
param_value = params.get(param_name, 0) | |
for i, j, k, factor in adjustments: | |
x_d_new[i, j, k] += param_value * factor | |
# Special case for pupil_y affecting eyes | |
x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001 | |
x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001 | |
params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2. | |
async def cleanup(self): | |
""" | |
Perform cleanup operations for the Engine. | |
This method should be called when shutting down the application. | |
""" | |
logger.info("Starting Engine cleanup") | |
try: | |
# TODO: Add any additional cleanup operations here | |
logger.info("Engine cleanup completed successfully") | |
except Exception as e: | |
logger.error(f"Error during Engine cleanup: {str(e)}") | |
logger.exception("Full traceback:") | |
def create_engine(models): | |
logger.info("Creating Engine instance...") | |
live_portrait = models | |
engine = Engine( | |
live_portrait=live_portrait, | |
# we might have more in the future | |
) | |
logger.info("Engine instance created successfully") | |
return engine | |