dkdaniz commited on
Commit
5ccd0e0
·
1 Parent(s): 70a5935

Update run_localGPT.py

Browse files
Files changed (1) hide show
  1. run_localGPT.py +106 -212
run_localGPT.py CHANGED
@@ -1,163 +1,95 @@
1
- import os
2
  import logging
 
 
 
3
  import click
4
  import torch
5
- from langchain.chains import RetrievalQA
6
  from langchain.embeddings import HuggingFaceInstructEmbeddings
7
- from langchain.llms import HuggingFacePipeline
8
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
9
- from langchain.callbacks.manager import CallbackManager
10
-
11
- torch.set_grad_enabled(False)
12
-
13
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
14
-
15
- from prompt_template_utils import get_prompt_template
16
-
17
- # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
18
  from langchain.vectorstores import Chroma
19
- from transformers import (
20
- GenerationConfig,
21
- pipeline,
22
- )
23
 
24
- from load_models import (
25
- load_quantized_model_gguf_ggml,
26
- load_quantized_model_qptq,
27
- load_full_model,
28
- )
29
 
30
  from constants import (
 
 
31
  EMBEDDING_MODEL_NAME,
 
32
  PERSIST_DIRECTORY,
33
- MODEL_ID,
34
- MODEL_BASENAME,
35
- MAX_NEW_TOKENS,
36
- MODELS_PATH,
37
  )
38
 
39
 
40
- def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
41
- """
42
- Select a model for text generation using the HuggingFace library.
43
- If you are running this for the first time, it will download a model for you.
44
- subsequent runs will use the model from the disk.
45
-
46
- Args:
47
- device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
48
- model_id (str): Identifier of the model to load from HuggingFace's model hub.
49
- model_basename (str, optional): Basename of the model if using quantized models.
50
- Defaults to None.
51
-
52
- Returns:
53
- HuggingFacePipeline: A pipeline object for text generation using the loaded model.
54
-
55
- Raises:
56
- ValueError: If an unsupported model or device type is provided.
57
- """
58
- logging.info(f"Loading Model: {model_id}, on: {device_type}")
59
- logging.info("This action can take a few minutes!")
60
-
61
- if model_basename is not None:
62
- if ".gguf" in model_basename.lower():
63
- llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
64
- return llm
65
- elif ".ggml" in model_basename.lower():
66
- model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
67
- else:
68
- model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING)
69
- else:
70
- model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING)
71
-
72
- # Load configuration from the model to avoid warnings
73
- generation_config = GenerationConfig.from_pretrained(model_id)
74
- # see here for details:
75
- # https://huggingface.co/docs/transformers/
76
- # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
77
-
78
- # Create a pipeline for text generation
79
- pipe = pipeline(
80
- "text-generation",
81
- model=model,
82
- tokenizer=tokenizer,
83
- max_length=50,
84
- temperature=0.2,
85
- # top_p=0.95,
86
- repetition_penalty=1.15,
87
- generation_config=generation_config,
88
- )
89
-
90
- local_llm = HuggingFacePipeline(pipeline=pipe)
91
- logging.info("Local LLM Loaded")
92
-
93
- return local_llm
94
-
95
-
96
- def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
97
- """
98
- Initializes and returns a retrieval-based Question Answering (QA) pipeline.
99
-
100
- This function sets up a QA system that retrieves relevant information using embeddings
101
- from the HuggingFace library. It then answers questions based on the retrieved information.
102
-
103
- Parameters:
104
- - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'cuda', etc.
105
- - use_history (bool): Flag to determine whether to use chat history or not.
106
-
107
- Returns:
108
- - RetrievalQA: An initialized retrieval-based QA system.
109
-
110
- Notes:
111
- - The function uses embeddings from the HuggingFace library, either instruction-based or regular.
112
- - The Chroma class is used to load a vector store containing pre-computed embeddings.
113
- - The retriever fetches relevant documents or data based on a query.
114
- - The prompt and memory, obtained from the `get_prompt_template` function, might be used in the QA system.
115
- - The model is loaded onto the specified device using its ID and basename.
116
- - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
117
- """
118
-
119
- embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
120
- # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
121
- # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
122
-
123
- # load the vectorstore
124
- db = Chroma(
125
- persist_directory=PERSIST_DIRECTORY,
126
- embedding_function=embeddings,
127
- )
128
- retriever = db.as_retriever()
129
-
130
- # get the prompt template and memory if set by the user.
131
- prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
132
-
133
- # load the llm pipeline
134
- llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
135
-
136
- if use_history:
137
- qa = RetrievalQA.from_chain_type(
138
- llm=llm,
139
- chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
140
- retriever=retriever,
141
- return_source_documents=True, # verbose=True,
142
- callbacks=callback_manager,
143
- chain_type_kwargs={"prompt": prompt, "memory": memory},
144
- )
145
  else:
146
- qa = RetrievalQA.from_chain_type(
147
- llm=llm,
148
- chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
149
- retriever=retriever,
150
- return_source_documents=True, # verbose=True,
151
- callbacks=callback_manager,
152
- chain_type_kwargs={
153
- "prompt": prompt,
154
- },
155
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- return qa
158
 
159
 
160
- # chose device typ to run on as well as to show source documents.
161
  @click.command()
162
  @click.option(
163
  "--device_type",
@@ -187,78 +119,40 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
187
  ),
188
  help="Device to run on. (Default is cuda)",
189
  )
190
- @click.option(
191
- "--show_sources",
192
- "-s",
193
- is_flag=True,
194
- help="Show sources along with answers (Default is False)",
195
- )
196
- @click.option(
197
- "--use_history",
198
- "-h",
199
- is_flag=True,
200
- help="Use history (Default is False)",
201
- )
202
- @click.option(
203
- "--model_type",
204
- default="llama",
205
- type=click.Choice(
206
- ["llama", "mistral", "non_llama"],
207
- ),
208
- help="model type, llama, mistral or non_llama",
209
- )
210
- def main(device_type, show_sources, use_history, model_type):
211
- """
212
- Implements the main information retrieval task for a localGPT.
213
-
214
- This function sets up the QA system by loading the necessary embeddings, vectorstore, and LLM model.
215
- It then enters an interactive loop where the user can input queries and receive answers. Optionally,
216
- the source documents used to derive the answers can also be displayed.
217
-
218
- Parameters:
219
- - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'mps', 'cuda', etc.
220
- - show_sources (bool): Flag to determine whether to display the source documents used for answering.
221
- - use_history (bool): Flag to determine whether to use chat history or not.
222
-
223
- Notes:
224
- - Logging information includes the device type, whether source documents are displayed, and the use of history.
225
- - If the models directory does not exist, it creates a new one to store models.
226
- - The user can exit the interactive loop by entering "exit".
227
- - The source documents are displayed if the show_sources flag is set to True.
228
-
229
- """
230
-
231
- logging.info(f"Running on: {device_type}")
232
- logging.info(f"Display Source Documents set to: {show_sources}")
233
- logging.info(f"Use history set to: {use_history}")
234
-
235
- # check if models directory do not exist, create a new one and store models here.
236
- if not os.path.exists(MODELS_PATH):
237
- os.mkdir(MODELS_PATH)
238
 
239
- qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type=model_type)
240
- # Interactive questions and answers
241
- while True:
242
- query = input("\nEnter a query: ")
243
- if query == "exit":
244
- break
245
- # Get the answer from the chain
246
- res = qa(query)
247
- answer, docs = res["result"], res["source_documents"]
248
 
249
- # Print the result
250
- print("\n\n> Question:")
251
- print(query)
252
- print("\n> Answer:")
253
- print(answer)
254
 
255
- if show_sources: # this is a flag that you can set to disable showing answers.
256
- # # Print the relevant sources used for the answer
257
- print("----------------------------------SOURCE DOCUMENTS---------------------------")
258
- for document in docs:
259
- print("\n> " + document.metadata["source"] + ":")
260
- print(document.page_content)
261
- print("----------------------------------SOURCE DOCUMENTS---------------------------")
262
 
263
 
264
  if __name__ == "__main__":
 
 
1
  import logging
2
+ import os
3
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
4
+
5
  import click
6
  import torch
7
+ from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
10
  from langchain.vectorstores import Chroma
 
 
 
 
11
 
12
+ torch.cuda.empty_cache()
13
+ torch.cuda.memory_summary(device=None, abbreviated=False)
 
 
 
14
 
15
  from constants import (
16
+ CHROMA_SETTINGS,
17
+ DOCUMENT_MAP,
18
  EMBEDDING_MODEL_NAME,
19
+ INGEST_THREADS,
20
  PERSIST_DIRECTORY,
21
+ SOURCE_DIRECTORY,
 
 
 
22
  )
