# RAG_Library_2.py
# Description: This script contains the main RAG pipeline function and related functions for the RAG pipeline.
#
# Import necessary modules and functions
import configparser
import logging
import os
import time
from typing import Dict, Any, List, Optional

from App_Function_Libraries.DB.Character_Chat_DB import get_character_chats, perform_full_text_search_chat, \
    fetch_keywords_for_chats, search_character_chat, search_character_cards, fetch_character_ids_by_keywords
from App_Function_Libraries.DB.RAG_QA_Chat_DB import search_rag_chat, search_rag_notes
#
# Local Imports
from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
from App_Function_Libraries.RAG.RAG_Persona_Chat import perform_vector_search_chat
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_custom_openai
from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article
from App_Function_Libraries.DB.DB_Manager import fetch_keywords_for_media, search_media_db, get_notes_by_keywords, \
    search_conversations_by_keywords
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
#
# 3rd-Party Imports
import openai
from flashrank import Ranker, RerankRequest
#
########################################################################################################################
#
# Functions:

# Initialize OpenAI client (adjust this based on your API key management)
openai.api_key = "your-openai-api-key"

# Get the directory of the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the config file
config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
# Read the config file
config = configparser.ConfigParser()
# Read the configuration file
config.read('config.txt')


search_functions = {
    "Media DB": search_media_db,
    "RAG Chat": search_rag_chat,
    "RAG Notes": search_rag_notes,
    "Character Chat": search_character_chat,
    "Character Cards": search_character_cards
}

# RAG pipeline function for web scraping
# def rag_web_scraping_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]:
#     try:
#         # Extract content
#         try:
#             article_data = scrape_article(url)
#             content = article_data['content']
#             title = article_data['title']
#         except Exception as e:
#             logging.error(f"Error scraping article: {str(e)}")
#             return {"error": "Failed to scrape article", "details": str(e)}
#
#         # Store the article in the database and get the media_id
#         try:
#             media_id = add_media_to_database(url, title, 'article', content)
#         except Exception as e:
#             logging.error(f"Error adding article to database: {str(e)}")
#             return {"error": "Failed to store article in database", "details": str(e)}
#
#         # Process and store content
#         collection_name = f"article_{media_id}"
#         try:
#             # Assuming you have a database object available, let's call it 'db'
#             db = get_database_connection()
#
#             process_and_store_content(
#                 database=db,
#                 content=content,
#                 collection_name=collection_name,
#                 media_id=media_id,
#                 file_name=title,
#                 create_embeddings=True,
#                 create_contextualized=True,
#                 api_name=api_choice
#             )
#         except Exception as e:
#             logging.error(f"Error processing and storing content: {str(e)}")
#             return {"error": "Failed to process and store content", "details": str(e)}
#
#         # Perform searches
#         try:
#             vector_results = vector_search(collection_name, query, k=5)
#             fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
#         except Exception as e:
#             logging.error(f"Error performing searches: {str(e)}")
#             return {"error": "Failed to perform searches", "details": str(e)}
#
#         # Combine results with error handling for missing 'content' key
#         all_results = []
#         for result in vector_results + fts_results:
#             if isinstance(result, dict) and 'content' in result:
#                 all_results.append(result['content'])
#             else:
#                 logging.warning(f"Unexpected result format: {result}")
#                 all_results.append(str(result))
#
#         context = "\n".join(all_results)
#
#         # Generate answer using the selected API
#         try:
#             answer = generate_answer(api_choice, context, query)
#         except Exception as e:
#             logging.error(f"Error generating answer: {str(e)}")
#             return {"error": "Failed to generate answer", "details": str(e)}
#
#         return {
#             "answer": answer,
#             "context": context
#         }
#
#     except Exception as e:
#         logging.error(f"Unexpected error in rag_pipeline: {str(e)}")
#         return {"error": "An unexpected error occurred", "details": str(e)}


