|
|
|
from flask import Flask, request, jsonify, send_from_directory
|
|
|
|
from flask_cors import CORS
|
|
from flask_cors import cross_origin
|
|
import openai
|
|
import os
|
|
from pytube import YouTube
|
|
import re
|
|
from langchain_openai.chat_models import ChatOpenAI
|
|
from langchain.chains import ConversationalRetrievalChain
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.document_loaders import TextLoader
|
|
from langchain_community.vectorstores import Chroma
|
|
from youtube_transcript_api import YouTubeTranscriptApi
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
app = Flask(__name__, static_folder="./dist")
|
|
CORS(app, resources={r"/*": {"origins": "*"}})
|
|
openai.api_key = os.environ["OPENAI_API_KEY"]
|
|
llm_name = "gpt-3.5-turbo"
|
|
qna_chain = None
|
|
|
|
|
|
@app.route('/', defaults={'path': ''})
|
|
@app.route('/<path:path>')
|
|
def serve(path):
|
|
if path != "" and os.path.exists(app.static_folder + '/' + path):
|
|
return send_from_directory(app.static_folder, path)
|
|
else:
|
|
return send_from_directory(app.static_folder, 'index.html')
|
|
|
|
def load_db(file, chain_type, k):
|
|
"""
|
|
Central Function that:
|
|
- Loads the database
|
|
- Creates the retriever
|
|
- Creates the chatbot chain
|
|
- Returns the chatbot chain
|
|
- A Dictionary containing
|
|
-- question
|
|
-- llm answer
|
|
-- chat history
|
|
-- source_documents
|
|
-- generated_question
|
|
s
|
|
Usage: question_answer_chain = load_db(file, chain_type, k)
|
|
response = question_answer_chain({"question": query, "chat_history": chat_history}})
|
|
"""
|
|
|
|
transcript = TextLoader(file).load()
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=70)
|
|
docs = text_splitter.split_documents(transcript)
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
db = Chroma.from_documents(docs, embeddings)
|
|
|
|
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
|
|
|
|
|
qa = ConversationalRetrievalChain.from_llm(
|
|
llm = ChatOpenAI(temperature=0),
|
|
chain_type=chain_type,
|
|
retriever=retriever,
|
|
return_source_documents=True,
|
|
return_generated_question=True,
|
|
|
|
)
|
|
|
|
return qa
|
|
|
|
|
|
def buffer(history, buff):
|
|
"""
|
|
Buffer the history.
|
|
Keeps only buff recent chats in the history
|
|
|
|
Usage: history = buffer(history, buff)
|
|
"""
|
|
|
|
if len(history) > buff :
|
|
print(len(history)>buff)
|
|
return history[-buff:]
|
|
return history
|
|
|
|
|
|
def is_valid_yt(link):
|
|
"""
|
|
Check if a link is a valid YouTube link.
|
|
|
|
Usage: boolean, video_id = is_valid_yt(youtube_string)
|
|
"""
|
|
|
|
pattern = r'^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)([\w\-_]{11})(?:\S+)?$'
|
|
match = re.match(pattern, link)
|
|
if match:
|
|
return True, match.group(1)
|
|
else:
|
|
return False, None
|
|
|
|
|
|
def get_metadata(video_id) -> dict:
|
|
"""Get important video information.
|
|
|
|
Components are:
|
|
- title
|
|
- description
|
|
- thumbnail url,
|
|
- publish_date
|
|
- channel_author
|
|
- and more.
|
|
|
|
Usage: get_metadata(id)->dict
|
|
"""
|
|
|
|
try:
|
|
from pytube import YouTube
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import pytube python package. "
|
|
"Please install it with `pip install pytube`."
|
|
)
|
|
yt = YouTube(f"https://www.youtube.com/watch?v={video_id}")
|
|
video_info = {
|
|
"title": yt.title or "Unknown",
|
|
"description": yt.description or "Unknown",
|
|
"view_count": yt.views or 0,
|
|
"thumbnail_url": yt.thumbnail_url or "Unknown",
|
|
"publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S")
|
|
if yt.publish_date
|
|
else "Unknown",
|
|
"length": yt.length or 0,
|
|
"author": yt.author or "Unknown",
|
|
}
|
|
return video_info
|
|
|
|
|
|
def save_transcript(video_id):
|
|
"""
|
|
Saves the transcript of a valid yt video to a text file.
|
|
"""
|
|
|
|
try:
|
|
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
|
except Exception as e:
|
|
print(f"Error fetching transcript for video {video_id}: {e}")
|
|
return None
|
|
if transcript:
|
|
with open('transcript.txt', 'w') as file:
|
|
for entry in transcript:
|
|
file.write(f"~{int(entry['start'])}~{entry['text']} ")
|
|
print(f"Transcript saved to: transcript.txt")
|
|
|
|
@app.route('/init', methods=['POST'])
|
|
@cross_origin()
|
|
def initialize():
|
|
"""
|
|
Initialize the qna_chain for a user.
|
|
"""
|
|
global qna_chain
|
|
|
|
qna_chain = 0
|
|
|
|
|
|
yt_link = request.json.get('yt_link', '')
|
|
valid, id = is_valid_yt(yt_link)
|
|
if valid:
|
|
metadata = get_metadata(id)
|
|
try:
|
|
os.remove('./transcript.txt')
|
|
except:
|
|
print("No transcript file to remove.")
|
|
|
|
save_transcript(id)
|
|
|
|
|
|
qna_chain = load_db("./transcript.txt", 'stuff', 5)
|
|
|
|
|
|
|
|
return jsonify({"status": "success",
|
|
"message": "qna_chain initialized.",
|
|
"metadata": metadata,
|
|
})
|
|
else:
|
|
return jsonify({"status": "error", "message": "Invalid YouTube link."})
|
|
|
|
|
|
@app.route('/response', methods=['POST'])
|
|
def response():
|
|
"""
|
|
- Expects youtube Video Link and chat-history in payload
|
|
- Returns response on the query.
|
|
"""
|
|
global qna_chain
|
|
|
|
req = request.get_json()
|
|
raw = req.get('chat_history', [])
|
|
|
|
|
|
if len(raw) > 0:
|
|
chat_history = [tuple(x) for x in raw]
|
|
else:
|
|
chat_history = []
|
|
|
|
|
|
memory = chat_history
|
|
query = req.get('query', '')
|
|
|
|
|
|
if memory is None:
|
|
memory = []
|
|
|
|
if qna_chain is None:
|
|
return jsonify({"status": "error", "message": "qna_chain not initialized."}), 400
|
|
|
|
response = qna_chain({'question': query, 'chat_history': buffer(memory,7)})
|
|
|
|
if response['source_documents']:
|
|
pattern = r'~(\d+)~'
|
|
backlinked_docs = [response['source_documents'][i].page_content for i in range(len(response['source_documents']))]
|
|
timestamps = list(map(lambda s: int(re.search(pattern, s).group(1)) if re.search(pattern, s) else None, backlinked_docs))
|
|
|
|
return jsonify(dict(timestamps=timestamps, answer=response['answer']))
|
|
|
|
return jsonify(response['answer'])
|
|
|
|
@app.route('/transcript', methods=['POST'])
|
|
@cross_origin()
|
|
def send_transcript():
|
|
"""
|
|
Send the transcript of the video.
|
|
"""
|
|
try:
|
|
with open('transcript.txt', 'r') as file:
|
|
transcript = file.read()
|
|
return jsonify({"status": "success", "transcript": transcript})
|
|
except:
|
|
return jsonify({"status": "error", "message": "Transcript not found."})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True)
|
|
|