Jamiiwej2903 commited on
Commit
03fe7d6
1 Parent(s): 6b4be8f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -15
main.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  import uvicorn
3
  from fastapi.responses import StreamingResponse
4
  import io
5
- import requests
6
  from PIL import Image
7
  import ffmpeg
8
  import tempfile
@@ -27,29 +27,35 @@ async def generate_video_api(
27
  image_content = await file.read()
28
  image = Image.open(io.BytesIO(image_content))
29
 
30
- # Convert image to bytes
31
- img_byte_arr = io.BytesIO()
32
- image.save(img_byte_arr, format='PNG')
33
- img_byte_arr = img_byte_arr.getvalue()
34
 
35
  # Generate video frames using the stable-video-diffusion model
36
- video_frames = client.post(
37
- "video-generation",
38
- files=[("image", img_byte_arr)],
39
- data={
40
- "num_inference_steps": 25,
41
- "num_frames": num_frames,
42
- "motion_bucket_id": motion_bucket_id,
43
- "cond_aug": cond_aug,
44
- "seed": seed
 
45
  }
46
  )
47
 
 
 
 
 
48
  # Create a temporary directory
49
  with tempfile.TemporaryDirectory() as tmpdir:
50
  # Save frames as temporary files
51
  frame_files = []
52
- for i, frame in enumerate(video_frames):
 
53
  frame_file = os.path.join(tmpdir, f"frame_{i:03d}.png")
54
  frame.save(frame_file)
55
  frame_files.append(frame_file)
 
2
  import uvicorn
3
  from fastapi.responses import StreamingResponse
4
  import io
5
+ import base64
6
  from PIL import Image
7
  import ffmpeg
8
  import tempfile
 
27
  image_content = await file.read()
28
  image = Image.open(io.BytesIO(image_content))
29
 
30
+ # Convert image to base64
31
+ buffered = io.BytesIO()
32
+ image.save(buffered, format="PNG")
33
+ img_str = base64.b64encode(buffered.getvalue()).decode()
34
 
35
  # Generate video frames using the stable-video-diffusion model
36
+ response = client.post(
37
+ json={
38
+ "inputs": img_str,
39
+ "parameters": {
40
+ "num_inference_steps": 25,
41
+ "num_frames": num_frames,
42
+ "motion_bucket_id": motion_bucket_id,
43
+ "cond_aug": cond_aug,
44
+ "seed": seed
45
+ }
46
  }
47
  )
48
 
49
+ # Check if the response is a list of images
50
+ if not isinstance(response, list):
51
+ raise ValueError(f"Unexpected response from the model: {response}")
52
+
53
  # Create a temporary directory
54
  with tempfile.TemporaryDirectory() as tmpdir:
55
  # Save frames as temporary files
56
  frame_files = []
57
+ for i, frame_data in enumerate(response):
58
+ frame = Image.open(io.BytesIO(base64.b64decode(frame_data)))
59
  frame_file = os.path.join(tmpdir, f"frame_{i:03d}.png")
60
  frame.save(frame_file)
61
  frame_files.append(frame_file)