Spaces:
Paused
Paused
Update services/ltx_server.py
Browse files- services/ltx_server.py +49 -64
services/ltx_server.py
CHANGED
|
@@ -8,9 +8,7 @@ from typing import Optional, Tuple
|
|
| 8 |
import torch
|
| 9 |
from PIL import Image
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
# A importação só funcionará depois que o repo for clonado e adicionado ao sys.path
|
| 13 |
-
# Portanto, faremos a importação dentro do __init__
|
| 14 |
|
| 15 |
APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
|
| 16 |
|
|
@@ -27,33 +25,36 @@ class LTXServer:
|
|
| 27 |
def __init__(self):
|
| 28 |
if hasattr(self, '_initialized') and self._initialized: return
|
| 29 |
|
| 30 |
-
print("🚀 LTXServer (
|
| 31 |
|
| 32 |
self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
|
| 33 |
-
self.LTX_REPO_DIR = Path(
|
| 34 |
-
self.MODELS_DIR = Path("/data/ltx_models") #
|
| 35 |
-
self.
|
| 36 |
-
|
| 37 |
-
self.CONFIG_PATH = APP_HOME / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" # <--- Seu arquivo de config FP8
|
| 38 |
|
| 39 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 40 |
self.dtype = torch.bfloat16 if self.device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
|
| 41 |
|
| 42 |
-
for p in [self.
|
| 43 |
p.mkdir(parents=True, exist_ok=True)
|
| 44 |
|
| 45 |
self.setup_dependencies()
|
| 46 |
|
| 47 |
-
# Importações dinâmicas
|
| 48 |
from ltx_video.inference import create_ltx_video_pipeline, create_latent_upsampler
|
| 49 |
|
| 50 |
try:
|
| 51 |
-
print("[LTXServer] Montando pipelines a partir dos arquivos
|
| 52 |
|
| 53 |
with open(self.CONFIG_PATH, "r") as f:
|
| 54 |
self.config_yaml = yaml.safe_load(f)
|
| 55 |
|
| 56 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
self._pipeline = create_ltx_video_pipeline(
|
| 58 |
ckpt_path=str(self.MODELS_DIR / self.config_yaml["checkpoint_path"]),
|
| 59 |
precision=self.config_yaml["precision"],
|
|
@@ -68,7 +69,7 @@ class LTXServer:
|
|
| 68 |
device=self.device
|
| 69 |
)
|
| 70 |
|
| 71 |
-
print("✅ LTXServer (
|
| 72 |
except Exception as e:
|
| 73 |
print(f"ERRO CRÍTICO ao montar as pipelines LTX: {e}")
|
| 74 |
raise
|
|
@@ -76,53 +77,47 @@ class LTXServer:
|
|
| 76 |
self._initialized = True
|
| 77 |
|
| 78 |
def setup_dependencies(self):
|
| 79 |
-
|
| 80 |
-
self.
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
else:
|
| 90 |
-
print("[LTXServer] Repositório LTX-Video já existe.")
|
| 91 |
|
| 92 |
if str(self.LTX_REPO_DIR) not in sys.path:
|
| 93 |
sys.path.insert(0, str(self.LTX_REPO_DIR))
|
| 94 |
|
| 95 |
-
def
|
| 96 |
-
"""Baixa
|
| 97 |
-
from huggingface_hub import
|
| 98 |
-
|
| 99 |
-
print(f"[LTXServer] Verificando arquivos de modelo em {self.MODELS_DIR}...")
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
filename=filename,
|
| 115 |
-
local_dir=str(self.MODELS_DIR),
|
| 116 |
-
token=os.getenv("HF_TOKEN")
|
| 117 |
-
)
|
| 118 |
-
print("[LTXServer] Arquivos de modelo essenciais verificados/baixados.")
|
| 119 |
|
| 120 |
def run_inference(self, **kwargs) -> str:
|
| 121 |
-
#
|
|
|
|
| 122 |
from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline, ConditioningItem
|
| 123 |
-
from
|
| 124 |
|
| 125 |
-
# Extrai os parâmetros com valores padrão
|
| 126 |
prompt = kwargs.get("prompt")
|
| 127 |
image_path = kwargs.get("image_path")
|
| 128 |
target_height = kwargs.get("target_height")
|
|
@@ -132,28 +127,20 @@ class LTXServer:
|
|
| 132 |
|
| 133 |
output_file_path = self.OUTPUT_ROOT / f"run_{int(time.time())}.mp4"
|
| 134 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 135 |
-
|
| 136 |
-
# Monta o objeto de pipeline multi-escala
|
| 137 |
multi_scale_pipeline = LTXMultiScalePipeline(self._pipeline, self._latent_upsampler)
|
| 138 |
|
| 139 |
-
# Prepara a condição da imagem
|
| 140 |
conditions = None
|
| 141 |
if image_path:
|
| 142 |
-
from diffusers.utils import export_to_video, load_image, load_video
|
| 143 |
-
from ltx_video.pipelines.pipeline_ltx_condition import LTXVideoCondition
|
| 144 |
-
|
| 145 |
image = load_image(image_path)
|
| 146 |
video_condition_input = load_video(export_to_video([image]))
|
| 147 |
condition = LTXVideoCondition(video=video_condition_input, frame_index=0)
|
| 148 |
conditions = [condition]
|
| 149 |
|
| 150 |
-
# Configura os parâmetros da chamada com base no arquivo YAML
|
| 151 |
call_kwargs = {
|
| 152 |
-
"prompt": prompt,
|
| 153 |
-
"negative_prompt": "worst quality...",
|
| 154 |
"height": target_height, "width": target_width, "num_frames": num_frames,
|
| 155 |
-
"generator": generator, "output_type": "pt",
|
| 156 |
-
"conditions": conditions,
|
| 157 |
"decode_timestep": self.config_yaml["decode_timestep"],
|
| 158 |
"decode_noise_scale": self.config_yaml["decode_noise_scale"],
|
| 159 |
"first_pass": self.config_yaml["first_pass"],
|
|
@@ -164,8 +151,6 @@ class LTXServer:
|
|
| 164 |
print("[LTXServer] Executando pipeline multi-escala...")
|
| 165 |
result_tensor = multi_scale_pipeline(**call_kwargs).images
|
| 166 |
|
| 167 |
-
# Exporta para vídeo
|
| 168 |
-
from diffusers.utils import export_to_video
|
| 169 |
video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
|
| 170 |
video_np = (video_np * 255).astype("uint8")
|
| 171 |
export_to_video(video_np, str(output_file_path), fps=24)
|
|
|
|
| 8 |
import torch
|
| 9 |
from PIL import Image
|
| 10 |
|
| 11 |
+
# Importações serão feitas dinamicamente após o setup
|
|
|
|
|
|
|
| 12 |
|
| 13 |
APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
|
| 14 |
|
|
|
|
| 25 |
def __init__(self):
|
| 26 |
if hasattr(self, '_initialized') and self._initialized: return
|
| 27 |
|
| 28 |
+
print("🚀 LTXServer (Full Cache) inicializando...")
|
| 29 |
|
| 30 |
self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
|
| 31 |
+
self.LTX_REPO_DIR = Path("/opt/LTX-Video") # Instalado pelo Dockerfile
|
| 32 |
+
self.MODELS_DIR = Path("/data/ltx_models") # Pasta unificada para TODOS os modelos
|
| 33 |
+
self.CONFIG_PATH = APP_HOME / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
|
| 34 |
+
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
|
|
|
|
| 35 |
|
| 36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
self.dtype = torch.bfloat16 if self.device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
|
| 38 |
|
| 39 |
+
for p in [self.MODELS_DIR, self.OUTPUT_ROOT]:
|
| 40 |
p.mkdir(parents=True, exist_ok=True)
|
| 41 |
|
| 42 |
self.setup_dependencies()
|
| 43 |
|
| 44 |
+
# Importações dinâmicas
|
| 45 |
from ltx_video.inference import create_ltx_video_pipeline, create_latent_upsampler
|
| 46 |
|
| 47 |
try:
|
| 48 |
+
print("[LTXServer] Montando pipelines a partir dos arquivos locais...")
|
| 49 |
|
| 50 |
with open(self.CONFIG_PATH, "r") as f:
|
| 51 |
self.config_yaml = yaml.safe_load(f)
|
| 52 |
|
| 53 |
+
# Para que a `create_ltx_video_pipeline` encontre os modelos,
|
| 54 |
+
# o `text_encoder_model_name_or_path` deve apontar para o nosso diretório local.
|
| 55 |
+
self.config_yaml["text_encoder_model_name_or_path"] = str(self.MODELS_DIR)
|
| 56 |
+
|
| 57 |
+
# Monta a pipeline principal, passando o caminho para os pesos e o diretório do text encoder
|
| 58 |
self._pipeline = create_ltx_video_pipeline(
|
| 59 |
ckpt_path=str(self.MODELS_DIR / self.config_yaml["checkpoint_path"]),
|
| 60 |
precision=self.config_yaml["precision"],
|
|
|
|
| 69 |
device=self.device
|
| 70 |
)
|
| 71 |
|
| 72 |
+
print("✅ LTXServer (Full Cache) pronto.")
|
| 73 |
except Exception as e:
|
| 74 |
print(f"ERRO CRÍTICO ao montar as pipelines LTX: {e}")
|
| 75 |
raise
|
|
|
|
| 77 |
self._initialized = True
|
| 78 |
|
| 79 |
def setup_dependencies(self):
|
| 80 |
+
"""Clona o repo (se Dockerfile não o fez) e baixa TODOS os modelos necessários."""
|
| 81 |
+
self._ensure_repo_and_install()
|
| 82 |
+
self._ensure_models_full_download()
|
| 83 |
+
|
| 84 |
+
def _ensure_repo_and_install(self) -> None:
|
| 85 |
+
"""Clona e instala o repositório LTX-Video."""
|
| 86 |
+
if not (self.LTX_REPO_DIR / "setup.py").exists():
|
| 87 |
+
print(f"[LTXServer] Clonando repositório LTX-Video para {self.LTX_REPO_DIR}...")
|
| 88 |
+
subprocess.run(["git", "clone", "--depth", "1", "https://github.com/Lightricks/LTX-Video.git", str(self.LTX_REPO_DIR)], check=True)
|
| 89 |
+
print("[LTXServer] Instalando LTX-Video em modo editável...")
|
| 90 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-e", f"{self.LTX_REPO_DIR}[inference-script]"], check=True)
|
| 91 |
else:
|
| 92 |
+
print("[LTXServer] Repositório LTX-Video já existe e está instalado.")
|
| 93 |
|
| 94 |
if str(self.LTX_REPO_DIR) not in sys.path:
|
| 95 |
sys.path.insert(0, str(self.LTX_REPO_DIR))
|
| 96 |
|
| 97 |
+
def _ensure_models_full_download(self) -> None:
|
| 98 |
+
"""Baixa o snapshot completo de todos os modelos necessários para o cache local."""
|
| 99 |
+
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
print(f"[LTXServer] Verificando snapshot completo dos modelos em {self.MODELS_DIR}...")
|
| 102 |
+
|
| 103 |
+
# Baixa todos os arquivos do repositório Lightricks/LTX-Video
|
| 104 |
+
# A função snapshot_download é idempotente e usa cache.
|
| 105 |
+
snapshot_download(
|
| 106 |
+
repo_id="Lightricks/LTX-Video",
|
| 107 |
+
local_dir=str(self.MODELS_DIR),
|
| 108 |
+
cache_dir=str(self.HF_HOME_CACHE),
|
| 109 |
+
token=os.getenv("HF_TOKEN"),
|
| 110 |
+
# Padrões para garantir que baixamos tudo, incluindo VAE, text encoder e os pesos
|
| 111 |
+
allow_patterns=["*.safetensors", "*.json", "*.py", "text_encoder/*", "vae/*", "scheduler/*"],
|
| 112 |
+
)
|
| 113 |
+
print("[LTXServer] Snapshot completo dos modelos verificado/baixado.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def run_inference(self, **kwargs) -> str:
|
| 116 |
+
# A lógica de inferência permanece a mesma da resposta anterior,
|
| 117 |
+
# pois ela já usa as pipelines que inicializamos.
|
| 118 |
from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline, ConditioningItem
|
| 119 |
+
from diffusers.utils import export_to_video, load_image, load_video
|
| 120 |
|
|
|
|
| 121 |
prompt = kwargs.get("prompt")
|
| 122 |
image_path = kwargs.get("image_path")
|
| 123 |
target_height = kwargs.get("target_height")
|
|
|
|
| 127 |
|
| 128 |
output_file_path = self.OUTPUT_ROOT / f"run_{int(time.time())}.mp4"
|
| 129 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 130 |
+
|
|
|
|
| 131 |
multi_scale_pipeline = LTXMultiScalePipeline(self._pipeline, self._latent_upsampler)
|
| 132 |
|
|
|
|
| 133 |
conditions = None
|
| 134 |
if image_path:
|
|
|
|
|
|
|
|
|
|
| 135 |
image = load_image(image_path)
|
| 136 |
video_condition_input = load_video(export_to_video([image]))
|
| 137 |
condition = LTXVideoCondition(video=video_condition_input, frame_index=0)
|
| 138 |
conditions = [condition]
|
| 139 |
|
|
|
|
| 140 |
call_kwargs = {
|
| 141 |
+
"prompt": prompt, "negative_prompt": "worst quality...",
|
|
|
|
| 142 |
"height": target_height, "width": target_width, "num_frames": num_frames,
|
| 143 |
+
"generator": generator, "output_type": "pt", "conditions": conditions,
|
|
|
|
| 144 |
"decode_timestep": self.config_yaml["decode_timestep"],
|
| 145 |
"decode_noise_scale": self.config_yaml["decode_noise_scale"],
|
| 146 |
"first_pass": self.config_yaml["first_pass"],
|
|
|
|
| 151 |
print("[LTXServer] Executando pipeline multi-escala...")
|
| 152 |
result_tensor = multi_scale_pipeline(**call_kwargs).images
|
| 153 |
|
|
|
|
|
|
|
| 154 |
video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
|
| 155 |
video_np = (video_np * 255).astype("uint8")
|
| 156 |
export_to_video(video_np, str(output_file_path), fps=24)
|