from flask import Flask, request, jsonify, send_from_directory # from flask_session import Session from flask_cors import CORS # <-- New import here from flask_cors import cross_origin import openai import os from pytube import YouTube import re from langchain_openai.chat_models import ChatOpenAI from langchain.chains import ConversationalRetrievalChain from langchain_openai import OpenAIEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import Chroma from youtube_transcript_api import YouTubeTranscriptApi from dotenv import load_dotenv load_dotenv() app = Flask(__name__, static_folder="./dist") # requests in the dist folder are being sent to http://localhost:5000/ CORS(app, resources={r"/*": {"origins": "*"}}) openai.api_key = os.environ["OPENAI_API_KEY"] llm_name = "gpt-3.5-turbo" qna_chain = None @app.route('/', defaults={'path': ''}) @app.route('/') def serve(path): if path != "" and os.path.exists(app.static_folder + '/' + path): return send_from_directory(app.static_folder, path) else: return send_from_directory(app.static_folder, 'index.html') def load_db(file, chain_type, k): """ Central Function that: - Loads the database - Creates the retriever - Creates the chatbot chain - Returns the chatbot chain - A Dictionary containing -- question -- llm answer -- chat history -- source_documents -- generated_question s Usage: question_answer_chain = load_db(file, chain_type, k) response = question_answer_chain({"question": query, "chat_history": chat_history}}) """ transcript = TextLoader(file).load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=70) docs = text_splitter.split_documents(transcript) embeddings = OpenAIEmbeddings() db = Chroma.from_documents(docs, embeddings) retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k}) # create a chatbot chain. Memory is managed externally. qa = ConversationalRetrievalChain.from_llm( llm = ChatOpenAI(temperature=0), #### Prompt Template is yet to be created chain_type=chain_type, retriever=retriever, return_source_documents=True, return_generated_question=True, # memory=memory ) return qa def buffer(history, buff): """ Buffer the history. Keeps only buff recent chats in the history Usage: history = buffer(history, buff) """ if len(history) > buff : print(len(history)>buff) return history[-buff:] return history def is_valid_yt(link): """ Check if a link is a valid YouTube link. Usage: boolean, video_id = is_valid_yt(youtube_string) """ pattern = r'^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)([\w\-_]{11})(?:\S+)?$' match = re.match(pattern, link) if match: return True, match.group(1) else: return False, None def get_metadata(video_id) -> dict: """Get important video information. Components are: - title - description - thumbnail url, - publish_date - channel_author - and more. Usage: get_metadata(id)->dict """ try: from pytube import YouTube except ImportError: raise ImportError( "Could not import pytube python package. " "Please install it with `pip install pytube`." ) yt = YouTube(f"https://www.youtube.com/watch?v={video_id}") video_info = { "title": yt.title or "Unknown", "description": yt.description or "Unknown", "view_count": yt.views or 0, "thumbnail_url": yt.thumbnail_url or "Unknown", "publish_date": yt.publish_date.strftime("%Y-%m-%d %H:%M:%S") if yt.publish_date else "Unknown", "length": yt.length or 0, "author": yt.author or "Unknown", } return video_info def save_transcript(video_id): """ Saves the transcript of a valid yt video to a text file. """ try: transcript = YouTubeTranscriptApi.get_transcript(video_id) except Exception as e: print(f"Error fetching transcript for video {video_id}: {e}") return None if transcript: with open('transcript.txt', 'w') as file: for entry in transcript: file.write(f"~{int(entry['start'])}~{entry['text']} ") print(f"Transcript saved to: transcript.txt") @app.route('/init', methods=['POST']) @cross_origin() def initialize(): """ Initialize the qna_chain for a user. """ global qna_chain qna_chain = 0 # NEED to authenticate the user here yt_link = request.json.get('yt_link', '') valid, id = is_valid_yt(yt_link) if valid: metadata = get_metadata(id) try: os.remove('./transcript.txt') except: print("No transcript file to remove.") save_transcript(id) # Initialize qna_chain for the user qna_chain = load_db("./transcript.txt", 'stuff', 5) # os.remove('./transcript.txt') return jsonify({"status": "success", "message": "qna_chain initialized.", "metadata": metadata, }) else: return jsonify({"status": "error", "message": "Invalid YouTube link."}) @app.route('/response', methods=['POST']) def response(): """ - Expects youtube Video Link and chat-history in payload - Returns response on the query. """ global qna_chain req = request.get_json() raw = req.get('chat_history', []) # raw is a list of list containing two strings convert that into a list of tuples if len(raw) > 0: chat_history = [tuple(x) for x in raw] else: chat_history = [] # print(f"Chat History: {chat_history}") memory = chat_history query = req.get('query', '') # print(f"Query: {query}") if memory is None: memory = [] if qna_chain is None: return jsonify({"status": "error", "message": "qna_chain not initialized."}), 400 response = qna_chain({'question': query, 'chat_history': buffer(memory,7)}) if response['source_documents']: pattern = r'~(\d+)~' backlinked_docs = [response['source_documents'][i].page_content for i in range(len(response['source_documents']))] timestamps = list(map(lambda s: int(re.search(pattern, s).group(1)) if re.search(pattern, s) else None, backlinked_docs)) return jsonify(dict(timestamps=timestamps, answer=response['answer'])) return jsonify(response['answer']) @app.route('/transcript', methods=['POST']) @cross_origin() def send_transcript(): """ Send the transcript of the video. """ try: with open('transcript.txt', 'r') as file: transcript = file.read() return jsonify({"status": "success", "transcript": transcript}) except: return jsonify({"status": "error", "message": "Transcript not found."}) if __name__ == '__main__': app.run(debug=False)