aiqcamp commited on
Commit
9ae8acd
ยท
verified ยท
1 Parent(s): 595f5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -88
app.py CHANGED
@@ -1,12 +1,17 @@
1
- import spaces
2
- import logging
3
  from datetime import datetime
4
- from pathlib import Path
5
-
6
  import gradio as gr
7
  import torch
 
 
 
 
 
 
 
8
  import torchaudio
9
- import os
10
 
11
  try:
12
  import mmaudio
@@ -20,13 +25,7 @@ from mmaudio.model.flow_matching import FlowMatching
20
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
21
  from mmaudio.model.sequence_config import SequenceConfig
22
  from mmaudio.model.utils.features_utils import FeaturesUtils
23
- import tempfile
24
-
25
- torch.backends.cuda.matmul.allow_tf32 = True
26
- torch.backends.cudnn.allow_tf32 = True
27
-
28
- log = logging.getLogger()
29
-
30
  device = 'cuda'
31
  dtype = torch.bfloat16
32
 
@@ -35,83 +34,304 @@ model.download_if_needed()
35
  output_dir = Path('./output/gradio')
36
 
37
  setup_eval_logging()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
- seq_cfg = model.seq_cfg
42
-
43
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
- log.info(f'Loaded weights from {model.model_path}')
46
-
47
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
- synchformer_ckpt=model.synchformer_ckpt,
49
- enable_conditions=True,
50
- mode=model.mode,
51
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
- need_vae_encoder=False)
53
- feature_utils = feature_utils.to(device, dtype).eval()
54
-
55
- return net, feature_utils, seq_cfg
56
-
57
-
58
- net, feature_utils, seq_cfg = get_model()
59
-
60
-
61
- @spaces.GPU(duration=120)
62
- @torch.inference_mode()
63
- def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
64
- cfg_strength: float, duration: float):
65
-
66
- rng = torch.Generator(device=device)
67
- if seed >= 0:
68
- rng.manual_seed(seed)
69
- else:
70
- rng.seed()
71
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
72
-
73
- video_info = load_video(video, duration)
74
- clip_frames = video_info.clip_frames
75
- sync_frames = video_info.sync_frames
76
- duration = video_info.duration_sec
77
- clip_frames = clip_frames.unsqueeze(0)
78
- sync_frames = sync_frames.unsqueeze(0)
79
- seq_cfg.duration = duration
80
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
81
-
82
- audios = generate(clip_frames,
83
- sync_frames, [prompt],
84
- negative_text=[negative_prompt],
85
- feature_utils=feature_utils,
86
- net=net,
87
- fm=fm,
88
- rng=rng,
89
- cfg_strength=cfg_strength)
90
- audio = audios.float().cpu()[0]
91
-
92
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
93
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
- # output_dir.mkdir(exist_ok=True, parents=True)
95
- # video_save_path = output_dir / f'{current_time_string}.mp4'
96
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
- log.info(f'Saved video to {video_save_path}')
98
- return video_save_path
99
-
100
-
101
- video_to_audio_tab = gr.Interface(
102
- fn=video_to_audio,
103
-
104
- inputs=[
105
- gr.Video(),
106
- gr.Text(label='Prompt'),
107
- gr.Text(label='Negative prompt', value='music'),
108
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
109
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
110
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
111
- gr.Number(label='Duration (sec)', value=8, minimum=1),
112
- ],
113
- outputs='playable_video',
114
- )
115
 
116
  if __name__ == "__main__":
117
- video_to_audio_tab.launch(allowed_paths=[output_dir])
 
1
+ import os
2
+ import time
3
  from datetime import datetime
 
 
4
  import gradio as gr
5
  import torch
6
+ import logging
7
+ import requests
8
+ from pathlib import Path
9
+ import cv2
10
+ from PIL import Image
11
+ import json
12
+ import spaces
13
  import torchaudio
14
+ import tempfile
15
 
16
  try:
17
  import mmaudio
 
25
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
26
  from mmaudio.model.sequence_config import SequenceConfig
27
  from mmaudio.model.utils.features_utils import FeaturesUtils
28
+ # ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
 
 
 
 
 
 
29
  device = 'cuda'
30
  dtype = torch.bfloat16
31
 
 
34
  output_dir = Path('./output/gradio')
35
 
36
  setup_eval_logging()
