Phoenix21 commited on
Commit
cfeeb60
·
verified ·
1 Parent(s): f83c037

Update classification_chain.py

Browse files
Files changed (1) hide show
  1. classification_chain.py +10 -2
classification_chain.py CHANGED
@@ -2,8 +2,6 @@
2
  import os
3
  from langchain.chains import LLMChain
4
  from langchain_groq import ChatGroq
5
-
6
- # We'll import the classification_prompt from prompts.py
7
  from prompts import classification_prompt
8
 
9
  def get_classification_chain() -> LLMChain:
@@ -22,3 +20,13 @@ def get_classification_chain() -> LLMChain:
22
  prompt=classification_prompt
23
  )
24
  return classification_chain
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  from langchain.chains import LLMChain
4
  from langchain_groq import ChatGroq
 
 
5
  from prompts import classification_prompt
6
 
7
  def get_classification_chain() -> LLMChain:
 
20
  prompt=classification_prompt
21
  )
22
  return classification_chain
23
+
24
+ def classify_with_history(query: str, chat_history: list) -> str:
25
+ """
26
+ Classifies a user query based on the context of previous conversation (chat_history).
27
+ """
28
+ # Add the history into the query context if needed (depending on the type of model)
29
+ context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query
30
+ # Update the prompt with both the context and the query
31
+ classification_result = get_classification_chain().run({"query": context})
32
+ return classification_result