|
from typing import Optional, List
|
|
import dspy
|
|
import copy
|
|
import dspy.evaluate
|
|
from pydantic import BaseModel
|
|
from dotenv import load_dotenv
|
|
import os
|
|
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
|
|
|
|
load_dotenv()
|
|
|
|
class Agent(dspy.Module):
|
|
"""
|
|
Base Agent Module
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Optional[str] | None = "llama3",
|
|
client: Optional[str] | None = "ollama",
|
|
max_tokens: Optional[int] | None = 4096,
|
|
temperature: Optional[float] | None = 0.5,
|
|
) -> None:
|
|
"""
|
|
Initialising Agent Module
|
|
|
|
Args:
|
|
model: str -> default = llama3
|
|
client: str -> default = ollama
|
|
max_tokens: int -> default = 4096
|
|
temperature: float -> default = 0.5
|
|
tools: List[Tool] -> default = None
|
|
"""
|
|
|
|
self.model = dspy.GROQ(
|
|
model="llama3-8b-8192",
|
|
temperature=temperature,
|
|
api_key=os.getenv("GROQ_API_KEY"),
|
|
max_tokens=max_tokens,
|
|
frequency_penalty=1.5,
|
|
presence_penalty=1.5,
|
|
)
|
|
|
|
dspy.settings.configure(
|
|
lm=self.model,
|
|
max_tokens = max_tokens,
|
|
temperature = temperature
|
|
)
|
|
|
|
def __deepcopy__(self, memo):
|
|
new_instance = self.__class__.__new__(self.__class__)
|
|
memo[id(self)] = new_instance
|
|
for k, v in self.__dict__.items():
|
|
if k != 'model':
|
|
setattr(new_instance, k, copy.deepcopy(v, memo))
|
|
new_instance.model = self.model
|
|
return new_instance
|
|
|
|
class OutputFormat(BaseModel):
|
|
expand: Optional[str]
|
|
topic: str
|
|
|
|
class Conversation(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class Memory(BaseModel):
|
|
conversations: List[Conversation]
|
|
|
|
class BaseSignature(dspy.Signature):
|
|
"""
|
|
You are an expert in expanding the user question and generating suitable tags for the question.
|
|
Follow the exact instructions given:
|
|
1. Expand with only single question.
|
|
2. Try to keep the actual content in the expand question. Example: User question: What is math ?, expand: What is mathematics ?
|
|
3. Tags should be 2-level hierarchy topics. Eg - India - Politics, Sports- Football. Tags should be as specific as possible. If it is a general question topic: GENERAL
|
|
4. Do not give the reference of the previous question in the expanded question.
|
|
5. If there is no expanded version of the user question, then give it as expand = "None"
|
|
6. If there is a general question asked, do not expand the question, just give it as expand="None"
|
|
7. topic can not be "None"
|
|
8. Use the provided memory to understand context and provide more relevant expansions and topics.
|
|
"""
|
|
|
|
query: str = dspy.InputField(prefix = "Question: ")
|
|
memory: Memory = dspy.InputField(prefix = "Previous conversations: ", desc="This is a list of previous conversations.")
|
|
output: OutputFormat = dspy.OutputField(desc='''Expanded user question and tags are generated as output. Respond with a single JSON object. JSON Schema: {"properties": {"expand": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Expand"}, "topic": {"title": "Topic", "type": "string"}}, "required": ["expand", "topic"], "title": "OutputFormat", "type": "object"}''')
|
|
|
|
class OutputAgent(Agent):
|
|
"""
|
|
Multi-output Agent Module. Inherited from Agent Module
|
|
"""
|
|
|
|
def __init__(self, model: str | None = "llama3", client: str | None = "ollama", max_tokens: int | None = 8192) -> None:
|
|
super().__init__(
|
|
model = model,
|
|
client = client,
|
|
max_tokens = max_tokens
|
|
)
|
|
|
|
def __call__(self, query: str, memory: List[dict]) -> dspy.Prediction:
|
|
"""
|
|
This function expands the user question and generates the tags for the user question.
|
|
|
|
Args:
|
|
query: str -> The current user query
|
|
memory: List[dict] -> List of previous conversations
|
|
|
|
Returns:
|
|
dspy.Prediction: Expanded question and topic
|
|
"""
|
|
|
|
|
|
conversations = [Conversation(role=m["role"], content=m["content"]) for m in memory]
|
|
memory_model = Memory(conversations=conversations)
|
|
|
|
|
|
outputGenerator = dspy.TypedPredictor(BaseSignature)
|
|
|
|
|
|
try:
|
|
output = outputGenerator(query=query, memory=memory_model)
|
|
return output
|
|
except Exception as e:
|
|
print("Retrying...", e)
|
|
return self.__call__(query=query, memory=memory)
|
|
|
|
|
|
def get_expanded_query_and_topic(query: str, conversation_context: List[dict]):
|
|
agent = OutputAgent()
|
|
result = agent(query, conversation_context)
|
|
return result.output |