Anonymous commited on
Commit
4dd34ca
·
1 Parent(s): 610f540

add example

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -1,8 +1,10 @@
1
- import gradio as gr
2
-
3
  import sys
 
4
  import random
5
  import pandas as pd
 
 
 
6
  import os
7
  import argparse
8
  import random
@@ -19,6 +21,7 @@ from funcs import (
19
  load_model_checkpoint,
20
  )
21
  from utils.utils import instantiate_from_config
 
22
 
23
  MAX_KEYS = 5
24
 
@@ -521,29 +524,25 @@ def demo_update_w(mode):
521
 
522
  def plot_update(*positions):
523
  if type(positions[-1]) != int:
524
- traj_plot = gr.ScatterPlot(
525
- label="Trajectory",
526
- width=512,
527
- height=320,
528
  )
529
  return traj_plot
530
  key_length = positions[-1]
531
  frame_indices = positions[:key_length]
 
532
  h_positions = positions[MAX_KEYS:MAX_KEYS+key_length]
533
- h_positions_re = []
534
- for i in h_positions:
535
- h_positions_re.append(-i)
536
  w_positions = positions[2*MAX_KEYS:2*MAX_KEYS+key_length]
537
- traj_plot = gr.ScatterPlot(
538
- value=pd.DataFrame({"x": w_positions, "y": h_positions_re, "frame": frame_indices}),
539
- x="x",
540
- y="y",
541
- color='frame',
542
- x_lim= [-0.05, 1.05],
543
- y_lim= [-1.05, 0.05],
 
544
  label="Trajectory",
545
- width=512,
546
- height=320,
547
  )
548
  return traj_plot
549
 
@@ -640,10 +639,8 @@ with gr.Blocks(css=css) as demo:
640
  dropdown_demo.change(demo_update_w, dropdown_demo, w_positions)
641
  radio_mode.change(mode_update, radio_mode, [row_demo, row_diy])
642
 
643
- traj_plot = gr.ScatterPlot(
644
- label="Trajectory",
645
- width=512,
646
- height=320,
647
  )
648
 
649
  h_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
@@ -731,6 +728,6 @@ with gr.Blocks(css=css) as demo:
731
  submit_btn.click(fn=infer,
732
  inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions],
733
  outputs=[video_result, video_result_bbox],
734
- api_name="zrscp")
735
 
736
- demo.queue(max_size=12).launch(show_api=True)
 
 
 
1
  import sys
2
+ import gradio as gr
3
  import random
4
  import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
  import os
9
  import argparse
10
  import random
 
21
  load_model_checkpoint,
22
  )
23
  from utils.utils import instantiate_from_config
24
+ from utils.utils_freetraj import plan_path
25
 
26
  MAX_KEYS = 5
27
 
 
524
 
525
  def plot_update(*positions):
526
  if type(positions[-1]) != int:
527
+ traj_plot = gr.Plot(
528
+ label="Trajectory"
 
 
529
  )
530
  return traj_plot
531
  key_length = positions[-1]
532
  frame_indices = positions[:key_length]
533
+ frame_indices = [int(i) for i in frame_indices]
534
  h_positions = positions[MAX_KEYS:MAX_KEYS+key_length]
 
 
 
535
  w_positions = positions[2*MAX_KEYS:2*MAX_KEYS+key_length]
536
+ frame_indices, h_positions, w_positions = zip(*sorted(zip(frame_indices, h_positions, w_positions)))
537
+ plt.cla()
538
+ plt.xlim(0, 1)
539
+ plt.ylim(0, 1)
540
+ plt.gca().invert_yaxis()
541
+ plt.gca().xaxis.tick_top()
542
+ plt.plot(w_positions, h_positions, linestyle='-', marker = 'o', markerfacecolor='r')
543
+ traj_plot = gr.Plot(
544
  label="Trajectory",
545
+ value = plt
 
546
  )
547
  return traj_plot
548
 
 
639
  dropdown_demo.change(demo_update_w, dropdown_demo, w_positions)
640
  radio_mode.change(mode_update, radio_mode, [row_demo, row_diy])
641
 
642
+ traj_plot = gr.Plot(
643
+ label="Trajectory"
 
 
644
  )
645
 
646
  h_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
 
728
  submit_btn.click(fn=infer,
729
  inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions],
730
  outputs=[video_result, video_result_bbox],
731
+ api_name="freetraj")
732
 
733
+ demo.queue(max_size=8).launch(show_api=True)