File size: 1,230 Bytes
1fcb538
9131bec
 
1fcb538
931207e
 
ba93ad8
617fdc7
 
 
f81dc6c
57301c2
96f9a87
 
 
9131bec
96f9a87
 
 
432ab81
96f9a87
 
 
 
 
617fdc7
96f9a87
9131bec
 
96f9a87
 
9131bec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt

st.title('Fitting simple models with JAX')
st.header('A quadratric regression example')

st.markdown('**This is a simple text** to specify the goal of this simple data app\n.')
st.latex('h(\boldsymbol x, \boldsymbol w)= \sum_{k=1}^{K}\boldsymbol w_{k} \phi_{k}(\boldsymbol x)'

number_of_observations = st.sidebar.slider('Number of observations', min_value=50, max_value=150, value=50)
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=0.25)

np.random.seed(2)

X = np.column_stack((np.ones(number_of_observations), 
                     np.random.random(number_of_observations)))      

w = np.array([3.0, -20.0, 32.0])  # coefficients                                    

X = np.column_stack((X, X[:,1] ** 2))   # add x**2 column
additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
y = np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
        + additional_noise	


fig, ax = plt.subplots(dpi=320)
ax.set_xlim((0,1))
ax.set_ylim((-5,20))
ax.scatter(X[:,1], y, c='r', edgecolors='black')

st.pyplot(fig)
st.write(X[:5, :])