# RAG Search with keyword filtering
# FIXME - Update each called function to support modifiable top-k results
def enhanced_rag_pipeline(
    query: str,
    api_choice: str,
    keywords: Optional[str] = None,
    fts_top_k: int = 10,
    apply_re_ranking: bool = True,
    database_types: List[str] = ["Media DB"]
) -> Dict[str, Any]:
    """
    Perform full text search across specified database type.

    Args:
        query: Search query string
        api_choice: API to use for generating the response
        keywords: Optional list of media IDs to filter results
        fts_top_k: Maximum number of results to return
        apply_re_ranking: Whether to apply re-ranking to results
        database_types: Type of database to search

    Returns:
        Dictionary containing search results with content
    """
    log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
    start_time = time.time()

    try:
        # Load embedding provider from config, or fallback to 'openai'
        embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
        logging.debug(f"Using embedding provider: {embedding_provider}")

        # Initialize relevant IDs dictionary
        relevant_ids: Dict[str, Optional[List[str]]] = {}

        # Process keywords if provided
        if keywords:
            keyword_list = [k.strip().lower() for k in keywords.split(',')]
            logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")

            try:
                for db_type in database_types:
                    if db_type == "Media DB":
                        media_ids = fetch_relevant_media_ids(keyword_list)
                        relevant_ids[db_type] = [str(id_) for id_ in media_ids]
                    elif db_type == "RAG Chat":
                        conversations, _, _ = search_conversations_by_keywords(keywords=keyword_list)
                        relevant_ids[db_type] = [str(conv['conversation_id']) for conv in conversations]
                    elif db_type == "RAG Notes":
                        notes, _, _ = get_notes_by_keywords(keyword_list)
                        relevant_ids[db_type] = [str(note_id) for note_id, _, _, _ in notes]
                    elif db_type == "Character Chat":
                        relevant_ids[db_type] = [str(id_) for id_ in fetch_keywords_for_chats(keyword_list)]
                    elif db_type == "Character Cards":
                        relevant_ids[db_type] = [str(id_) for id_ in fetch_character_ids_by_keywords(keyword_list)]
                    else:
                        logging.error(f"Unsupported database type: {db_type}")

                    logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
            except Exception as e:
                logging.error(f"Error fetching relevant IDs: {str(e)}")
                relevant_ids = {db_type: None for db_type in database_types}
        else:
            relevant_ids = {db_type: None for db_type in database_types}

        # Perform vector search
        vector_results = []
        for db_type in database_types:
            try:
                db_relevant_ids = relevant_ids.get(db_type)
                results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
                vector_results.extend(results)
                logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
            except Exception as e:
                logging.error(f"Error performing vector search on {db_type}: {str(e)}")

        # Perform vector search
        # FIXME
        #vector_results = perform_vector_search(query, relevant_media_ids)
        #ogging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")

        # Perform full-text search
        #v1
        #fts_results = perform_full_text_search(query, database_type, relevant_media_ids, fts_top_k)

        # v2
        # Perform full-text search across specified databases
        fts_results = []
        for db_type in database_types:
            try:
                db_relevant_ids = relevant_ids.get(db_type)
                db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
                fts_results.extend(db_results)
                logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
            except Exception as e:
                logging.error(f"Error performing full-text search on {db_type}: {str(e)}")

        #logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
        logging.debug(
            "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
                [str(item) for item in fts_results]) + "\n"
        )

        # Combine results
        all_results = vector_results + fts_results

        # FIXME - specify model + add param to modify at call time
        # You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
        # Apply re-ranking if enabled and results exist
        if apply_re_ranking and all_results:
            logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")

            if all_results:
                ranker = Ranker()

                # Prepare passages for re-ranking
                passages = [{"id": i, "text": result['content']} for i, result in enumerate(all_results)]
                rerank_request = RerankRequest(query=query, passages=passages)

                # Rerank the results
                reranked_results = ranker.rerank(rerank_request)

                # Sort results based on the re-ranking score
                reranked_results = sorted(reranked_results, key=lambda x: x['score'], reverse=True)

                # Log reranked results
                logging.debug(f"\n\nenhanced_rag_pipeline - Reranked results: {reranked_results}")

                # Update all_results based on reranking
                all_results = [all_results[result['id']] for result in reranked_results]

        # Extract content from results (top fts_top_k by default)
        context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
        #logging.debug(f"Context length: {len(context)}")
        logging.debug(f"Context: {context[:200]}")

        # Generate answer using the selected API
        answer = generate_answer(api_choice, context, query)

        if not all_results:
            logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
            return {
                "answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
                "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
            }

        # Log metrics
        pipeline_duration = time.time() - start_time
        log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
        log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})

        return {
            "answer": answer,
            "context": context
        }

    except Exception as e:
        log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
        logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
        logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
        return {
            "answer": "An error occurred while processing your request.",
            "context": ""
        }



