Gaussian processes with Jax in 80 lines

Summary: A minimal Gaussian process implementation using Jax. This tutorial teaches kernel computation, conditioning of Gaussian distributions, parameter transformations, and gradient based optimisation.

It is assumed that the reader of this article is familiar with the applications of Gaussian process (GP) models, and has maybe even fit one using a popular package such as GPFlow or GPyTorch. However, in this article, we will lift the lid on how one can create a stable and efficient GP implementation. We will be using Jax to facilitate this, however, any scientific library supporting automatic differentiation will be suitable.

Before we get started, we’ll make the imports necesssary for this notebook. We’ll be using Jax’s numpy module for most scientific evaluations, TensorFlow Probability’s Jax substrate for access to probability distributions, the Jax’s optimizers module for gradient-based optimisation, and Jax’s linalg module for Cholesky factorisation of matrices. Further to this, we enfore double precision in Jax. This additional precision is critical for stable matrix inversions, but more on that later.

import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable, Tuple
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.experimental import optimizers
from jax.scipy.linalg import cho_factor, cho_solve

key = jr.PRNGKey(123)

Data

We’ll be considering 1D regression in this notebook. For a set of input locations $x$, the corresponding response variable $y$ is computed by $y_i = f(x_i) + \epsilon_i$ where $\epsilon_i\sim\mathcal{N}(0, 0.2)$ and $f$ is given by $$f(x) = \sin(4x) + \cos(2x),.$$ Our aim is to recover $f$ using a Gaussian process.

N = 50
noise = 0.2

x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(N, 1)).sort(axis=0)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

plt.plot(x, y, 'o', label='Observations')
plt.plot(jnp.linspace(-3., 3., num=500), f(jnp.linspace(-3., 3., num=500)), label='True function')
plt.legend(loc='best')

Building a GP

Kernels

A GP is defined in terms of its mean function $\mu$ and its kernel $k$. For this tutorial, we’ll use the squared exponential kernel and’ll make the standard assumption that our mean function is zero. For a variance parameter $\sigma$ and lengthscale $\ell$, we can write the squared exponential kernel as $$k(x, x') = \sigma^2 \exp\left(\frac{-0.5\lVert x - x' \rVert_2^2}{\ell^2}\right) . \tag{1}$$

Jax is a functional language. As such, we’ll first define a generic function that accepts three arguments: a pair of inputs $x$ and $x'$, and a dictionary of parameters that will contain the above defined lengthscale and variance terms. To compute a kernel matrix, we will then have to evaluate this function for every pair of inputs. Fortunately, Jax provides the vmap function that, when used twice, computes functions on every pair of items from two arrays in a highly efficient manner. Let’s now code this up.

def rbf_kernel(x: jnp.array, y: jnp.array, params: dict) -> jnp.array:
    ell, sigma = params["lengthscale"], params["variance"]
    tau = jnp.sum(jnp.square(x / ell - y / ell))
    return sigma * jnp.exp(-0.5 * tau)


def evaluate_kernel(
    kernel_fn: Callable, x: jnp.array, y: jnp.array, params: dict
) -> jnp.array:
    K = jax.vmap(lambda x1: jax.vmap(lambda y1: kernel_fn(x1, y1, params))(y))(x)
    return K 

Conditioning on data

With a GP prior now defined, we can condition on our observed data to yield a GP posterior. Equipped with this, we can query the GP’s posterior mean and variance at a set of new points to obtain a set of predictions. When computing these two terms, we are required to compute a matrix inverse using our $n\times n$ covariance matrix $K$. However, the matrix $K$ is symmetric and positive-definite and consequently admits a Cholesky factorisation $K = LL^{\top}$. Computing matrix inverses using $L$ instead of $K$ is not only computationally faster, but also a more numerically stable operation. In the following predict function, we’ll make use of Jax’s linalg library to compute Cholesky factors.

def predict(
    x: jnp.array, y: jnp.array, params: dict
) -> Tuple[Callable, Callable]:
    n = x.shape[0]
    Kxx = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
    Kxx += jnp.eye(n) * params["likelihood"]["obs_noise"]
    prior_mean = jnp.zeros_like(x)
    L = cho_factor(Kxx, lower=True)
    prior_distance = y - prior_mean
    weights = cho_solve(L, prior_distance)

    def mean_and_variance(test_points) -> jnp.array:
        Kfx = evaluate_kernel(rbf_kernel, x, test_points, params["kernel"])
        mu = jnp.dot(Kfx.T, weights)
        Kxx = evaluate_kernel(
            rbf_kernel, test_points, test_points, params["kernel"]
        )
        latents = cho_solve(L, Kfx)
        return mu, Kxx - jnp.dot(Kfx.T, latents)

    return mean_and_variance

