euIaxs22 commited on
Commit
8623364
·
verified ·
1 Parent(s): b77d7cf

Update services/ltx_server.py

Browse files
Files changed (1) hide show
  1. services/ltx_server.py +49 -64
services/ltx_server.py CHANGED
@@ -8,9 +8,7 @@ from typing import Optional, Tuple
8
  import torch
9
  from PIL import Image
10
 
11
- # Importa a função de fábrica do LTX-Video
12
- # A importação só funcionará depois que o repo for clonado e adicionado ao sys.path
13
- # Portanto, faremos a importação dentro do __init__
14
 
15
  APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
16
 
@@ -27,33 +25,36 @@ class LTXServer:
27
  def __init__(self):
28
  if hasattr(self, '_initialized') and self._initialized: return
29
 
30
- print("🚀 LTXServer (Manual Pipeline Assembly) inicializando...")
31
 
32
  self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
33
- self.LTX_REPO_DIR = Path(os.getenv("LTX_REPO_DIR", "/data/LTX-Video"))
34
- self.MODELS_DIR = Path("/data/ltx_models") # Um diretório unificado para todos os pesos
35
- self.REPO_URL = "https://github.com/Lightricks/LTX-Video.git"
36
-
37
- self.CONFIG_PATH = APP_HOME / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" # <--- Seu arquivo de config FP8
38
 
39
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
40
  self.dtype = torch.bfloat16 if self.device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
41
 
42
- for p in [self.LTX_REPO_DIR.parent, self.MODELS_DIR, self.OUTPUT_ROOT]:
43
  p.mkdir(parents=True, exist_ok=True)
44
 
45
  self.setup_dependencies()
46
 
47
- # Importações dinâmicas após o setup
48
  from ltx_video.inference import create_ltx_video_pipeline, create_latent_upsampler
49
 
50
  try:
51
- print("[LTXServer] Montando pipelines a partir dos arquivos baixados...")
52
 
53
  with open(self.CONFIG_PATH, "r") as f:
54
  self.config_yaml = yaml.safe_load(f)
55
 
56
- # Monta a pipeline principal
 
 
 
 
57
  self._pipeline = create_ltx_video_pipeline(
58
  ckpt_path=str(self.MODELS_DIR / self.config_yaml["checkpoint_path"]),
59
  precision=self.config_yaml["precision"],
@@ -68,7 +69,7 @@ class LTXServer:
68
  device=self.device
69
  )
70
 
71
- print("✅ LTXServer (Manual Assembly) pronto.")
72
  except Exception as e:
73
  print(f"ERRO CRÍTICO ao montar as pipelines LTX: {e}")
74
  raise
@@ -76,53 +77,47 @@ class LTXServer:
76
  self._initialized = True
77
 
78
  def setup_dependencies(self):
79
- self._ensure_repo()
80
- self._ensure_models()
81
-
82
- def _ensure_repo(self) -> None:
83
- if not (self.LTX_REPO_DIR / ".git").exists():
84
- print(f"[LTXServer] Clonando repositório de '{self.REPO_URL}'...")
85
- subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.LTX_REPO_DIR)], check=True)
86
- # Instala o pacote localmente
87
- print("[LTXServer] Instalando LTX-Video em modo editável...")
88
- subprocess.run([sys.executable, "-m", "pip", "install", "-e", f"{self.LTX_REPO_DIR}[inference-script]"], check=True)
 
89
  else:
90
- print("[LTXServer] Repositório LTX-Video já existe.")
91
 
92
  if str(self.LTX_REPO_DIR) not in sys.path:
93
  sys.path.insert(0, str(self.LTX_REPO_DIR))
94
 
95
- def _ensure_models(self) -> None:
96
- """Baixa todos os arquivos de modelo necessários para a pasta MODELS_DIR."""
97
- from huggingface_hub import hf_hub_download
98
-
99
- print(f"[LTXServer] Verificando arquivos de modelo em {self.MODELS_DIR}...")
100
 
101
- # Lista de arquivos a serem baixados do repositório principal
102
- files_to_download = [
103
- "ltxv-13b-0.9.8-distilled-fp8.safetensors", # Modelo principal
104
- "ltxv-spatial-upscaler-0.9.8.safetensors", # Upscaler
105
- # Componentes adicionais como VAE e Text Encoder serão baixados
106
- # pela `create_ltx_video_pipeline` usando o cache do HF_HOME.
107
- ]
108
-
109
- for filename in files_to_download:
110
- if not (self.MODELS_DIR / filename).exists():
111
- print(f"Baixando {filename}...")
112
- hf_hub_download(
113
- repo_id="Lightricks/LTX-Video",
114
- filename=filename,
115
- local_dir=str(self.MODELS_DIR),
116
- token=os.getenv("HF_TOKEN")
117
- )
118
- print("[LTXServer] Arquivos de modelo essenciais verificados/baixados.")
119
 
