Spaces:
Running
Running
import spaces | |
import fire | |
import subprocess | |
import os | |
import time | |
import signal | |
import subprocess | |
import atexit | |
try: | |
import flash_attn | |
except ImportError: | |
def install_flash_attn(): | |
os.system("pip install flash-attn==2.5.9.post1") | |
# install_flash_attn() | |
# import flash_attn | |
def kill_processes_by_cmd_substring(cmd_substring): | |
# execute `ps -ef` and obtain its output | |
result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True) | |
lines = result.stdout.splitlines() | |
# visit each line | |
for line in lines: | |
if cmd_substring in line: | |
# extract PID | |
parts = line.split() | |
pid = int(parts[1]) | |
print(f"Killing process with PID: {pid}, CMD: {line}") | |
os.kill(pid, signal.SIGTERM) | |
def main( | |
python_path="python", | |
run_controller=True, | |
run_worker=True, | |
run_gradio=True, | |
controller_port=10086, | |
gradio_port=7860, | |
worker_names=[ | |
"OpenGVLab/InternVL2-8B", | |
], | |
run_sd_worker=False, | |
**kwargs, | |
): | |
host = "http://0.0.0.0" | |
controller_process = None | |
if run_controller: | |
# python controller.py --host 0.0.0.0 --port 10086 | |
cmd_args = [ | |
f"{python_path}", | |
"controller.py", | |
"--host", | |
"0.0.0.0", | |
"--port", | |
f"{controller_port}", | |
] | |
kill_processes_by_cmd_substring(" ".join(cmd_args)) | |
print("Launching controller: ", " ".join(cmd_args)) | |
controller_process = subprocess.Popen(cmd_args) | |
atexit.register(controller_process.terminate) | |
worker_processes = [] | |
if run_worker: | |
worker_port = 10088 | |
for worker_name in worker_names: | |
cmd_args = [ | |
f"{python_path}", | |
"model_worker.py", | |
"--port", | |
f"{worker_port}", | |
"--controller-url", | |
f"{host}:{controller_port}", | |
"--model-path", | |
f"{worker_name}", | |
"--load-8bit", | |
] | |
kill_processes_by_cmd_substring(" ".join(cmd_args)) | |
print("Launching worker: ", " ".join(cmd_args)) | |
worker_process = subprocess.Popen(cmd_args) | |
worker_processes.append(worker_process) | |
atexit.register(worker_process.terminate) | |
worker_port += 1 | |
time.sleep(10) | |
gradio_process = None | |
if run_gradio: | |
# python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086 | |
cmd_args = [ | |
f"{python_path}", | |
"gradio_web_server.py", | |
"--port", | |
f"{gradio_port}", | |
"--controller-url", | |
f"{host}:{controller_port}", | |
"--model-list-mode", | |
"reload", | |
] | |
kill_processes_by_cmd_substring(" ".join(cmd_args)) | |
print("Launching gradio: ", " ".join(cmd_args)) | |
gradio_process = subprocess.Popen(cmd_args) | |
atexit.register(gradio_process.terminate) | |
sd_worker_process = None | |
if run_sd_worker: | |
# python model_worker.py --port 10088 --controller-address http:// | |
cmd_args = [f"{python_path}", "sd_worker.py"] | |
kill_processes_by_cmd_substring(" ".join(cmd_args)) | |
print("Launching sd_worker: ", " ".join(cmd_args)) | |
sd_worker_process = subprocess.Popen(cmd_args) | |
atexit.register(sd_worker_process.terminate) | |
for worker_process in worker_processes: | |
worker_process.wait() | |
if controller_process: | |
controller_process.wait() | |
if gradio_process: | |
gradio_process.wait() | |
if sd_worker_process: | |
sd_worker_process.wait() | |
if __name__ == "__main__": | |
fire.Fire(main) | |