# Need to write a test for this function FIXME
def generate_answer(api_choice: str, context: str, query: str) -> str:
    # Metrics
    log_counter("generate_answer_attempt", labels={"api_choice": api_choice})
    start_time = time.time()
    logging.debug("Entering generate_answer function")
    config = load_comprehensive_config()
    logging.debug(f"Config sections: {config.sections()}")
    prompt = f"Context: {context}\n\nQuestion: {query}"
    try:
        if api_choice == "OpenAI":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openai
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_openai(config['API']['openai_api_key'], prompt, "")

        elif api_choice == "Anthropic":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_anthropic
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")

        elif api_choice == "Cohere":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_cohere
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")

        elif api_choice == "Groq":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_groq
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_groq(config['API']['groq_api_key'], prompt, "")

        elif api_choice == "OpenRouter":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openrouter
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")

        elif api_choice == "HuggingFace":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_huggingface
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")

        elif api_choice == "DeepSeek":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_deepseek
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")

        elif api_choice == "Mistral":
            from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_mistral
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")

        # Local LLM APIs
        elif api_choice == "Local-LLM":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_local_llm
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            # FIXME
            return summarize_with_local_llm(config['Local-API']['local_llm_path'], prompt, "")

        elif api_choice == "Llama.cpp":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_llama
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_llama(prompt, "", config['Local-API']['llama_api_key'], None, None)
        elif api_choice == "Kobold":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_kobold
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_kobold(prompt, config['Local-API']['kobold_api_key'], "", system_message=None, temp=None)

        elif api_choice == "Ooba":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_oobabooga
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_oobabooga(prompt, config['Local-API']['ooba_api_key'], custom_prompt="", system_message=None, temp=None)

        elif api_choice == "TabbyAPI":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_tabbyapi
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_tabbyapi(prompt, None, None, None, None, )

        elif api_choice == "vLLM":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_vllm
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_vllm(prompt, "", config['Local-API']['vllm_api_key'], None, None)

        elif api_choice.lower() == "ollama":
            from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_ollama
            answer_generation_duration = time.time() - start_time
            log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
            log_counter("generate_answer_success", labels={"api_choice": api_choice})
            return summarize_with_ollama(prompt, "", config['Local-API']['ollama_api_IP'], config['Local-API']['ollama_api_key'], None, None, None)

        elif api_choice.lower() == "custom_openai_api":
            logging.debug(f"RAG Answer Gen: Trying with Custom_OpenAI API")
            summary = summarize_with_custom_openai(prompt, "", config['API']['custom_openai_api_key'], None,
                                                   None)
        else:
            log_counter("generate_answer_error", labels={"api_choice": api_choice, "error": str()})
            raise ValueError(f"Unsupported API choice: {api_choice}")
    except Exception as e:
        log_counter("generate_answer_error", labels={"api_choice": api_choice, "error": str(e)})
        logging.error(f"Error in generate_answer: {str(e)}")
        return "An error occurred while generating the answer."


