abalakrishnaTRI
commited on
Commit
•
5b53c67
1
Parent(s):
6ba6dce
clean
Browse files- interactive_demo.py +16 -39
interactive_demo.py
CHANGED
@@ -47,20 +47,12 @@ def heart_beat_worker(controller):
|
|
47 |
|
48 |
|
49 |
class ModelWorker:
|
50 |
-
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm,
|
51 |
self.controller_addr = controller_addr
|
52 |
self.worker_addr = worker_addr
|
53 |
self.worker_id = worker_id
|
54 |
self.model_name = model_name
|
55 |
-
|
56 |
-
# logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
57 |
self.vlm = vlm
|
58 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = (
|
59 |
-
vlm.tokenizer,
|
60 |
-
vlm.model,
|
61 |
-
vlm.image_processor,
|
62 |
-
vlm.max_length,
|
63 |
-
)
|
64 |
|
65 |
if not no_register:
|
66 |
self.register_to_controller()
|
@@ -68,18 +60,12 @@ class ModelWorker:
|
|
68 |
self.heart_beat_thread.start()
|
69 |
|
70 |
def register_to_controller(self):
|
71 |
-
# logger.info("Register to controller")
|
72 |
-
|
73 |
url = self.controller_addr + "/register_worker"
|
74 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
75 |
r = requests.post(url, json=data)
|
76 |
assert r.status_code == 200
|
77 |
|
78 |
def send_heart_beat(self):
|
79 |
-
# logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
80 |
-
# f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
81 |
-
# f"global_counter: {global_counter}")
|
82 |
-
|
83 |
url = self.controller_addr + "/receive_heart_beat"
|
84 |
|
85 |
while True:
|
@@ -91,7 +77,6 @@ class ModelWorker:
|
|
91 |
break
|
92 |
except requests.exceptions.RequestException:
|
93 |
pass
|
94 |
-
# logger.error(f"heart beat error: {e}")
|
95 |
time.sleep(5)
|
96 |
|
97 |
if not exist:
|
@@ -145,12 +130,12 @@ class ModelWorker:
|
|
145 |
else:
|
146 |
question_prompt = [prompt_fn()]
|
147 |
|
148 |
-
if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
|
149 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
150 |
-
pixel_values = self.image_processor(images[0].convert("RGB"))
|
151 |
else:
|
152 |
# Assume `image_transform` is a HF ImageProcessor...
|
153 |
-
pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
154 |
|
155 |
if type(pixel_values) is dict:
|
156 |
for k in pixel_values.keys():
|
@@ -227,31 +212,29 @@ overwatch = initialize_overwatch(__name__)
|
|
227 |
class DemoConfig:
|
228 |
# fmt: off
|
229 |
|
230 |
-
# === Model Parameters =>>
|
231 |
-
model_family: str = "
|
232 |
-
model_id: str = "
|
233 |
-
model_dir:
|
234 |
-
"resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
|
235 |
-
)
|
236 |
|
237 |
# === Model Parameters =>> Official LLaVa ===
|
238 |
# model_family: str = "llava-v15"
|
239 |
# model_id: str = "llava-v1.5-13b"
|
240 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
# Model Worker Parameters
|
243 |
host: str = "0.0.0.0"
|
244 |
port: int = 40000
|
245 |
controller_address: str = "http://localhost:10000"
|
246 |
-
model_base: str = "llava-v15"
|
247 |
limit_model_concurrency: int = 5
|
248 |
stream_interval: int = 1
|
249 |
no_register: bool = False
|
250 |
|
251 |
-
# Inference Parameters
|
252 |
-
device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
|
253 |
-
num_workers: int = 2 # Number of Dataloader Workers (on each process)
|
254 |
-
|
255 |
# HF Hub Credentials (for LLaMa-2)
|
256 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
257 |
|
@@ -259,14 +242,8 @@ class DemoConfig:
|
|
259 |
seed: int = 21 # Random Seed (for reproducibility)
|
260 |
|
261 |
def __post_init__(self) -> None:
|
262 |
-
|
263 |
-
|
264 |
-
self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
|
265 |
-
elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
|
266 |
-
self.model_name = MODEL_ID_TO_NAME[self.model_id]
|
267 |
-
self.run_dir = self.model_dir
|
268 |
-
else:
|
269 |
-
raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
|
270 |
self.worker_address = f"http://localhost:{self.port}"
|
271 |
|
272 |
# fmt: on
|
@@ -286,7 +263,7 @@ def interactive_demo(cfg: DemoConfig):
|
|
286 |
global limit_model_concurrency
|
287 |
limit_model_concurrency = cfg.limit_model_concurrency
|
288 |
worker = ModelWorker(
|
289 |
-
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.
|
290 |
)
|
291 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
292 |
|
|
|
47 |
|
48 |
|
49 |
class ModelWorker:
|
50 |
+
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_name):
|
51 |
self.controller_addr = controller_addr
|
52 |
self.worker_addr = worker_addr
|
53 |
self.worker_id = worker_id
|
54 |
self.model_name = model_name
|
|
|
|
|
55 |
self.vlm = vlm
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if not no_register:
|
58 |
self.register_to_controller()
|
|
|
60 |
self.heart_beat_thread.start()
|
61 |
|
62 |
def register_to_controller(self):
|
|
|
|
|
63 |
url = self.controller_addr + "/register_worker"
|
64 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
65 |
r = requests.post(url, json=data)
|
66 |
assert r.status_code == 200
|
67 |
|
68 |
def send_heart_beat(self):
|
|
|
|
|
|
|
|
|
69 |
url = self.controller_addr + "/receive_heart_beat"
|
70 |
|
71 |
while True:
|
|
|
77 |
break
|
78 |
except requests.exceptions.RequestException:
|
79 |
pass
|
|
|
80 |
time.sleep(5)
|
81 |
|
82 |
if not exist:
|
|
|
130 |
else:
|
131 |
question_prompt = [prompt_fn()]
|
132 |
|
133 |
+
if isinstance(self.vlm.image_processor, Compose) or hasattr(self.vlm.image_processor, "is_prismatic"):
|
134 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
135 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"))
|
136 |
else:
|
137 |
# Assume `image_transform` is a HF ImageProcessor...
|
138 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
139 |
|
140 |
if type(pixel_values) is dict:
|
141 |
for k in pixel_values.keys():
|
|
|
212 |
class DemoConfig:
|
213 |
# fmt: off
|
214 |
|
215 |
+
# === Model Parameters =>> Prismatic ===
|
216 |
+
model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
|
217 |
+
model_id: str = "prism-dinosiglip+7b" # Model ID to load and run (instance of `model_family`)
|
218 |
+
model_dir: str = None # Can optionally supply model_dir instead of model_id
|
|
|
|
|
219 |
|
220 |
# === Model Parameters =>> Official LLaVa ===
|
221 |
# model_family: str = "llava-v15"
|
222 |
# model_id: str = "llava-v1.5-13b"
|
223 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
224 |
|
225 |
+
# === Model Parameters =>> Official InstructBLIP ===
|
226 |
+
# model_family: str = "instruct-blip"
|
227 |
+
# model_id: str = "instructblip-vicuna-7b"
|
228 |
+
# model_dir: Path = "Salesforce/instructblip-vicuna-7b"
|
229 |
+
|
230 |
# Model Worker Parameters
|
231 |
host: str = "0.0.0.0"
|
232 |
port: int = 40000
|
233 |
controller_address: str = "http://localhost:10000"
|
|
|
234 |
limit_model_concurrency: int = 5
|
235 |
stream_interval: int = 1
|
236 |
no_register: bool = False
|
237 |
|
|
|
|
|
|
|
|
|
238 |
# HF Hub Credentials (for LLaMa-2)
|
239 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
240 |
|
|
|
242 |
seed: int = 21 # Random Seed (for reproducibility)
|
243 |
|
244 |
def __post_init__(self) -> None:
|
245 |
+
self.run_dir = self.model_dir
|
246 |
+
self.model_name = MODEL_ID_TO_NAME[str(self.model_id)]
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
self.worker_address = f"http://localhost:{self.port}"
|
248 |
|
249 |
# fmt: on
|
|
|
263 |
global limit_model_concurrency
|
264 |
limit_model_concurrency = cfg.limit_model_concurrency
|
265 |
worker = ModelWorker(
|
266 |
+
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_name
|
267 |
)
|
268 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
269 |
|