import streamlit as st import os import time import json import re from typing import List, Literal, TypedDict from transformers import AutoTokenizer from tools.tools import toolsInfo from gradio_client import Client import constants as C import utils as U from openai import OpenAI import anthropic from groq import Groq from dotenv import load_dotenv load_dotenv() ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | Groq | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA" MODEL_CONFIG: dict[ModelType, ModelConfig] = { "GPT4": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o-mini", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-sonnet-20240620", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "LLAMA": { "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), "model": "llama-3.1-70b-versatile", # "model": "llama-3.2-90b-text-preview", "tools_model": "llama3-groq-70b-8192-tool-use-preview", "max_context": 12800, # intentionally reduced to 1/10th "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") } } client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) st.set_page_config( page_title="Dr Newtons PG research Ai", page_icon=C.AI_ICON, ) st.markdown('', unsafe_allow_html=True) def __isInvalidResponse(response: str): if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: U.pprint("new line followed by small case char") return True if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: U.pprint("lot of consecutive repeating words") return True if len(re.findall(r'\n\n', response)) > 20: U.pprint("lots of paragraphs") return True if C.EXCEPTION_KEYWORD in response: U.pprint("LLM API threw exception") return True if ('{\n "questions"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if response.startswith(C.JSON_SEPARATOR): U.pprint("only options with no text") return True def __matchingKeywordsCount(keywords: List[str], text: str): return sum([ 1 if keyword in text else 0 for keyword in keywords ]) def __getMessages(): def getContextSize(): currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 U.pprint(f"{currContextSize=}") return currContextSize while getContextSize() > MAX_CONTEXT: U.pprint("Context size exceeded, removing first message") st.session_state.messages.pop(0) return st.session_state.messages def __logLlmRequest(messagesFormatted: list, model: str): contextSize = __countTokens(messagesFormatted) U.pprint(f"{contextSize=} | {model}") # U.pprint(f"{messagesFormatted=}") tools = [ toolsInfo["getGoogleSearchResults"]["schema"], ] def __showToolResponse(toolResponseDisplay: dict): msg = toolResponseDisplay.get("text") icon = toolResponseDisplay.get("icon") col1, col2 = st.columns([1, 20]) with col1: st.image( icon or C.TOOL_ICON, width=30 ) with col2: if "`" not in msg: st.markdown(f"`{msg}`") else: st.markdown(msg) def __addToolCallToMsgs(toolCall: dict): if isClaudeModel: st.session_state.messages.append(toolCall) else: st.session_state.messages.append( { "role": "assistant", "tool_calls": [ { "id": toolCall.id, "function": { "name": toolCall.function.name, "arguments": toolCall.function.arguments, }, "type": toolCall.type, } ], } ) def __processToolCalls(toolCalls): for toolCall in toolCalls: functionName = toolCall.function.name functionToCall = toolsInfo[functionName]["func"] functionArgsStr = toolCall.function.arguments U.pprint(f"{functionName=} | {functionArgsStr=}") functionArgs = json.loads(functionArgsStr) functionResult = functionToCall(**functionArgs) functionResponse = functionResult.get("response") responseDisplay = functionResult.get("display") U.pprint(f"{functionResponse=}") if responseDisplay: __showToolResponse(responseDisplay) st.session_state.toolResponseDisplay = responseDisplay __addToolCallToMsgs(toolCall) st.session_state.messages.append({ "role": "tool", "tool_call_id": toolCall.id, "name": functionName, "content": functionResponse, }) def __processClaudeToolCalls(toolResponse): toolCall = toolResponse[1] functionName = toolCall.name functionToCall = toolsInfo[functionName]["func"] functionArgs = toolCall.input functionResult = functionToCall(**functionArgs) functionResponse = functionResult.get("response") responseDisplay = functionResult.get("display") U.pprint(f"{functionResponse=}") if responseDisplay: __showToolResponse(responseDisplay) st.session_state.toolResponseDisplay = responseDisplay __addToolCallToMsgs({ "role": "assistant", "content": toolResponse }) st.session_state.messages.append({ "role": "user", "content": [{ "type": "tool_result", "tool_use_id": toolCall.id, "content": functionResponse, }], }) def __dedupeToolCalls(toolCalls: list): toolCallsDict = {} for toolCall in toolCalls: funcName = toolCall.name if isClaudeModel else toolCall.function.name toolCallsDict[funcName] = toolCall dedupedToolCalls = list(toolCallsDict.values()) if len(toolCalls) != len(dedupedToolCalls): U.pprint("Deduped tool calls!") U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}") return dedupedToolCalls def __getClaudeTools(): claudeTools = [] for tool in tools: funcInfo = tool["function"] name = funcInfo["name"] description = funcInfo["description"] schema = funcInfo["parameters"] claudeTools.append({ "name": name, "description": description, "input_schema": schema, }) return claudeTools def __removeFunctionCall(response: str): pattern = r'\{"query": ".*?"\}' return re.sub(pattern, '', response) def predict(model: str = None, attempts=0): model = model or MODEL messagesFormatted = [] try: if isClaudeModel: messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted, model) responseMessage = client.messages.create( model=model, messages=messagesFormatted, system=C.SYSTEM_MSG, temperature=0.5, max_tokens=4000, tools=__getClaudeTools() ) responseMessageContent = responseMessage.content responseContent = responseMessageContent[0].text toolCalls = [] if len(responseMessageContent) > 1: toolCalls = [responseMessageContent[1]] else: messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}] messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted, model) response = client.chat.completions.create( model=model, messages=messagesFormatted, temperature=0.6, max_tokens=4000, stream=False, tools=tools ) responseMessage = response.choices[0].message responseContent = responseMessage.content if responseContent and ' 1: [response, jsonStr] = responseParts imagePath = None # imageContainer = st.empty() # try: # (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) # if imagePrompt: # imgContainer = imageContainer.container() # imgContainer.write( # f""" #
# {loaderText} #
# """, # unsafe_allow_html=True # ) # # imgContainer.markdown(f"`{loaderText}`") # imgContainer.image(C.IMAGE_LOADER) # (imagePath, seed) = __generateImage(imagePrompt) # imageContainer.image(imagePath) # except Exception as e: # U.pprint(e) # imageContainer.empty() toolResponseDisplay = st.session_state.toolResponseDisplay st.session_state.chatHistory.append({ "role": "assistant", "content": response, "image": imagePath, "toolResponseDisplay": toolResponseDisplay }) st.session_state.messages.append({ "role": "assistant", "content": rawResponse, }) if jsonStr: try: jsonStr = jsonStr.replace("```", "") json.loads(jsonStr) jsonObj = json.loads(jsonStr) questions = jsonObj.get("questions") action = jsonObj.get("action") if questions: for option in questions: st.button( option["label"], key=option["id"], on_click=lambda label=option["label"]: selectButton(label) ) elif action: pass except Exception as e: U.pprint(e) # if st.button("Rerun"): # # __resetButtonState() # st.session_state.chatHistory = [] # st.session_state.messages = [] # st.rerun()