File size: 3,541 Bytes
f289b70
 
 
 
 
 
 
c09265b
f289b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b33d6d
f289b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4779b65
f289b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import fire
import subprocess
import os
import time
import signal
import subprocess
import atexit


def kill_processes_by_cmd_substring(cmd_substring):
    # execute `ps -ef` and obtain its output
    result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
    lines = result.stdout.splitlines()

    # visit each line
    for line in lines:
        if cmd_substring in line:
            # extract PID
            parts = line.split()
            pid = int(parts[1])
            print(f"Killing process with PID: {pid}, CMD: {line}")
            os.kill(pid, signal.SIGTERM)


def main(
    python_path="python",
    run_controller=True,
    run_worker=True,
    run_gradio=True,
    controller_port=10086,
    gradio_port=7860,
    worker_names=[
        "OpenGVLab/InternVL2-8B",
    ],
    run_sd_worker=False,
    **kwargs,
):
    host = "http://0.0.0.0"
    controller_process = None
    if run_controller:
        # python controller.py --host 0.0.0.0 --port 10086
        cmd_args = [
            f"{python_path}",
            "controller.py",
            "--host",
            "0.0.0.0",
            "--port",
            f"{controller_port}",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching controller: ", " ".join(cmd_args))
        controller_process = subprocess.Popen(cmd_args)
        atexit.register(controller_process.terminate)

    worker_processes = []
    if run_worker:
        worker_port = 10088
        for worker_name in worker_names:
            cmd_args = [
                f"{python_path}",
                "model_worker.py",
                "--port",
                f"{worker_port}",
                "--controller-url",
                f"{host}:{controller_port}",
                "--model-path",
                f"{worker_name}",
                "--load-8bit",
            ]
            kill_processes_by_cmd_substring(" ".join(cmd_args))
            print("Launching worker: ", " ".join(cmd_args))
            worker_process = subprocess.Popen(cmd_args)
            worker_processes.append(worker_process)
            atexit.register(worker_process.terminate)
            worker_port += 1

    time.sleep(60)
    gradio_process = None
    if run_gradio:
        #  python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
        cmd_args = [
            f"{python_path}",
            "gradio_web_server.py",
            "--port",
            f"{gradio_port}",
            "--controller-url",
            f"{host}:{controller_port}",
            "--model-list-mode",
            "reload",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching gradio: ", " ".join(cmd_args))
        gradio_process = subprocess.Popen(cmd_args)
        atexit.register(gradio_process.terminate)

    sd_worker_process = None
    if run_sd_worker:
        # python model_worker.py --port 10088 --controller-address http://
        cmd_args = [f"{python_path}", "sd_worker.py"]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching sd_worker: ", " ".join(cmd_args))
        sd_worker_process = subprocess.Popen(cmd_args)
        atexit.register(sd_worker_process.terminate)

    for worker_process in worker_processes:
        worker_process.wait()
    if controller_process:
        controller_process.wait()
    if gradio_process:
        gradio_process.wait()
    if sd_worker_process:
        sd_worker_process.wait()


if __name__ == "__main__":
    fire.Fire(main)