def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]:
    log_counter("perform_vector_search_attempt")
    start_time = time.time()
    all_collections = chroma_client.list_collections()
    vector_results = []
    try:
        for collection in all_collections:
            collection_results = vector_search(collection.name, query, k=top_k)
            if not collection_results:
                continue  # Skip empty results
            filtered_results = [
                result for result in collection_results
                if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
            ]
            vector_results.extend(filtered_results)
        search_duration = time.time() - start_time
        log_histogram("perform_vector_search_duration", search_duration)
        log_counter("perform_vector_search_success", labels={"result_count": len(vector_results)})
        return vector_results
    except Exception as e:
        log_counter("perform_vector_search_error", labels={"error": str(e)})
        logging.error(f"Error in perform_vector_search: {str(e)}")
        raise


# V2
def perform_full_text_search(query: str, database_type: str, relevant_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]:
    """
    Perform full-text search on a specified database type.

    Args:
        query: Search query string
        database_type: Type of database to search ("Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards")
        relevant_ids: Optional list of media IDs to filter results
        fts_top_k: Maximum number of results to return

    Returns:
        List of search results with content and metadata
    """
    log_counter("perform_full_text_search_attempt", labels={"database_type": database_type})
    start_time = time.time()

    try:
        # Set default for fts_top_k
        if fts_top_k is None:
            fts_top_k = 10

        # Call appropriate search function based on database type
        if database_type not in search_functions:
            raise ValueError(f"Unsupported database type: {database_type}")

        # Call the appropriate search function
        results = search_functions[database_type](query, fts_top_k, relevant_ids)

        search_duration = time.time() - start_time
        log_histogram("perform_full_text_search_duration", search_duration,
                      labels={"database_type": database_type})
        log_counter("perform_full_text_search_success",
                    labels={"database_type": database_type, "result_count": len(results)})

        return results

    except Exception as e:
        log_counter("perform_full_text_search_error",
                    labels={"database_type": database_type, "error": str(e)})
        logging.error(f"Error in perform_full_text_search ({database_type}): {str(e)}")
        raise


# v1
# def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]:
#     log_counter("perform_full_text_search_attempt")
#     start_time = time.time()
#     try:
#         fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10)
#         filtered_fts_results = [
#             {
#                 "content": result['content'],
#                 "metadata": {"media_id": result['id']}
#             }
#             for result in fts_results
#             if relevant_media_ids is None or result['id'] in relevant_media_ids
#         ]
#         search_duration = time.time() - start_time
#         log_histogram("perform_full_text_search_duration", search_duration)
#         log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)})
#         return filtered_fts_results
#     except Exception as e:
#         log_counter("perform_full_text_search_error", labels={"error": str(e)})
#         logging.error(f"Error in perform_full_text_search: {str(e)}")
#         raise


def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]:
    log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)})
    start_time = time.time()
    relevant_ids = set()
    for keyword in keywords:
        try:
            media_ids = fetch_keywords_for_media(keyword)
            relevant_ids.update(media_ids)
        except Exception as e:
            log_counter("fetch_relevant_media_ids_error", labels={"error": str(e)})
            logging.error(f"Error fetching relevant media IDs for keyword '{keyword}': {str(e)}")
            # Continue processing other keywords

    fetch_duration = time.time() - start_time
    log_histogram("fetch_relevant_media_ids_duration", fetch_duration)
    log_counter("fetch_relevant_media_ids_success", labels={"result_count": len(relevant_ids)})
    return list(relevant_ids)


def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
    log_counter("filter_results_by_keywords_attempt", labels={"result_count": len(results), "keyword_count": len(keywords)})
    start_time = time.time()
    if not keywords:
        return results

    filtered_results = []
    for result in results:
        try:
            metadata = result.get('metadata', {})
            if metadata is None:
                logging.warning(f"No metadata found for result: {result}")
                continue
            if not isinstance(metadata, dict):
                logging.warning(f"Unexpected metadata type: {type(metadata)}. Expected dict.")
                continue

            media_id = metadata.get('media_id')
            if media_id is None:
                logging.warning(f"No media_id found in metadata: {metadata}")
                continue

            media_keywords = fetch_keywords_for_media(media_id)
            if any(keyword.lower() in [mk.lower() for mk in media_keywords] for keyword in keywords):
                filtered_results.append(result)
        except Exception as e:
            logging.error(f"Error processing result: {result}. Error: {str(e)}")

    filter_duration = time.time() - start_time
    log_histogram("filter_results_by_keywords_duration", filter_duration)
    log_counter("filter_results_by_keywords_success", labels={"filtered_count": len(filtered_results)})
    return filtered_results