Parameters

Before we can make any predictions we must define a set of model parameters. Along with our kernel’s parameters, we also must deal with the observational noise parameter that parameterises our model’s likelihood function. We’ll organise these three terms using a dictionary.

It is helpful to pause here and acknowledge that this is another area where Jax diverges from other scientific frameworks such as TensorFlow and PyTorch. Due to Jax’s functional nature, functions cannot have any side-effects. In practice, this means that any items that we wish to change should have state and should not be contained within an object’s attribute set.

params = {
    "kernel": {"lengthscale": jnp.array(1.0), "variance": jnp.array(1.0)},
    "likelihood": {"obs_noise": jnp.array(1.0)},
}
x_test = jnp.linspace(-3.2, 3.2, num=300).reshape(-1, 1)
mu, sigma2 = predict(x, y, params)(x_test)
def plot(x_train, y_train, x_test, mu, sigma2, ax = None, legend=False):
    if not ax:
        fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(x_train, y_train, 'o', label='Observations')
    ax.plot(x_test, mu, label='Predictive mean', linewidth=2)
    ax.fill_between(
        x_test.squeeze(),
        mu.squeeze() - 3 * jnp.sqrt(jnp.diag(sigma2).squeeze()),
        mu.squeeze() + 3 * jnp.sqrt(jnp.diag(sigma2).squeeze()),
        alpha=0.1,
        color="tab:blue",
        label=r"3 $\sigma$",
    )
    ax.fill_between(
        x_test.squeeze(),
        mu.squeeze() - jnp.sqrt(jnp.diag(sigma2).squeeze()),
        mu.squeeze() + jnp.sqrt(jnp.diag(sigma2).squeeze()),
        alpha=0.3,
        color="tab:blue",
        label=r"1 $\sigma$",
    )
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set(xlabel=r'$x$', ylabel=r'$f(x)$')
    return ax

ax = plot(x, y, x_test, mu, sigma2)
ax.legend(loc="best")

This looks OK, but we can probably do better if we pick a better set of model parameters. To see this, let’s do a coarse grid search over the kernel’s lengthscale and variance terms and plot the results.

from itertools import product

vals = [0.05, 0.5, 1.]
fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(18, 10), tight_layout=True)
for (ell, sigma), ax in zip(list(product(vals,repeat=2)), axes.ravel()):
    params['kernel']['lengthscale'] = jnp.array(ell)
    params['kernel']['variance'] = jnp.array(sigma)
    mu, sigma2 = predict(x, y, params)(x_test)
    ax = plot(x, y, x_test, mu, sigma2, ax = ax)
    ax.set_title(f"Lengthscale: {ell}     Variance: {sigma}")

plt.show()

Some parameter configurations are clearly better than others. For small lenghtscale and large variance values, our GP clearly overfits the data, whilst for large lengthscale and low variance, the opposite is true as we oversmooth the data. To get an optimal set of parameters though we can do something a little more principled than grid search if we instead optimise the GP’s marginal log likelihood. To achieve this, we’ll first define the marginal log-likelihood as a function and the use Jax’s grad function to compute derivatives of function with respect to the parameters. Using these derivatives we can carry out gradient based optimisation to learn an optimal parameter set.

Transforming parameters

Each of the three parameters considered here are strictly positive. Because of this, there is a danger when optimising them that we could step too far and arrive at a negative value. For this reason, we transform our parameters so that they are defined on the entire real line, make an optimisation step, and then back-transform the parameter value so that they are re-defined on their original constrained space. We therefore require that our function is bijective i.e., has a one-to-one mapping, so that the transforming and back-transforming steps are fully identifiable. For this tutorial, we will use the sofplus function that maps a parameter $\theta$ according to

$$\begin{aligned}g: \mathbb{R} \to \mathbb{R}_{>0}\, , \quad & \text{where} \ g(x) = \log(1+\exp(x)) \\ g^{-1}: \mathbb{R}_{>0} \to \mathbb{R}\, , \quad & \text{where} \ g(x) = \log(\exp(x) - 1)\end{aligned}$$

We need not do anything too fancy here, so we’ll just define two simple lambda functions to achieve this bijection.

