Rahatara commited on
Commit
2a772ca
1 Parent(s): 061f74e

Delete pdfchatbot.py

Browse files
Files changed (1) hide show
  1. pdfchatbot.py +0 -199
pdfchatbot.py DELETED
@@ -1,199 +0,0 @@
1
- import yaml
2
- import fitz
3
- import torch
4
- import gradio as gr
5
- from PIL import Image
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.vectorstores import Chroma
8
- from langchain.llms import HuggingFacePipeline
9
- from langchain.chains import ConversationalRetrievalChain
10
- from langchain.document_loaders import PyPDFLoader
11
- from langchain.prompts import PromptTemplate
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
-
14
- # Store the Hugging Face token in a variable
15
- HUGGINGFACE_TOKEN = gr.Textbox()
16
-
17
- class PDFChatBot:
18
- def __init__(self, config_path="config.yaml"):
19
- """
20
- Initialize the PDFChatBot instance.
21
-
22
- Parameters:
23
- config_path (str): Path to the configuration file (default is "config.yaml").
24
- """
25
- self.processed = False
26
- self.page = 0
27
- self.chat_history = []
28
- self.config = self.load_config(config_path)
29
- # Initialize other attributes to None
30
- self.prompt = None
31
- self.documents = None
32
- self.embeddings = None
33
- self.vectordb = None
34
- self.tokenizer = None
35
- self.model = None
36
- self.pipeline = None
37
- self.chain = None
38
-
39
- def load_config(self, file_path):
40
- """
41
- Load configuration from a YAML file.
42
-
43
- Parameters:
44
- file_path (str): Path to the YAML configuration file.
45
-
46
- Returns:
47
- dict: Configuration as a dictionary.
48
- """
49
- with open(file_path, 'r') as stream:
50
- try:
51
- config = yaml.safe_load(stream)
52
- return config
53
- except yaml.YAMLError as exc:
54
- print(f"Error loading configuration: {exc}")
55
- return None
56
-
57
- def add_text(self, history, text):
58
- """
59
- Add user-entered text to the chat history.
60
-
61
- Parameters:
62
- history (list): List of chat history tuples.
63
- text (str): User-entered text.
64
-
65
- Returns:
66
- list: Updated chat history.
67
- """
68
- if not text:
69
- raise gr.Error('Enter text')
70
- history.append((text, ''))
71
- return history
72
-
73
- def create_prompt_template(self):
74
- """
75
- Create a prompt template for the chatbot.
76
- """
77
- template = (
78
- f"The assistant should provide detailed explanations."
79
- "Combine the chat history and follow up question into "
80
- "Follow up question: What is this"
81
- )
82
- self.prompt = PromptTemplate.from_template(template)
83
-
84
- def load_embeddings(self):
85
- """
86
- Load embeddings from Hugging Face and set in the config file.
87
- """
88
- self.embeddings = HuggingFaceEmbeddings(model_name=self.config.get("modelEmbeddings"))
89
-
90
- def load_vectordb(self):
91
- """
92
- Load the vector database from the documents and embeddings.
93
- """
94
- self.vectordb = Chroma.from_documents(self.documents, self.embeddings)
95
-
96
- def load_tokenizer(self):
97
- """
98
- Load the tokenizer from Hugging Face and set in the config file.
99
- """
100
- self.tokenizer = AutoTokenizer.from_pretrained(
101
- self.config.get("autoTokenizer"),
102
- use_auth_token=HUGGINGFACE_TOKEN
103
- )
104
-
105
- def load_model(self):
106
- """
107
- Load the causal language model from Hugging Face and set in the config file.
108
- """
109
- self.model = AutoModelForCausalLM.from_pretrained(
110
- self.config.get("autoModelForCausalLM"),
111
- device_map='auto',
112
- torch_dtype=torch.float32,
113
- use_auth_token=HUGGINGFACE_TOKEN,
114
- load_in_8bit=False
115
- )
116
-
117
- def create_pipeline(self):
118
- """
119
- Create a pipeline for text generation using the loaded model and tokenizer.
120
- """
121
- pipe = pipeline(
122
- model=self.model,
123
- task='text-generation',
124
- tokenizer=self.tokenizer,
125
- max_new_tokens=200
126
- )
127
- self.pipeline = HuggingFacePipeline(pipeline=pipe)
128
-
129
- def create_chain(self):
130
- """
131
- Create a Conversational Retrieval Chain
132
- """
133
- self.chain = ConversationalRetrievalChain.from_llm(
134
- self.pipeline,
135
- chain_type="stuff",
136
- retriever=self.vectordb.as_retriever(search_kwargs={"k": 1}),
137
- condense_question_prompt=self.prompt,
138
- return_source_documents=True
139
- )
140
-
141
- def process_file(self, file):
142
- """
143
- Process the uploaded PDF file and initialize necessary components: Tokenizer, VectorDB and LLM.
144
-
145
- Parameters:
146
- file (FileStorage): The uploaded PDF file.
147
- """
148
- self.create_prompt_template()
149
- self.documents = PyPDFLoader(file.name).load()
150
- self.load_embeddings()
151
- self.load_vectordb()
152
- self.load_tokenizer()
153
- self.load_model()
154
- self.create_pipeline()
155
- self.create_chain()
156
-
157
- def generate_response(self, history, query, file):
158
- """
159
- Generate a response based on user query and chat history.
160
-
161
- Parameters:
162
- history (list): List of chat history tuples.
163
- query (str): User's query.
164
- file (FileStorage): The uploaded PDF file.
165
-
166
- Returns:
167
- tuple: Updated chat history and a space.
168
- """
169
- if not query:
170
- raise gr.Error(message='Submit a question')
171
- if not file:
172
- raise gr.Error(message='Upload a PDF')
173
- if not self.processed:
174
- self.process_file(file)
175
- self.processed = True
176
-
177
- result = self.chain({"question": query, 'chat_history': self.chat_history}, return_only_outputs=True)
178
- self.chat_history.append((query, result["answer"]))
179
- self.page = list(result['source_documents'][0])[1][1]['page']
180
-
181
- for char in result['answer']:
182
- history[-1][-1] += char
183
- return history, " "
184
-
185
- def render_file(self, file):
186
- """
187
- Renders a specific page of a PDF file as an image.
188
-
189
- Parameters:
190
- file (FileStorage): The PDF file.
191
-
192
- Returns:
193
- PIL.Image.Image: The rendered page as an image.
194
- """
195
- doc = fitz.open(file.name)
196
- page = doc[self.page]
197
- pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
198
- image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
199
- return image