Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
69b34c4
1
Parent(s):
1a96c54
fix: cleanup build config again, reverting reverted changes
Browse files- Dockerfile +5 -5
- Dockerfile-Base +2 -2
- README.md +2 -2
- backend/controller.py +0 -1
- entrypoint.sh +0 -8
- explanation/interpret.py +21 -27
- explanation/visualize.py +30 -23
- main.py +17 -5
Dockerfile
CHANGED
@@ -3,8 +3,8 @@
|
|
3 |
# complete build based on clean python (slower)
|
4 |
#FROM python:3.11.6
|
5 |
|
6 |
-
# build based on
|
7 |
-
FROM thesis:0.
|
8 |
|
9 |
# install dependencies and copy files into image folder
|
10 |
COPY requirements.txt .
|
@@ -16,8 +16,8 @@ COPY . .
|
|
16 |
RUN ls --recursive .
|
17 |
|
18 |
# setting config and run command
|
19 |
-
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "
|
20 |
|
21 |
# build and run commands:
|
22 |
-
## docker build -t thesis:0.
|
23 |
-
## docker run -d --name thesis -p 8080:8080 thesis:0.
|
|
|
3 |
# complete build based on clean python (slower)
|
4 |
#FROM python:3.11.6
|
5 |
|
6 |
+
# build based on thesis base with dependencies (quicker) - for dev
|
7 |
+
FROM thesis-base:0.1.1
|
8 |
|
9 |
# install dependencies and copy files into image folder
|
10 |
COPY requirements.txt .
|
|
|
16 |
RUN ls --recursive .
|
17 |
|
18 |
# setting config and run command
|
19 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
20 |
|
21 |
# build and run commands:
|
22 |
+
## docker build -t thesis:0.1.4 -f Dockerfile .
|
23 |
+
## docker run -d --name thesis -p 8080:8080 thesis:0.1.4
|
Dockerfile-Base
CHANGED
@@ -2,11 +2,11 @@
|
|
2 |
# because all dependencies are already installed, the next webapp build using this base image is much quicker
|
3 |
|
4 |
# using newest python as a base image
|
5 |
-
FROM
|
6 |
|
7 |
# install dependencies based on requirements
|
8 |
COPY requirements.txt ./
|
9 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
10 |
|
11 |
# build and run commands
|
12 |
-
## docker build -t thesis:0.1
|
|
|
2 |
# because all dependencies are already installed, the next webapp build using this base image is much quicker
|
3 |
|
4 |
# using newest python as a base image
|
5 |
+
FROM thesis-base:0.1.1
|
6 |
|
7 |
# install dependencies based on requirements
|
8 |
COPY requirements.txt ./
|
9 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
10 |
|
11 |
# build and run commands
|
12 |
+
## docker build -t thesis-base:1.0.1 -f Dockerfile-Base .
|
README.md
CHANGED
@@ -3,12 +3,12 @@ title: Thesis
|
|
3 |
emoji: 🎓
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
-
sdk:
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: main.py
|
9 |
pinned: true
|
10 |
license: mit
|
11 |
-
app_port:
|
12 |
---
|
13 |
|
14 |
# Bachelor Thesis
|
|
|
3 |
emoji: 🎓
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: main.py
|
9 |
pinned: true
|
10 |
license: mit
|
11 |
+
app_port: 8080
|
12 |
---
|
13 |
|
14 |
# Bachelor Thesis
|
backend/controller.py
CHANGED
@@ -10,7 +10,6 @@ from explanation import interpret, visualize
|
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
13 |
-
# TODO: Limit maximum tokens/model input
|
14 |
def interference(
|
15 |
prompt: str,
|
16 |
history: list,
|
|
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
|
|
13 |
def interference(
|
14 |
prompt: str,
|
15 |
history: list,
|
entrypoint.sh
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
# entrypoint script for the docker container to run at start
|
3 |
-
|
4 |
-
# installing all the dependencies
|
5 |
-
pip install --no-cache-dir --upgrade -r requirements.txt
|
6 |
-
|
7 |
-
# running the fastapi app
|
8 |
-
uvicorn main:app --host 0.0.0.0 --port 8080
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/interpret.py
CHANGED
@@ -62,35 +62,29 @@ def create_graphic(shap_values):
|
|
62 |
return str(graphic_html)
|
63 |
|
64 |
|
65 |
-
#
|
|
|
|
|
66 |
def create_plot(shap_values):
|
67 |
values = shap_values.values[0]
|
68 |
output_names = shap_values.output_names
|
69 |
input_names = shap_values.data[0]
|
70 |
|
71 |
-
# Transpose the values for horizontal input names
|
72 |
-
transposed_values = np.transpose(values)
|
73 |
-
|
74 |
# Set seaborn style to dark
|
75 |
-
sns.set(style="
|
76 |
-
|
77 |
fig, ax = plt.subplots()
|
78 |
|
79 |
-
# Making background transparent
|
80 |
-
ax.set_alpha(0)
|
81 |
-
fig.patch.set_alpha(0)
|
82 |
-
|
83 |
# Setting figure size
|
84 |
fig.set_size_inches(
|
85 |
-
max(
|
86 |
-
max(
|
87 |
)
|
88 |
|
89 |
# Plotting the heatmap with Seaborn's color palette
|
90 |
im = ax.imshow(
|
91 |
-
|
92 |
-
vmax=
|
93 |
-
vmin
|
94 |
cmap=sns.color_palette("vlag_r", as_cmap=True),
|
95 |
aspect="auto",
|
96 |
)
|
@@ -98,25 +92,25 @@ def create_plot(shap_values):
|
|
98 |
# Creating colorbar
|
99 |
cbar = ax.figure.colorbar(im, ax=ax)
|
100 |
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
|
101 |
-
cbar.ax.yaxis.set_tick_params(color="
|
102 |
-
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="
|
103 |
|
104 |
# Setting ticks and labels with white color for visibility
|
105 |
-
ax.
|
106 |
-
ax.
|
107 |
-
plt.setp(ax.get_xticklabels(), color="
|
108 |
-
plt.setp(ax.get_yticklabels(), color="
|
109 |
|
110 |
# Adjusting tick labels
|
111 |
ax.tick_params(
|
112 |
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
|
113 |
)
|
114 |
|
115 |
-
# Adding text annotations
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
return plt
|
|
|
62 |
return str(graphic_html)
|
63 |
|
64 |
|
65 |
+
# creating an attention heatmap plot using matplotlib/seaborn
|
66 |
+
# CREDIT: adopted from official Matplotlib documentation
|
67 |
+
## see https://matplotlib.org/stable/
|
68 |
def create_plot(shap_values):
|
69 |
values = shap_values.values[0]
|
70 |
output_names = shap_values.output_names
|
71 |
input_names = shap_values.data[0]
|
72 |
|
|
|
|
|
|
|
73 |
# Set seaborn style to dark
|
74 |
+
sns.set(style="white")
|
|
|
75 |
fig, ax = plt.subplots()
|
76 |
|
|
|
|
|
|
|
|
|
77 |
# Setting figure size
|
78 |
fig.set_size_inches(
|
79 |
+
max(values.shape[1] * 2, 10),
|
80 |
+
max(values.shape[0] * 1, 5),
|
81 |
)
|
82 |
|
83 |
# Plotting the heatmap with Seaborn's color palette
|
84 |
im = ax.imshow(
|
85 |
+
values,
|
86 |
+
vmax=values.max(),
|
87 |
+
vmin=values.min(),
|
88 |
cmap=sns.color_palette("vlag_r", as_cmap=True),
|
89 |
aspect="auto",
|
90 |
)
|
|
|
92 |
# Creating colorbar
|
93 |
cbar = ax.figure.colorbar(im, ax=ax)
|
94 |
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
|
95 |
+
cbar.ax.yaxis.set_tick_params(color="black")
|
96 |
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
|
97 |
|
98 |
# Setting ticks and labels with white color for visibility
|
99 |
+
ax.set_yticks(np.arange(len(input_names)), labels=input_names)
|
100 |
+
ax.set_xticks(np.arange(len(output_names)), labels=output_names)
|
101 |
+
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
|
102 |
+
plt.setp(ax.get_yticklabels(), color="black")
|
103 |
|
104 |
# Adjusting tick labels
|
105 |
ax.tick_params(
|
106 |
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
|
107 |
)
|
108 |
|
109 |
+
# Adding text annotations with appropriate contrast
|
110 |
+
for i in range(values.shape[0]):
|
111 |
+
for j in range(values.shape[1]):
|
112 |
+
val = values[i, j]
|
113 |
+
color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
|
114 |
+
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
|
115 |
|
116 |
return plt
|
explanation/visualize.py
CHANGED
@@ -57,28 +57,27 @@ def create_graphic(attention_output, enc_dec_texts: tuple):
|
|
57 |
return str(hview.data)
|
58 |
|
59 |
|
60 |
-
# creating an attention heatmap plot using seaborn
|
|
|
|
|
61 |
def create_plot(attention_output, enc_dec_texts: tuple):
|
62 |
# get the averaged attention weights
|
63 |
attention = attention_output.cross_attentions[0][0].detach().numpy()
|
64 |
averaged_attention_weights = np.mean(attention, axis=0)
|
|
|
65 |
|
66 |
-
# get the encoder and decoder tokens
|
67 |
encoder_tokens = enc_dec_texts[0]
|
68 |
decoder_tokens = enc_dec_texts[1]
|
69 |
|
70 |
# set seaborn style to dark and initialize figure and axis
|
71 |
-
sns.set(style="
|
72 |
fig, ax = plt.subplots()
|
73 |
|
74 |
-
# Making background transparent
|
75 |
-
ax.set_alpha(0)
|
76 |
-
fig.patch.set_alpha(0)
|
77 |
-
|
78 |
# Setting figure size
|
79 |
fig.set_size_inches(
|
80 |
max(averaged_attention_weights.shape[1] * 2, 10),
|
81 |
-
max(averaged_attention_weights.shape[0]
|
82 |
)
|
83 |
|
84 |
# Plotting the heatmap with seaborn's color palette
|
@@ -92,19 +91,27 @@ def create_plot(attention_output, enc_dec_texts: tuple):
|
|
92 |
|
93 |
# Creating colorbar
|
94 |
cbar = ax.figure.colorbar(im, ax=ax)
|
95 |
-
cbar.ax.set_ylabel("
|
96 |
-
cbar.ax.yaxis.set_tick_params(color="
|
97 |
-
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="
|
98 |
-
|
99 |
-
# Setting ticks and labels with
|
100 |
-
ax.
|
101 |
-
ax.
|
102 |
-
|
103 |
-
plt.setp(ax.
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
return plt
|
|
|
57 |
return str(hview.data)
|
58 |
|
59 |
|
60 |
+
# creating an attention heatmap plot using matplotlib/seaborn
|
61 |
+
# CREDIT: adopted from official Matplotlib documentation
|
62 |
+
## see https://matplotlib.org/stable/
|
63 |
def create_plot(attention_output, enc_dec_texts: tuple):
|
64 |
# get the averaged attention weights
|
65 |
attention = attention_output.cross_attentions[0][0].detach().numpy()
|
66 |
averaged_attention_weights = np.mean(attention, axis=0)
|
67 |
+
averaged_attention_weights = np.transpose(averaged_attention_weights)
|
68 |
|
69 |
+
# get the encoder and decoder tokens in text form
|
70 |
encoder_tokens = enc_dec_texts[0]
|
71 |
decoder_tokens = enc_dec_texts[1]
|
72 |
|
73 |
# set seaborn style to dark and initialize figure and axis
|
74 |
+
sns.set(style="white")
|
75 |
fig, ax = plt.subplots()
|
76 |
|
|
|
|
|
|
|
|
|
77 |
# Setting figure size
|
78 |
fig.set_size_inches(
|
79 |
max(averaged_attention_weights.shape[1] * 2, 10),
|
80 |
+
max(averaged_attention_weights.shape[0] * 1, 5),
|
81 |
)
|
82 |
|
83 |
# Plotting the heatmap with seaborn's color palette
|
|
|
91 |
|
92 |
# Creating colorbar
|
93 |
cbar = ax.figure.colorbar(im, ax=ax)
|
94 |
+
cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
|
95 |
+
cbar.ax.yaxis.set_tick_params(color="black")
|
96 |
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
|
97 |
+
|
98 |
+
# Setting ticks and labels with black color for visibility
|
99 |
+
ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
|
100 |
+
ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
|
101 |
+
ax.set_title("Attention Weights by Token")
|
102 |
+
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
|
103 |
+
plt.setp(ax.get_yticklabels(), color="black")
|
104 |
+
|
105 |
+
# Adding text annotations with appropriate contrast
|
106 |
+
for i in range(averaged_attention_weights.shape[0]):
|
107 |
+
for j in range(averaged_attention_weights.shape[1]):
|
108 |
+
val = averaged_attention_weights[i, j]
|
109 |
+
color = (
|
110 |
+
"white"
|
111 |
+
if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
|
112 |
+
else "black"
|
113 |
+
)
|
114 |
+
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
|
115 |
+
|
116 |
+
# return the plot
|
117 |
return plt
|
main.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
# main application file initializing the gradio based ui and calling other
|
|
|
|
|
|
|
|
|
2 |
# external imports
|
3 |
from fastapi import FastAPI
|
4 |
import markdown
|
5 |
import gradio as gr
|
|
|
|
|
6 |
|
7 |
# internal imports
|
8 |
from backend.controller import interference
|
9 |
|
|
|
10 |
# Global Variables and css
|
11 |
app = FastAPI()
|
12 |
css = "body {text-align: start !important;}"
|
@@ -187,7 +194,7 @@ with gr.Blocks(
|
|
187 |
Values have been excluded for readability. See colorbar for value indication.
|
188 |
""")
|
189 |
# plot component that takes a matplotlib figure as input
|
190 |
-
xai_plot = gr.Plot(label="Token Level Explanation"
|
191 |
|
192 |
# functions to trigger the controller
|
193 |
## takes information for the chat and the xai selection
|
@@ -207,16 +214,21 @@ with gr.Blocks(
|
|
207 |
|
208 |
# final row to show legal information
|
209 |
## - credits, data protection and link to the License
|
210 |
-
with gr.Tab(label="
|
211 |
-
gr.Markdown(value=load_md("public/
|
|
|
|
|
212 |
|
213 |
# mount function for fastAPI Application
|
214 |
app = gr.mount_gradio_app(app, ui, path="/")
|
215 |
|
216 |
# launch function using uvicorn to launch the fastAPI application
|
217 |
if __name__ == "__main__":
|
218 |
-
from uvicorn import run
|
219 |
|
220 |
-
#
|
|
|
|
|
|
|
|
|
221 |
## for local development, uses Docker for Prod deployment
|
222 |
run("main:app", port=8080, reload=True)
|
|
|
1 |
# main application file initializing the gradio based ui and calling other
|
2 |
+
|
3 |
+
# standard imports
|
4 |
+
import os
|
5 |
+
|
6 |
# external imports
|
7 |
from fastapi import FastAPI
|
8 |
import markdown
|
9 |
import gradio as gr
|
10 |
+
from uvicorn import run
|
11 |
+
|
12 |
|
13 |
# internal imports
|
14 |
from backend.controller import interference
|
15 |
|
16 |
+
|
17 |
# Global Variables and css
|
18 |
app = FastAPI()
|
19 |
css = "body {text-align: start !important;}"
|
|
|
194 |
Values have been excluded for readability. See colorbar for value indication.
|
195 |
""")
|
196 |
# plot component that takes a matplotlib figure as input
|
197 |
+
xai_plot = gr.Plot(label="Token Level Explanation")
|
198 |
|
199 |
# functions to trigger the controller
|
200 |
## takes information for the chat and the xai selection
|
|
|
214 |
|
215 |
# final row to show legal information
|
216 |
## - credits, data protection and link to the License
|
217 |
+
with gr.Tab(label="About"):
|
218 |
+
gr.Markdown(value=load_md("public/about.md"))
|
219 |
+
with gr.Accordion(label="Credits, Data Protection, License"):
|
220 |
+
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
221 |
|
222 |
# mount function for fastAPI Application
|
223 |
app = gr.mount_gradio_app(app, ui, path="/")
|
224 |
|
225 |
# launch function using uvicorn to launch the fastAPI application
|
226 |
if __name__ == "__main__":
|
|
|
227 |
|
228 |
+
# use standard gradio launch option for hgf spaces
|
229 |
+
if os.environ["HOSTING"].lower() == "spaces":
|
230 |
+
ui.launch(auth=("htw", "berlin@123"))
|
231 |
+
|
232 |
+
# otherwise run the application on port 8080 in reload mode
|
233 |
## for local development, uses Docker for Prod deployment
|
234 |
run("main:app", port=8080, reload=True)
|