alkzar90 commited on
Commit
b5beaf9
·
1 Parent(s): 80f8995

Add training loop for RMSE-loss

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -51,6 +51,15 @@ st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{
51
 
52
 
53
  # Fitting by the respective cost_function
 
 
 
 
 
 
 
 
 
54
  if cost_function == 'RMSE-Loss':
55
 
56
  def loss(w):
@@ -60,6 +69,33 @@ if cost_function == 'RMSE-Loss':
60
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
61
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
62
  st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  else:
64
  st.write("You selected the Huber loss function.")
65
  st.latex(r'''
 
51
 
52
 
53
  # Fitting by the respective cost_function
54
+ w = jnp.array(np.random.random(3))
55
+ learning_rate = 0.05
56
+ NUM_ITER = 1000
57
+
58
+ fig, ax = plt.subplots(dpi=120)
59
+ ax.set_xlim((0,1))
60
+ ax.set_ylim((-5,26))
61
+ ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
62
+
63
  if cost_function == 'RMSE-Loss':
64
 
65
  def loss(w):
 
69
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
70
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
71
  st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
72
+
73
+
74
+ progress_bar = st.progress(0)
75
+ status_text = st.empty()
76
+ grad_loss = jax.grad(loss)
77
+
78
+ # Perform gradient descent
79
+ for i in range(NUM_ITER):
80
+ # Update progress bar.
81
+ progress_bar.progress(i + 1)
82
+
83
+ # Update parameters.
84
+ w -= learning_rate * grad_loss(w)
85
+
86
+ # Update status text.
87
+ if (i+1)%100==0:
88
+ status_text.text(
89
+ 'Trained loss is: %s' % loss(w))
90
+
91
+ # add a line to the plot
92
+ plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--')
93
+ print(f"Trained loss at epoch #{i}:", loss(w))
94
+ plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--', label='Final line')
95
+
96
+ status_text.text('Done!')
97
+
98
+
99
  else:
100
  st.write("You selected the Huber loss function.")
101
  st.latex(r'''