Spaces:
Runtime error
Runtime error
Change numpy code by jax.numpy
Browse files
app.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
-
import numpy as
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
|
|
|
|
|
|
|
|
|
5 |
st.title('Fitting simple models with JAX')
|
6 |
st.header('A quadratric regression example')
|
7 |
|
@@ -15,19 +20,18 @@ number_of_observations = st.sidebar.slider('Number of observations', min_value=5
|
|
15 |
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
|
16 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
17 |
|
18 |
-
|
19 |
-
|
20 |
X = np.column_stack((np.ones(number_of_observations),
|
21 |
np.random.random(number_of_observations)))
|
22 |
|
23 |
w = np.array([3.0, -20.0, 32.0]) # coefficients
|
24 |
-
|
25 |
X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
26 |
-
additional_noise = 8 *
|
27 |
-
y =
|
28 |
+ additional_noise
|
29 |
|
30 |
|
|
|
31 |
fig, ax = plt.subplots(dpi=320)
|
32 |
ax.set_xlim((0,1))
|
33 |
ax.set_ylim((-5,26))
|
@@ -46,6 +50,10 @@ st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{
|
|
46 |
|
47 |
# Fitting by the respective cost_function
|
48 |
if cost_function == 'RMSE-Loss':
|
|
|
|
|
|
|
|
|
49 |
st.write('You selected the RMSE loss function.')
|
50 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
|
51 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
|
|
|
1 |
import streamlit as st
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import jax
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
6 |
+
# Set random key
|
7 |
+
seed=321
|
8 |
+
key = jax.random.PRNGKey(seed)
|
9 |
+
|
10 |
st.title('Fitting simple models with JAX')
|
11 |
st.header('A quadratric regression example')
|
12 |
|
|
|
20 |
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
|
21 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
22 |
|
23 |
+
# Generate random data
|
|
|
24 |
X = np.column_stack((np.ones(number_of_observations),
|
25 |
np.random.random(number_of_observations)))
|
26 |
|
27 |
w = np.array([3.0, -20.0, 32.0]) # coefficients
|
|
|
28 |
X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
29 |
+
additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
|
30 |
+
y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
|
31 |
+ additional_noise
|
32 |
|
33 |
|
34 |
+
# Plot the data
|
35 |
fig, ax = plt.subplots(dpi=320)
|
36 |
ax.set_xlim((0,1))
|
37 |
ax.set_ylim((-5,26))
|
|
|
50 |
|
51 |
# Fitting by the respective cost_function
|
52 |
if cost_function == 'RMSE-Loss':
|
53 |
+
|
54 |
+
def loss(w):
|
55 |
+
return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2
|
56 |
+
|
57 |
st.write('You selected the RMSE loss function.')
|
58 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
|
59 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
|