Spectral embedding and clustering

Data $s_1,\ldots,s_N$

To perform our analysis this time, we will use the MNIST handwritten digit dataset.

In [1]:
% matplotlib notebook

from sklearn.datasets import fetch_openml
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

#rcParams['font.size'] = 30

# Load MNIST data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
In [2]:
% matplotlib notebook

np.random.seed(1)

num_samples = 2000

# Select num_samples random MNIST digits to analyze.
data = X[np.random.choice(int(X.shape[0]), num_samples), :]/255.

num_clusters = 10

# Inspect a single digit from the dataset.
digit = data[0,:]
digit = np.reshape(digit, (28,28))

plt.imshow(digit, cmap='Greys')

plt.show()

Kernel $K(s_i,s_j)$

Spectral embedding and other "kernel trick" methods use a kernel $K(s_i, s_j)$ that measures similarity between data points. The choice of kernel is part of the procedure and requires some intuition/trial-and-error.

A popular choice of kernel is the Gaussian kernel $$ K(s_i, s_j) = \exp\left(- \frac{||s_i - s_j||^2}{2\sigma^2}\right) $$

In [3]:
% matplotlib notebook

import numpy.linalg as nla

# Gaussian kernel with variance sigma^2.
sigma = 100.0
def kernel(s_i, s_j):
    return np.exp(-nla.norm(s_i - s_j)**2.0 / (2.0*sigma**2.0))

# The kernel matrix of the data.
kernel_matrix = np.zeros((num_samples, num_samples))
for i in range(num_samples):
    for j in range(i, num_samples):
        kernel_matrix[i,j] = kernel(data[i,:], data[j,:])
        kernel_matrix[j,i] = kernel_matrix[i,j]
        
plt.matshow(kernel_matrix)
Out[3]:
<matplotlib.image.AxesImage at 0x7f4c629d2890>

Graph Laplacian $L$

A key step of spectral embedding is the construction of the graph Laplacian, $L = D - K$, where $D_{ij} = \sum_{l=1}^n K(s_i, s_l) \delta_{ij}$ is the degree matrix and $K_{ij} = K(s_i, s_j)$ is the edge weight matrix (the kernel matrix). The graph Laplacian can be interpreted as an effective quadratic Hamiltonian for springs coupled Harmonically.

In [4]:
% matplotlib notebook

degrees = np.sum(kernel_matrix, axis=0)
D = np.diag(degrees)
K = kernel_matrix

L = D - K

plt.matshow(L)
Out[4]:
<matplotlib.image.AxesImage at 0x7f4c62f11c10>

Spectrum of $L$

Compute the spectrum of $L$. Ignore the exactly zero eigenvector as it has to always be $(1,\ldots,1)$, which does not give us useful information.

In [5]:
% matplotlib notebook

import numpy.linalg as nla

(eigvals, eigvecs) = nla.eigh(L)

print('Smallest eigenvalues = {}'.format(eigvals[0:2*num_clusters]))

plt.plot(eigvals[1:], 'ro')
plt.xlabel('Eigenvalue index')
plt.ylabel('Eigenvalue')

plt.show()
Smallest eigenvalues = [  2.48569139e-13   1.98290340e+03   1.98298467e+03   1.98316849e+03
   1.98318602e+03   1.98416835e+03   1.98420977e+03   1.98447910e+03
   1.98449305e+03   1.98451311e+03   1.98453060e+03   1.98463880e+03
   1.98466076e+03   1.98467592e+03   1.98475062e+03   1.98475523e+03
   1.98479111e+03   1.98481531e+03   1.98485246e+03   1.98492046e+03]

Spectral clustering

Here we use spectral clustering to find which cluster the digits fit into:

In [6]:
from sklearn.cluster import KMeans, SpectralClustering

sigma = 500.0

# spectral clustering
sc = SpectralClustering(n_clusters=num_clusters, gamma=1.0/sigma**2.0, affinity='rbf', n_init=100, random_state=0, assign_labels='kmeans').fit(data)

skl_sc_clusters_info = []
for ind_cluster in range(num_clusters):
    skl_sc_clusters_info.append([])

for ind_point in range(num_samples):
    ind_cluster = sc.labels_[ind_point]
    skl_sc_clusters_info[ind_cluster].append(ind_point)

Results

Here we show which digits it put into which cluster.

In [7]:
% matplotlib notebook

num_cols = 10

plt.figure()

ind_subplot = 1
for ind_cluster in range(num_clusters):
    for ind_col in range(num_cols):
        plt.subplot(num_clusters, num_cols, ind_subplot)
        
        ind_point = skl_sc_clusters_info[ind_cluster][ind_col]
        digit = data[ind_point, :]
        digit = np.reshape(digit, (28,28))
        
        plt.imshow(digit, cmap='Greys')
        
        frame1 = plt.gca()
        frame1.axes.xaxis.set_ticklabels([])
        frame1.axes.yaxis.set_ticklabels([])
        
        ind_subplot += 1
        
plt.show()
In [ ]: