Daniel Marques commited on
Commit
b775dc2
·
1 Parent(s): 0a6d582

feat: add stream

Browse files
Files changed (2) hide show
  1. main.py +1 -1
  2. run_localGPT.py +4 -2
main.py CHANGED
@@ -44,7 +44,7 @@ DB = Chroma(
44
 
45
  RETRIEVER = DB.as_retriever()
46
 
47
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
48
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
49
 
50
  template = """you are a helpful, respectful and honest assistant.
 
44
 
45
  RETRIEVER = DB.as_retriever()
46
 
47
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
48
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
49
 
50
  template = """you are a helpful, respectful and honest assistant.
run_localGPT.py CHANGED
@@ -36,9 +36,9 @@ from constants import (
36
  MODELS_PATH,
37
  )
38
 
39
- streamer = TextStreamer(tokenizer, skip_prompt=True)
40
 
41
- def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
 
42
  """
43
  Select a model for text generation using the HuggingFace library.
44
  If you are running this for the first time, it will download a model for you.
@@ -56,6 +56,8 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
56
  Raises:
57
  ValueError: If an unsupported model or device type is provided.
58
  """
 
 
59
  logging.info(f"Loading Model: {model_id}, on: {device_type}")
60
  logging.info("This action can take a few minutes!")
61
 
 
36
  MODELS_PATH,
37
  )
38
 
 
39
 
40
+
41
+ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
42
  """
43
  Select a model for text generation using the HuggingFace library.
44
  If you are running this for the first time, it will download a model for you.
 
56
  Raises:
57
  ValueError: If an unsupported model or device type is provided.
58
  """
59
+ streamer = TextStreamer(tokenizer, skip_prompt=stream)
60
+
61
  logging.info(f"Loading Model: {model_id}, on: {device_type}")
62
  logging.info("This action can take a few minutes!")
63