query-expansion-and-tagging / dspy_inference.py
likhithv's picture
Upload 3 files
b987573 verified
raw
history blame
4.98 kB
from typing import Optional, List
import dspy
import copy
import dspy.evaluate
from pydantic import BaseModel
from dotenv import load_dotenv
import os
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
load_dotenv()
class Agent(dspy.Module):
"""
Base Agent Module
"""
def __init__(
self,
model: Optional[str] | None = "llama3",
client: Optional[str] | None = "ollama",
max_tokens: Optional[int] | None = 4096,
temperature: Optional[float] | None = 0.5,
) -> None:
"""
Initialising Agent Module
Args:
model: str -> default = llama3
client: str -> default = ollama
max_tokens: int -> default = 4096
temperature: float -> default = 0.5
tools: List[Tool] -> default = None
"""
self.model = dspy.GROQ(
model="llama3-8b-8192",
temperature=temperature,
api_key=os.getenv("GROQ_API_KEY"),
max_tokens=max_tokens,
frequency_penalty=1.5,
presence_penalty=1.5,
)
dspy.settings.configure(
lm=self.model,
max_tokens = max_tokens,
temperature = temperature
)
def __deepcopy__(self, memo):
new_instance = self.__class__.__new__(self.__class__)
memo[id(self)] = new_instance
for k, v in self.__dict__.items():
if k != 'model':
setattr(new_instance, k, copy.deepcopy(v, memo))
new_instance.model = self.model
return new_instance
class OutputFormat(BaseModel):
expand: Optional[str]
topic: str
class Conversation(BaseModel):
role: str
content: str
class Memory(BaseModel):
conversations: List[Conversation]
class BaseSignature(dspy.Signature):
"""
You are an expert in expanding the user question and generating suitable tags for the question.
Follow the exact instructions given:
1. Expand with only single question.
2. Try to keep the actual content in the expand question. Example: User question: What is math ?, expand: What is mathematics ?
3. Tags should be 2-level hierarchy topics. Eg - India - Politics, Sports- Football. Tags should be as specific as possible. If it is a general question topic: GENERAL
4. Do not give the reference of the previous question in the expanded question.
5. If there is no expanded version of the user question, then give it as expand = "None"
6. If there is a general question asked, do not expand the question, just give it as expand="None"
7. topic can not be "None"
8. Use the provided memory to understand context and provide more relevant expansions and topics.
"""
query: str = dspy.InputField(prefix = "Question: ")
memory: Memory = dspy.InputField(prefix = "Previous conversations: ", desc="This is a list of previous conversations.")
output: OutputFormat = dspy.OutputField(desc='''Expanded user question and tags are generated as output. Respond with a single JSON object. JSON Schema: {"properties": {"expand": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Expand"}, "topic": {"title": "Topic", "type": "string"}}, "required": ["expand", "topic"], "title": "OutputFormat", "type": "object"}''')
class OutputAgent(Agent):
"""
Multi-output Agent Module. Inherited from Agent Module
"""
def __init__(self, model: str | None = "llama3", client: str | None = "ollama", max_tokens: int | None = 8192) -> None:
super().__init__(
model = model,
client = client,
max_tokens = max_tokens
)
def __call__(self, query: str, memory: List[dict]) -> dspy.Prediction:
"""
This function expands the user question and generates the tags for the user question.
Args:
query: str -> The current user query
memory: List[dict] -> List of previous conversations
Returns:
dspy.Prediction: Expanded question and topic
"""
# Convert the memory list to the Memory model
conversations = [Conversation(role=m["role"], content=m["content"]) for m in memory]
memory_model = Memory(conversations=conversations)
# modules
outputGenerator = dspy.TypedPredictor(BaseSignature)
# infer
try:
output = outputGenerator(query=query, memory=memory_model)
return output
except Exception as e:
print("Retrying...", e)
return self.__call__(query=query, memory=memory)
# This function can be called from app.py to get the expanded question and topic
def get_expanded_query_and_topic(query: str, conversation_context: List[dict]):
agent = OutputAgent()
result = agent(query, conversation_context)
return result.output