Spaces:
Running
Running
File size: 750 Bytes
7fb6500 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# classification_chain.py
import os
from langchain.chains import LLMChain
from langchain_groq import ChatGroq
# We'll import the classification_prompt from prompts.py
from prompts import classification_prompt
def get_classification_chain() -> LLMChain:
"""
Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
"""
# Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY
chat_groq_model = ChatGroq(
model="Gemma2-9b-It",
groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
)
# Build an LLMChain
classification_chain = LLMChain(
llm=chat_groq_model,
prompt=classification_prompt
)
return classification_chain
|