sreevidya16
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,249 +1,249 @@
|
|
1 |
-
|
2 |
-
from flask import Flask, request, jsonify, send_from_directory
|
3 |
-
# from flask_session import Session
|
4 |
-
from flask_cors import CORS # <-- New import here
|
5 |
-
from flask_cors import cross_origin
|
6 |
-
import openai
|
7 |
-
import os
|
8 |
-
from pytube import YouTube
|
9 |
-
import re
|
10 |
-
from langchain_openai.chat_models import ChatOpenAI
|
11 |
-
from langchain.chains import ConversationalRetrievalChain
|
12 |
-
from langchain_openai import OpenAIEmbeddings
|
13 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
-
from langchain_community.document_loaders import TextLoader
|
15 |
-
from langchain_community.vectorstores import Chroma
|
16 |
-
from youtube_transcript_api import YouTubeTranscriptApi
|
17 |
-
from dotenv import load_dotenv
|
18 |
-
|
19 |
-
load_dotenv()
|
20 |
-
|
21 |
-
app = Flask(__name__, static_folder="./dist") # requests in the dist folder are being sent to http://localhost:5000/<endpoint>
|
22 |
-
CORS(app, resources={r"/*": {"origins": "*"}})
|
23 |
-
openai.api_key = os.environ["OPENAI_API_KEY"]
|
24 |
-
llm_name = "gpt-3.5-turbo"
|
25 |
-
qna_chain = None
|
26 |
-
|
27 |
-
|
28 |
-
@app.route('/', defaults={'path': ''})
|
29 |
-
@app.route('/<path:path>')
|
30 |
-
def serve(path):
|
31 |
-
if path != "" and os.path.exists(app.static_folder + '/' + path):
|
32 |
-
return send_from_directory(app.static_folder, path)
|
33 |
-
else:
|
34 |
-
return send_from_directory(app.static_folder, 'index.html')
|
35 |
-
|
36 |
-
def load_db(file, chain_type, k):
|
37 |
-
"""
|
38 |
-
Central Function that:
|
39 |
-
- Loads the database
|
40 |
-
- Creates the retriever
|
41 |
-
- Creates the chatbot chain
|
42 |
-
- Returns the chatbot chain
|
43 |
-
- A Dictionary containing
|
44 |
-
-- question
|
45 |
-
-- llm answer
|
46 |
-
-- chat history
|
47 |
-
-- source_documents
|
48 |
-
-- generated_question
|
49 |
-
s
|
50 |
-
Usage: question_answer_chain = load_db(file, chain_type, k)
|
51 |
-
response = question_answer_chain({"question": query, "chat_history": chat_history}})
|
52 |
-
"""
|
53 |
-
|
54 |
-
transcript = TextLoader(file).load()
|
55 |
-
|
56 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=70)
|
57 |
-
docs = text_splitter.split_documents(transcript)
|
58 |
-
|
59 |
-
embeddings = OpenAIEmbeddings()
|
60 |
-
|
61 |
-
db = Chroma.from_documents(docs, embeddings)
|
62 |
-
|
63 |
-
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
64 |
-
|
65 |
-
# create a chatbot chain. Memory is managed externally.
|
66 |
-
qa = ConversationalRetrievalChain.from_llm(
|
67 |
-
llm = ChatOpenAI(temperature=0), #### Prompt Template is yet to be created
|
68 |
-
chain_type=chain_type,
|
69 |
-
retriever=retriever,
|
70 |
-
return_source_documents=True,
|
71 |
-
return_generated_question=True,
|
72 |
-
# memory=memory
|
73 |
-
)
|
74 |
-
|
75 |
-
return qa
|
76 |
-
|
77 |
-
|
78 |
-
def buffer(history, buff):
|
79 |
-
"""
|
80 |
-
Buffer the history.
|
81 |
-
Keeps only buff recent chats in the history
|
82 |
-
|
83 |
-
Usage: history = buffer(history, buff)
|
84 |
-
"""
|
85 |
-
|
86 |
-
if len(history) > buff :
|
87 |
-
print(len(history)>buff)
|
88 |
-
return history[-buff:]
|
89 |
-
return history
|
90 |
-
|
91 |
-
|
92 |
-
def is_valid_yt(link):
|
93 |
-
"""
|
94 |
-
Check if a link is a valid YouTube link.
|
95 |
-
|
96 |
-
Usage: boolean, video_id = is_valid_yt(youtube_string)
|
97 |
-
"""
|
98 |
-
|
99 |
-
pattern = r'^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)([\w\-_]{11})(?:\S+)?$'
|
100 |
-
match = re.match(pattern, link)
|
101 |
-
if match:
|
102 |
-
return True, match.group(1)
|
103 |
-
else:
|
104 |
-
return False, None
|
105 |
-
|
106 |
-
|
107 |
-
def get_metadata(video_id) -> dict:
|
108 |
-
"""Get important video information.
|
109 |
-
|
110 |
-
Components are:
|
111 |
-
- title
|
112 |
-
- description
|
113 |
-
- thumbnail url,
|
114 |
-
- publish_date
|
115 |
-
- channel_author
|
116 |
-
- and more.
|
117 |
-
|
118 |
-
Usage: get_metadata(id)->dict
|
119 |
-
"""
|
120 |
-
|
121 |
-
try:
|
122 |
-
from pytube import YouTube
|
123 |
-
|
124 |
-
except ImportError:
|
125 |
-
raise ImportError(
|
126 |
-
"Could not import pytube python package. "
|
127 |
-
"Please install it with `pip install pytube`."
|
128 |
-
)
|
129 |
-
yt = YouTube(f"https://www.youtube.com/watch?v={video_id}")
|
130 |
-
video_info = {
|
131 |
-
"title": yt.title or "Unknown",
|
132 |
-
"description": yt.description or "Unknown",
|
133 |
-
"view_count": yt.views or 0,
|
134 |
-
"thumbnail_url": yt.thumbnail_url or "Unknown",
|
135 |
-
"publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S")
|
136 |
-
if yt.publish_date
|
137 |
-
else "Unknown",
|
138 |
-
"length": yt.length or 0,
|
139 |
-
"author": yt.author or "Unknown",
|
140 |
-
}
|
141 |
-
return video_info
|
142 |
-
|
143 |
-
|
144 |
-
def save_transcript(video_id):
|
145 |
-
"""
|
146 |
-
Saves the transcript of a valid yt video to a text file.
|
147 |
-
"""
|
148 |
-
|
149 |
-
try:
|
150 |
-
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
151 |
-
except Exception as e:
|
152 |
-
print(f"Error fetching transcript for video {video_id}: {e}")
|
153 |
-
return None
|
154 |
-
if transcript:
|
155 |
-
with open('transcript.txt', 'w') as file:
|
156 |
-
for entry in transcript:
|
157 |
-
file.write(f"~{int(entry['start'])}~{entry['text']} ")
|
158 |
-
print(f"Transcript saved to: transcript.txt")
|
159 |
-
|
160 |
-
@app.route('/init', methods=['POST'])
|
161 |
-
@cross_origin()
|
162 |
-
def initialize():
|
163 |
-
"""
|
164 |
-
Initialize the qna_chain for a user.
|
165 |
-
"""
|
166 |
-
global qna_chain
|
167 |
-
|
168 |
-
qna_chain = 0
|
169 |
-
|
170 |
-
# NEED to authenticate the user here
|
171 |
-
yt_link = request.json.get('yt_link', '')
|
172 |
-
valid, id = is_valid_yt(yt_link)
|
173 |
-
if valid:
|
174 |
-
metadata = get_metadata(id)
|
175 |
-
try:
|
176 |
-
os.remove('./transcript.txt')
|
177 |
-
except:
|
178 |
-
print("No transcript file to remove.")
|
179 |
-
|
180 |
-
save_transcript(id)
|
181 |
-
|
182 |
-
# Initialize qna_chain for the user
|
183 |
-
qna_chain = load_db("./transcript.txt", 'stuff', 5)
|
184 |
-
|
185 |
-
# os.remove('./transcript.txt')
|
186 |
-
|
187 |
-
return jsonify({"status": "success",
|
188 |
-
"message": "qna_chain initialized.",
|
189 |
-
"metadata": metadata,
|
190 |
-
})
|
191 |
-
else:
|
192 |
-
return jsonify({"status": "error", "message": "Invalid YouTube link."})
|
193 |
-
|
194 |
-
|
195 |
-
@app.route('/response', methods=['POST'])
|
196 |
-
def response():
|
197 |
-
"""
|
198 |
-
- Expects youtube Video Link and chat-history in payload
|
199 |
-
- Returns response on the query.
|
200 |
-
"""
|
201 |
-
global qna_chain
|
202 |
-
|
203 |
-
req = request.get_json()
|
204 |
-
raw = req.get('chat_history', [])
|
205 |
-
|
206 |
-
# raw is a list of list containing two strings convert that into a list of tuples
|
207 |
-
if len(raw) > 0:
|
208 |
-
chat_history = [tuple(x) for x in raw]
|
209 |
-
else:
|
210 |
-
chat_history = []
|
211 |
-
# print(f"Chat History: {chat_history}")
|
212 |
-
|
213 |
-
memory = chat_history
|
214 |
-
query = req.get('query', '')
|
215 |
-
# print(f"Query: {query}")
|
216 |
-
|
217 |
-
if memory is None:
|
218 |
-
memory = []
|
219 |
-
|
220 |
-
if qna_chain is None:
|
221 |
-
return jsonify({"status": "error", "message": "qna_chain not initialized."}), 400
|
222 |
-
|
223 |
-
response = qna_chain({'question': query, 'chat_history': buffer(memory,7)})
|
224 |
-
|
225 |
-
if response['source_documents']:
|
226 |
-
pattern = r'~(\d+)~'
|
227 |
-
backlinked_docs = [response['source_documents'][i].page_content for i in range(len(response['source_documents']))]
|
228 |
-
timestamps = list(map(lambda s: int(re.search(pattern, s).group(1)) if re.search(pattern, s) else None, backlinked_docs))
|
229 |
-
|
230 |
-
return jsonify(dict(timestamps=timestamps, answer=response['answer']))
|
231 |
-
|
232 |
-
return jsonify(response['answer'])
|
233 |
-
|
234 |
-
@app.route('/transcript', methods=['POST'])
|
235 |
-
@cross_origin()
|
236 |
-
def send_transcript():
|
237 |
-
"""
|
238 |
-
Send the transcript of the video.
|
239 |
-
"""
|
240 |
-
try:
|
241 |
-
with open('transcript.txt', 'r') as file:
|
242 |
-
transcript = file.read()
|
243 |
-
return jsonify({"status": "success", "transcript": transcript})
|
244 |
-
except:
|
245 |
-
return jsonify({"status": "error", "message": "Transcript not found."})
|
246 |
-
|
247 |
-
|
248 |
-
if __name__ == '__main__':
|
249 |
-
app.run(debug=
|
|
|
1 |
+
|
2 |
+
from flask import Flask, request, jsonify, send_from_directory
|
3 |
+
# from flask_session import Session
|
4 |
+
from flask_cors import CORS # <-- New import here
|
5 |
+
from flask_cors import cross_origin
|
6 |
+
import openai
|
7 |
+
import os
|
8 |
+
from pytube import YouTube
|
9 |
+
import re
|
10 |
+
from langchain_openai.chat_models import ChatOpenAI
|
11 |
+
from langchain.chains import ConversationalRetrievalChain
|
12 |
+
from langchain_openai import OpenAIEmbeddings
|
13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from langchain_community.document_loaders import TextLoader
|
15 |
+
from langchain_community.vectorstores import Chroma
|
16 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
17 |
+
from dotenv import load_dotenv
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
app = Flask(__name__, static_folder="./dist") # requests in the dist folder are being sent to http://localhost:5000/<endpoint>
|
22 |
+
CORS(app, resources={r"/*": {"origins": "*"}})
|
23 |
+
openai.api_key = os.environ["OPENAI_API_KEY"]
|
24 |
+
llm_name = "gpt-3.5-turbo"
|
25 |
+
qna_chain = None
|
26 |
+
|
27 |
+
|
28 |
+
@app.route('/', defaults={'path': ''})
|
29 |
+
@app.route('/<path:path>')
|
30 |
+
def serve(path):
|
31 |
+
if path != "" and os.path.exists(app.static_folder + '/' + path):
|
32 |
+
return send_from_directory(app.static_folder, path)
|
33 |
+
else:
|
34 |
+
return send_from_directory(app.static_folder, 'index.html')
|
35 |
+
|
36 |
+
def load_db(file, chain_type, k):
|
37 |
+
"""
|
38 |
+
Central Function that:
|
39 |
+
- Loads the database
|
40 |
+
- Creates the retriever
|
41 |
+
- Creates the chatbot chain
|
42 |
+
- Returns the chatbot chain
|
43 |
+
- A Dictionary containing
|
44 |
+
-- question
|
45 |
+
-- llm answer
|
46 |
+
-- chat history
|
47 |
+
-- source_documents
|
48 |
+
-- generated_question
|
49 |
+
s
|
50 |
+
Usage: question_answer_chain = load_db(file, chain_type, k)
|
51 |
+
response = question_answer_chain({"question": query, "chat_history": chat_history}})
|
52 |
+
"""
|
53 |
+
|
54 |
+
transcript = TextLoader(file).load()
|
55 |
+
|
56 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=70)
|
57 |
+
docs = text_splitter.split_documents(transcript)
|
58 |
+
|
59 |
+
embeddings = OpenAIEmbeddings()
|
60 |
+
|
61 |
+
db = Chroma.from_documents(docs, embeddings)
|
62 |
+
|
63 |
+
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
64 |
+
|
65 |
+
# create a chatbot chain. Memory is managed externally.
|
66 |
+
qa = ConversationalRetrievalChain.from_llm(
|
67 |
+
llm = ChatOpenAI(temperature=0), #### Prompt Template is yet to be created
|
68 |
+
chain_type=chain_type,
|
69 |
+
retriever=retriever,
|
70 |
+
return_source_documents=True,
|
71 |
+
return_generated_question=True,
|
72 |
+
# memory=memory
|
73 |
+
)
|
74 |
+
|
75 |
+
return qa
|
76 |
+
|
77 |
+
|
78 |
+
def buffer(history, buff):
|
79 |
+
"""
|
80 |
+
Buffer the history.
|
81 |
+
Keeps only buff recent chats in the history
|
82 |
+
|
83 |
+
Usage: history = buffer(history, buff)
|
84 |
+
"""
|
85 |
+
|
86 |
+
if len(history) > buff :
|
87 |
+
print(len(history)>buff)
|
88 |
+
return history[-buff:]
|
89 |
+
return history
|
90 |
+
|
91 |
+
|
92 |
+
def is_valid_yt(link):
|
93 |
+
"""
|
94 |
+
Check if a link is a valid YouTube link.
|
95 |
+
|
96 |
+
Usage: boolean, video_id = is_valid_yt(youtube_string)
|
97 |
+
"""
|
98 |
+
|
99 |
+
pattern = r'^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)([\w\-_]{11})(?:\S+)?$'
|
100 |
+
match = re.match(pattern, link)
|
101 |
+
if match:
|
102 |
+
return True, match.group(1)
|
103 |
+
else:
|
104 |
+
return False, None
|
105 |
+
|
106 |
+
|
107 |
+
def get_metadata(video_id) -> dict:
|
108 |
+
"""Get important video information.
|
109 |
+
|
110 |
+
Components are:
|
111 |
+
- title
|
112 |
+
- description
|
113 |
+
- thumbnail url,
|
114 |
+
- publish_date
|
115 |
+
- channel_author
|
116 |
+
- and more.
|
117 |
+
|
118 |
+
Usage: get_metadata(id)->dict
|
119 |
+
"""
|
120 |
+
|
121 |
+
try:
|
122 |
+
from pytube import YouTube
|
123 |
+
|
124 |
+
except ImportError:
|
125 |
+
raise ImportError(
|
126 |
+
"Could not import pytube python package. "
|
127 |
+
"Please install it with `pip install pytube`."
|
128 |
+
)
|
129 |
+
yt = YouTube(f"https://www.youtube.com/watch?v={video_id}")
|
130 |
+
video_info = {
|
131 |
+
"title": yt.title or "Unknown",
|
132 |
+
"description": yt.description or "Unknown",
|
133 |
+
"view_count": yt.views or 0,
|
134 |
+
"thumbnail_url": yt.thumbnail_url or "Unknown",
|
135 |
+
"publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S")
|
136 |
+
if yt.publish_date
|
137 |
+
else "Unknown",
|
138 |
+
"length": yt.length or 0,
|
139 |
+
"author": yt.author or "Unknown",
|
140 |
+
}
|
141 |
+
return video_info
|
142 |
+
|
143 |
+
|
144 |
+
def save_transcript(video_id):
|
145 |
+
"""
|
146 |
+
Saves the transcript of a valid yt video to a text file.
|
147 |
+
"""
|
148 |
+
|
149 |
+
try:
|
150 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error fetching transcript for video {video_id}: {e}")
|
153 |
+
return None
|
154 |
+
if transcript:
|
155 |
+
with open('transcript.txt', 'w') as file:
|
156 |
+
for entry in transcript:
|
157 |
+
file.write(f"~{int(entry['start'])}~{entry['text']} ")
|
158 |
+
print(f"Transcript saved to: transcript.txt")
|
159 |
+
|
160 |
+
@app.route('/init', methods=['POST'])
|
161 |
+
@cross_origin()
|
162 |
+
def initialize():
|
163 |
+
"""
|
164 |
+
Initialize the qna_chain for a user.
|
165 |
+
"""
|
166 |
+
global qna_chain
|
167 |
+
|
168 |
+
qna_chain = 0
|
169 |
+
|
170 |
+
# NEED to authenticate the user here
|
171 |
+
yt_link = request.json.get('yt_link', '')
|
172 |
+
valid, id = is_valid_yt(yt_link)
|
173 |
+
if valid:
|
174 |
+
metadata = get_metadata(id)
|
175 |
+
try:
|
176 |
+
os.remove('./transcript.txt')
|
177 |
+
except:
|
178 |
+
print("No transcript file to remove.")
|
179 |
+
|
180 |
+
save_transcript(id)
|
181 |
+
|
182 |
+
# Initialize qna_chain for the user
|
183 |
+
qna_chain = load_db("./transcript.txt", 'stuff', 5)
|
184 |
+
|
185 |
+
# os.remove('./transcript.txt')
|
186 |
+
|
187 |
+
return jsonify({"status": "success",
|
188 |
+
"message": "qna_chain initialized.",
|
189 |
+
"metadata": metadata,
|
190 |
+
})
|
191 |
+
else:
|
192 |
+
return jsonify({"status": "error", "message": "Invalid YouTube link."})
|
193 |
+
|
194 |
+
|
195 |
+
@app.route('/response', methods=['POST'])
|
196 |
+
def response():
|
197 |
+
"""
|
198 |
+
- Expects youtube Video Link and chat-history in payload
|
199 |
+
- Returns response on the query.
|
200 |
+
"""
|
201 |
+
global qna_chain
|
202 |
+
|
203 |
+
req = request.get_json()
|
204 |
+
raw = req.get('chat_history', [])
|
205 |
+
|
206 |
+
# raw is a list of list containing two strings convert that into a list of tuples
|
207 |
+
if len(raw) > 0:
|
208 |
+
chat_history = [tuple(x) for x in raw]
|
209 |
+
else:
|
210 |
+
chat_history = []
|
211 |
+
# print(f"Chat History: {chat_history}")
|
212 |
+
|
213 |
+
memory = chat_history
|
214 |
+
query = req.get('query', '')
|
215 |
+
# print(f"Query: {query}")
|
216 |
+
|
217 |
+
if memory is None:
|
218 |
+
memory = []
|
219 |
+
|
220 |
+
if qna_chain is None:
|
221 |
+
return jsonify({"status": "error", "message": "qna_chain not initialized."}), 400
|
222 |
+
|
223 |
+
response = qna_chain({'question': query, 'chat_history': buffer(memory,7)})
|
224 |
+
|
225 |
+
if response['source_documents']:
|
226 |
+
pattern = r'~(\d+)~'
|
227 |
+
backlinked_docs = [response['source_documents'][i].page_content for i in range(len(response['source_documents']))]
|
228 |
+
timestamps = list(map(lambda s: int(re.search(pattern, s).group(1)) if re.search(pattern, s) else None, backlinked_docs))
|
229 |
+
|
230 |
+
return jsonify(dict(timestamps=timestamps, answer=response['answer']))
|
231 |
+
|
232 |
+
return jsonify(response['answer'])
|
233 |
+
|
234 |
+
@app.route('/transcript', methods=['POST'])
|
235 |
+
@cross_origin()
|
236 |
+
def send_transcript():
|
237 |
+
"""
|
238 |
+
Send the transcript of the video.
|
239 |
+
"""
|
240 |
+
try:
|
241 |
+
with open('transcript.txt', 'r') as file:
|
242 |
+
transcript = file.read()
|
243 |
+
return jsonify({"status": "success", "transcript": transcript})
|
244 |
+
except:
|
245 |
+
return jsonify({"status": "error", "message": "Transcript not found."})
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == '__main__':
|
249 |
+
app.run(debug=False)
|