chore: Add local backend option for LLM
Browse files- kitt/core/model.py +57 -1
- main.py +5 -2
kitt/core/model.py
CHANGED
@@ -345,6 +345,58 @@ def run_inference_ollama(prompt):
|
|
345 |
return res
|
346 |
|
347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
def run_inference(prompt, backend="ollama"):
|
349 |
prompt += AI_PREAMBLE
|
350 |
|
@@ -352,8 +404,12 @@ def run_inference(prompt, backend="ollama"):
|
|
352 |
|
353 |
if backend == "ollama":
|
354 |
output = run_inference_ollama(prompt)
|
355 |
-
|
356 |
output = run_inference_replicate(prompt)
|
|
|
|
|
|
|
|
|
357 |
|
358 |
logger.debug(f"Response from model: {output}")
|
359 |
return output
|
|
|
345 |
return res
|
346 |
|
347 |
|
348 |
+
def load_gpu_model():
|
349 |
+
import bitsandbytes
|
350 |
+
import flash_attn
|
351 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
352 |
+
|
353 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
354 |
+
"NousResearch/Hermes-2-Pro-Llama-3-8B", trust_remote_code=True
|
355 |
+
)
|
356 |
+
model = LlamaForCausalLM.from_pretrained(
|
357 |
+
"NousResearch/Hermes-2-Pro-Llama-3-8B",
|
358 |
+
torch_dtype=torch.float16,
|
359 |
+
device_map="auto",
|
360 |
+
load_in_8bit=False,
|
361 |
+
load_in_4bit=True,
|
362 |
+
use_flash_attention_2=True,
|
363 |
+
)
|
364 |
+
|
365 |
+
return model, tokenizer
|
366 |
+
|
367 |
+
|
368 |
+
try:
|
369 |
+
model, tokenizer = load_gpu_model()
|
370 |
+
except Exception as e:
|
371 |
+
logger.error(f"Could not load model: {e}")
|
372 |
+
model, tokenizer = None, None
|
373 |
+
|
374 |
+
|
375 |
+
def run_inference_local(prompt):
|
376 |
+
"""Run inference on local model using huggingface transformers"""
|
377 |
+
|
378 |
+
if not model:
|
379 |
+
logger.error("Model not loaded. Exiting.")
|
380 |
+
raise ValueError("Model not loaded. Exiting.")
|
381 |
+
|
382 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
383 |
+
generated_ids = model.generate(
|
384 |
+
input_ids,
|
385 |
+
max_new_tokens=1500,
|
386 |
+
temperature=TEMPERATURE,
|
387 |
+
repetition_penalty=REPEAT_PENALTY,
|
388 |
+
do_sample=True,
|
389 |
+
eos_token_id=tokenizer.eos_token_id,
|
390 |
+
)
|
391 |
+
response = tokenizer.decode(
|
392 |
+
generated_ids[0][input_ids.shape[-1] :],
|
393 |
+
skip_special_tokens=True,
|
394 |
+
clean_up_tokenization_space=True,
|
395 |
+
)
|
396 |
+
|
397 |
+
return response
|
398 |
+
|
399 |
+
|
400 |
def run_inference(prompt, backend="ollama"):
|
401 |
prompt += AI_PREAMBLE
|
402 |
|
|
|
404 |
|
405 |
if backend == "ollama":
|
406 |
output = run_inference_ollama(prompt)
|
407 |
+
elif backend == "replicate":
|
408 |
output = run_inference_replicate(prompt)
|
409 |
+
elif backend == "local":
|
410 |
+
output = run_inference_local(prompt)
|
411 |
+
else:
|
412 |
+
raise ValueError(f"Backend {backend} not supported")
|
413 |
|
414 |
logger.debug(f"Response from model: {output}")
|
415 |
return output
|
main.py
CHANGED
@@ -198,7 +198,8 @@ def set_tts_enabled(tts_enabled, state):
|
|
198 |
|
199 |
|
200 |
def set_llm_backend(llm_backend, state):
|
201 |
-
|
|
|
202 |
logger.info(
|
203 |
f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
|
204 |
)
|
@@ -283,6 +284,8 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
283 |
global_context["map"] = plot
|
284 |
|
285 |
with gr.Row():
|
|
|
|
|
286 |
with gr.Column(scale=1, min_width=300):
|
287 |
vehicle_status = gr.JSON(
|
288 |
value=vehicle.model_dump(), label="Vehicle status"
|
@@ -375,7 +378,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
375 |
interactive=True,
|
376 |
)
|
377 |
llm_backend = gr.Radio(
|
378 |
-
choices=["Ollama", "Replicate"],
|
379 |
label="LLM Backend",
|
380 |
value=DEFAULT_LLM_BACKEND.title(),
|
381 |
interactive=True,
|
|
|
198 |
|
199 |
|
200 |
def set_llm_backend(llm_backend, state):
|
201 |
+
assert llm_backend in ["Ollama", "Replicate", "Local"], "Invalid LLM backend"
|
202 |
+
new_llm_backend = llm_backend.lower()
|
203 |
logger.info(
|
204 |
f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
|
205 |
)
|
|
|
284 |
global_context["map"] = plot
|
285 |
|
286 |
with gr.Row():
|
287 |
+
# with gr.Row():
|
288 |
+
# gr.Text("KITT", interactive=False)
|
289 |
with gr.Column(scale=1, min_width=300):
|
290 |
vehicle_status = gr.JSON(
|
291 |
value=vehicle.model_dump(), label="Vehicle status"
|
|
|
378 |
interactive=True,
|
379 |
)
|
380 |
llm_backend = gr.Radio(
|
381 |
+
choices=["Ollama", "Replicate", "Local"],
|
382 |
label="LLM Backend",
|
383 |
value=DEFAULT_LLM_BACKEND.title(),
|
384 |
interactive=True,
|