Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
0a6d582
1
Parent(s):
2fa8d08
feat: add streem
Browse files- run_localGPT.py +4 -2
run_localGPT.py
CHANGED
@@ -14,11 +14,11 @@ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
|
14 |
|
15 |
from prompt_template_utils import get_prompt_template
|
16 |
|
17 |
-
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
18 |
from langchain.vectorstores import Chroma
|
19 |
from transformers import (
|
20 |
GenerationConfig,
|
21 |
pipeline,
|
|
|
22 |
)
|
23 |
|
24 |
from load_models import (
|
@@ -36,6 +36,7 @@ from constants import (
|
|
36 |
MODELS_PATH,
|
37 |
)
|
38 |
|
|
|
39 |
|
40 |
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
41 |
"""
|
@@ -76,6 +77,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
|
76 |
# main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
|
77 |
|
78 |
# Create a pipeline for text generation
|
|
|
79 |
pipe = pipeline(
|
80 |
"text-generation",
|
81 |
model=model,
|
@@ -86,7 +88,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
|
86 |
top_k=40,
|
87 |
repetition_penalty=1.0,
|
88 |
generation_config=generation_config,
|
89 |
-
streamer=
|
90 |
)
|
91 |
|
92 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
14 |
|
15 |
from prompt_template_utils import get_prompt_template
|
16 |
|
|
|
17 |
from langchain.vectorstores import Chroma
|
18 |
from transformers import (
|
19 |
GenerationConfig,
|
20 |
pipeline,
|
21 |
+
TextStreamer
|
22 |
)
|
23 |
|
24 |
from load_models import (
|
|
|
36 |
MODELS_PATH,
|
37 |
)
|
38 |
|
39 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
40 |
|
41 |
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
42 |
"""
|
|
|
77 |
# main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
|
78 |
|
79 |
# Create a pipeline for text generation
|
80 |
+
|
81 |
pipe = pipeline(
|
82 |
"text-generation",
|
83 |
model=model,
|
|
|
88 |
top_k=40,
|
89 |
repetition_penalty=1.0,
|
90 |
generation_config=generation_config,
|
91 |
+
streamer=streamer
|
92 |
)
|
93 |
|
94 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|