A GP implementation using Jax that shows kernel computation, conditioning of Gaussian distributions, parameter transformations, and gradient-based optimisation.
It is assumed that the reader of this article is familiar with the applications
of Gaussian process (GP) models, and has maybe even fit one using a popular
package such as GPFlow or
GPyTorch. However, in this article,
we will lift the lid on how one can create a stable and efficient GP
implementation. We will be using Jax to
facilitate this, however, any scientific library supporting automatic
differentiation will be suitable.
Before we get started, we’ll make the imports necesssary for this notebook.
We’ll be using Jax’s numpy
module for most scientific evaluations, TensorFlow
Probability’s Jax substrate for access to probability distributions, the Jax’s
optimizers
module for gradient-based optimisation, and Jax’s linalg
module
for Cholesky factorisation of matrices. Further to this, we enfore double
precision in Jax. This additional precision is critical for stable matrix
inversions, but more on that later.
1
2
3
4
5
6
7
8
9
10
11
| import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable, Tuple
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.experimental import optimizers
from jax.scipy.linalg import cho_factor, cho_solve
key = jr.PRNGKey(123)
|
Data
We’ll be considering 1D regression in this notebook. For a set of input
locations $x$, the corresponding response variable $y$ is computed by $y_i =
f(x_i) + \epsilon_i$ where $\epsilon_i\sim\mathcal{N}(0, 0.2)$ and $f$ is given
by $$f(x) = \sin(4x) + \cos(2x),.$$ Our aim is to recover $f$ using a Gaussian
process.
1
2
3
4
5
6
7
8
9
10
11
| N = 50
noise = 0.2
x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(N, 1)).sort(axis=0)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise
plt.plot(x, y, 'o', label='Observations')
plt.plot(jnp.linspace(-3., 3., 500), f(jnp.linspace(-3., 3., 500)), label='True function')
plt.legend(loc='best')
|
Building a GP
Kernels
A GP is defined in terms of its mean function $\mu$ and its kernel $k$. For this
tutorial, we’ll use the squared exponential kernel and’ll make the standard
assumption that our mean function is zero. For a variance parameter $\sigma$ and
lengthscale $\ell$, we can write the squared exponential kernel as
$$
k(x, x’) =
\sigma^2 \exp\left(\frac{-0.5\lVert x - x’ \rVert_2^2}{\ell^2}\right) . \tag{1}
$$
Jax is a functional language. As such, we’ll first define a generic function
that accepts three arguments: a pair of inputs $x$ and $x’$, and a dictionary of
parameters that will contain the above defined lengthscale and variance terms.
To compute a kernel matrix, we will then have to evaluate this function for
every pair of inputs. Fortunately, Jax provides the vmap
function that, when
used twice, computes functions on every pair of items from two arrays in a
highly efficient manner. Let’s now code this up.
1
2
3
4
5
6
7
8
9
10
11
| def rbf_kernel(x: jax.Array, y: jax.Array, params: dict) -> jax.Array:
ell, sigma = params["lengthscale"], params["variance"]
tau = jnp.sum(jnp.square(x / ell - y / ell))
return sigma * jnp.exp(-0.5 * tau)
def evaluate_kernel(
kernel_fn: Callable, x: jax.Array, y: jax.Array, params: dict
) -> jax.Array:
K = jax.vmap(lambda x1: jax.vmap(lambda y1: kernel_fn(x1, y1, params))(y))(x)
return K
|
Conditioning on data
With a GP prior now defined, we can condition on our observed data to yield a GP
posterior. Equipped with this, we can query the GP’s posterior mean and variance
at a set of new points to obtain a set of predictions. When computing these two
terms, we are required to compute a matrix inverse using our $n\times n$
covariance matrix $K$. However, the matrix $K$ is symmetric and
positive-definite and consequently admits a Cholesky factorisation $K =
LL^{\top}$. Computing matrix inverses using $L$ instead of $K$ is not only
computationally faster, but also a more numerically stable operation. In the
following predict
function, we’ll make use of Jax’s linalg
library to
compute Cholesky factors.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| def predict(
x: jax.Array, y: jax.Array, params: dict
) -> Tuple[Callable, Callable]:
n = x.shape[0]
Kxx = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
Kxx += jnp.eye(n) * params["likelihood"]["obs_noise"]
prior_mean = jnp.zeros_like(x)
L = cho_factor(Kxx, lower=True)
prior_distance = y - prior_mean
weights = cho_solve(L, prior_distance)
def mean_and_variance(test_points) -> jax.Array:
Kfx = evaluate_kernel(rbf_kernel, x, test_points, params["kernel"])
mu = jnp.dot(Kfx.T, weights)
Kxx = evaluate_kernel(
rbf_kernel, test_points, test_points, params["kernel"]
)
latents = cho_solve(L, Kfx)
return mu, Kxx - jnp.dot(Kfx.T, latents)
return mean_and_variance
|
Parameters
Before we can make any predictions we must define a set of model parameters.
Along with our kernel’s parameters, we also must deal with the observational
noise parameter that parameterises our model’s likelihood function. We’ll
organise these three terms using a dictionary.
It is helpful to pause here and acknowledge that this is another area where Jax
diverges from other scientific frameworks such as TensorFlow and PyTorch. Due to
Jax’s functional nature, functions cannot have any side-effects. In practice,
this means that any items that we wish to change should have state and should
not be contained within an object’s attribute set.
1
2
3
4
| params = {
"kernel": {"lengthscale": jnp.array(1.0), "variance": jnp.array(1.0)},
"likelihood": {"obs_noise": jnp.array(1.0)},
}
|
1
2
| x_test = jnp.linspace(-3.2, 3.2, num=300).reshape(-1, 1)
mu, sigma2 = predict(x, y, params)(x_test)
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| def plot(x_train, y_train, x_test, mu, sigma2, ax = None, legend=False):
if not ax:
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(x_train, y_train, 'o', label='Observations')
ax.plot(x_test, mu, label='Predictive mean', linewidth=2)
ax.fill_between(
x_test.squeeze(),
mu.squeeze() - 3 * jnp.sqrt(jnp.diag(sigma2).squeeze()),
mu.squeeze() + 3 * jnp.sqrt(jnp.diag(sigma2).squeeze()),
alpha=0.1,
color="tab:blue",
label=r"3 $\sigma$",
)
ax.fill_between(
x_test.squeeze(),
mu.squeeze() - jnp.sqrt(jnp.diag(sigma2).squeeze()),
mu.squeeze() + jnp.sqrt(jnp.diag(sigma2).squeeze()),
alpha=0.3,
color="tab:blue",
label=r"1 $\sigma$",
)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set(xlabel=r'$x$', ylabel=r'$f(x)$')
return ax
ax = plot(x, y, x_test, mu, sigma2)
ax.legend(loc="best")
|
This looks OK, but we can probably do better if we pick a better set of model
parameters. To see this, let’s do a coarse grid search over the kernel’s
lengthscale and variance terms and plot the results.
1
2
3
4
5
6
7
8
9
10
11
12
| from itertools import product
vals = [0.05, 0.5, 1.]
fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(18, 10), tight_layout=True)
for (ell, sigma), ax in zip(list(product(vals,repeat=2)), axes.ravel()):
params['kernel']['lengthscale'] = jnp.array(ell)
params['kernel']['variance'] = jnp.array(sigma)
mu, sigma2 = predict(x, y, params)(x_test)
ax = plot(x, y, x_test, mu, sigma2, ax = ax)
ax.set_title(f"Lengthscale: {ell} Variance: {sigma}")
plt.show()
|
Some parameter configurations are clearly better than others. For small
lenghtscale and large variance values, our GP clearly overfits the data, whilst
for large lengthscale and low variance, the opposite is true as we oversmooth
the data. To get an optimal set of parameters though we can do something a
little more principled than grid search if we instead optimise the GP’s marginal
log likelihood. To achieve this, we’ll first define the marginal log-likelihood
as a function and the use Jax’s grad
function to compute derivatives of
function with respect to the parameters. Using these derivatives we can carry
out gradient based optimisation to learn an optimal parameter set.
Each of the three parameters considered here are strictly positive. Because of
this, there is a danger when optimising them that we could step too far and
arrive at a negative value. For this reason, we transform our parameters so that
they are defined on the entire real line, make an optimisation step, and then
back-transform the parameter value so that they are re-defined on their original
constrained space. We therefore require that our function is bijective i.e., has
a one-to-one mapping, so that the transforming and back-transforming steps are
fully identifiable. For this tutorial, we will use the sofplus function that
maps a parameter $\theta$ according to
$$\begin{aligned}g: \mathbb{R} \to \mathbb{R}_{>0}\, , \quad & \text{where} \ g(x) = \log(1+\exp(x)) \\
g^{-1}: \mathbb{R}_{>0} \to \mathbb{R}\, , \quad & \text{where} \ g(x) = \log(\exp(x) - 1)\end{aligned}$$
We need not do anything too fancy here, so we’ll
just define two simple lambda functions to achieve this bijection.
1
2
| softplus = lambda x: jnp.log(1.0 + jnp.exp(x))
inv_softplus = lambda x: jnp.log(jnp.exp(x) - 1.0)
|
It would be quite cumbersome to apply this function to our parameter dictionary
using list comprehension, particularly given that we’ve organised our parameters
into a set of nested dictionaries. Jax provides an elegant solution to this
though through the tree_map
function. Jax sees objects such as a dictionary as
a PyTree with the leaves of the tree being arrays. tree_map
takes two
arguments, a function and a PyTree, and then scans the PyTree and applies the
given function to the tree’s leaves. The returned value is a tree of identical
shape where the leaves have been transformed according to the function given to
tree_map
. To see this, we’ll now unconstrain our parameter set.
1
2
| unconstrained_params = jax.tree_map(inv_softplus, params)
print(unconstrained_params)
|
With parameters now defined on the entire real line, we’ll now go ahead and
define our GP’s marginal log-likelihood function.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
| def marginal_log_likelihood(
x: jax.Array, y: jax.Array
) -> Callable[[dict], jax.Array]:
n = x.shape[0]
mu = jnp.zeros(shape=x.shape[0])
def objective(params: dict, jitter_amount: float = 1e-8):
params = jax.tree_map(softplus, params)
Kff = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
noise_matrix = jnp.eye(n) * params["likelihood"]["obs_noise"]
gram_matrix = Kff + noise_matrix + jnp.eye(n) * jitter_amount
L = jnp.linalg.cholesky(gram_matrix)
return (
jnp.array(-1.0)
* tfp.distributions.MultivariateNormalTriL(mu, L)
.log_prob(y.squeeze())
.sum()
)
return objective
|
In the above function, one can see that we use TensorFlow Probability to
evaluate the log-probability density function of a multivariate normal
distribution. This is by no means a requirement, and one can certainly define
their own log-pdf, however, this is a nice point to show how TensorFlow
Porbability can be seamlessly integrated into Jax code.
With everything in place, we can now go ahead an compile our marginal
log-likelihood function and use Jax’s adam optimiser to carry out gradient-based
optimisation of this function with respect to our parameter set.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| mll = jax.jit(marginal_log_likelihood(x, y))
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(unconstrained_params)
def step(i, opt_state):
unconstrained_params = get_params(opt_state)
v, g = jax.value_and_grad(mll)(unconstrained_params)
return opt_update(i, g, opt_state), v
mlls = []
for i in range(500):
opt_state, objective = step(i, opt_state)
mlls.append(-objective)
|
Assessing convergence
We can plot the set of marginal log-likelihood evaluations now to check whether
or not our optimisation routine converged. If we were successful, then we would
hope to see a plateau marginal log-likelihood evaluations after a certain number
of iterations.
1
2
| plt.plot(jnp.arange(len(mlls)), mlls)
plt.show()
|
Inspecting the learned posterior
The optimisation curve looks good, so we can be happy that we’ve converged to a
final set of parameter values. It is of interest here to inspect these values
and see how they compare to the values used earlier on in the grid-search step.
To do this though, we first need to back-transform our parameters onto the
positive real-line.
1
2
| final_params = jax.tree_map(softplus, get_params(opt_state))
print(final_params)
|
We can see that the learned observational noise parameter is much lower than the
value of 1 that was being used earlier. Conversely, the kernel parameters of are
quite similar to the optimal set found in the early grid search routine.
Finally, we can use these parameters to query our posterior distribution at a
new set of inputs.
1
2
3
4
5
| mu, sigma = predict(x, y, final_params)(x_test)
ax = plot(x, y, x_test, mu, sigma)
ax.plot(x_test, f(x_test), label='True function', color='red', linestyle = '--')
ax.legend(loc='best')
|
Visually, this looks great. Our GP has nice uncertainty bands that increase as
we move away from the data and narrow in regions where data is abundant.
Further, our posterior mean models has done a pretty good job of recovering the
true latent function and certainly hasn’t overfit to the data.
Conclusions
To wrap up, this tutorial has given a barebones implementation of Gaussian
processes using Jax. Through this implementation, light has been shone on some
of the more computational aspects required to fit Gaussian processes, such as
parameter transformations and Cholesky factorisations. Finally, many of the
principles outlined in this post can be used to extend Gaussian processes to
non-conjugate and/or sparse settings.
System configuration
1
2
| %load_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'
|
Just give me the code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
| import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable, Tuple
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.experimental import optimizers
from jax.scipy.linalg import cho_factor, cho_solve
key = jr.PRNGKey(123)
softplus = lambda x: jnp.log(1.0 + jnp.exp(x))
inv_softplus = lambda x: jnp.log(jnp.exp(x) - 1.0)
def rbf_kernel(x: jax.Array, y: jax.Array, params: dict) -> jax.Array:
ell, sigma = params["lengthscale"], params["variance"]
tau = jnp.sum(jnp.square(x / ell - y / ell))
return sigma * jnp.exp(-0.5 * tau)
def evaluate_kernel(
kernel_fn: Callable, x: jax.Array, y: jax.Array, params: dict
) -> jax.Array:
return jax.vmap(
lambda x1: jax.vmap(lambda y1: kernel_fn(x1, y1, params))(y)
)(x)
def marginal_log_likelihood(
x: jax.Array, y: jax.Array
) -> Callable[[dict], jax.Array]:
n = x.shape[0]
mu = jnp.zeros(shape=x.shape[0])
def objective(params: dict, jitter_amount: float = 1e-8):
params = jax.tree_map(softplus, params)
Kff = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
noise_matrix = jnp.eye(n) * params["likelihood"]["obs_noise"]
gram_matrix = Kff + noise_matrix + jnp.eye(n) * jitter_amount
L = jnp.linalg.cholesky(gram_matrix)
return (
jnp.array(-1.0)
* tfp.distributions.MultivariateNormalTriL(mu, L)
.log_prob(y.squeeze())
.sum()
)
return objective
# %%
def predict(
x: jax.Array, y: jax.Array, params: dict
) -> Tuple[Callable, Callable]:
n = x.shape[0]
Kxx = evaluate_kernel(rbf_kernel, x, x, params["kernel"])
Kxx += jnp.eye(n) * params["likelihood"]["obs_noise"]
prior_mean = jnp.zeros_like(x)
L = cho_factor(Kxx, lower=True)
prior_distance = y - prior_mean
weights = cho_solve(L, prior_distance)
def mean_and_variance(test_points) -> jax.Array:
Kfx = evaluate_kernel(rbf_kernel, x, test_points, params["kernel"])
mu = jnp.dot(Kfx.T, weights)
Kxx = evaluate_kernel(
rbf_kernel, test_points, test_points, params["kernel"]
)
latents = cho_solve(L, Kfx)
return mu, Kxx - jnp.dot(Kfx.T, latents)
return mean_and_variance
# %%
if __name__ == "__main__":
params = {
"kernel": {"lengthscale": jnp.array(1.0), "variance": jnp.array(1.0)},
"likelihood": {"obs_noise": jnp.array(1.0)},
}
f = lambda t: jnp.sin(4 * t) + jnp.cos(3 * t)
N = 50
noise = 0.2
x = (
jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(N,))
.sort()
.reshape(-1, 1)
)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise
mll = jax.jit(marginal_log_likelihood(x, y))
unconstrained_params = jax.tree_map(inv_softplus, params)
print(mll(unconstrained_params))
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(unconstrained_params)
def step(i, opt_state):
unconstrained_params = get_params(opt_state)
g = jax.grad(mll)(unconstrained_params)
return opt_update(i, g, opt_state)
for i in range(500):
opt_state = step(i, opt_state)
print(mll(get_params(opt_state)))
final_params = jax.tree_map(softplus, get_params(opt_state))
x_test = jnp.linspace(-3.2, 3.2, num=300).reshape(-1, 1)
mu, sigma = predict(x, y, final_params)(x_test)
plt.plot(x, y, "+", color="tab:orange", label="Obs")
plt.plot(x_test, mu, label="Preds", color="tab:blue")
plt.fill_between(
x_test.squeeze(),
mu.squeeze() - 3 * jnp.sqrt(jnp.diag(sigma).squeeze()),
mu.squeeze() + 3 * jnp.sqrt(jnp.diag(sigma).squeeze()),
alpha=0.2,
color="tab:blue",
label=r"1 $\sigma$",
)
plt.fill_between(
x_test.squeeze(),
mu.squeeze() - jnp.sqrt(jnp.diag(sigma).squeeze()),
mu.squeeze() + jnp.sqrt(jnp.diag(sigma).squeeze()),
alpha=0.4,
color="tab:blue",
label=r"1 $\sigma$",
)
plt.legend(loc="best")
plt.savefig("preds.png")
|