softplus = lambda x: jnp.log(1.0 + jnp.exp(x))
inv_softplus = lambda x: jnp.log(jnp.exp(x) - 1.0)

It would be quite cumbersome to apply this function to our parameter dictionary using list comprehension, particularly given that we’ve organised our parameters into a set of nested dictionaries. Jax provides an elegant solution to this though through the tree_map function. Jax sees objects such as a dictionary as a PyTree with the leaves of the tree being arrays. tree_map takes two arguments, a function and a PyTree, and then scans the PyTree and applies the given function to the tree’s leaves. The returned value is a tree of identical shape where the leaves have been transformed according to the function given to tree_map. To see this, we’ll now unconstrain our parameter set.

unconstrained_params = jax.tree_map(inv_softplus, params)
print(unconstrained_params)

With parameters now defined on the entire real line, we’ll now go ahead and define our GP’s marginal log-likelihood function.

def marginal_log_likelihood(
    x: jnp.array, y: jnp.array
) -> Callable[[dict], jnp.array]:
    n = x.shape[0]
    mu = jnp.zeros(shape=x.shape[0])

    def objective(params: dict, jitter_amount: float = 1e-8):
        params = jax.tree_map(softplus, params)
        Kff = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
        noise_matrix = jnp.eye(n) * params["likelihood"]["obs_noise"]
        gram_matrix = Kff + noise_matrix + jnp.eye(n) * jitter_amount
        L = jnp.linalg.cholesky(gram_matrix)
        return (
            jnp.array(-1.0)
            * tfp.distributions.MultivariateNormalTriL(mu, L)
            .log_prob(y.squeeze())
            .sum()
        )

    return objective

In the above function, one can see that we use TensorFlow Probability to evaluate the log-probability density function of a multivariate normal distribution. This is by no means a requirement, and one can certainly define their own log-pdf, however, this is a nice point to show how TensorFlow Porbability can be seamlessly integrated into Jax code.

With everything in place, we can now go ahead an compile our marginal log-likelihood function and use Jax’s adam optimiser to carry out gradient-based optimisation of this function with respect to our parameter set.

mll = jax.jit(marginal_log_likelihood(x, y))

opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(unconstrained_params)

def step(i, opt_state):
    unconstrained_params = get_params(opt_state)
    v, g = jax.value_and_grad(mll)(unconstrained_params)
    return opt_update(i, g, opt_state), v

mlls = []
for i in range(500):
    opt_state, objective = step(i, opt_state)
    mlls.append(-objective)

Assessing convergence

We can plot the set of marginal log-likelihood evaluations now to check whether or not our optimisation routine converged. If we were successful, then we would hope to see a plateau marginal log-likelihood evaluations after a certain number of iterations.

plt.plot(jnp.arange(len(mlls)), mlls)
plt.show()

Inspecting the learned posterior

The optimisation curve looks good, so we can be happy that we’ve converged to a final set of parameter values. It is of interest here to inspect these values and see how they compare to the values used earlier on in the grid-search step. To do this though, we first need to back-transform our parameters onto the positive real-line.

final_params = jax.tree_map(softplus, get_params(opt_state))
print(final_params)

We can see that the learned observational noise parameter is much lower than the value of 1 that was being used earlier. Conversely, the kernel parameters of are quite similar to the optimal set found in the early grid search routine.

Finally, we can use these parameters to query our posterior distribution at a new set of inputs.

mu, sigma = predict(x, y, final_params)(x_test)

ax = plot(x, y, x_test, mu, sigma)
ax.plot(x_test, f(x_test), label='True function', color='red', linestyle = '--')
ax.legend(loc='best')

Visually, this looks great. Our GP has nice uncertainty bands that increase as we move away from the data and narrow in regions where data is abundant. Further, our posterior mean models has done a pretty good job of recovering the true latent function and certainly hasn’t overfit to the data.

Conclusions

To wrap up, this tutorial has given a barebones implementation of Gaussian processes using Jax. Through this implementation, light has been shone on some of the more computational aspects required to fit Gaussian processes, such as parameter transformations and Cholesky factorisations. Finally, many of the principles outlined in this post can be used to extend Gaussian processes to non-conjugate and/or sparse settings.

System configuration

%load_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Just give me the code

import jax

jax.config.update("jax_enable_x64", True)
from typing import Callable, Tuple
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.experimental import optimizers
from jax.scipy.linalg import cho_factor, cho_solve

key = jr.PRNGKey(123)

softplus = lambda x: jnp.log(1.0 + jnp.exp(x))
inv_softplus = lambda x: jnp.log(jnp.exp(x) - 1.0)


