from haystack import Document
from haystack.utils import Secret
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.builders import PromptBuilder
from haystack.components.generators import HuggingFaceTGIGenerator
from haystack import Pipeline

import sys
import subprocess

def install(name):
    subprocess.call([sys.executable, '-m', 'pip', 'install', name])

def init_doc_store(path, files):
    docs = []
    for file in files:
        with open(path + file, 'r') as f:
            content = f.read()
            docs.append(Document(content=content, meta={'name':file}))

    document_store = InMemoryDocumentStore()
    document_store.write_documents(docs)
    return document_store

def define_components(document_store, api_key):
    retriever = InMemoryBM25Retriever(document_store, top_k=3)
    
    template = """
    You are a Chatbot designed to spread Awareness about Alzheimer's Disease. You are AI Chaperone.
    You will be provided information about Alzheimer's Disease as context for each question. Given the following information, answer the question.
    
    Context:
    {% for document in documents %}
        {{ document.content }}
    {% endfor %}
    
    Question: {{question}}
    Answer:
    """
    prompt_builder = PromptBuilder(template=template)
    
    generator = HuggingFaceTGIGenerator(
        model="mistralai/Mistral-7B-Instruct-v0.1", 
        token=Secret.from_token(api_key),
        generation_kwargs = {
            'max_new_tokens':50,
            'temperature':0.7
        }
    )
    generator.warm_up()
    return retriever, prompt_builder, generator

def define_pipeline(retreiver, prompt_builder, generator):
    basic_rag_pipeline = Pipeline()
    
    basic_rag_pipeline.add_component("retriever", retreiver)
    basic_rag_pipeline.add_component("prompt_builder", prompt_builder)
    basic_rag_pipeline.add_component("llm", generator)
    
    basic_rag_pipeline.connect("retriever", "prompt_builder.documents")
    basic_rag_pipeline.connect("prompt_builder", "llm")

    return basic_rag_pipeline