Jamiiwej2903 commited on
Commit
6b4be8f
1 Parent(s): 4db1b30

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -10
main.py CHANGED
@@ -15,23 +15,33 @@ client = InferenceClient("stabilityai/stable-video-diffusion-img2vid-xt-1-1-tens
15
 
16
  @app.post("/generate_video/")
17
  async def generate_video_api(
18
- file: UploadFile = File(...),
19
- num_frames: int = Form(14),
20
- fps: int = Form(7)
 
 
 
21
  ):
22
  try:
23
  # Read the uploaded image file
24
  image_content = await file.read()
25
  image = Image.open(io.BytesIO(image_content))
26
 
 
 
 
 
 
27
  # Generate video frames using the stable-video-diffusion model
28
  video_frames = client.post(
29
- json={
30
- "inputs": image,
31
- "parameters": {
32
- "num_inference_steps": 25,
33
- "num_frames": num_frames,
34
- }
 
 
35
  }
36
  )
37
 
@@ -65,7 +75,7 @@ async def generate_video_api(
65
 
66
  except Exception as err:
67
  # Handle any errors
68
- raise HTTPException(status_code=500, detail=f"An error occurred: {err}")
69
 
70
  if __name__ == "__main__":
71
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
15
 
16
  @app.post("/generate_video/")
17
  async def generate_video_api(
18
+ file: UploadFile = File(...),
19
+ num_frames: int = Form(14),
20
+ fps: int = Form(7),
21
+ motion_bucket_id: int = Form(127),
22
+ cond_aug: float = Form(0.02),
23
+ seed: int = Form(0)
24
  ):
25
  try:
26
  # Read the uploaded image file
27
  image_content = await file.read()
28
  image = Image.open(io.BytesIO(image_content))
29
 
30
+ # Convert image to bytes
31
+ img_byte_arr = io.BytesIO()
32
+ image.save(img_byte_arr, format='PNG')
33
+ img_byte_arr = img_byte_arr.getvalue()
34
+
35
  # Generate video frames using the stable-video-diffusion model
36
  video_frames = client.post(
37
+ "video-generation",
38
+ files=[("image", img_byte_arr)],
39
+ data={
40
+ "num_inference_steps": 25,
41
+ "num_frames": num_frames,
42
+ "motion_bucket_id": motion_bucket_id,
43
+ "cond_aug": cond_aug,
44
+ "seed": seed
45
  }
46
  )
47
 
 
75
 
76
  except Exception as err:
77
  # Handle any errors
78
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(err)}")
79
 
80
  if __name__ == "__main__":
81
  uvicorn.run(app, host="0.0.0.0", port=7860)