• Home
  • Teaching visualisations
    • VMLS: Least squares
    • VMLS: Lighting problem
    • VMLS: Regularisation
    • VMLS: Constrained least squares
  • Teaching tools
    • canvasapi
    • OTA tool
  • CV

VMLS: Regularisation

#| '!! shinylive warning !!': |
#|   shinylive does not work in self-contained HTML documents.
#|   Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 620

from shiny import App, render, ui, reactive
import numpy as np
import matplotlib.pyplot as plt

def vandermonde(t, n):
    m = t.shape[0]
    A = np.zeros((m, n))

    for j in range(m):
        for i in range(n):
            A[j, i] = t[j]**i

    return A

def rms(y):
    return np.linalg.norm(y) / np.sqrt(y.shape[0])

np.random.seed(10)
N_train = 10
x_train = np.random.uniform(low=-3, high=3, size=N_train)
true = np.array([-3, -2.7, -2, 0.8, 0.5])
deg1 = true.shape[0]
y_train = vandermonde(x_train, deg1) @ true + np.random.normal(scale=3, size=N_train)

N_test = 20
x_test = np.random.uniform(low=-3, high=3, size=N_test)
y_test = vandermonde(x_test, deg1) @ true + np.random.normal(scale=3, size=N_test)

x_plot = np.linspace(np.min(np.concatenate([x_train, x_test])), np.max(np.concatenate([x_train, x_test])))
A_plot = vandermonde(x_plot, deg1)
y_true = A_plot @ true

L = 1000
lambdas = 10**np.linspace(-4, 7, L)
theta_hats = np.zeros((L, deg1))

A_train = vandermonde(x_train, deg1)
A_test = vandermonde(x_test, deg1)

rms_train = np.zeros(L)
rms_test = np.zeros(L)
J1 = np.zeros(L)
J2 = np.zeros(L)

for i in range(L):
    lamb = lambdas[i]
    A_reg = np.sqrt(lamb)*np.eye(deg1)
    A_reg = A_reg[1:]
    A_tilde = np.concatenate((A_train, A_reg), axis=0)
    zeros = np.zeros(deg1-1)
    y_tilde = np.concatenate((y_train, zeros))
    theta_hat = np.linalg.pinv(A_tilde) @ y_tilde

    rms_train[i] = rms(A_train @ theta_hat - y_train)
    rms_test[i] = rms(A_test @ theta_hat - y_test)

    theta_hats[i] = theta_hat

    J1[i] = np.linalg.norm(A_train @ theta_hat - y_train) ** 2
    J2[i] = np.linalg.norm(theta_hat[1:]) ** 2

app_ui = ui.page_fluid(
    ui.layout_sidebar(
        ui.sidebar(
            ui.input_slider('lamb', '10^λ', min=-4, max=7, value=-4, step = 0.1)
        ),
        ui.layout_columns(
        ui.card(
            ui.output_plot("mainplot", height="240px"),
            ui.output_plot("pareto_plot", height="240px"), 
        ),
        ui.card(
            ui.output_plot("shrinkage_plot", height="240px"),
            ui.output_plot("error_plot", height="240px"), 
        ),
        ),
    ),
)


def server(input, output, session):

    @render.plot
    def mainplot():
        A_reg = np.sqrt(10**input.lamb())*np.eye(deg1)
        A_reg = A_reg[1:]
        A_tilde = np.concatenate((A_train, A_reg), axis=0)
        zeros = np.zeros(deg1-1)
        y_tilde = np.concatenate((y_train, zeros))
        theta_hat = np.linalg.pinv(A_tilde) @ y_tilde

        fig, ax = plt.subplots()
        ax.scatter(x_train, y_train, label='Training')
        ax.scatter(x_test, y_test, label='Test')
        ax.plot(x_plot, y_true, c='black', label='True')
        ax.plot(x_plot, A_plot @ theta_hat, c='red', label='Fitted', linewidth=4)
        ax.legend()

        return fig

    # ========================================================================

    @render.plot
    def pareto_plot():
        A_reg = np.sqrt(10**input.lamb())*np.eye(deg1)
        A_reg = A_reg[1:]
        A_tilde = np.concatenate((A_train, A_reg), axis=0)
        zeros = np.zeros(deg1-1)
        y_tilde = np.concatenate((y_train, zeros))
        theta_hat = np.linalg.pinv(A_tilde) @ y_tilde

        j1 = np.linalg.norm(A_train @ theta_hat - y_train) ** 2
        j2 = np.linalg.norm(theta_hat[1:]) ** 2

        fig, ax = plt.subplots()
        ax.plot(J1, J2)
        ax.vlines(j1, 0, j2, color='black')
        ax.hlines(j2, 0, j1, color='black')
        ax.set_xscale('log')
        ax.set_xlabel('J1')
        ax.set_ylabel('J2')
        return fig

    # ========================================================================

    @render.plot
    def shrinkage_plot():
        fig, ax = plt.subplots()
        cols = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
        for i in range(deg1):
            ax.plot(lambdas, theta_hats[:, i], label='$\\theta_{}$'.format(i+1))
            ax.axhline(true[i], c=cols[i], linestyle='dashed')

        ax.axvline(10**input.lamb(), c='black', linestyle='dashed')
        ax.set_xlabel('$\\lambda$')
        ax.set_ylabel('$\\theta$')
        ax.set_xscale('log')
        ax.legend()
        return fig

    # ========================================================================

    @render.plot
    def error_plot():
        fig, ax = plt.subplots()
        ax.plot(lambdas, rms_train, label='Train error')
        ax.plot(lambdas, rms_test, label='Test error')
        ax.axvline(10**input.lamb(), c='black', linestyle='dashed')
        ax.set_xscale('log')
        ax.set_xlabel('$\\lambda$')
        ax.set_ylabel('RMS Error')
        ax.legend()
        return fig


app = App(app_ui, server)
 
 

This page is built with Quarto.