File size: 1,689 Bytes
8a581f5
 
a79cc2a
 
8a581f5
 
 
a79cc2a
8a581f5
 
 
 
 
 
 
b985e31
f4cffb9
 
b985e31
212387e
b985e31
8a581f5
 
 
 
 
 
 
 
 
 
 
597d161
b985e31
597d161
 
8a581f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597d161
8a581f5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import subprocess
import gradio as gr

# execute a CLI command
def execute_command(command: str) -> None:
    subprocess.run(command, check=True)


def infer(video_frames, masks_frames):

    video_frames_folder = "inputs/object_removal/bmx-trees"
    masks_folder = "inputs/object_removal/bmx-trees_mask"

    # Create the "results" folder if it doesn't exist
    output_folder = "results"
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)


    bmx_trees_folder = os.path.join(output_folder, "bmx-trees")
    
    command = [
      f"python", 
      f"inference_propainter.py",
      f"--video={video_frames_folder}",
      f"--mask={masks_folder}",
      f"--output={output_folder}"
    ]

    execute_command(command)

    # Get the list of files in the "results" folder
    result_files = os.listdir(bmx_trees_folder)

    return result_files

css="""
#col-container{
    margin: 0 auto;
    max-width: 840px;
    text-align: left;
}
"""
    
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML("""
        <h2 style="text-align: center;">ProPainter</h2>
        <p style="text-align: center;">
            
        </p>
                """)

        with gr.Row():
            with gr.Column():
                video_frames = gr.Files(label="Video frames")
                masks_frames = gr.Files(label="Masks frames")
        
                submit_btn = gr.Button("Submit")

            with gr.Column():
                result = gr.Files(label="Result")

        
            
    submit_btn.click(fn=infer, inputs=[video_frames, masks_frames], outputs=[result])
    
demo.queue(max_size=12).launch()