File size: 1,249 Bytes
7a6639b
 
 
 
 
47b1df8
7a6639b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfeeb60
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# classification_chain.py
import os
from langchain.chains import LLMChain
from langchain_groq import ChatGroq
from prompts import classification_prompt
# classification_chain.py

def get_classification_chain() -> LLMChain:
    """
    Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
    """
    # Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY
    chat_groq_model = ChatGroq(
        model="Gemma2-9b-It",
        groq_api_key=os.environ["GROQ_API_KEY"]  # must be set in environment
    )

    # Build an LLMChain
    classification_chain = LLMChain(
        llm=chat_groq_model,
        prompt=classification_prompt
    )
    return classification_chain

def classify_with_history(query: str, chat_history: list) -> str:
    """
    Classifies a user query based on the context of previous conversation (chat_history).
    """
    # Add the history into the query context if needed (depending on the type of model)
    context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query
    # Update the prompt with both the context and the query
    classification_result = get_classification_chain().run({"query": context})
    return classification_result