Spaces:
Runtime error
Runtime error
Fix code
Browse files
app.py
CHANGED
@@ -49,18 +49,22 @@ st.markdown('Using gradient descent we find the minima of the loss adjusting the
|
|
49 |
|
50 |
st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{w})}{\partial \bf{w}}''')
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Fitting by the respective cost_function
|
54 |
w = jnp.array(np.random.random(3))
|
55 |
learning_rate = 0.05
|
56 |
NUM_ITER = 1000
|
57 |
|
58 |
-
fig, ax = plt.subplots(dpi=120)
|
59 |
-
ax.set_xlim((0,1))
|
60 |
-
ax.set_ylim((-5,26))
|
61 |
-
ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
|
62 |
-
st.pyplot(fig)
|
63 |
-
|
64 |
if cost_function == 'RMSE-Loss':
|
65 |
|
66 |
def loss(w):
|
@@ -92,12 +96,12 @@ if cost_function == 'RMSE-Loss':
|
|
92 |
# report the loss at the current epoch
|
93 |
status_text.text(
|
94 |
'Trained loss at epoch %s is %s' % (i, loss(w)))
|
95 |
-
|
96 |
-
# add a line to the plot
|
97 |
-
plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'c--')
|
98 |
-
st.pyplot(fig)
|
99 |
# Plot the final line
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
st.pyplot(fig)
|
102 |
status_text.text('Done!')
|
103 |
|
@@ -110,14 +114,3 @@ else:
|
|
110 |
(y^{(i)}-\hat{y}^{(i)})^2 & \text{for }\quad |y^{(i)}-\hat{y}^{(i)}|\leq \delta \\
|
111 |
2\delta|y^{(i)}-\hat{y}^{(i)}| - \delta^2 & \text{otherwise}
|
112 |
\end{cases}''')
|
113 |
-
|
114 |
-
st.markdown('The training loop:')
|
115 |
-
|
116 |
-
code = '''NUM_ITER = 1000
|
117 |
-
# initialize parameters
|
118 |
-
w = np.array([3., -2., -8.])
|
119 |
-
for i in range(NUM_ITER):
|
120 |
-
# update parameters
|
121 |
-
w -= learning_rte * grad_loss(w)'''
|
122 |
-
|
123 |
-
st.code(code, language='python')
|
|
|
49 |
|
50 |
st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{w})}{\partial \bf{w}}''')
|
51 |
|
52 |
+
st.markdown('The training loop:')
|
53 |
+
|
54 |
+
code = '''NUM_ITER = 1000
|
55 |
+
# initialize parameters
|
56 |
+
w = np.array([3., -2., -8.])
|
57 |
+
for i in range(NUM_ITER):
|
58 |
+
# update parameters
|
59 |
+
w -= learning_rte * grad_loss(w)'''
|
60 |
+
|
61 |
+
st.code(code, language='python')
|
62 |
|
63 |
# Fitting by the respective cost_function
|
64 |
w = jnp.array(np.random.random(3))
|
65 |
learning_rate = 0.05
|
66 |
NUM_ITER = 1000
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if cost_function == 'RMSE-Loss':
|
69 |
|
70 |
def loss(w):
|
|
|
96 |
# report the loss at the current epoch
|
97 |
status_text.text(
|
98 |
'Trained loss at epoch %s is %s' % (i, loss(w)))
|
|
|
|
|
|
|
|
|
99 |
# Plot the final line
|
100 |
+
fig, ax = plt.subplots(dpi=120)
|
101 |
+
ax.set_xlim((0,1))
|
102 |
+
ax.set_ylim((-5,26))
|
103 |
+
ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
|
104 |
+
ax.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'k-', label='Final line')
|
105 |
st.pyplot(fig)
|
106 |
status_text.text('Done!')
|
107 |
|
|
|
114 |
(y^{(i)}-\hat{y}^{(i)})^2 & \text{for }\quad |y^{(i)}-\hat{y}^{(i)}|\leq \delta \\
|
115 |
2\delta|y^{(i)}-\hat{y}^{(i)}| - \delta^2 & \text{otherwise}
|
116 |
\end{cases}''')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|