120
  def run_inference(self, **kwargs) -> str:
121
- # Importa as classes necessárias aqui, pois o sys.path foi modificado no __init__
 
122
  from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline, ConditioningItem
123
- from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
124
 
125
- # Extrai os parâmetros com valores padrão
126
  prompt = kwargs.get("prompt")
127
  image_path = kwargs.get("image_path")
128
  target_height = kwargs.get("target_height")
@@ -132,28 +127,20 @@ class LTXServer:
132
 
133
  output_file_path = self.OUTPUT_ROOT / f"run_{int(time.time())}.mp4"
134
  generator = torch.Generator(device=self.device).manual_seed(seed)
135
-
136
- # Monta o objeto de pipeline multi-escala
137
  multi_scale_pipeline = LTXMultiScalePipeline(self._pipeline, self._latent_upsampler)
138
 
139
- # Prepara a condição da imagem
140
  conditions = None
141
  if image_path:
142
- from diffusers.utils import export_to_video, load_image, load_video
143
- from ltx_video.pipelines.pipeline_ltx_condition import LTXVideoCondition
144
-
145
  image = load_image(image_path)
146
  video_condition_input = load_video(export_to_video([image]))
147
  condition = LTXVideoCondition(video=video_condition_input, frame_index=0)
148
  conditions = [condition]
149
 
