thesis / backend /controller.py
LennardZuendorf's picture
fix: another set of attention fixes
1f063be unverified
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements
# external imports
import gradio as gr
# internal imports
from model import godel
from model import mistral
from explanation import (
attention as attention_viz,
interpret_shap as shap_int,
interpret_captum as cpt_int,
)
# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running normal chat with {model}.")
# formatting the prompt using the model's format_prompt function
prompt = model.format_prompt(message, history, system_prompt, knowledge)
# generating an answer using the model's respond function
answer = model.respond(prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history
return "", history
def explained_chat(
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running explained chat with {xai} with {model}.")
# formatting the prompt using the model's format_prompt function
# message, history, system_prompt, knowledge = mdl.prompt_limiter(
# message, history, system_prompt, knowledge
# )
prompt = model.format_prompt(message, history, system_prompt, knowledge)
# generating an answer using the methods chat function
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history, xai graphic and xai plot elements
return "", history, xai_graphic, xai_markup, xai_plot
# main interference function that calls chat functions depending on selections
def interference(
prompt: str,
history: list,
knowledge: str,
system_prompt: str,
xai_selection: str,
model_selection: str,
):
# if no proper system prompt is given, use a default one
if system_prompt in ("", " "):
system_prompt = (
"You are a helpful, respectful and honest assistant."
"Always answer as helpfully as possible, while being safe."
)
# if a model is selected, grab the model instance
if model_selection.lower() == "mistral":
model = mistral
print("Identified model as Mistral")
else:
model = godel
print("Identified model as GODEL")
# if a XAI approach is selected, grab the XAI module instance
# and call the explained chat function
if xai_selection in ("SHAP", "Attention"):
# matching selection
match xai_selection.lower():
case "shap":
if model_selection.lower() == "mistral":
xai = cpt_int
else:
xai = shap_int
case "attention":
xai = attention_viz
case _:
# use Gradio warning to display error message
gr.Warning(f"""
There was an error in the selected XAI Approach.
It is "{xai_selection}"
""")
# raise runtime exception
raise RuntimeError("There was an error in the selected XAI approach.")
# call the explained chat function with the model instance
prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
explained_chat(
model=model,
xai=xai,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
)
# if no XAI approach is selected call the vanilla chat function
else:
# calling the vanilla chat function
prompt_output, history_output = vanilla_chat(
model=model,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
# set XAI outputs to disclaimer html/none
xai_interactive, xai_markup, xai_plot = (
"""
<div style="text-align: center"><h4>Without Selected XAI Approach,
no graphic will be displayed</h4></div>
""",
[("", "")],
None,
)
# return the outputs
return prompt_output, history_output, xai_interactive, xai_markup, xai_plot