Daniel Marques commited on
Commit
c18ec7e
·
1 Parent(s): 198843f

fix: add streamer

Browse files
Files changed (2) hide show
  1. load_models.py +8 -32
  2. main.py +40 -18
load_models.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- import asyncio
3
  import logging
4
  from typing import Any, Dict, List
5
 
@@ -7,9 +7,6 @@ from auto_gptq import AutoGPTQForCausalLM
7
  from huggingface_hub import hf_hub_download
8
  from langchain.llms import LlamaCpp, HuggingFacePipeline
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
- from langchain.schema import LLMResult
11
-
12
- from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
13
 
14
  from transformers import (
15
  AutoModelForCausalLM,
@@ -28,30 +25,7 @@ torch.set_grad_enabled(False)
28
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
29
 
30
 
31
- class MyCustomSyncHandler(BaseCallbackHandler):
32
- def on_llm_new_token(self, token: str, **kwargs) -> None:
33
- print(f"Sync handler being called in a `thread_pool_executor`: token: {token}")
34
-
35
- class MyCustomAsyncHandler(AsyncCallbackHandler):
36
- """Async callback handler that can be used to handle callbacks from langchain."""
37
-
38
- async def on_llm_start(
39
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
40
- ) -> None:
41
- """Run when chain starts running."""
42
- print("zzzz....")
43
- await asyncio.sleep(0.3)
44
- class_name = serialized["name"]
45
- print("Hi! I just woke up. Your llm is starting")
46
-
47
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
48
- """Run when chain ends running."""
49
- print("zzzz....")
50
- await asyncio.sleep(0.3)
51
- print("Hi! I just woke up. Your llm is ending")
52
-
53
-
54
- def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False):
55
  """
56
  Load a GGUF/GGML quantized model using LlamaCpp.
57
 
@@ -93,9 +67,10 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
93
  if device_type.lower() == "cuda":
94
  kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
95
 
96
- #add stream
97
  kwargs["stream"] = stream
98
- kwargs["callbacks"] = [MyCustomSyncHandler(), MyCustomAsyncHandler()]
 
 
99
 
100
  return LlamaCpp(**kwargs)
101
  except:
@@ -145,6 +120,7 @@ def load_quantized_model_qptq(model_id, model_basename, device_type, logging):
145
  use_triton=False,
146
  quantize_config=None,
147
  )
 
148
  return model, tokenizer
149
 
150
 
@@ -195,7 +171,7 @@ def load_full_model(model_id, model_basename, device_type, logging):
195
  return model, tokenizer
196
 
197
 
198
- def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
199
  """
200
  Select a model for text generation using the HuggingFace library.
201
  If you are running this for the first time, it will download a model for you.
@@ -219,7 +195,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
219
 
220
  if model_basename is not None:
221
  if ".gguf" in model_basename.lower():
222
- llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING, stream)
223
  return llm
224
  elif ".ggml" in model_basename.lower():
225
  model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
 
1
  import torch
2
+
3
  import logging
4
  from typing import Any, Dict, List
5
 
 
7
  from huggingface_hub import hf_hub_download
8
  from langchain.llms import LlamaCpp, HuggingFacePipeline
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
 
10
 
11
  from transformers import (
12
  AutoModelForCausalLM,
 
25
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
26
 
27
 
28
+ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False, callbacks = []):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
  Load a GGUF/GGML quantized model using LlamaCpp.
31
 
 
67
  if device_type.lower() == "cuda":
68
  kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
69
 
 
70
  kwargs["stream"] = stream
71
+
72
+ if stream == True:
73
+ kwargs["callbacks"] = callbacks
74
 
75
  return LlamaCpp(**kwargs)
76
  except:
 
120
  use_triton=False,
121
  quantize_config=None,
122
  )
123
+
124
  return model, tokenizer
125
 
126
 
 
171
  return model, tokenizer
172
 
173
 
