# classification_chain.py import os from langchain.chains import LLMChain from langchain_groq import ChatGroq from prompts import classification_prompt # classification_chain.py def get_classification_chain() -> LLMChain: """ Builds the classification chain (LLMChain) using ChatGroq and the classification prompt. """ # Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY chat_groq_model = ChatGroq( model="Gemma2-9b-It", groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment ) # Build an LLMChain classification_chain = LLMChain( llm=chat_groq_model, prompt=classification_prompt ) return classification_chain def classify_with_history(query: str, chat_history: list) -> str: """ Classifies a user query based on the context of previous conversation (chat_history). """ # Add the history into the query context if needed (depending on the type of model) context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query # Update the prompt with both the context and the query classification_result = get_classification_chain().run({"query": context}) return classification_result