Daniel Marques commited on
Commit
0a6d582
·
1 Parent(s): 2fa8d08

feat: add streem

Browse files
Files changed (1) hide show
  1. run_localGPT.py +4 -2
run_localGPT.py CHANGED
@@ -14,11 +14,11 @@ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
14
 
15
  from prompt_template_utils import get_prompt_template
16
 
17
- # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
18
  from langchain.vectorstores import Chroma
19
  from transformers import (
20
  GenerationConfig,
21
  pipeline,
 
22
  )
23
 
24
  from load_models import (
@@ -36,6 +36,7 @@ from constants import (
36
  MODELS_PATH,
37
  )
38
 
 
39
 
40
  def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
41
  """
@@ -76,6 +77,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
76
  # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
77
 
78
  # Create a pipeline for text generation
 
79
  pipe = pipeline(
80
  "text-generation",
81
  model=model,
@@ -86,7 +88,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
86
  top_k=40,
87
  repetition_penalty=1.0,
88
  generation_config=generation_config,
89
- streamer=True
90
  )
91
 
92
  local_llm = HuggingFacePipeline(pipeline=pipe)
 
14
 
15
  from prompt_template_utils import get_prompt_template
16
 
 
17
  from langchain.vectorstores import Chroma
18
  from transformers import (
19
  GenerationConfig,
20
  pipeline,
21
+ TextStreamer
22
  )
23
 
24
  from load_models import (
 
36
  MODELS_PATH,
37
  )
38
 
39
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
40
 
41
  def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
42
  """
 
77
  # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
78
 
79
  # Create a pipeline for text generation
80
+
81
  pipe = pipeline(
82
  "text-generation",
83
  model=model,
 
88
  top_k=40,
89
  repetition_penalty=1.0,
90
  generation_config=generation_config,
91
+ streamer=streamer
92
  )
93
 
94
  local_llm = HuggingFacePipeline(pipeline=pipe)