sasan commited on
Commit
676fdbb
1 Parent(s): 411ba15

chore: Add local backend option for LLM

Browse files
Files changed (2) hide show
  1. kitt/core/model.py +57 -1
  2. main.py +5 -2
kitt/core/model.py CHANGED
@@ -345,6 +345,58 @@ def run_inference_ollama(prompt):
345
  return res
346
 
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  def run_inference(prompt, backend="ollama"):
349
  prompt += AI_PREAMBLE
350
 
@@ -352,8 +404,12 @@ def run_inference(prompt, backend="ollama"):
352
 
353
  if backend == "ollama":
354
  output = run_inference_ollama(prompt)
355
- else:
356
  output = run_inference_replicate(prompt)
 
 
 
 
357
 
358
  logger.debug(f"Response from model: {output}")
359
  return output
 
345
  return res
346
 
347
 
348
+ def load_gpu_model():
349
+ import bitsandbytes
350
+ import flash_attn
351
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
352
+
353
+ tokenizer = AutoTokenizer.from_pretrained(
354
+ "NousResearch/Hermes-2-Pro-Llama-3-8B", trust_remote_code=True
355
+ )
356
+ model = LlamaForCausalLM.from_pretrained(
357
+ "NousResearch/Hermes-2-Pro-Llama-3-8B",
358
+ torch_dtype=torch.float16,
359
+ device_map="auto",
360
+ load_in_8bit=False,
361
+ load_in_4bit=True,
362
+ use_flash_attention_2=True,
363
+ )
364
+
365
+ return model, tokenizer
366
+
367
+
368
+ try:
369
+ model, tokenizer = load_gpu_model()
370
+ except Exception as e:
371
+ logger.error(f"Could not load model: {e}")
372
+ model, tokenizer = None, None
373
+
374
+
375
+ def run_inference_local(prompt):
376
+ """Run inference on local model using huggingface transformers"""
377
+
378
+ if not model:
379
+ logger.error("Model not loaded. Exiting.")
380
+ raise ValueError("Model not loaded. Exiting.")
381
+
382
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
383
+ generated_ids = model.generate(
384
+ input_ids,
385
+ max_new_tokens=1500,
386
+ temperature=TEMPERATURE,
387
+ repetition_penalty=REPEAT_PENALTY,
388
+ do_sample=True,
389
+ eos_token_id=tokenizer.eos_token_id,
390
+ )
391
+ response = tokenizer.decode(
392
+ generated_ids[0][input_ids.shape[-1] :],
393
+ skip_special_tokens=True,
394
+ clean_up_tokenization_space=True,
395
+ )
396
+
397
+ return response
398
+
399
+
400
  def run_inference(prompt, backend="ollama"):
401
  prompt += AI_PREAMBLE
402
 
 
404
 
405
  if backend == "ollama":
406
  output = run_inference_ollama(prompt)
407
+ elif backend == "replicate":
408
  output = run_inference_replicate(prompt)
409
+ elif backend == "local":
410
+ output = run_inference_local(prompt)
411
+ else:
412
+ raise ValueError(f"Backend {backend} not supported")
413
 
414
  logger.debug(f"Response from model: {output}")
415
  return output
main.py CHANGED
@@ -198,7 +198,8 @@ def set_tts_enabled(tts_enabled, state):
198
 
199
 
200
  def set_llm_backend(llm_backend, state):
201
- new_llm_backend = "ollama" if llm_backend == "Ollama" else "replicate"
 
202
  logger.info(
203
  f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
204
  )
@@ -283,6 +284,8 @@ def create_demo(tts_server: bool = False, model="llama3"):
283
  global_context["map"] = plot
284
 
285
  with gr.Row():
 
 
286
  with gr.Column(scale=1, min_width=300):
287
  vehicle_status = gr.JSON(
288
  value=vehicle.model_dump(), label="Vehicle status"
@@ -375,7 +378,7 @@ def create_demo(tts_server: bool = False, model="llama3"):
375
  interactive=True,
376
  )
377
  llm_backend = gr.Radio(
378
- choices=["Ollama", "Replicate"],
379
  label="LLM Backend",
380
  value=DEFAULT_LLM_BACKEND.title(),
381
  interactive=True,
 
198
 
199
 
200
  def set_llm_backend(llm_backend, state):
201
+ assert llm_backend in ["Ollama", "Replicate", "Local"], "Invalid LLM backend"
202
+ new_llm_backend = llm_backend.lower()
203
  logger.info(
204
  f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
205
  )
 
284
  global_context["map"] = plot
285
 
286
  with gr.Row():
287
+ # with gr.Row():
288
+ # gr.Text("KITT", interactive=False)
289
  with gr.Column(scale=1, min_width=300):
290
  vehicle_status = gr.JSON(
291
  value=vehicle.model_dump(), label="Vehicle status"
 
378
  interactive=True,
379
  )
380
  llm_backend = gr.Radio(
381
+ choices=["Ollama", "Replicate", "Local"],
382
  label="LLM Backend",
383
  value=DEFAULT_LLM_BACKEND.title(),
384
  interactive=True,