from flask import Flask, request, jsonify, send_from_directory |
from flask_cors import CORS |
from flask_cors import cross_origin |
import openai |
import os |
from pytube import YouTube |
import re |
from langchain_openai.chat_models import ChatOpenAI |
from langchain.chains import ConversationalRetrievalChain |
from langchain_openai import OpenAIEmbeddings |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain_community.document_loaders import TextLoader |
from langchain_community.vectorstores import Chroma |
from youtube_transcript_api import YouTubeTranscriptApi |
from dotenv import load_dotenv |
load_dotenv() |
app = Flask(__name__, static_folder="./dist") |
CORS(app, resources={r"/*": {"origins": "*"}}) |
openai.api_key = os.environ["OPENAI_API_KEY"] |
llm_name = "gpt-3.5-turbo" |
qna_chain = None |
@app.route('/', defaults={'path': ''}) |
@app.route('/<path:path>') |
def serve(path): |
if path != "" and os.path.exists(app.static_folder + '/' + path): |
return send_from_directory(app.static_folder, path) |
else: |
return send_from_directory(app.static_folder, 'index.html') |
def load_db(file, chain_type, k): |
""" |
Central Function that: |
- Loads the database |
- Creates the retriever |
- Creates the chatbot chain |
- Returns the chatbot chain |
- A Dictionary containing |
-- question |
-- llm answer |
-- chat history |
-- source_documents |
-- generated_question |
s |
Usage: question_answer_chain = load_db(file, chain_type, k) |
response = question_answer_chain({"question": query, "chat_history": chat_history}}) |
""" |
transcript = TextLoader(file).load() |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=70) |
docs = text_splitter.split_documents(transcript) |
embeddings = OpenAIEmbeddings() |
db = Chroma.from_documents(docs, embeddings) |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k}) |
qa = ConversationalRetrievalChain.from_llm( |
llm = ChatOpenAI(temperature=0), |
chain_type=chain_type, |
retriever=retriever, |
return_source_documents=True, |
return_generated_question=True, |
) |
return qa |
def buffer(history, buff): |
""" |
Buffer the history. |
Keeps only buff recent chats in the history |
Usage: history = buffer(history, buff) |
""" |
if len(history) > buff : |
print(len(history)>buff) |
return history[-buff:] |
return history |
def is_valid_yt(link): |
""" |
Check if a link is a valid YouTube link. |
Usage: boolean, video_id = is_valid_yt(youtube_string) |
""" |
pattern = r'^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)([\w\-_]{11})(?:\S+)?$' |
match = re.match(pattern, link) |
if match: |
return True, match.group(1) |
else: |
return False, None |
def get_metadata(video_id) -> dict: |
"""Get important video information. |
Components are: |
- title |
- description |
- thumbnail url, |
- publish_date |
- channel_author |
- and more. |
Usage: get_metadata(id)->dict |
""" |
try: |
from pytube import YouTube |
except ImportError: |
raise ImportError( |
"Could not import pytube python package. " |
"Please install it with `pip install pytube`." |
) |
yt = YouTube(f"https://www.youtube.com/watch?v={video_id}") |
video_info = { |
"title": yt.title or "Unknown", |
"description": yt.description or "Unknown", |
"view_count": yt.views or 0, |
"thumbnail_url": yt.thumbnail_url or "Unknown", |
"publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S") |
if yt.publish_date |
else "Unknown", |
"length": yt.length or 0, |
"author": yt.author or "Unknown", |
} |
return video_info |
def save_transcript(video_id): |
""" |
Saves the transcript of a valid yt video to a text file. |
""" |
try: |
transcript = YouTubeTranscriptApi.get_transcript(video_id) |
except Exception as e: |
print(f"Error fetching transcript for video {video_id}: {e}") |
return None |
if transcript: |
with open('transcript.txt', 'w') as file: |
for entry in transcript: |
file.write(f"~{int(entry['start'])}~{entry['text']} ") |
print(f"Transcript saved to: transcript.txt") |
@app.route('/init', methods=['POST']) |
@cross_origin() |
def initialize(): |
""" |
Initialize the qna_chain for a user. |
""" |
global qna_chain |
qna_chain = 0 |
yt_link = request.json.get('yt_link', '') |
valid, id = is_valid_yt(yt_link) |
if valid: |
metadata = get_metadata(id) |
try: |
os.remove('./transcript.txt') |
except: |
print("No transcript file to remove.") |
save_transcript(id) |
qna_chain = load_db("./transcript.txt", 'stuff', 5) |
return jsonify({"status": "success", |
"message": "qna_chain initialized.", |
"metadata": metadata, |
}) |
else: |
return jsonify({"status": "error", "message": "Invalid YouTube link."}) |
@app.route('/response', methods=['POST']) |
def response(): |
""" |
- Expects youtube Video Link and chat-history in payload |
- Returns response on the query. |
""" |
global qna_chain |
req = request.get_json() |
raw = req.get('chat_history', []) |
if len(raw) > 0: |
chat_history = [tuple(x) for x in raw] |
else: |
chat_history = [] |
memory = chat_history |
query = req.get('query', '') |
if memory is None: |
memory = [] |
if qna_chain is None: |
return jsonify({"status": "error", "message": "qna_chain not initialized."}), 400 |
response = qna_chain({'question': query, 'chat_history': buffer(memory,7)}) |
if response['source_documents']: |
pattern = r'~(\d+)~' |
backlinked_docs = [response['source_documents'][i].page_content for i in range(len(response['source_documents']))] |
timestamps = list(map(lambda s: int(re.search(pattern, s).group(1)) if re.search(pattern, s) else None, backlinked_docs)) |
return jsonify(dict(timestamps=timestamps, answer=response['answer'])) |
return jsonify(response['answer']) |
@app.route('/transcript', methods=['POST']) |
@cross_origin() |
def send_transcript(): |
""" |
Send the transcript of the video. |
""" |
try: |
with open('transcript.txt', 'r') as file: |
transcript = file.read() |
return jsonify({"status": "success", "transcript": transcript}) |
except: |
return jsonify({"status": "error", "message": "Transcript not found."}) |
if __name__ == '__main__': |
app.run(debug=False) |