#| '!! shinylive warning !!': |
#| shinylive does not work in self-contained HTML documents.
#| Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 620
from shiny import App, render, ui, reactive
import numpy as np
import matplotlib.pyplot as plt
def vandermonde(t, n):
m = t.shape[0]
A = np.zeros((m, n))
for j in range(m):
for i in range(n):
A[j, i] = t[j]**i
return A
def rms(y):
return np.linalg.norm(y) / np.sqrt(y.shape[0])
np.random.seed(10)
N_train = 10
x_train = np.random.uniform(low=-3, high=3, size=N_train)
true = np.array([-3, -2.7, -2, 0.8, 0.5])
deg1 = true.shape[0]
y_train = vandermonde(x_train, deg1) @ true + np.random.normal(scale=3, size=N_train)
N_test = 20
x_test = np.random.uniform(low=-3, high=3, size=N_test)
y_test = vandermonde(x_test, deg1) @ true + np.random.normal(scale=3, size=N_test)
x_plot = np.linspace(np.min(np.concatenate([x_train, x_test])), np.max(np.concatenate([x_train, x_test])))
A_plot = vandermonde(x_plot, deg1)
y_true = A_plot @ true
L = 1000
lambdas = 10**np.linspace(-4, 7, L)
theta_hats = np.zeros((L, deg1))
A_train = vandermonde(x_train, deg1)
A_test = vandermonde(x_test, deg1)
rms_train = np.zeros(L)
rms_test = np.zeros(L)
J1 = np.zeros(L)
J2 = np.zeros(L)
for i in range(L):
lamb = lambdas[i]
A_reg = np.sqrt(lamb)*np.eye(deg1)
A_reg = A_reg[1:]
A_tilde = np.concatenate((A_train, A_reg), axis=0)
zeros = np.zeros(deg1-1)
y_tilde = np.concatenate((y_train, zeros))
theta_hat = np.linalg.pinv(A_tilde) @ y_tilde
rms_train[i] = rms(A_train @ theta_hat - y_train)
rms_test[i] = rms(A_test @ theta_hat - y_test)
theta_hats[i] = theta_hat
J1[i] = np.linalg.norm(A_train @ theta_hat - y_train) ** 2
J2[i] = np.linalg.norm(theta_hat[1:]) ** 2
app_ui = ui.page_fluid(
ui.layout_sidebar(
ui.sidebar(
ui.input_slider('lamb', '10^λ', min=-4, max=7, value=-4, step = 0.1)
),
ui.layout_columns(
ui.card(
ui.output_plot("mainplot", height="240px"),
ui.output_plot("pareto_plot", height="240px"),
),
ui.card(
ui.output_plot("shrinkage_plot", height="240px"),
ui.output_plot("error_plot", height="240px"),
),
),
),
)
def server(input, output, session):
@render.plot
def mainplot():
A_reg = np.sqrt(10**input.lamb())*np.eye(deg1)
A_reg = A_reg[1:]
A_tilde = np.concatenate((A_train, A_reg), axis=0)
zeros = np.zeros(deg1-1)
y_tilde = np.concatenate((y_train, zeros))
theta_hat = np.linalg.pinv(A_tilde) @ y_tilde
fig, ax = plt.subplots()
ax.scatter(x_train, y_train, label='Training')
ax.scatter(x_test, y_test, label='Test')
ax.plot(x_plot, y_true, c='black', label='True')
ax.plot(x_plot, A_plot @ theta_hat, c='red', label='Fitted', linewidth=4)
ax.legend()
return fig
# ========================================================================
@render.plot
def pareto_plot():
A_reg = np.sqrt(10**input.lamb())*np.eye(deg1)
A_reg = A_reg[1:]
A_tilde = np.concatenate((A_train, A_reg), axis=0)
zeros = np.zeros(deg1-1)
y_tilde = np.concatenate((y_train, zeros))
theta_hat = np.linalg.pinv(A_tilde) @ y_tilde
j1 = np.linalg.norm(A_train @ theta_hat - y_train) ** 2
j2 = np.linalg.norm(theta_hat[1:]) ** 2
fig, ax = plt.subplots()
ax.plot(J1, J2)
ax.vlines(j1, 0, j2, color='black')
ax.hlines(j2, 0, j1, color='black')
ax.set_xscale('log')
ax.set_xlabel('J1')
ax.set_ylabel('J2')
return fig
# ========================================================================
@render.plot
def shrinkage_plot():
fig, ax = plt.subplots()
cols = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
for i in range(deg1):
ax.plot(lambdas, theta_hats[:, i], label='$\\theta_{}$'.format(i+1))
ax.axhline(true[i], c=cols[i], linestyle='dashed')
ax.axvline(10**input.lamb(), c='black', linestyle='dashed')
ax.set_xlabel('$\\lambda$')
ax.set_ylabel('$\\theta$')
ax.set_xscale('log')
ax.legend()
return fig
# ========================================================================
@render.plot
def error_plot():
fig, ax = plt.subplots()
ax.plot(lambdas, rms_train, label='Train error')
ax.plot(lambdas, rms_test, label='Test error')
ax.axvline(10**input.lamb(), c='black', linestyle='dashed')
ax.set_xscale('log')
ax.set_xlabel('$\\lambda$')
ax.set_ylabel('RMS Error')
ax.legend()
return fig
app = App(app_ui, server)