150
- # Configura os parâmetros da chamada com base no arquivo YAML
151
  call_kwargs = {
152
- "prompt": prompt,
153
- "negative_prompt": "worst quality...",
154
  "height": target_height, "width": target_width, "num_frames": num_frames,
155
- "generator": generator, "output_type": "pt",
156
- "conditions": conditions,
157
  "decode_timestep": self.config_yaml["decode_timestep"],
158
  "decode_noise_scale": self.config_yaml["decode_noise_scale"],
159
  "first_pass": self.config_yaml["first_pass"],
@@ -164,8 +151,6 @@ class LTXServer:
164
  print("[LTXServer] Executando pipeline multi-escala...")
165
  result_tensor = multi_scale_pipeline(**call_kwargs).images
166
 
167
- # Exporta para vídeo
168
- from diffusers.utils import export_to_video
169
  video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
170
  video_np = (video_np * 255).astype("uint8")
171
  export_to_video(video_np, str(output_file_path), fps=24)
 
8
  import torch
9
  from PIL import Image
10
 
11
+ # Importações serão feitas dinamicamente após o setup
 
 
12
 
13
  APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
14
 
 
25
  def __init__(self):
26
  if hasattr(self, '_initialized') and self._initialized: return
27
 
28
+ print("🚀 LTXServer (Full Cache) inicializando...")
29
 
30
  self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
31
+ self.LTX_REPO_DIR = Path("/opt/LTX-Video") # Instalado pelo Dockerfile
32
+ self.MODELS_DIR = Path("/data/ltx_models") # Pasta unificada para TODOS os modelos
33
+ self.CONFIG_PATH = APP_HOME / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
34
+ self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
 
35
 
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  self.dtype = torch.bfloat16 if self.device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
38
 
39
+ for p in [self.MODELS_DIR, self.OUTPUT_ROOT]:
40
  p.mkdir(parents=True, exist_ok=True)
41
 
42
  self.setup_dependencies()
43
 
44
+ # Importações dinâmicas
45
  from ltx_video.inference import create_ltx_video_pipeline, create_latent_upsampler
46
 
47
  try:
48
+ print("[LTXServer] Montando pipelines a partir dos arquivos locais...")
49
 
50
  with open(self.CONFIG_PATH, "r") as f:
51
  self.config_yaml = yaml.safe_load(f)
52
 
53
+ # Para que a `create_ltx_video_pipeline` encontre os modelos,
54
+ # o `text_encoder_model_name_or_path` deve apontar para o nosso diretório local.
55
+ self.config_yaml["text_encoder_model_name_or_path"] = str(self.MODELS_DIR)
56
+
57
+ # Monta a pipeline principal, passando o caminho para os pesos e o diretório do text encoder
58
  self._pipeline = create_ltx_video_pipeline(
59
  ckpt_path=str(self.MODELS_DIR / self.config_yaml["checkpoint_path"]),
60
  precision=self.config_yaml["precision"],
 
69
  device=self.device
70
  )
71
 
72
+ print("✅ LTXServer (Full Cache) pronto.")
73
  except Exception as e:
74
  print(f"ERRO CRÍTICO ao montar as pipelines LTX: {e}")
75
  raise
 
77
  self._initialized = True
78
 
79
  def setup_dependencies(self):
80
+ """Clona o repo (se Dockerfile não o fez) e baixa TODOS os modelos necessários."""
81
+ self._ensure_repo_and_install()
82
+ self._ensure_models_full_download()
83
+
84
+ def _ensure_repo_and_install(self) -> None:
85
+ """Clona e instala o repositório LTX-Video."""
86
+ if not (self.LTX_REPO_DIR / "setup.py").exists():
87
+ print(f"[LTXServer] Clonando repositório LTX-Video para {self.LTX_REPO_DIR}...")
88
+ subprocess.run(["git", "clone", "--depth", "1", "https://github.com/Lightricks/LTX-Video.git", str(self.LTX_REPO_DIR)], check=True)
89
+ print("[LTXServer] Instalando LTX-Video em modo editável...")
90
+ subprocess.run([sys.executable, "-m", "pip", "install", "-e", f"{self.LTX_REPO_DIR}[inference-script]"], check=True)
91
  else:
92
+ print("[LTXServer] Repositório LTX-Video já existe e está instalado.")
93
 
94
  if str(self.LTX_REPO_DIR) not in sys.path:
95
  sys.path.insert(0, str(self.LTX_REPO_DIR))
96
 
97
+ def _ensure_models_full_download(self) -> None:
98
+ """Baixa o snapshot completo de todos os modelos necessários para o cache local."""
99
+ from huggingface_hub import snapshot_download
 
 
100
 
101
+ print(f"[LTXServer] Verificando snapshot completo dos modelos em {self.MODELS_DIR}...")
102
+
103
+ # Baixa todos os arquivos do repositório Lightricks/LTX-Video
104
+ # A função snapshot_download é idempotente e usa cache.
105
+ snapshot_download(
106
+ repo_id="Lightricks/LTX-Video",
107
+ local_dir=str(self.MODELS_DIR),
108
+ cache_dir=str(self.HF_HOME_CACHE),
109
+ token=os.getenv("HF_TOKEN"),
110
+ # Padrões para garantir que baixamos tudo, incluindo VAE, text encoder e os pesos
111
+ allow_patterns=["*.safetensors", "*.json", "*.py", "text_encoder/*", "vae/*", "scheduler/*"],
112
+ )
113
+ print("[LTXServer] Snapshot completo dos modelos verificado/baixado.")
 
 
 
 
 
114
 
115
  def run_inference(self, **kwargs) -> str:
116
+ # A lógica de inferência permanece a mesma da resposta anterior,
117
+ # pois ela já usa as pipelines que inicializamos.
118
  from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline, ConditioningItem
119
+ from diffusers.utils import export_to_video, load_image, load_video
120
 
 
121
  prompt = kwargs.get("prompt")
122
  image_path = kwargs.get("image_path")
123
  target_height = kwargs.get("target_height")
 
127
 
128
  output_file_path = self.OUTPUT_ROOT / f"run_{int(time.time())}.mp4"
129
  generator = torch.Generator(device=self.device).manual_seed(seed)
130
+
 
131
  multi_scale_pipeline = LTXMultiScalePipeline(self._pipeline, self._latent_upsampler)
132
 
 
133
  conditions = None
134
  if image_path:
 
 
 
135
  image = load_image(image_path)
136
  video_condition_input = load_video(export_to_video([image]))
137
  condition = LTXVideoCondition(video=video_condition_input, frame_index=0)
138
  conditions = [condition]
139
 
 
140
  call_kwargs = {
141
+ "prompt": prompt, "negative_prompt": "worst quality...",
 
142
  "height": target_height, "width": target_width, "num_frames": num_frames,
143
+ "generator": generator, "output_type": "pt", "conditions": conditions,
 
144
  "decode_timestep": self.config_yaml["decode_timestep"],
145
  "decode_noise_scale": self.config_yaml["decode_noise_scale"],
146
  "first_pass": self.config_yaml["first_pass"],
 
151
  print("[LTXServer] Executando pipeline multi-escala...")
152
  result_tensor = multi_scale_pipeline(**call_kwargs).images
153
 
 
 
154
  video_np = result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
155
  video_np = (video_np * 255).astype("uint8")
156
  export_to_video(video_np, str(output_file_path), fps=24)