import streamlit as st
import os
import time
import json
import re
from typing import List, Literal, TypedDict
from transformers import AutoTokenizer
from tools.tools import toolsInfo
from gradio_client import Client
import constants as C
import utils as U
from openai import OpenAI
import anthropic
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | Groq | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA"
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"GPT4": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o-mini",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-sonnet-20240620",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"LLAMA": {
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
"model": "llama-3.1-70b-versatile",
# "model": "llama-3.2-90b-text-preview",
"tools_model": "llama3-groq-70b-8192-tool-use-preview",
"max_context": 12800, # intentionally reduced to 1/10th
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
}
}
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
st.set_page_config(
page_title="Dr Newtons PG research Ai",
page_icon=C.AI_ICON,
)
st.markdown('', unsafe_allow_html=True)
def __isInvalidResponse(response: str):
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response:
U.pprint("new line followed by small case char")
return True
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
U.pprint("lot of consecutive repeating words")
return True
if len(re.findall(r'\n\n', response)) > 20:
U.pprint("lots of paragraphs")
return True
if C.EXCEPTION_KEYWORD in response:
U.pprint("LLM API threw exception")
return True
if ('{\n "questions"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if response.startswith(C.JSON_SEPARATOR):
U.pprint("only options with no text")
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __getMessages():
def getContextSize():
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
U.pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
U.pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
def __logLlmRequest(messagesFormatted: list, model: str):
contextSize = __countTokens(messagesFormatted)
U.pprint(f"{contextSize=} | {model}")
# U.pprint(f"{messagesFormatted=}")
tools = [
toolsInfo["getGoogleSearchResults"]["schema"],
]
def __showToolResponse(toolResponseDisplay: dict):
msg = toolResponseDisplay.get("text")
icon = toolResponseDisplay.get("icon")
col1, col2 = st.columns([1, 20])
with col1:
st.image(
icon or C.TOOL_ICON,
width=30
)
with col2:
if "`" not in msg:
st.markdown(f"`{msg}`")
else:
st.markdown(msg)
def __addToolCallToMsgs(toolCall: dict):
if isClaudeModel:
st.session_state.messages.append(toolCall)
else:
st.session_state.messages.append(
{
"role": "assistant",
"tool_calls": [
{
"id": toolCall.id,
"function": {
"name": toolCall.function.name,
"arguments": toolCall.function.arguments,
},
"type": toolCall.type,
}
],
}
)
def __processToolCalls(toolCalls):
for toolCall in toolCalls:
functionName = toolCall.function.name
functionToCall = toolsInfo[functionName]["func"]
functionArgsStr = toolCall.function.arguments
U.pprint(f"{functionName=} | {functionArgsStr=}")
functionArgs = json.loads(functionArgsStr)
functionResult = functionToCall(**functionArgs)
functionResponse = functionResult.get("response")
responseDisplay = functionResult.get("display")
U.pprint(f"{functionResponse=}")
if responseDisplay:
__showToolResponse(responseDisplay)
st.session_state.toolResponseDisplay = responseDisplay
__addToolCallToMsgs(toolCall)
st.session_state.messages.append({
"role": "tool",
"tool_call_id": toolCall.id,
"name": functionName,
"content": functionResponse,
})
def __processClaudeToolCalls(toolResponse):
toolCall = toolResponse[1]
functionName = toolCall.name
functionToCall = toolsInfo[functionName]["func"]
functionArgs = toolCall.input
functionResult = functionToCall(**functionArgs)
functionResponse = functionResult.get("response")
responseDisplay = functionResult.get("display")
U.pprint(f"{functionResponse=}")
if responseDisplay:
__showToolResponse(responseDisplay)
st.session_state.toolResponseDisplay = responseDisplay
__addToolCallToMsgs({
"role": "assistant",
"content": toolResponse
})
st.session_state.messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": toolCall.id,
"content": functionResponse,
}],
})
def __dedupeToolCalls(toolCalls: list):
toolCallsDict = {}
for toolCall in toolCalls:
funcName = toolCall.name if isClaudeModel else toolCall.function.name
toolCallsDict[funcName] = toolCall
dedupedToolCalls = list(toolCallsDict.values())
if len(toolCalls) != len(dedupedToolCalls):
U.pprint("Deduped tool calls!")
U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}")
return dedupedToolCalls
def __getClaudeTools():
claudeTools = []
for tool in tools:
funcInfo = tool["function"]
name = funcInfo["name"]
description = funcInfo["description"]
schema = funcInfo["parameters"]
claudeTools.append({
"name": name,
"description": description,
"input_schema": schema,
})
return claudeTools
def __removeFunctionCall(response: str):
pattern = r'\{"query": ".*?"\}'
return re.sub(pattern, '', response)
def predict(model: str = None, attempts=0):
model = model or MODEL
messagesFormatted = []
try:
if isClaudeModel:
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted, model)
responseMessage = client.messages.create(
model=model,
messages=messagesFormatted,
system=C.SYSTEM_MSG,
temperature=0.5,
max_tokens=4000,
tools=__getClaudeTools()
)
responseMessageContent = responseMessage.content
responseContent = responseMessageContent[0].text
toolCalls = []
if len(responseMessageContent) > 1:
toolCalls = [responseMessageContent[1]]
else:
messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}]
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted, model)
response = client.chat.completions.create(
model=model,
messages=messagesFormatted,
temperature=0.6,
max_tokens=4000,
stream=False,
tools=tools
)
responseMessage = response.choices[0].message
responseContent = responseMessage.content
if responseContent and ' 1:
[response, jsonStr] = responseParts
imagePath = None
# imageContainer = st.empty()
# try:
# (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
# if imagePrompt:
# imgContainer = imageContainer.container()
# imgContainer.write(
# f"""
#
# {loaderText}
#
# """,
# unsafe_allow_html=True
# )
# # imgContainer.markdown(f"`{loaderText}`")
# imgContainer.image(C.IMAGE_LOADER)
# (imagePath, seed) = __generateImage(imagePrompt)
# imageContainer.image(imagePath)
# except Exception as e:
# U.pprint(e)
# imageContainer.empty()
toolResponseDisplay = st.session_state.toolResponseDisplay
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
"image": imagePath,
"toolResponseDisplay": toolResponseDisplay
})
st.session_state.messages.append({
"role": "assistant",
"content": rawResponse,
})
if jsonStr:
try:
jsonStr = jsonStr.replace("```", "")
json.loads(jsonStr)
jsonObj = json.loads(jsonStr)
questions = jsonObj.get("questions")
action = jsonObj.get("action")
if questions:
for option in questions:
st.button(
option["label"],
key=option["id"],
on_click=lambda label=option["label"]: selectButton(label)
)
elif action:
pass
except Exception as e:
U.pprint(e)
# if st.button("Rerun"):
# # __resetButtonState()
# st.session_state.chatHistory = []
# st.session_state.messages = []
# st.rerun()