Daniel Marques commited on
Commit
2ea73cf
·
1 Parent(s): 3861b3b

fix: add streamer

Browse files
Files changed (2) hide show
  1. load_models.py +1 -6
  2. main.py +13 -4
load_models.py CHANGED
@@ -222,9 +222,4 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
222
  local_llm = HuggingFacePipeline(pipeline=pipe)
223
  logging.info("Local LLM Loaded")
224
 
225
- generated_text = ""
226
- for new_text in streamer:
227
- generated_text += new_text
228
- print(generated_text)
229
-
230
- return local_llm
 
222
  local_llm = HuggingFacePipeline(pipeline=pipe)
223
  logging.info("Local LLM Loaded")
224
 
225
+ return [local_llm, streamer]
 
 
 
 
 
main.py CHANGED
@@ -42,9 +42,11 @@ DB = Chroma(
42
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
 
 
46
 
47
- template = """Your name is Katara and you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
48
  You should only respond only topics that contains in documents use to training.
49
  Use the following pieces of context to answer the question at the end.
50
  Always answer in the most helpful and safe way possible.
@@ -70,7 +72,6 @@ QA = RetrievalQA.from_chain_type(
70
  },
71
  )
72
 
73
-
74
  class Predict(BaseModel):
75
  prompt: str
76
 
@@ -145,7 +146,7 @@ def get_files():
145
  def delete_source_route(data: Delete):
146
  filename = data.filename
147
  path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
148
- file_to_delete = f"{path_source_documents}/${filename}"
149
 
150
  if os.path.exists(file_to_delete):
151
  try:
@@ -166,6 +167,9 @@ async def predict(data: Predict):
166
  # print(f'User Prompt: {user_prompt}')
167
  # Get the answer from the chain
168
  res = QA(user_prompt)
 
 
 
169
  answer, docs = res["result"], res["source_documents"]
170
 
171
  prompt_response_dict = {
@@ -179,6 +183,11 @@ async def predict(data: Predict):
179
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
180
  )
181
 
 
 
 
 
 
182
  return {"response": prompt_response_dict}
183
  else:
184
  raise HTTPException(status_code=400, detail="Prompt Incorrect")
 
42
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
+ models = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
46
+ LLM = models[0]
47
+ STREAMER = models[1]
48
 
49
+ template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
50
  You should only respond only topics that contains in documents use to training.
51
  Use the following pieces of context to answer the question at the end.
52
  Always answer in the most helpful and safe way possible.
 
72
  },
73
  )
74
 
 
75
  class Predict(BaseModel):
76
  prompt: str
77
 
 
146
  def delete_source_route(data: Delete):
147
  filename = data.filename
148
  path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
149
+ file_to_delete = f"{path_source_documents}/{filename}"
150
 
151
  if os.path.exists(file_to_delete):
152
  try:
 
167
  # print(f'User Prompt: {user_prompt}')
168
  # Get the answer from the chain
169
  res = QA(user_prompt)
170
+
171
+ print(res)
172
+
173
  answer, docs = res["result"], res["source_documents"]
174
 
175
  prompt_response_dict = {
 
183
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
184
  )
185
 
186
+ generated_text = ""
187
+ for new_text in STREAMER:
188
+ generated_text += new_text
189
+ print(generated_text)
190
+
191
  return {"response": prompt_response_dict}
192
  else:
193
  raise HTTPException(status_code=400, detail="Prompt Incorrect")