23
 
24
 
25
+ def load_single_document(file_path: str) -> Document:
26
+ # Loads a single document from a file path
27
+ file_extension = os.path.splitext(file_path)[1]
28
+ loader_class = DOCUMENT_MAP.get(file_extension)
29
+ if loader_class:
30
+ loader = loader_class(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
+ raise ValueError("Document type is undefined")
33
+ return loader.load()[0]
34
+
35
+
36
+ def load_document_batch(filepaths):
37
+ logging.info("Loading document batch")
38
+ # create a thread pool
39
+ with ThreadPoolExecutor(len(filepaths)) as exe:
40
+ # load files
41
+ futures = [exe.submit(load_single_document, name) for name in filepaths]
42
+ # collect data
43
+ data_list = [future.result() for future in futures]
44
+ # return data and file paths
45
+ return (data_list, filepaths)
46
+
47
+
48
+ def load_documents(source_dir: str) -> list[Document]:
49
+ # Loads all documents from the source documents directory, including nested folders
50
+ paths = []
51
+ for root, _, files in os.walk(source_dir):
52
+ for file_name in files:
53
+ file_extension = os.path.splitext(file_name)[1]
54
+ source_file_path = os.path.join(root, file_name)
55
+ if file_extension in DOCUMENT_MAP.keys():
56
+ paths.append(source_file_path)
57
+
58
+ # Have at least one worker and at most INGEST_THREADS workers
59
+ n_workers = min(INGEST_THREADS, max(len(paths), 1))
60
+ chunksize = round(len(paths) / n_workers)
61
+ docs = []
62
+ with ProcessPoolExecutor(n_workers) as executor:
63
+ futures = []
64
+ # split the load operations into chunks
65
+ for i in range(0, len(paths), chunksize):
66
+ # select a chunk of filenames
67
+ filepaths = paths[i : (i + chunksize)]
68
+ # submit the task
69
+ future = executor.submit(load_document_batch, filepaths)
70
+ futures.append(future)
71
+ # process all results
72
+ for future in as_completed(futures):
73
+ # open the file and load the data
74
+ contents, _ = future.result()
75
+ docs.extend(contents)
76
+
77
+ return docs
78
+
79
+
80
+ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Document]]:
81
+ # Splits documents for correct Text Splitter
82
+ text_docs, python_docs = [], []
83
+ for doc in documents:
84
+ file_extension = os.path.splitext(doc.metadata["source"])[1]
85
+ if file_extension == ".py":
86
+ python_docs.append(doc)
87
+ else:
88
+ text_docs.append(doc)
89
 
90
+ return text_docs, python_docs
91
 
92
 
 
93
  @click.command()
94
  @click.option(
95
  "--device_type",
 
119
  ),
120
  help="Device to run on. (Default is cuda)",
121
  )
122
+ def main(device_type):
123
+ # Load documents and split in chunks
124
+ logging.info(f"Loading documents from {SOURCE_DIRECTORY}")
125
+ documents = load_documents(SOURCE_DIRECTORY)
126
+ text_documents, python_documents = split_documents(documents)
127
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
128
+ python_splitter = RecursiveCharacterTextSplitter.from_language(
129
+ language=Language.PYTHON, chunk_size=880, chunk_overlap=200
130
+ )
131
+ texts = text_splitter.split_documents(text_documents)
132
+ texts.extend(python_splitter.split_documents(python_documents))
133
+ logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
134
+ logging.info(f"Split into {len(texts)} chunks of text")
135
+
136
+ # Create embeddings
137
+ embeddings = HuggingFaceInstructEmbeddings(
138
+ model_name=EMBEDDING_MODEL_NAME,
139
+ model_kwargs={"device": device_type},
140
+ )
141
+ # change the embedding type here if you are running into issues.
142
+ # These are much smaller embeddings and will work for most appications
143
+ # If you use HuggingFaceEmbeddings, make sure to also use the same in the
144
+ # run_localGPT.py file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
 
 
 
 
 
 
 
 
147
 
148
+ db = Chroma.from_documents(
149
+ texts,
150
+ embeddings,
151
+ persist_directory=PERSIST_DIRECTORY,
152
+ client_settings=CHROMA_SETTINGS,
153
 
154
+ )
155
+
 
 
 
 
 
156
 
157
 
158
  if __name__ == "__main__":