Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import time | |
import json | |
import re | |
from typing import List, Literal, TypedDict | |
from transformers import AutoTokenizer | |
from tools.tools import toolsInfo | |
from gradio_client import Client | |
import constants as C | |
import utils as U | |
from openai import OpenAI | |
import anthropic | |
from groq import Groq | |
from dotenv import load_dotenv | |
load_dotenv() | |
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] | |
ModelConfig = TypedDict("ModelConfig", { | |
"client": OpenAI | Groq | anthropic.Anthropic, | |
"model": str, | |
"max_context": int, | |
"tokenizer": AutoTokenizer | |
}) | |
modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA" | |
MODEL_CONFIG: dict[ModelType, ModelConfig] = { | |
"GPT4": { | |
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), | |
"model": "gpt-4o-mini", | |
"max_context": 128000, | |
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") | |
}, | |
"CLAUDE": { | |
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), | |
"model": "claude-3-5-sonnet-20240620", | |
"max_context": 128000, | |
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") | |
}, | |
"LLAMA": { | |
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")), | |
"model": "llama-3.1-70b-versatile", | |
# "model": "llama-3.2-90b-text-preview", | |
"tools_model": "llama3-groq-70b-8192-tool-use-preview", | |
"max_context": 12800, # intentionally reduced to 1/10th | |
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") | |
} | |
} | |
client = MODEL_CONFIG[modelType]["client"] | |
MODEL = MODEL_CONFIG[modelType]["model"] | |
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL | |
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] | |
tokenizer = MODEL_CONFIG[modelType]["tokenizer"] | |
isClaudeModel = modelType == "CLAUDE" | |
def __countTokens(text): | |
text = str(text) | |
tokens = tokenizer.encode(text, add_special_tokens=False) | |
return len(tokens) | |
st.set_page_config( | |
page_title="Dr Newtons PG research Ai", | |
page_icon=C.AI_ICON, | |
) | |
st.markdown('<link rel="manifest" href="manifest.json">', unsafe_allow_html=True) | |
def __isInvalidResponse(response: str): | |
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: | |
U.pprint("new line followed by small case char") | |
return True | |
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: | |
U.pprint("lot of consecutive repeating words") | |
return True | |
if len(re.findall(r'\n\n', response)) > 20: | |
U.pprint("lots of paragraphs") | |
return True | |
if C.EXCEPTION_KEYWORD in response: | |
U.pprint("LLM API threw exception") | |
return True | |
if ('{\n "questions"' in response) and (C.JSON_SEPARATOR not in response): | |
U.pprint("JSON response without json separator") | |
return True | |
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): | |
U.pprint("JSON response without json separator") | |
return True | |
if response.startswith(C.JSON_SEPARATOR): | |
U.pprint("only options with no text") | |
return True | |
def __matchingKeywordsCount(keywords: List[str], text: str): | |
return sum([ | |
1 if keyword in text else 0 | |
for keyword in keywords | |
]) | |
def __getMessages(): | |
def getContextSize(): | |
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 | |
U.pprint(f"{currContextSize=}") | |
return currContextSize | |
while getContextSize() > MAX_CONTEXT: | |
U.pprint("Context size exceeded, removing first message") | |
st.session_state.messages.pop(0) | |
return st.session_state.messages | |
def __logLlmRequest(messagesFormatted: list, model: str): | |
contextSize = __countTokens(messagesFormatted) | |
U.pprint(f"{contextSize=} | {model}") | |
# U.pprint(f"{messagesFormatted=}") | |
tools = [ | |
toolsInfo["getGoogleSearchResults"]["schema"], | |
] | |
def __showToolResponse(toolResponseDisplay: dict): | |
msg = toolResponseDisplay.get("text") | |
icon = toolResponseDisplay.get("icon") | |
col1, col2 = st.columns([1, 20]) | |
with col1: | |
st.image( | |
icon or C.TOOL_ICON, | |
width=30 | |
) | |
with col2: | |
if "`" not in msg: | |
st.markdown(f"`{msg}`") | |
else: | |
st.markdown(msg) | |
def __addToolCallToMsgs(toolCall: dict): | |
if isClaudeModel: | |
st.session_state.messages.append(toolCall) | |
else: | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"tool_calls": [ | |
{ | |
"id": toolCall.id, | |
"function": { | |
"name": toolCall.function.name, | |
"arguments": toolCall.function.arguments, | |
}, | |
"type": toolCall.type, | |
} | |
], | |
} | |
) | |
def __processToolCalls(toolCalls): | |
for toolCall in toolCalls: | |
functionName = toolCall.function.name | |
functionToCall = toolsInfo[functionName]["func"] | |
functionArgsStr = toolCall.function.arguments | |
U.pprint(f"{functionName=} | {functionArgsStr=}") | |
functionArgs = json.loads(functionArgsStr) | |
functionResult = functionToCall(**functionArgs) | |
functionResponse = functionResult.get("response") | |
responseDisplay = functionResult.get("display") | |
U.pprint(f"{functionResponse=}") | |
if responseDisplay: | |
__showToolResponse(responseDisplay) | |
st.session_state.toolResponseDisplay = responseDisplay | |
__addToolCallToMsgs(toolCall) | |
st.session_state.messages.append({ | |
"role": "tool", | |
"tool_call_id": toolCall.id, | |
"name": functionName, | |
"content": functionResponse, | |
}) | |
def __processClaudeToolCalls(toolResponse): | |
toolCall = toolResponse[1] | |
functionName = toolCall.name | |
functionToCall = toolsInfo[functionName]["func"] | |
functionArgs = toolCall.input | |
functionResult = functionToCall(**functionArgs) | |
functionResponse = functionResult.get("response") | |
responseDisplay = functionResult.get("display") | |
U.pprint(f"{functionResponse=}") | |
if responseDisplay: | |
__showToolResponse(responseDisplay) | |
st.session_state.toolResponseDisplay = responseDisplay | |
__addToolCallToMsgs({ | |
"role": "assistant", | |
"content": toolResponse | |
}) | |
st.session_state.messages.append({ | |
"role": "user", | |
"content": [{ | |
"type": "tool_result", | |
"tool_use_id": toolCall.id, | |
"content": functionResponse, | |
}], | |
}) | |
def __dedupeToolCalls(toolCalls: list): | |
toolCallsDict = {} | |
for toolCall in toolCalls: | |
funcName = toolCall.name if isClaudeModel else toolCall.function.name | |
toolCallsDict[funcName] = toolCall | |
dedupedToolCalls = list(toolCallsDict.values()) | |
if len(toolCalls) != len(dedupedToolCalls): | |
U.pprint("Deduped tool calls!") | |
U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}") | |
return dedupedToolCalls | |
def __getClaudeTools(): | |
claudeTools = [] | |
for tool in tools: | |
funcInfo = tool["function"] | |
name = funcInfo["name"] | |
description = funcInfo["description"] | |
schema = funcInfo["parameters"] | |
claudeTools.append({ | |
"name": name, | |
"description": description, | |
"input_schema": schema, | |
}) | |
return claudeTools | |
def __removeFunctionCall(response: str): | |
pattern = r'<function=getGoogleSearchResults>\{"query": ".*?"\}<function>' | |
return re.sub(pattern, '', response) | |
def predict(model: str = None, attempts=0): | |
model = model or MODEL | |
messagesFormatted = [] | |
try: | |
if isClaudeModel: | |
messagesFormatted.extend(__getMessages()) | |
__logLlmRequest(messagesFormatted, model) | |
responseMessage = client.messages.create( | |
model=model, | |
messages=messagesFormatted, | |
system=C.SYSTEM_MSG, | |
temperature=0.5, | |
max_tokens=4000, | |
tools=__getClaudeTools() | |
) | |
responseMessageContent = responseMessage.content | |
responseContent = responseMessageContent[0].text | |
toolCalls = [] | |
if len(responseMessageContent) > 1: | |
toolCalls = [responseMessageContent[1]] | |
else: | |
messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}] | |
messagesFormatted.extend(__getMessages()) | |
__logLlmRequest(messagesFormatted, model) | |
response = client.chat.completions.create( | |
model=model, | |
messages=messagesFormatted, | |
temperature=0.6, | |
max_tokens=4000, | |
stream=False, | |
tools=tools | |
) | |
responseMessage = response.choices[0].message | |
responseContent = responseMessage.content | |
if responseContent and '<function=' in responseContent: | |
U.pprint(f"Wrong toolCall response: {responseContent}") | |
if attempts < 3: | |
U.pprint(f"Retrying...{attempts + 1}/3") | |
time.sleep(0.2) | |
return predict(model, attempts + 1) | |
else: | |
responseContent = __removeFunctionCall(responseContent) | |
if "<function=" in responseContent: | |
U.pprint("Switching to TOOLS_MODEL") | |
return predict(TOOLS_MODEL) | |
toolCalls = responseMessage.tool_calls | |
# U.pprint(f"{responseMessage=}") | |
# U.pprint(f"{responseContent=}") | |
U.pprint(f"{toolCalls=}") | |
if toolCalls: | |
toolCalls = __dedupeToolCalls(toolCalls) | |
U.pprint("Deduping done!") | |
try: | |
if isClaudeModel: | |
__processClaudeToolCalls(responseMessage.content) | |
else: | |
__processToolCalls(toolCalls) | |
return predict() | |
except Exception as e: | |
U.pprint(e) | |
else: | |
return responseContent | |
except Exception as e: | |
U.pprint(f"LLM API Error: {e}") | |
return f"{C.EXCEPTION_KEYWORD} | {e}" | |
def __generateImage(prompt: str): | |
fluxClient = Client("black-forest-labs/FLUX.1-schnell") | |
result = fluxClient.predict( | |
prompt=prompt, | |
seed=0, | |
randomize_seed=True, | |
width=1024, | |
height=768, | |
num_inference_steps=4, | |
api_name="/infer" | |
) | |
U.pprint(f"imageResult={result}") | |
return result | |
def __resetButtonState(): | |
st.session_state.buttonValue = "" | |
if "ipAddress" not in st.session_state: | |
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") | |
if "chatHistory" not in st.session_state: | |
st.session_state.chatHistory = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "buttonValue" not in st.session_state: | |
__resetButtonState() | |
st.session_state.toolResponseDisplay = {} | |
U.pprint("\n") | |
U.pprint("\n") | |
U.applyCommonStyles() | |
st.title("Dr Newtons Ai Client") | |
for chat in st.session_state.chatHistory: | |
role = chat["role"] | |
content = chat["content"] | |
imagePath = chat.get("image") | |
toolResponseDisplay = chat.get("toolResponseDisplay") | |
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON | |
with st.chat_message(role, avatar=avatar): | |
st.markdown(content) | |
if toolResponseDisplay: | |
__showToolResponse(toolResponseDisplay) | |
if imagePath: | |
st.image(imagePath) | |
# U.pprint(f"{st.session_state.buttonValue=}") | |
# U.pprint(f"{st.session_state.selectedStory=}") | |
# U.pprint(f"{st.session_state.startMsg=}") | |
if prompt := ( | |
st.chat_input("Ask anything ...") | |
or st.session_state["buttonValue"] | |
): | |
__resetButtonState() | |
with st.chat_message("user", avatar=C.USER_ICON): | |
st.markdown(prompt) | |
U.pprint(f"{prompt=}") | |
st.session_state.chatHistory.append({"role": "user", "content": prompt }) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("assistant", avatar=C.AI_ICON): | |
responseContainer = st.empty() | |
def __printAndGetResponse(): | |
response = "" | |
responseContainer.image(C.TEXT_LOADER) | |
responseGenerator = predict() | |
for chunk in responseGenerator: | |
response += chunk | |
if __isInvalidResponse(response): | |
U.pprint(f"InvalidResponse={response}") | |
return | |
if C.JSON_SEPARATOR not in response: | |
responseContainer.markdown(response) | |
return response | |
response = __printAndGetResponse() | |
while not response: | |
U.pprint("Empty response. Retrying..") | |
time.sleep(0.7) | |
response = __printAndGetResponse() | |
U.pprint(f"{response=}") | |
def selectButton(optionLabel): | |
st.session_state["buttonValue"] = optionLabel | |
U.pprint(f"Selected: {optionLabel}") | |
rawResponse = response | |
responseParts = response.split(C.JSON_SEPARATOR) | |
jsonStr = None | |
if len(responseParts) > 1: | |
[response, jsonStr] = responseParts | |
imagePath = None | |
# imageContainer = st.empty() | |
# try: | |
# (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) | |
# if imagePrompt: | |
# imgContainer = imageContainer.container() | |
# imgContainer.write( | |
# f""" | |
# <div class='blinking code'> | |
# {loaderText} | |
# </div> | |
# """, | |
# unsafe_allow_html=True | |
# ) | |
# # imgContainer.markdown(f"`{loaderText}`") | |
# imgContainer.image(C.IMAGE_LOADER) | |
# (imagePath, seed) = __generateImage(imagePrompt) | |
# imageContainer.image(imagePath) | |
# except Exception as e: | |
# U.pprint(e) | |
# imageContainer.empty() | |
toolResponseDisplay = st.session_state.toolResponseDisplay | |
st.session_state.chatHistory.append({ | |
"role": "assistant", | |
"content": response, | |
"image": imagePath, | |
"toolResponseDisplay": toolResponseDisplay | |
}) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": rawResponse, | |
}) | |
if jsonStr: | |
try: | |
jsonStr = jsonStr.replace("```", "") | |
json.loads(jsonStr) | |
jsonObj = json.loads(jsonStr) | |
questions = jsonObj.get("questions") | |
action = jsonObj.get("action") | |
if questions: | |
for option in questions: | |
st.button( | |
option["label"], | |
key=option["id"], | |
on_click=lambda label=option["label"]: selectButton(label) | |
) | |
elif action: | |
pass | |
except Exception as e: | |
U.pprint(e) | |
# if st.button("Rerun"): | |
# # __resetButtonState() | |
# st.session_state.chatHistory = [] | |
# st.session_state.messages = [] | |
# st.rerun() | |