import spaces import fire import subprocess import os import time import signal import subprocess import atexit try: import flash_attn except ImportError: @spaces.GPU def install_flash_attn(): os.system("pip install flash-attn==2.5.9.post1") # install_flash_attn() # import flash_attn def kill_processes_by_cmd_substring(cmd_substring): # execute `ps -ef` and obtain its output result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True) lines = result.stdout.splitlines() # visit each line for line in lines: if cmd_substring in line: # extract PID parts = line.split() pid = int(parts[1]) print(f"Killing process with PID: {pid}, CMD: {line}") os.kill(pid, signal.SIGTERM) def main( python_path="python", run_controller=True, run_worker=True, run_gradio=True, controller_port=10086, gradio_port=7860, worker_names=[ "OpenGVLab/InternVL2-8B", ], run_sd_worker=False, **kwargs, ): host = "http://0.0.0.0" controller_process = None if run_controller: # python controller.py --host 0.0.0.0 --port 10086 cmd_args = [ f"{python_path}", "controller.py", "--host", "0.0.0.0", "--port", f"{controller_port}", ] kill_processes_by_cmd_substring(" ".join(cmd_args)) print("Launching controller: ", " ".join(cmd_args)) controller_process = subprocess.Popen(cmd_args) atexit.register(controller_process.terminate) worker_processes = [] if run_worker: worker_port = 10088 for worker_name in worker_names: cmd_args = [ f"{python_path}", "model_worker.py", "--port", f"{worker_port}", "--controller-url", f"{host}:{controller_port}", "--model-path", f"{worker_name}", "--load-8bit", ] kill_processes_by_cmd_substring(" ".join(cmd_args)) print("Launching worker: ", " ".join(cmd_args)) worker_process = subprocess.Popen(cmd_args) worker_processes.append(worker_process) atexit.register(worker_process.terminate) worker_port += 1 time.sleep(10) gradio_process = None if run_gradio: # python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086 cmd_args = [ f"{python_path}", "gradio_web_server.py", "--port", f"{gradio_port}", "--controller-url", f"{host}:{controller_port}", "--model-list-mode", "reload", ] kill_processes_by_cmd_substring(" ".join(cmd_args)) print("Launching gradio: ", " ".join(cmd_args)) gradio_process = subprocess.Popen(cmd_args) atexit.register(gradio_process.terminate) sd_worker_process = None if run_sd_worker: # python model_worker.py --port 10088 --controller-address http:// cmd_args = [f"{python_path}", "sd_worker.py"] kill_processes_by_cmd_substring(" ".join(cmd_args)) print("Launching sd_worker: ", " ".join(cmd_args)) sd_worker_process = subprocess.Popen(cmd_args) atexit.register(sd_worker_process.terminate) for worker_process in worker_processes: worker_process.wait() if controller_process: controller_process.wait() if gradio_process: gradio_process.wait() if sd_worker_process: sd_worker_process.wait() if __name__ == "__main__": fire.Fire(main)