174
+ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False, callbacks = []):
175
  """
176
  Select a model for text generation using the HuggingFace library.
177
  If you are running this for the first time, it will download a model for you.
 
195
 
196
  if model_basename is not None:
197
  if ".gguf" in model_basename.lower():
198
+ llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING, stream, callbacks)
199
  return llm
200
  elif ".ggml" in model_basename.lower():
201
  model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
main.py CHANGED
@@ -1,17 +1,21 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, WebSocket
2
- from fastapi.staticfiles import StaticFiles
3
-
4
- from pydantic import BaseModel
5
  import os
6
  import glob
7
  import shutil
8
  import subprocess
 
 
 
 
 
 
9
 
10
  # import torch
11
  from langchain.chains import RetrievalQA
12
  from langchain.embeddings import HuggingFaceInstructEmbeddings
13
  from langchain.prompts import PromptTemplate
14
  from langchain.memory import ConversationBufferMemory
 
 
15
 
16
  # from langchain.embeddings import HuggingFaceEmbeddings
17
  from load_models import load_model
@@ -21,6 +25,26 @@ from langchain.vectorstores import Chroma
21
 
22
  from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # if torch.backends.mps.is_available():
25
  # DEVICE_TYPE = "mps"
26
  # elif torch.cuda.is_available():
@@ -42,15 +66,13 @@ DB = Chroma(
42
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
46
 
47
- template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
48
- You should only respond only topics that contains in documents use to training.
49
- Use the following pieces of context to answer the question at the end.
50
- Always answer in the most helpful and safe way possible.
51
- If you don't know the answer to a question, just say that you don't know, don't try to make up an answer, don't share false information.
52
- Use 15 sentences maximum. Keep the answer as concise as possible.
53
- Always say "thanks for asking!" at the end of the answer.
54
  Context: {history} \n {context}
55
  Question: {question}
56
  """
@@ -70,12 +92,6 @@ QA = RetrievalQA.from_chain_type(
70
  },
71
  )
72
 
73
- class Predict(BaseModel):
74
- prompt: str
75
-
76
- class Delete(BaseModel):
77
- filename: str
78
-
79
  app = FastAPI(title="homepage-app")
80
  api_app = FastAPI(title="api app")
81
 
@@ -179,6 +195,12 @@ async def predict(data: Predict):
179
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
180
  )
181
 
 
 
 
 
 
 
182
  # generated_text = ""
183
  # for new_text in STREAMER:
184
  # generated_text += new_text
 
 
 
 
 
1
  import os
2
  import glob
3
  import shutil
4
  import subprocess
5
+ import asyncio
6
+
7
+ from fastapi import FastAPI, HTTPException, UploadFile, WebSocket
8
+ from fastapi.staticfiles import StaticFiles
9
+
10
+ from pydantic import BaseModel
11
 
12
  # import torch
13
  from langchain.chains import RetrievalQA
14
  from langchain.embeddings import HuggingFaceInstructEmbeddings
15
  from langchain.prompts import PromptTemplate
16
  from langchain.memory import ConversationBufferMemory
17
+ from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
18
+ from langchain.schema import LLMResult
19
 
20
  # from langchain.embeddings import HuggingFaceEmbeddings
21
  from load_models import load_model
 
25
 
26
  from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY
27
 
28
+ class Predict(BaseModel):
29
+ prompt: str
30
+
31
+ class Delete(BaseModel):
32
+ filename: str
33
+
34
+
35
+ class MyCustomAsyncHandler(AsyncCallbackHandler):
36
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
37
+ print(f" token: {token}")
38
+
39
+ async def on_llm_start(
40
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
41
+ ) -> None:
42
+ class_name = serialized["name"]
43
+ print("start")
44
+
45
+ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
46
+ print("finish")
47
+
48
  # if torch.backends.mps.is_available():
49
  # DEVICE_TYPE = "mps"
50
  # elif torch.cuda.is_available():
 
66
 
67
  RETRIEVER = DB.as_retriever()
68
 
69
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks = [MyCustomAsyncHandler])
70
 
71
+ template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided.
72
+ You should only answer the topics that appear in these documents.
73
+ Always answer in the most helpful and reliable way possible, if you don't know the answer to a question, just say you don't know, don't try to make up an answer,
74
+ don't share false information. you should use no more than 15 sentences and all your answers should be as concise as possible.
75
+ Always say "Thank you for asking!" at the end of your answer.
 
 
76
  Context: {history} \n {context}
77
  Question: {question}
78
  """
 
92
  },
93
  )
94
 
 
 
 
 
 
 
95
  app = FastAPI(title="homepage-app")
96
  api_app = FastAPI(title="api app")
97
 
 
195
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
196
  )
197
 
198
+ qa_chain_response = res.stream(
199
+ {"query": user_prompt},
200
+ )
201
+
202
+ print(f"{qa_chain_response} stream")
203
+
204
  # generated_text = ""
205
  # for new_text in STREAMER:
206
  # generated_text += new_text