# FIXME: to be implememted
def extract_media_id_from_result(result: str) -> Optional[int]:
    # Implement this function based on how you store the media_id in your results
    # For example, if it's stored at the beginning of each result:
    try:
        return int(result.split('_')[0])
    except (IndexError, ValueError):
        logging.error(f"Failed to extract media_id from result: {result}")
        return None

#
#
########################################################################################################################


############################################################################################################
#
# Chat RAG

def enhanced_rag_pipeline_chat(query: str, api_choice: str, character_id: int, keywords: Optional[str] = None) -> Dict[str, Any]:
    """
    Enhanced RAG pipeline tailored for the Character Chat tab.

    Args:
        query (str): The user's input query.
        api_choice (str): The API to use for generating the response.
        character_id (int): The ID of the character being interacted with.
        keywords (Optional[str]): Comma-separated keywords to filter search results.

    Returns:
        Dict[str, Any]: Contains the generated answer and the context used.
    """
    log_counter("enhanced_rag_pipeline_chat_attempt", labels={"api_choice": api_choice, "character_id": character_id})
    start_time = time.time()
    try:
        # Load embedding provider from config, or fallback to 'openai'
        embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
        logging.debug(f"Using embedding provider: {embedding_provider}")

        # Process keywords if provided
        keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
        logging.debug(f"enhanced_rag_pipeline_chat - Keywords: {keyword_list}")

        # Fetch relevant chat IDs based on character_id and keywords
        if keyword_list:
            relevant_chat_ids = fetch_keywords_for_chats(keyword_list)
        else:
            relevant_chat_ids = fetch_all_chat_ids(character_id)
        logging.debug(f"enhanced_rag_pipeline_chat - Relevant chat IDs: {relevant_chat_ids}")

        if not relevant_chat_ids:
            logging.info(f"No chats found for the given keywords and character ID: {character_id}")
            # Fallback to generating answer without context
            answer = generate_answer(api_choice, "", query)
            # Metrics
            pipeline_duration = time.time() - start_time
            log_histogram("enhanced_rag_pipeline_chat_duration", pipeline_duration, labels={"api_choice": api_choice})
            log_counter("enhanced_rag_pipeline_chat_success",
                        labels={"api_choice": api_choice, "character_id": character_id})
            return {
                "answer": answer,
                "context": ""
            }

        # Perform vector search within the relevant chats
        vector_results = perform_vector_search_chat(query, relevant_chat_ids)
        logging.debug(f"enhanced_rag_pipeline_chat - Vector search results: {vector_results}")

        # Perform full-text search within the relevant chats
        # FIXME - Update for DB Selection
        fts_results = perform_full_text_search_chat(query, relevant_chat_ids)
        logging.debug("enhanced_rag_pipeline_chat - Full-text search results:")
        logging.debug("\n".join([str(item) for item in fts_results]))

        # Combine results
        all_results = vector_results + fts_results

        apply_re_ranking = True
        if apply_re_ranking:
            logging.debug("enhanced_rag_pipeline_chat - Applying Re-Ranking")
            ranker = Ranker()

            # Prepare passages for re-ranking
            passages = [{"id": i, "text": result['content']} for i, result in enumerate(all_results)]
            rerank_request = RerankRequest(query=query, passages=passages)

            # Rerank the results
            reranked_results = ranker.rerank(rerank_request)

            # Sort results based on the re-ranking score
            reranked_results = sorted(reranked_results, key=lambda x: x['score'], reverse=True)

            # Log reranked results
            logging.debug(f"enhanced_rag_pipeline_chat - Reranked results: {reranked_results}")

            # Update all_results based on reranking
            all_results = [all_results[result['id']] for result in reranked_results]

        # Extract context from top results (limit to top 10)
        context = "\n".join([result['content'] for result in all_results[:10]])
        logging.debug(f"Context length: {len(context)}")
        logging.debug(f"Context: {context[:200]}")  # Log only the first 200 characters for brevity

        # Generate answer using the selected API
        answer = generate_answer(api_choice, context, query)

        if not all_results:
            logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
            return {
                "answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
                "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
            }

        return {
            "answer": answer,
            "context": context
        }

    except Exception as e:
        log_counter("enhanced_rag_pipeline_chat_error", labels={"api_choice": api_choice, "character_id": character_id, "error": str(e)})
        logging.error(f"Error in enhanced_rag_pipeline_chat: {str(e)}")
        return {
            "answer": "An error occurred while processing your request.",
            "context": ""
        }


