|
import gradio as gr |
|
import os |
|
import json |
|
|
|
from groq import Groq |
|
from search import answer_query |
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv(dotenv_path="./.env") |
|
except: |
|
pass |
|
|
|
client = Groq( |
|
api_key=os.environ.get("GROQ_API_KEY"), |
|
) |
|
|
|
tools = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "get_related_functions", |
|
"description": "Get docstrings for internal functions for any library on PyPi.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"user_query": { |
|
"type": "string", |
|
"description": "A query to retrieve docstrings and find useful information.", |
|
} |
|
}, |
|
"required": ["user_query"], |
|
}, |
|
}, |
|
} |
|
] |
|
|
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def get_related_functions(user_query: str) -> dict: |
|
docstring_top10 = answer_query(user_query) |
|
print("added torch mul") |
|
return docstring_top10[0] |
|
|
|
|
|
def generate_rag(history): |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a function calling LLM that uses the data extracted from the get_related_functions function to answer questions around writing Python code. Use the extraced docstrings to write better code." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": history[-1][0], |
|
} |
|
] |
|
history[-1][1] = "" |
|
tool_call_count = 0 |
|
max_tool_calls = 3 |
|
while tool_call_count <= max_tool_calls: |
|
response = client.chat.completions.create( |
|
model="llama3-70b-8192", |
|
messages=messages, |
|
tools=tools if tool_call_count < 3 else None, |
|
tool_choice="auto", |
|
max_tokens=4096 |
|
) |
|
tool_call_count += 1 |
|
response_message = response.choices[0].message |
|
tool_calls = response_message.tool_calls |
|
|
|
if tool_calls: |
|
available_functions = { |
|
"get_related_functions": get_related_functions, |
|
} |
|
messages.append(response_message) |
|
|
|
for tool_call in tool_calls: |
|
function_name = tool_call.function.name |
|
function_to_call = available_functions[function_name] |
|
function_args = json.loads(tool_call.function.arguments) |
|
function_response = function_to_call( |
|
user_query=function_args.get("user_query") |
|
) |
|
messages.append( |
|
{ |
|
"tool_call_id": tool_call.id, |
|
"role": "tool", |
|
"name": function_name, |
|
"content": function_response, |
|
} |
|
) |
|
else: |
|
break |
|
|
|
history[-1][1] += response_message.content |
|
return history |
|
|
|
|
|
def generate_llama3(history): |
|
history[-1][1] = "" |
|
stream = client.chat.completions.create( |
|
messages=[ |
|
|
|
|
|
|
|
{ |
|
"role": "system", |
|
"content": "you are a helpful assistant." |
|
}, |
|
|
|
{ |
|
"role": "user", |
|
"content": history[-1][0], |
|
} |
|
], |
|
stream=True, |
|
model="llama3-8b-8192", |
|
max_tokens=1024, |
|
temperature=0 |
|
) |
|
|
|
for chunk in stream: |
|
if chunk.choices[0].delta.content != None: |
|
history[-1][1] += chunk.choices[0].delta.content |
|
yield history |
|
else: |
|
return |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("# Mongoose Miner Search Demo") |
|
gr.Markdown( |
|
"Augmenting LLM code generation with function-level search across all of PyPi.") |
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot(height="35rem", label="Llama3 unaugmented") |
|
chatbot2 = gr.Chatbot( |
|
height="35rem", label="Llama3 with MongooseMiner Search") |
|
msg = gr.Textbox() |
|
|
|
clear = gr.Button("Clear") |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
generate_llama3, chatbot, chatbot |
|
) |
|
msg.submit(user, [msg, chatbot2], [msg, chatbot2], queue=False).then( |
|
generate_rag, chatbot2, chatbot2 |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
clear.click(lambda: None, None, chatbot2, queue=False) |
|
|
|
|
|
demo.queue() |
|
demo.launch() |
|
|