def rbf_kernel(x: jnp.array, y: jnp.array, params: dict) -> jnp.array:
    ell, sigma = params["lengthscale"], params["variance"]
    tau = jnp.sum(jnp.square(x / ell - y / ell))
    return sigma * jnp.exp(-0.5 * tau)


def evaluate_kernel(
    kernel_fn: Callable, x: jnp.array, y: jnp.array, params: dict
) -> jnp.array:
    return jax.vmap(
        lambda x1: jax.vmap(lambda y1: kernel_fn(x1, y1, params))(y)
    )(x)


def marginal_log_likelihood(
    x: jnp.array, y: jnp.array
) -> Callable[[dict], jnp.array]:
    n = x.shape[0]
    mu = jnp.zeros(shape=x.shape[0])

    def objective(params: dict, jitter_amount: float = 1e-8):
        params = jax.tree_map(softplus, params)
        Kff = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
        noise_matrix = jnp.eye(n) * params["likelihood"]["obs_noise"]
        gram_matrix = Kff + noise_matrix + jnp.eye(n) * jitter_amount
        L = jnp.linalg.cholesky(gram_matrix)
        return (
            jnp.array(-1.0)
            * tfp.distributions.MultivariateNormalTriL(mu, L)
            .log_prob(y.squeeze())
            .sum()
        )

    return objective


# %%
def predict(
    x: jnp.array, y: jnp.array, params: dict
) -> Tuple[Callable, Callable]:
    n = x.shape[0]
    Kxx = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
    Kxx += jnp.eye(n) * params["likelihood"]["obs_noise"]
    prior_mean = jnp.zeros_like(x)
    L = cho_factor(Kxx, lower=True)
    prior_distance = y - prior_mean
    weights = cho_solve(L, prior_distance)

    def mean_and_variance(test_points) -> jnp.array:
        Kfx = evaluate_kernel(rbf_kernel, x, test_points, params["kernel"])
        mu = jnp.dot(Kfx.T, weights)
        Kxx = evaluate_kernel(
            rbf_kernel, test_points, test_points, params["kernel"]
        )
        latents = cho_solve(L, Kfx)
        return mu, Kxx - jnp.dot(Kfx.T, latents)

    return mean_and_variance


# %%
if __name__ == "__main__":
    params = {
        "kernel": {"lengthscale": jnp.array(1.0), "variance": jnp.array(1.0)},
        "likelihood": {"obs_noise": jnp.array(1.0)},
    }

    f = lambda t: jnp.sin(4 * t) + jnp.cos(3 * t)

    N = 50
    noise = 0.2

    x = (
        jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(N,))
        .sort()
        .reshape(-1, 1)
    )
    f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
    signal = f(x)
    y = signal + jr.normal(key, shape=signal.shape) * noise

    mll = jax.jit(marginal_log_likelihood(x, y))

    unconstrained_params = jax.tree_map(inv_softplus, params)
    print(mll(unconstrained_params))

    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
    opt_state = opt_init(unconstrained_params)

    def step(i, opt_state):
        unconstrained_params = get_params(opt_state)
        g = jax.grad(mll)(unconstrained_params)
        return opt_update(i, g, opt_state)

    for i in range(500):
        opt_state = step(i, opt_state)

    print(mll(get_params(opt_state)))
    final_params = jax.tree_map(softplus, get_params(opt_state))

    x_test = jnp.linspace(-3.2, 3.2, num=300).reshape(-1, 1)
    mu, sigma = predict(x, y, final_params)(x_test)

    plt.plot(x, y, "+", color="tab:orange", label="Obs")
    plt.plot(x_test, mu, label="Preds", color="tab:blue")
    plt.fill_between(
        x_test.squeeze(),
        mu.squeeze() - 3 * jnp.sqrt(jnp.diag(sigma).squeeze()),
        mu.squeeze() + 3 * jnp.sqrt(jnp.diag(sigma).squeeze()),
        alpha=0.2,
        color="tab:blue",
        label=r"1 $\sigma$",
    )
    plt.fill_between(
        x_test.squeeze(),
        mu.squeeze() - jnp.sqrt(jnp.diag(sigma).squeeze()),
        mu.squeeze() + jnp.sqrt(jnp.diag(sigma).squeeze()),
        alpha=0.4,
        color="tab:blue",
        label=r"1 $\sigma$",
    )
    plt.legend(loc="best")
    plt.savefig("preds.png")