RAHMAN00700 commited on
Commit
519d5f1
1 Parent(s): 70f8d5c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import tempfile
4
+ import pandas as pd
5
+ import json
6
+ import xml.etree.ElementTree as ET
7
+ import yaml
8
+ from bs4 import BeautifulSoup
9
+ from pptx import Presentation
10
+ from docx import Document
11
+
12
+ from langchain.document_loaders import PyPDFLoader, TextLoader
13
+ from langchain.indexes import VectorstoreIndexCreator
14
+ from langchain.chains import RetrievalQA
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.embeddings import HuggingFaceEmbeddings
17
+ from langchain.chains import LLMChain
18
+ from langchain.prompts import PromptTemplate
19
+
20
+ from ibm_watson_machine_learning.foundation_models import Model
21
+ from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
22
+ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
23
+ from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
24
+
25
+ # Initialize index to None
26
+ index = None
27
+ rag_chain = None # Initialize rag_chain as None by default
28
+
29
+ # Custom loader for DOCX files
30
+ class DocxLoader:
31
+ def __init__(self, file_path):
32
+ self.file_path = file_path
33
+
34
+ def load(self):
35
+ document = Document(self.file_path)
36
+ text_content = [para.text for para in document.paragraphs]
37
+ return " ".join(text_content)
38
+
39
+ # Custom loader for PPTX files
40
+ class PptxLoader:
41
+ def __init__(self, file_path):
42
+ self.file_path = file_path
43
+
44
+ def load(self):
45
+ presentation = Presentation(self.file_path)
46
+ text_content = [shape.text for slide in presentation.slides for shape in slide.shapes if hasattr(shape, "text")]
47
+ return " ".join(text_content)
48
+
49
+ # Custom loader for additional file types
50
+ def load_csv(file_path):
51
+ df = pd.read_csv(file_path)
52
+ # Adding pagination for large CSV data
53
+ st.write("Large dataset detected, displaying data in pages.")
54
+ page_size = 100 # Define the number of rows per page
55
+ page_number = st.number_input("Page number", min_value=1, max_value=(len(df) // page_size) + 1, step=1, value=1)
56
+
57
+ start_index = (page_number - 1) * page_size
58
+ end_index = start_index + page_size
59
+ paginated_data = df.iloc[start_index:end_index]
60
+
61
+ st.dataframe(paginated_data) # Display paginated data
62
+ return df.to_string(index=False)
63
+
64
+ def load_json(file_path):
65
+ with open(file_path, 'r') as file:
66
+ data = json.load(file)
67
+ return json.dumps(data, indent=2)
68
+
69
+ def load_xml(file_path):
70
+ tree = ET.parse(file_path)
71
+ root = tree.getroot()
72
+ return ET.tostring(root, encoding="unicode")
73
+
74
+ def load_yaml(file_path):
75
+ with open(file_path, 'r') as file:
76
+ data = yaml.safe_load(file)
77
+ return yaml.dump(data)
78
+
79
+ def load_html(file_path):
80
+ with open(file_path, 'r', encoding='utf-8') as file:
81
+ soup = BeautifulSoup(file, 'html.parser')
82
+ return soup.get_text()
83
+
84
+ # Caching function to load various file types
85
+ @st.cache_resource
86
+ def load_file(file_name, file_type):
87
+ loaders = []
88
+
89
+ if file_type == "pdf":
90
+ loaders = [PyPDFLoader(file_name)]
91
+ elif file_type == "docx":
92
+ loader = DocxLoader(file_name)
93
+ text = loader.load()
94
+ elif file_type == "pptx":
95
+ loader = PptxLoader(file_name)
96
+ text = loader.load()
97
+ elif file_type == "txt":
98
+ loaders = [TextLoader(file_name)]
99
+ elif file_type == "csv":
100
+ text = load_csv(file_name)
101
+ elif file_type == "json":
102
+ text = load_json(file_name)
103
+ elif file_type == "xml":
104
+ text = load_xml(file_name)
105
+ elif file_type == "yaml":
106
+ text = load_yaml(file_name)
107
+ elif file_type == "html":
108
+ text = load_html(file_name)
109
+ else:
110
+ st.error("Unsupported file type.")
111
+ return None
112
+
113
+ # Use TextLoader for intermediate text files from custom loaders
114
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
115
+ temp_file.write(text.encode("utf-8"))
116
+ temp_file_path = temp_file.name
117
+ loaders = [TextLoader(temp_file_path)]
118
+
119
+ index = VectorstoreIndexCreator(
120
+ embedding=HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2"),
121
+ text_splitter=RecursiveCharacterTextSplitter(chunk_size=450, chunk_overlap=50)
122
+ ).from_loaders(loaders)
123
+ return index
124
+
125
+ # Watsonx API setup
126
+ watsonx_api_key = os.getenv("WATSONX_API_KEY")
127
+ watsonx_project_id = os.getenv("WATSONX_PROJECT_ID")
128
+
129
+ if not watsonx_api_key or not watsonx_project_id:
130
+ st.error("API Key or Project ID is not set. Please set them as environment variables.")
131
+
132
+ prompt_template_br = PromptTemplate(
133
+ input_variables=["context", "question"],
134
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
135
+ I am a helpful assistant.
136
+
137
+ <|eot_id|>
138
+ {context}
139
+ <|start_header_id|>user<|end_header_id|>
140
+ {question}<|eot_id|>
141
+ """
142
+ )
143
+
144
+ with st.sidebar:
145
+ st.title("Watsonx RAG with Multiple docs")
146
+ watsonx_model = st.selectbox("Model", ["meta-llama/llama-3-405b-instruct", "codellama/codellama-34b-instruct-hf", "ibm/granite-20b-multilingual"])
147
+ max_new_tokens = st.slider("Max output tokens", min_value=100, max_value=4000, value=600, step=100)
148
+ decoding_method = st.radio("Decoding", (DecodingMethods.GREEDY.value, DecodingMethods.SAMPLE.value))
149
+ parameters = {
150
+ GenParams.DECODING_METHOD: decoding_method,
151
+ GenParams.MAX_NEW_TOKENS: max_new_tokens,
152
+ GenParams.MIN_NEW_TOKENS: 1,
153
+ GenParams.TEMPERATURE: 0,
154
+ GenParams.TOP_K: 50,
155
+ GenParams.TOP_P: 1,
156
+ GenParams.STOP_SEQUENCES: [],
157
+ GenParams.REPETITION_PENALTY: 1
158
+ }
159
+ st.info("Upload a file to use RAG")
160
+ uploaded_file = st.file_uploader("Upload file", accept_multiple_files=False, type=["pdf", "docx", "txt", "pptx", "csv", "json", "xml", "yaml", "html"])
161
+
162
+ if uploaded_file is not None:
163
+ bytes_data = uploaded_file.read()
164
+ st.write("Filename:", uploaded_file.name)
165
+
166
+ with open(uploaded_file.name, 'wb') as f:
167
+ f.write(bytes_data)
168
+
169
+ file_type = uploaded_file.name.split('.')[-1].lower()
170
+ index = load_file(uploaded_file.name, file_type)
171
+
172
+ model_name = watsonx_model
173
+
174
+ st.info("Setting up Watsonx...")
175
+ my_credentials = {
176
+ "url": "https://us-south.ml.cloud.ibm.com",
177
+ "apikey": watsonx_api_key
178
+ }
179
+ params = parameters
180
+ project_id = watsonx_project_id
181
+ space_id = None
182
+ verify = False
183
+ model = WatsonxLLM(model=Model(model_name, my_credentials, params, project_id, space_id, verify))
184
+
185
+ if model:
186
+ st.info(f"Model {model_name} ready.")
187
+ chain = LLMChain(llm=model, prompt=prompt_template_br, verbose=True)
188
+
189
+ if chain and index is not None:
190
+ rag_chain = RetrievalQA.from_chain_type(
191
+ llm=model,
192
+ chain_type="stuff",
193
+ retriever=index.vectorstore.as_retriever(),
194
+ chain_type_kwargs={"prompt": prompt_template_br},
195
+ return_source_documents=False,
196
+ verbose=True
197
+ )
198
+ st.info("Document-based retrieval is ready.")
199
+ else:
200
+ st.warning("No document uploaded or chain setup issue.")
201
+
202
+ # Chat loop
203
+ if "messages" not in st.session_state:
204
+ st.session_state.messages = []
205
+
206
+ for message in st.session_state.messages:
207
+ st.chat_message(message["role"]).markdown(message["content"])
208
+
209
+ prompt = st.chat_input("Ask your question here", disabled=False if chain else True)
210
+
211
+ if prompt:
212
+ st.chat_message("user").markdown(prompt)
213
+ if rag_chain:
214
+ response_text = rag_chain.run(prompt).strip()
215
+ else:
216
+ response_text = chain.run(question=prompt, context="").strip()
217
+
218
+ st.session_state.messages.append({'role': 'User', 'content': prompt})
219
+ st.chat_message("assistant").markdown(response_text)
220
+ st.session_state.messages.append({'role': 'Assistant', 'content': response_text})