File size: 3,759 Bytes
8b33d6d
f289b70
 
 
 
 
 
 
c09265b
966d74c
 
 
8b33d6d
 
 
 
 
d3f3e42
 
966d74c
f289b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b33d6d
f289b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import spaces
import fire
import subprocess
import os
import time
import signal
import subprocess
import atexit

try:
    import flash_attn
except ImportError:

    @spaces.GPU
    def install_flash_attn():
        os.system("pip install flash-attn==2.5.9.post1")

    # install_flash_attn()
    # import flash_attn


def kill_processes_by_cmd_substring(cmd_substring):
    # execute `ps -ef` and obtain its output
    result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
    lines = result.stdout.splitlines()

    # visit each line
    for line in lines:
        if cmd_substring in line:
            # extract PID
            parts = line.split()
            pid = int(parts[1])
            print(f"Killing process with PID: {pid}, CMD: {line}")
            os.kill(pid, signal.SIGTERM)


def main(
    python_path="python",
    run_controller=True,
    run_worker=True,
    run_gradio=True,
    controller_port=10086,
    gradio_port=7860,
    worker_names=[
        "OpenGVLab/InternVL2-8B",
    ],
    run_sd_worker=False,
    **kwargs,
):
    host = "http://0.0.0.0"
    controller_process = None
    if run_controller:
        # python controller.py --host 0.0.0.0 --port 10086
        cmd_args = [
            f"{python_path}",
            "controller.py",
            "--host",
            "0.0.0.0",
            "--port",
            f"{controller_port}",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching controller: ", " ".join(cmd_args))
        controller_process = subprocess.Popen(cmd_args)
        atexit.register(controller_process.terminate)

    worker_processes = []
    if run_worker:
        worker_port = 10088
        for worker_name in worker_names:
            cmd_args = [
                f"{python_path}",
                "model_worker.py",
                "--port",
                f"{worker_port}",
                "--controller-url",
                f"{host}:{controller_port}",
                "--model-path",
                f"{worker_name}",
                "--load-8bit",
            ]
            kill_processes_by_cmd_substring(" ".join(cmd_args))
            print("Launching worker: ", " ".join(cmd_args))
            worker_process = subprocess.Popen(cmd_args)
            worker_processes.append(worker_process)
            atexit.register(worker_process.terminate)
            worker_port += 1

    time.sleep(10)
    gradio_process = None
    if run_gradio:
        #  python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
        cmd_args = [
            f"{python_path}",
            "gradio_web_server.py",
            "--port",
            f"{gradio_port}",
            "--controller-url",
            f"{host}:{controller_port}",
            "--model-list-mode",
            "reload",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching gradio: ", " ".join(cmd_args))
        gradio_process = subprocess.Popen(cmd_args)
        atexit.register(gradio_process.terminate)

    sd_worker_process = None
    if run_sd_worker:
        # python model_worker.py --port 10088 --controller-address http://
        cmd_args = [f"{python_path}", "sd_worker.py"]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching sd_worker: ", " ".join(cmd_args))
        sd_worker_process = subprocess.Popen(cmd_args)
        atexit.register(sd_worker_process.terminate)

    for worker_process in worker_processes:
        worker_process.wait()
    if controller_process:
        controller_process.wait()
    if gradio_process:
        gradio_process.wait()
    if sd_worker_process:
        sd_worker_process.wait()


if __name__ == "__main__":
    fire.Fire(main)