Update handler.py
Browse files- handler.py +57 -7
handler.py
CHANGED
|
@@ -13,7 +13,6 @@ import base64
|
|
| 13 |
from hyvideo.utils.file_utils import save_videos_grid
|
| 14 |
from hyvideo.inference import HunyuanVideoSampler
|
| 15 |
from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
|
| 16 |
-
from hyvideo.modules.attenion import get_attention_modes
|
| 17 |
|
| 18 |
try:
|
| 19 |
import triton
|
|
@@ -37,8 +36,45 @@ DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 fr
|
|
| 37 |
DEFAULT_NB_STEPS = 22 # Default for standard model
|
| 38 |
DEFAULT_FPS = 24
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Get supported attention modes
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def setup_vae_path(vae_path: Path) -> Path:
|
| 44 |
"""Create a temporary directory with correctly named VAE config file"""
|
|
@@ -317,10 +353,20 @@ class EndpointHandler:
|
|
| 317 |
try:
|
| 318 |
logger.info("Attempting to initialize HunyuanVideoSampler...")
|
| 319 |
|
| 320 |
-
#
|
| 321 |
-
|
|
|
|
| 322 |
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
# Set attention mode for transformer blocks
|
| 326 |
if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
|
|
@@ -362,7 +408,7 @@ class EndpointHandler:
|
|
| 362 |
logger.error(f"Error initializing model: {str(e)}")
|
| 363 |
raise
|
| 364 |
|
| 365 |
-
def __call__(self, data: Dict[str, Any]) ->
|
| 366 |
"""Process a single request"""
|
| 367 |
# Log incoming request
|
| 368 |
logger.info(f"Processing request with data: {data}")
|
|
@@ -385,6 +431,7 @@ class EndpointHandler:
|
|
| 385 |
flow_shift = float(data.pop("flow_shift", 7.0))
|
| 386 |
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
|
| 387 |
enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
|
|
|
|
| 388 |
|
| 389 |
logger.info(f"Processing with parameters: width={width}, height={height}, "
|
| 390 |
f"video_length={video_length}, seed={seed}, "
|
|
@@ -392,10 +439,12 @@ class EndpointHandler:
|
|
| 392 |
|
| 393 |
try:
|
| 394 |
# Set up TeaCache for this generation if enabled
|
| 395 |
-
if hasattr(self.model.pipeline, 'transformer') and
|
| 396 |
transformer = self.model.pipeline.transformer
|
|
|
|
| 397 |
transformer.num_steps = num_inference_steps
|
| 398 |
transformer.cnt = 0
|
|
|
|
| 399 |
transformer.accumulated_rel_l1_distance = 0
|
| 400 |
transformer.previous_modulated_input = None
|
| 401 |
transformer.previous_residual = None
|
|
@@ -450,6 +499,7 @@ class EndpointHandler:
|
|
| 450 |
|
| 451 |
logger.info("Successfully generated and encoded video")
|
| 452 |
|
|
|
|
| 453 |
return video_data_uri
|
| 454 |
|
| 455 |
except Exception as e:
|
|
|
|
| 13 |
from hyvideo.utils.file_utils import save_videos_grid
|
| 14 |
from hyvideo.inference import HunyuanVideoSampler
|
| 15 |
from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
|
|
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
import triton
|
|
|
|
| 36 |
DEFAULT_NB_STEPS = 22 # Default for standard model
|
| 37 |
DEFAULT_FPS = 24
|
| 38 |
|
| 39 |
+
def get_attention_modes():
|
| 40 |
+
"""Get available attention modes - fallback if module function isn't available"""
|
| 41 |
+
modes = ["sdpa"] # Always available
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import torch
|
| 45 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
| 46 |
+
modes.append("sdpa")
|
| 47 |
+
except:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
import flash_attn
|
| 52 |
+
modes.append("flash")
|
| 53 |
+
except:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
import sageattention
|
| 58 |
+
modes.append("sage")
|
| 59 |
+
if hasattr(sageattention, 'efficient_attention_v2'):
|
| 60 |
+
modes.append("sage2")
|
| 61 |
+
except:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
import xformers
|
| 66 |
+
modes.append("xformers")
|
| 67 |
+
except:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
return modes
|
| 71 |
+
|
| 72 |
# Get supported attention modes
|
| 73 |
+
try:
|
| 74 |
+
from hyvideo.modules.attenion import get_attention_modes
|
| 75 |
+
attention_modes_supported = get_attention_modes()
|
| 76 |
+
except:
|
| 77 |
+
attention_modes_supported = get_attention_modes()
|
| 78 |
|
| 79 |
def setup_vae_path(vae_path: Path) -> Path:
|
| 80 |
"""Create a temporary directory with correctly named VAE config file"""
|
|
|
|
| 353 |
try:
|
| 354 |
logger.info("Attempting to initialize HunyuanVideoSampler...")
|
| 355 |
|
| 356 |
+
# Extract necessary paths
|
| 357 |
+
transformer_path = str(self.args.dit_weight)
|
| 358 |
+
text_encoder_path = str(Path(self.args.model_base) / "text_encoder")
|
| 359 |
|
| 360 |
+
logger.info(f"Transformer path: {transformer_path}")
|
| 361 |
+
logger.info(f"Text encoder path: {text_encoder_path}")
|
| 362 |
+
|
| 363 |
+
# Initialize the model using the exact signature from gradio_server.py
|
| 364 |
+
self.model = HunyuanVideoSampler.from_pretrained(
|
| 365 |
+
transformer_path,
|
| 366 |
+
text_encoder_path,
|
| 367 |
+
attention_mode=self.attention_mode,
|
| 368 |
+
args=self.args
|
| 369 |
+
)
|
| 370 |
|
| 371 |
# Set attention mode for transformer blocks
|
| 372 |
if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
|
|
|
|
| 408 |
logger.error(f"Error initializing model: {str(e)}")
|
| 409 |
raise
|
| 410 |
|
| 411 |
+
def __call__(self, data: Dict[str, Any]) -> str:
|
| 412 |
"""Process a single request"""
|
| 413 |
# Log incoming request
|
| 414 |
logger.info(f"Processing request with data: {data}")
|
|
|
|
| 431 |
flow_shift = float(data.pop("flow_shift", 7.0))
|
| 432 |
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
|
| 433 |
enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
|
| 434 |
+
tea_cache = float(data.pop("tea_cache", 0.0))
|
| 435 |
|
| 436 |
logger.info(f"Processing with parameters: width={width}, height={height}, "
|
| 437 |
f"video_length={video_length}, seed={seed}, "
|
|
|
|
| 439 |
|
| 440 |
try:
|
| 441 |
# Set up TeaCache for this generation if enabled
|
| 442 |
+
if hasattr(self.model.pipeline, 'transformer') and tea_cache > 0:
|
| 443 |
transformer = self.model.pipeline.transformer
|
| 444 |
+
transformer.enable_teacache = True
|
| 445 |
transformer.num_steps = num_inference_steps
|
| 446 |
transformer.cnt = 0
|
| 447 |
+
transformer.rel_l1_thresh = tea_cache
|
| 448 |
transformer.accumulated_rel_l1_distance = 0
|
| 449 |
transformer.previous_modulated_input = None
|
| 450 |
transformer.previous_residual = None
|
|
|
|
| 499 |
|
| 500 |
logger.info("Successfully generated and encoded video")
|
| 501 |
|
| 502 |
+
# Return exactly what the demo.py expects
|
| 503 |
return video_data_uri
|
| 504 |
|
| 505 |
except Exception as e:
|