mayankchugh-learning commited on
Commit
51dab23
1 Parent(s): e48c626

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import the necessary Libraries
3
+ import os
4
+ import uuid
5
+ import json
6
+
7
+ import gradio as gr
8
+
9
+ from openai import OpenAI
10
+
11
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
12
+ from langchain_community.vectorstores import Chroma
13
+
14
+ from huggingface_hub import CommitScheduler
15
+ from pathlib import Path
16
+ from dotenv import load_dotenv
17
+
18
+
19
+ # Create Client
20
+ load_dotenv()
21
+
22
+ os.environ["ANYSCALE_API_KEY"]=os.getenv("ANYSCALE_API_KEY")
23
+
24
+ client = OpenAI(
25
+ base_url="https://api.endpoints.anyscale.com/v1",
26
+ api_key=os.environ['ANYSCALE_API_KEY']
27
+ )
28
+
29
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
30
+ # Define the embedding model and the vectorstore
31
+
32
+ collection_name = 'report-10k-2024'
33
+
34
+ vectorstore_persisted = Chroma(
35
+ collection_name=collection_name,
36
+ persist_directory='./report_10kdb',
37
+ embedding_function=embedding_model
38
+ )
39
+
40
+ # Load the persisted vectorDB
41
+
42
+ retriever = vectorstore_persisted.as_retriever(
43
+ search_type='similarity',
44
+ search_kwargs={'k': 5}
45
+ )
46
+
47
+
48
+ # Prepare the logging functionality
49
+
50
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
51
+ log_folder = log_file.parent
52
+
53
+ scheduler = CommitScheduler(
54
+ repo_id="RAG-investment-recommendation-log",
55
+ repo_type="dataset",
56
+ folder_path=log_folder,
57
+ path_in_repo="data",
58
+ every=2
59
+ )
60
+
61
+ # Define the Q&A system message
62
+
63
+ qna_system_message = """
64
+ You are an AI assistant to a data scientist working on a Retrieval-Augmented Generation (RAG) system for Finsights Grey Inc.
65
+
66
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
67
+ The context contains references to specific portions of the documentation relevant to the user's query, along with source links.
68
+ The source for a context will begin with the token ###Source
69
+
70
+ When crafting your response:
71
+
72
+ Select only context relevant to answer the question.
73
+ Include the source links in your response.
74
+ User questions will begin with the token: ###Question.
75
+ If the question is not related to the provided context, respond with "I don't have enough information to answer that question."
76
+ Please adhere to the following guidelines:
77
+
78
+ Your response should only be about the question asked and the context provided.
79
+ Answer only using the context provided.
80
+ Do not mention anything about the context in your final answer.
81
+ If the answer is not found in the context, it is very important for you to respond with "I don't know."
82
+ Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
83
+ Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
84
+ Here is an example of how to structure your response:
85
+
86
+ Answer:
87
+ [Answer]
88
+
89
+ Source:
90
+ [Source]
91
+ """
92
+
93
+ # Define the user message template
94
+ qna_user_message_template = """
95
+ ###Context
96
+ Here are some documents that are relevant to the question.
97
+ {context}
98
+ ```
99
+ {question}
100
+ ```
101
+ """
102
+
103
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
104
+ def predict(user_input,company):
105
+
106
+ filter = "dataset/"+company+"-10-k-2023.pdf"
107
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
108
+
109
+ # Create context_for_query
110
+ context_list = [d.page_content for d in relevant_document_chunks]
111
+ context_for_query = ".".join(context_list)
112
+
113
+ # Create messages
114
+ prompt = [
115
+ {'role':'system', 'content': qna_system_message},
116
+ {'role': 'user', 'content': qna_user_message_template.format(
117
+ context=context_for_query,
118
+ question=user_input
119
+ )
120
+ }
121
+ ]
122
+
123
+ # Get response from the LLM
124
+ try:
125
+ response = client.chat.completions.create(
126
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
127
+ messages=prompt,
128
+ temperature=0
129
+ )
130
+
131
+ prediction = response.choices[0].message.content
132
+
133
+ except Exception as e:
134
+ prediction = e
135
+
136
+ # While the prediction is made, log both the inputs and outputs to a local log file
137
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
138
+ # access
139
+
140
+ with scheduler.lock:
141
+ with log_file.open("a") as f:
142
+ f.write(json.dumps(
143
+ {
144
+ 'user_input': user_input,
145
+ 'retrieved_context': context_for_query,
146
+ 'model_response': prediction
147
+ }
148
+ ))
149
+ f.write("\n")
150
+
151
+ return prediction
152
+
153
+
154
+ def get_predict(question, company):
155
+ # Implement your prediction logic here
156
+ if company == "AWS":
157
+ # Perform prediction for AWS
158
+ selectedCompany = "aws"
159
+ elif company == "IBM":
160
+ # Perform prediction for IBM
161
+ selectedCompany = "IBM"
162
+ elif company == "Google":
163
+ # Perform prediction for Google
164
+ selectedCompany = "Google"
165
+ elif company == "Meta":
166
+ # Perform prediction for Meta
167
+ selectedCompany = "meta"
168
+ elif company == "Microsoft":
169
+ # Perform prediction for Microsoft
170
+ selectedCompany = "msft"
171
+ else:
172
+ return "Invalid company selected"
173
+
174
+ output = predict(question, selectedCompany)
175
+ return output
176
+
177
+ # Set-up the Gradio UI
178
+ # Add text box and radio button to the interface
179
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
180
+
181
+ # Create the interface
182
+ # For the inputs parameter of Interface provide [textbox,company]
183
+
184
+ with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
185
+ with gr.Row():
186
+ company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
187
+
188
+ with gr.Row():
189
+ question = gr.Textbox(label="Enter your question")
190
+
191
+ submit = gr.Button("Submit")
192
+ output = gr.Textbox(label="Output")
193
+
194
+ submit.click(
195
+ fn=get_predict,
196
+ inputs=[question, company],
197
+ outputs=output
198
+ )
199
+
200
+ demo.queue()
201
+ demo.launch()