Spaces:
Runtime error
Runtime error
File size: 4,533 Bytes
1fcb538 d119a57 cc456a2 9131bec 1fcb538 cc456a2 931207e ba93ad8 61a9e50 617fdc7 1f08540 4c6f386 cc7b71f 96f9a87 cc456a2 80f8995 d119a57 9e06d30 617fdc7 d119a57 cc456a2 96f9a87 9131bec bea8b09 96f9a87 9131bec 16e4f76 e741bd4 cd42de6 e741bd4 16e4f76 b5beaf9 16e4f76 cc456a2 16e4f76 b5beaf9 4555e57 b5beaf9 4555e57 3cff897 fdfc503 3cff897 4555e57 fdfc503 4555e57 8869fec 4555e57 cd42de6 8241a76 4555e57 b5beaf9 7bc22ec 16e4f76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
# Set random key
seed=321
key = jax.random.PRNGKey(seed)
st.title('Fitting simple models with JAX')
st.header('A quadratric regression example')
st.markdown('*\"Parametrised models are simply functions that depend on inputs and trainable parameters. There is no fundamental difference between the two, except that trainable parameters are shared across training samples whereas the input varies from sample to sample.\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)')
st.latex(r'''h(\boldsymbol x, \boldsymbol w)= \sum_{k=1}^{K}\boldsymbol w_{k} \phi_{k}(\boldsymbol x)''')
# Sidebar inputs
number_of_observations = st.sidebar.slider('Number of observations', min_value=50, max_value=150, value=100)
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
# Generate random data
np.random.seed(2)
w = jnp.array([3.0, -20.0, 32.0]) # coefficients
X = np.column_stack((np.ones(number_of_observations),
np.random.random(number_of_observations)))
X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
+ additional_noise)
# Plot the data
fig, ax = plt.subplots(dpi=320)
ax.set_xlim((0,1))
ax.set_ylim((-5,26))
ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
st.pyplot(fig)
st.subheader('Train a model')
st.markdown('*\"A Gradient Based Method is a method/algorithm that finds the minima of a function, assuming that one can easily compute the gradient of that function. It assumes that the function is continuous and differentiable almost everywhere (it need not be differentiable everywhere).\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)')
st.markdown('Using gradient descent we find the minima of the loss adjusting the weights in each step given the following formula:')
st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{w})}{\partial \bf{w}}''')
st.markdown('The training loop:')
code = '''NUM_ITER = 1000
# initialize parameters
w = np.array([3., -2., -8.])
for i in range(NUM_ITER):
# update parameters
w -= learning_rte * grad_loss(w)'''
st.code(code, language='python')
# Fitting by the respective cost_function
w = jnp.array(np.random.random(3))
learning_rate = 0.05
NUM_ITER = 1000
if cost_function == 'RMSE-Loss':
def loss(w):
return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2
st.write('You selected the RMSE loss function.')
st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
progress_bar = st.progress(0)
status_text = st.empty()
grad_loss = jax.grad(loss)
# Perform gradient descent
progress_counter = 0
for i in range(1,NUM_ITER+1):
if i %10==0:
# Update progress bar.
progress_counter += 1
progress_bar.progress(progress_counter)
# Update parameters.
w -= learning_rate * grad_loss(w)
# Update status text.
if (i)%100==0:
# report the loss at the current epoch
status_text.text(
'Trained loss at epoch %s is %s' % (i, loss(w)))
# Plot the final line
fig, ax = plt.subplots(dpi=120)
ax.set_xlim((0,1))
ax.set_ylim((-5,26))
ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
ax.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'k-', label='Final line')
st.pyplot(fig)
status_text.text('Done!')
else:
st.write("You selected the Huber loss function.")
st.latex(r'''
\ell_{H} =
\begin{cases}
(y^{(i)}-\hat{y}^{(i)})^2 & \text{for }\quad |y^{(i)}-\hat{y}^{(i)}|\leq \delta \\
2\delta|y^{(i)}-\hat{y}^{(i)}| - \delta^2 & \text{otherwise}
\end{cases}''')
|