37
+ net, feature_utils, seq_cfg = get_model() # get_model ํ•จ์ˆ˜๋Š” ์ด์ „์— ์ œ๊ณต๋œ ์ฝ”๋“œ ์‚ฌ์šฉ
38
+
39
+ # ๋กœ๊น… ์„ค์ •
40
+ logging.basicConfig(level=logging.INFO)
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # API ์„ค์ •
44
+ CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
45
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
46
+
47
+ def upload_to_catbox(file_path):
48
+ """catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ผ ์—…๋กœ๋“œ"""
49
+ try:
50
+ logger.info(f"Preparing to upload file: {file_path}")
51
+ url = "https://catbox.moe/user/api.php"
52
+
53
+ mime_types = {
54
+ '.jpg': 'image/jpeg',
55
+ '.jpeg': 'image/jpeg',
56
+ '.png': 'image/png',
57
+ '.gif': 'image/gif',
58
+ '.webp': 'image/webp',
59
+ '.jfif': 'image/jpeg'
60
+ }
61
+
62
+ file_extension = Path(file_path).suffix.lower()
63
+
64
+ if file_extension not in mime_types:
65
+ try:
66
+ img = Image.open(file_path)
67
+ if img.mode != 'RGB':
68
+ img = img.convert('RGB')
69
+
70
+ new_path = file_path.rsplit('.', 1)[0] + '.png'
71
+ img.save(new_path, 'PNG')
72
+ file_path = new_path
73
+ file_extension = '.png'
74
+ logger.info(f"Converted image to PNG: {file_path}")
75
+ except Exception as e:
76
+ logger.error(f"Failed to convert image: {str(e)}")
77
+ return None
78
+
79
+ files = {
80
+ 'fileToUpload': (
81
+ os.path.basename(file_path),
82
+ open(file_path, 'rb'),
83
+ mime_types.get(file_extension, 'application/octet-stream')
84
+ )
85
+ }
86
+
87
+ data = {
88
+ 'reqtype': 'fileupload',
89
+ 'userhash': CATBOX_USER_HASH
90
+ }
91
+
92
+ response = requests.post(url, files=files, data=data)
93
+
94
+ if response.status_code == 200 and response.text.startswith('http'):
95
+ file_url = response.text
96
+ logger.info(f"File uploaded successfully: {file_url}")
97
+ return file_url
98
+ else:
99
+ raise Exception(f"Upload failed: {response.text}")
100
+
101
+ except Exception as e:
102
+ logger.error(f"File upload error: {str(e)}")
103
+ return None
104
+ finally:
105
+ if 'new_path' in locals() and os.path.exists(new_path):
106
+ try:
107
+ os.remove(new_path)
108
+ except:
109
+ pass
110
+
111
+ def add_watermark(video_path):
112
+ """OpenCV๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
113
+ try:
114
+ cap = cv2.VideoCapture(video_path)
115
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
116
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
117
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
118
+
119
+ text = "GiniGEN.AI"
120
+ font = cv2.FONT_HERSHEY_SIMPLEX
121
+ font_scale = height * 0.05 / 30
122
+ thickness = 2
123
+ color = (255, 255, 255)
124
+
125
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
126
+ margin = int(height * 0.02)
127
+ x_pos = width - text_width - margin
128
+ y_pos = height - margin
129
+
130
+ output_path = "watermarked_output.mp4"
131
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
132
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
133
+
134
+ while cap.isOpened():
135
+ ret, frame = cap.read()
136
+ if not ret:
137
+ break
138
+ cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
139
+ out.write(frame)
140
+
141
+ cap.release()
142
+ out.release()
143
+
144
+ return output_path
145
+
146
+ except Exception as e:
147
+ logger.error(f"Error adding watermark: {str(e)}")
148
+ return video_path
149
+
150
+ def generate_video(image, prompt):
151
+ logger.info("Starting video generation with API")
152
+ try:
153
+ API_KEY = os.getenv("API_KEY", "").strip()
154
+ if not API_KEY:
155
+ return "API key not properly configured"
156
+
157
+ temp_dir = "temp_videos"
158
+ os.makedirs(temp_dir, exist_ok=True)
159
+
160
+ image_url = None
161
+ if image:
162
+ image_url = upload_to_catbox(image)
163
+ if not image_url:
164
+ return "Failed to upload image"
165
+ logger.info(f"Input image URL: {image_url}")
166
+
167
+ generation_url = "https://api.minimaxi.chat/v1/video_generation"
168
+ headers = {
169
+ 'authorization': f'Bearer {API_KEY}',
170
+ 'Content-Type': 'application/json'
171
+ }
172
+
173
+ payload = {
174
+ "model": "video-01",
175
+ "prompt": prompt if prompt else "",
176
+ "prompt_optimizer": True
177
+ }
178
+
179
+ if image_url:
180
+ payload["first_frame_image"] = image_url
181
+
182
+ logger.info(f"Sending request with payload: {payload}")
183
+
184
+ response = requests.post(generation_url, headers=headers, json=payload)
185
+
186
+ if not response.ok:
187
+ error_msg = f"Failed to create video generation task: {response.text}"
188
+ logger.error(error_msg)
189
+ return error_msg
190
+
191
+ response_data = response.json()
192
+ task_id = response_data.get('task_id')
193
+ if not task_id:
194
+ return "Failed to get task ID from response"
195
+
196
+ query_url = "https://api.minimaxi.chat/v1/query/video_generation"
197
+ max_attempts = 30
198
+ attempt = 0
199
+
200
+ while attempt < max_attempts:
201
+ time.sleep(10)
202
+ query_response = requests.get(
203
+ f"{query_url}?task_id={task_id}",
204
+ headers={'authorization': f'Bearer {API_KEY}'}
205
+ )
206
+
207
+ if not query_response.ok:
208
+ attempt += 1
209
+ continue
210
+
211
+ status_data = query_response.json()
212
+ status = status_data.get('status')
213
+
214
+ if status == 'Success':
215
+ file_id = status_data.get('file_id')
216
+ if not file_id:
217
+ return "Failed to get file ID"
218
+
219
+ retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
220
+ params = {'file_id': file_id}
221
+
222
+ file_response = requests.get(
223
+ retrieve_url,
224
+ headers={'authorization': f'Bearer {API_KEY}'},
225
+ params=params
226
+ )
227
+
228
+ if not file_response.ok:
229
+ return "Failed to retrieve video file"
230
+
231
+ try:
232
+ file_data = file_response.json()
233
+ download_url = file_data.get('file', {}).get('download_url')
234
+ if not download_url:
235
+ return "Failed to get download URL"
236
+
237
+ result_info = {
238
+ "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
239
+ "input_image": image_url,
240
+ "output_video_url": download_url,
241
+ "prompt": prompt
242
+ }
243
+ logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
244
+
245
+ video_response = requests.get(download_url)
246
+ if not video_response.ok:
247
+ return "Failed to download video"
248
+
249
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
250
+ output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
251
+
252
+ with open(output_path, 'wb') as f:
253
+ f.write(video_response.content)
254
+
255
+ final_path = add_watermark(output_path)
256
+
257
+ # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์ถ”๊ฐ€
258
+ try:
259
+ final_path_with_audio = video_to_audio(
260
+ final_path,
261
+ prompt=prompt,
262
+ negative_prompt="music",
263
+ seed=-1,
264
+ num_steps=25,
265
+ cfg_strength=4.5,
266
+ duration=8
267
+ )
268
+
269
+ # ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
270
+ if output_path != final_path:
271
+ os.remove(output_path)
272
+ if final_path != final_path_with_audio:
273
+ os.remove(final_path)
274
+
275
+ return final_path_with_audio
276
+ except Exception as e:
277
+ logger.error(f"Error in audio processing: {str(e)}")
278
+ return final_path # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์‹คํŒจ ์‹œ ์›Œํ„ฐ๋งˆํฌ๋งŒ ๋œ ๋น„๋””์˜ค ๋ฐ˜ํ™˜
279
+
280
+ except Exception as e:
281
+ logger.error(f"Error processing video file: {str(e)}")
282
+ return "Error processing video file"
283
+
284
+ css = """
285
+ footer {display: none}
286
+ .gradio-container {max-width: 1200px !important}
287
+ """
288
+
289
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
290
+ gr.HTML('<div style="text-align: center; font-size: 1.5em; margin: 10px 0;">๐ŸŽฅ Image to Video Generator</div>')
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=3):
294
+ video_prompt = gr.Textbox(
295
+ label="Video Description",
296
+ placeholder="Enter video description...",
297
+ lines=3
298
+ )
299
+ upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
300
+ video_generate_btn = gr.Button("๐ŸŽฌ Generate Video")
301
+
302
+ with gr.Column(scale=4):
303
+ video_output = gr.Video(label="Generated Video")
304
 
305
+ def process_and_generate_video(image, prompt):
306
+ if image is None:
307
+ return "Please upload an image"
308
+
309
+ try:
310
+ img = Image.open(image)
311
+ if img.mode != 'RGB':
312
+ img = img.convert('RGB')
313
+
314
+ temp_path = f"temp_{int(time.time())}.png"
315
+ img.save(temp_path, 'PNG')
316
+
317
+ result = generate_video(temp_path, prompt)
318
+
319
+ try:
320
+ os.remove(temp_path)
321
+ except:
322
+ pass
323
+
324
+ return result
325
+
326
+ except Exception as e:
327
+ logger.error(f"Error processing image: {str(e)}")
328
+ return "Error processing image"
329
 
330
+ video_generate_btn.click(
331
+ process_and_generate_video,
332
+ inputs=[upload_image, video_prompt],
333
+ outputs=video_output
334
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  if __name__ == "__main__":
337
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)