def fetch_relevant_chat_ids(character_id: int, keywords: List[str]) -> List[int]:
    """
    Fetch chat IDs associated with a character and filtered by keywords.

    Args:
        character_id (int): The ID of the character.
        keywords (List[str]): List of keywords to filter chats.

    Returns:
        List[int]: List of relevant chat IDs.
    """
    log_counter("fetch_relevant_chat_ids_attempt", labels={"character_id": character_id, "keyword_count": len(keywords)})
    start_time = time.time()
    relevant_ids = set()
    try:
        media_ids = fetch_keywords_for_chats(keywords)
        fetch_duration = time.time() - start_time
        log_histogram("fetch_relevant_chat_ids_duration", fetch_duration)
        log_counter("fetch_relevant_chat_ids_success",
                    labels={"character_id": character_id, "result_count": len(relevant_ids)})
        relevant_ids.update(media_ids)
        return list(relevant_ids)
    except Exception as e:
        log_counter("fetch_relevant_chat_ids_error", labels={"character_id": character_id, "error": str(e)})
        logging.error(f"Error fetching relevant chat IDs: {str(e)}")
        return []


def fetch_all_chat_ids(character_id: int) -> List[int]:
    """
    Fetch all chat IDs associated with a specific character.

    Args:
        character_id (int): The ID of the character.

    Returns:
        List[int]: List of all chat IDs for the character.
    """
    log_counter("fetch_all_chat_ids_attempt", labels={"character_id": character_id})
    start_time = time.time()
    try:
        chats = get_character_chats(character_id=character_id)
        chat_ids = [chat['id'] for chat in chats]
        fetch_duration = time.time() - start_time
        log_histogram("fetch_all_chat_ids_duration", fetch_duration)
        log_counter("fetch_all_chat_ids_success", labels={"character_id": character_id, "chat_count": len(chat_ids)})
        return chat_ids
    except Exception as e:
        log_counter("fetch_all_chat_ids_error", labels={"character_id": character_id, "error": str(e)})
        logging.error(f"Error fetching all chat IDs for character {character_id}: {str(e)}")
        return []

#
# End of Chat RAG
############################################################################################################

# Function to preprocess and store all existing content in the database
# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"):
#     unprocessed_media = get_unprocessed_media()
#     total_media = len(unprocessed_media)
#
#     for index, row in enumerate(unprocessed_media, 1):
#         media_id, content, media_type, file_name = row
#         collection_name = f"{media_type}_{media_id}"
#
#         logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}")
#
#         try:
#             process_and_store_content(
#                 database=database,
#                 content=content,
#                 collection_name=collection_name,
#                 media_id=media_id,
#                 file_name=file_name or f"{media_type}_{media_id}",
#                 create_embeddings=True,
#                 create_contextualized=create_contextualized,
#                 api_name=api_name
#             )
#
#             # Mark the media as processed in the database
#             mark_media_as_processed(database, media_id)
#
#             logger.info(f"Successfully processed media ID {media_id}")
#         except Exception as e:
#             logger.error(f"Error processing media ID {media_id}: {str(e)}")
#
#     logger.info("Finished preprocessing all unprocessed content")

############################################################################################################
#
# ElasticSearch Retriever

# https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-elasticsearch
#
# https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-self-query

#
# End of RAG_Library_2.py
############################################################################################################