vinesmsuic commited on
Commit
4dd992e
·
1 Parent(s): 9722d77

trying to tickle the dimension problem of video output tensor

Browse files
Files changed (1) hide show
  1. serve/vote_utils.py +15 -0
serve/vote_utils.py CHANGED
@@ -122,6 +122,9 @@ def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request)
122
  else:
123
  print("======== video shape: ========")
124
  print(state.output.shape)
 
 
 
125
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
126
  save_video_file_on_log_server(output_file)
127
 
@@ -148,6 +151,9 @@ def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Reque
148
  elif isinstance(state.output, torch.Tensor):
149
  print("======== video shape: ========")
150
  print(state.output.shape)
 
 
 
151
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
152
  else:
153
  r = requests.get(state.output)
@@ -1228,6 +1234,9 @@ def generate_vg(gen_func, state, text, model_name, request: gr.Request):
1228
  else:
1229
  print("======== video shape: ========")
1230
  print(state.output.shape)
 
 
 
1231
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1232
 
1233
  save_video_file_on_log_server(output_file)
@@ -1343,6 +1352,9 @@ def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, reque
1343
  print("======== video shape: ========")
1344
  print(state.output)
1345
  print(state.output.shape)
 
 
 
1346
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1347
  save_video_file_on_log_server(output_file)
1348
  yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
@@ -1472,6 +1484,9 @@ def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1,
1472
  else:
1473
  print("======== video shape: ========")
1474
  print(state.output.shape)
 
 
 
1475
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1476
  save_video_file_on_log_server(output_file)
1477
 
 
122
  else:
123
  print("======== video shape: ========")
124
  print(state.output.shape)
125
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
126
+ if state.output.shape[-1] != 3:
127
+ state.output = state.output.permute(0, 2, 3, 1)
128
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
129
  save_video_file_on_log_server(output_file)
130
 
 
151
  elif isinstance(state.output, torch.Tensor):
152
  print("======== video shape: ========")
153
  print(state.output.shape)
154
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
155
+ if state.output.shape[-1] != 3:
156
+ state.output = state.output.permute(0, 2, 3, 1)
157
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
158
  else:
159
  r = requests.get(state.output)
 
1234
  else:
1235
  print("======== video shape: ========")
1236
  print(state.output.shape)
1237
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1238
+ if state.output.shape[-1] != 3:
1239
+ state.output = state.output.permute(0, 2, 3, 1)
1240
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1241
 
1242
  save_video_file_on_log_server(output_file)
 
1352
  print("======== video shape: ========")
1353
  print(state.output)
1354
  print(state.output.shape)
1355
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1356
+ if state.output.shape[-1] != 3:
1357
+ state.output = state.output.permute(0, 2, 3, 1)
1358
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1359
  save_video_file_on_log_server(output_file)
1360
  yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
 
1484
  else:
1485
  print("======== video shape: ========")
1486
  print(state.output.shape)
1487
+ # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels]
1488
+ if state.output.shape[-1] != 3:
1489
+ state.output = state.output.permute(0, 2, 3, 1)
1490
  imageio.mimwrite(output_file, state.output, fps=8, quality=9)
1491
  save_video_file_on_log_server(output_file)
1492