An overview for how a Gaussian process package can be built using Objax and JAX.

This article is currently very much a work in progress, so please excuse and typographical and grammatical errors.

There exist several excellent Gaussian process (GP) software packages. GPFlow, GPyTorch are just two examples of this for the Python language. In this series, I am setting out to build an alternative package.

Before we get into design choices, I’ll briefly describe a GP model for regression. This is more to define the notation used in this blog. For a more detailed introduction to GPs I would strongly encourage you to watch Richard Turner’s excellent lecture on GPs or refer to Rasmussen and Williams.

Gaussian process refresher

When given a dataset $\mathcal{D}$ that is comprised of a set of inputs $X \in \mathbb{R}^d$ and outputs $\mathbf{y} \in \mathbb{R}$, a common task is to learn a function $f$ that relates $X$ and $\mathbf{y}$ $$ \mathbf{y} = f(X) + \epsilon \ \ \text{, where } \epsilon \sim \mathcal{N}(0, \sigma_n^2 I).$$ Letting $\mu : X \rightarrow \mathbb{R}$ be a mean function and $k : X \times X \rightarrow \mathbb{R}$ be a positive-definite kernel function that produces the Gram matrix $K_{xx}$ such that $[K_{xx}]_{i, j}=k(x_i, x_j)$, we proceed to place a GP prior distribution directly on $f$ such that $$f(x) \sim \mathcal{GP}(\mu(x), k(x, x')).$$ In what follows we will assume that $\mu(\mathbf{x}) = 0$ for all $\mathbf{x} \in X$ without loss of generality.

To continue constructing our Bayesian model, we must also posit a likelihood function. Combined with our likelihood function $\mathbf{y} | f \sim \mathcal(f, \sigma_n^2 I)$ where $\sigma_n^2$ is an observational noise term, a posterior predictive distribution of the process at some new points $X^{\star}$ can now be formed as $$\begin{aligned}p(f_{\star} | \mathcal{D}) & = \int_{f \in \mathbb{R}^d}p(f_{\star}, f | \mathcal{D}) \mathrm{d}f \ & = \int_{f \in \mathbb{R}}p(f_{\star}|f)p(f|\mathcal{D})\mathrm{d}f \end{aligned}.$$ Both the constituent distributions in this posterior term are Gaussian densities, meaning that $p(f_{\star}|\mathcal{D})$ is also a Gaussian density with mean $\hat{\mu_{\star}}$ and covariance $\hat{\Sigma_{\star, \star}}$. Both these terms have closed form expressions: $$\hat{\mu_{\star}} = K_{\star x}(K_{xx} + \sigma_n^2 I)^{-1}\mathbf{y} \quad \text{and} \quad \hat{\Sigma_{\star, \star}}=K_{\star \star}-K_{\star x}(K_{xx}+\sigma_n^2 I)^{-1}K_{x, \star}.$$ We can then find optimal model hyperparameters $\theta$ by maximising the marginal log-likelihood which takes the form $$\log p(\mathbf{y}) = -0.5 \left(\operatorname{logdet}(K_{xx}+\sigma_n^2 I) + \mathbf{y}^{\top}(K_{xx} + \sigma_n^2 I)^{-1}\mathbf{y} + N \log (2 \pi) \right)$$ with respect to $\theta$. This can be done in a number of ways, the simplest being a gradient descent approach i.e. $\theta^{(t+1)} = \theta^{(t)}-\alpha \nabla_{\theta}\log p(\mathbf{y})$ where $\nabla_{\theta}$ is the gradient operator with respect to $\theta$.

Package aims

Not another Gaussian process package

Now the natural question that arises is “why do we need another GP package?”, and this would be perfectly valid if all we cared about was fitting state-of-the-art GP models in an efficient way. Whilst I will use TensorFlow in what follows, I will not be trying to compete with packages such as GPFlow and GPyTorch in terms of computational efficiency. Experienced teams of statisticians, machine learners and computer scientists who know far more than me about software development have spent several years refining these packages, so it would be somewhat naive of me to think I could create something better than this in a few weeks.

With that, I do hope that this package feels more natural to users familiar with GP theory. What do I mean by this? Well, currently the process of fitting a GP in a package such as GPFlow goes something along the lines of:

data = (X, y)
kernel = RBF()
likelihood = Gaussian()
model = GPR(data, kernel, likelihood)
optimiser.minimise(model.log_marginal_likelihood, model.trainable_variables)

Now you may have realised that this looks very different to the model we described earlier in our refresher. In this, we treated the GP as simply a prior distribution over some function $f$, and the likelihood and dataset were only realised when we formed the posterior. This is not the case though in the above when we instantiate our GP by storing the data and a likelihood within the model. To be a “little closer to the math” we should rethink this design…

As you would write it

To act upon the aforementioned motivation, we have designed GPBlocks to be of the following form. We first state our prior beliefs in the form of a Gaussian process

mean_func = Zero()
kernel = RBF()
prior = GP(mean_func, kernel)

Some data $\mathcal{D} = {x, y}$ is then observed. Through a likelihood function we state our beliefs around the generating process of this data.

data = (X, y)
likelihood = Gaussian()

With a prior distribution and likelihood formed, we can then go onto form our posterior. Due to the Gaussian likelihood, our GP prior is conjugate to the likelihood and our posterior is then analytically available. Practically, this means that we only need optimise the posterior distribution with respect to our model’s hyperparameter which, in this example, are three terms: an observation-specific noise term and the kernel variance and likelihood. The package is structured such that the posterior is found computationally in the same way you would write it down on paper.

posterior = prior * likelihood
optimise(posterior, posterior.hyperparameters)

With a posterior distribution found, we can then go about querying it in the normal way. This could be through either directly sampling from the posterior, or appealing to the posterior predictive distribution to make prediction at some new input values.

samps = sample(posterior, n_samples=5)
ystar = predict(posterior, xstar)

Package implementation

Underlying framework

With a high-level overview laid out, it would be useful now to think about how we might go about implementing such a framework. A fundamental consideration here are the code segments that are either tedious to code manually or expensive to compute, in which case we’d be best relying on a dependency which contains code that is optimised to perform such operations in the most efficient manner. Personally, in the GP setting, this yields the following list

  • Gradients - these are often expensive and tedious to code up. A package with automatic differentiation support would be both mentally and computationally beneficial.
  • Matrix operations - computing matrix products and inverses is often the most demanding step when fitting GPs. A dependency which has these operations written in optimised C/C++ or cuda code would greatly accelerate this operation.
  • Parameter handling - The parameters of a GP can often be scattered around several objects such as a kernel, likelihood, inducing points (in the sparse case), and as such it can be fiddly to keep track of them all.

Now several frameworks meet this specification, namely TensorFlow, PyTorch, and Jax, to name just a few. I’ve opted to base GPBlocks on Jax, which may seem odd given it’s relative immaturity when compared to well-established frameworks such as TensorFlow and PyTorch. However, this decision is motivated primarily by the two following points

  1. Jax has a very similar interface to NumPy. This appeals to the fundamental aim of GPBlocks: an easy to work with and adapt GP framework for researchers. By having a syntax that is very similar to NumPy (a framework most Python users are familiar with), I hope that people are able to easily understand the code and adapt it to their needs.
  2. I’ve never used Jax. Now this is a purely selfish point, but I’ve not had a good excuse yet to use Jax and I think it looks really nice, so this project serves as a perfect opportunity to learn Jax.

I’ll supplement Jax with Objax. Incorporating Objax allows for efficient handling of parameters within a model. This is accomplished through the TrainVar class. Wrapping any parameter in the TrainVar class and then assigning them to a model that inherits the Module class from Objax enables all trainable parameters to be returned through the .vars() method. This may seem somewhat abstract now, but the usefulness of this interaction will hopefully become apparent in the sequel.

Code skeleton

In what follows, I’ll layout the general structure of GPBlocks. Along the way I’ll hopefully shed some light on my thought process when designing the package, whilst simultaneously demonstrating how others can adapt the package to their own specific needs.


The workhorse of a GP is the kernel function. For this reason, it makes sense to start by setting up this object. Broadly speaking, kernel functions can be considered as stationary or non-stationary. For now, we’ll only concern ourselves to stationary kernels for which the kernel’s value is invariant to the value of the inputs $x$, and is instead governed entirely by the distance between the inputs $\tau = \lVert x-x'\rVert$. A common stationary kernel is the radial basis function (RBF) kernel $$k(x, x') = \sigma^2 \exp\left(-\frac{\lVert x-x' \rVert_2^2}{2\ell^2} \right)\tag{1}$$ where $\sigma$ is the kernel’s variance parameter, and $\ell$ the kernel lengthscale. Informally speaking, we can think of the variance as controlling the vertical displacement of the kernel’s output values, and the lengthscale as the horizontal displacement.

It’ll help for future code extension to have a base kernel class. In that, we’ll initialise a __call__(self, x, y) method that, for an RBF kernel, would compute (1) and a second method gram(func, x, y) that will compute the kernel’s Gram matrix given two, possibly identical, input vectors $x$ and $y$ and a kernel specific function func. Together, this yields the following base class

class Kernel(Module):
    def __init__(self,
                 name: str = None): = name

    def gram(func: Callable, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        mapx1 = vmap(lambda x, y: func(x=x, y=y),
                     in_axes=(0, None),
        mapx2 = vmap(lambda x, y: mapx1(x, y), in_axes=(None, 0), out_axes=1)
        return mapx2(x, y)

    def __call__(self, x: jnp.ndarray, y: jnp.ndarray):
        raise NotImplementedError

With some foundations established, we can go ahead and construct a stationary kernel class. Now, as discussed above, a stationary kernel is always a function of the distance between its two inputs, so we can go ahead and attach a new method that returns the distance between two inputs. We’ll also attach two extra attributes to the stationary class of kernel; a lengthscale and variance parameter. We’ll initialise each of these to 1 as there may be a case for a more specific type of kernel where one, but not the other, is required.

class Stationary(Kernel):
    def __init__(self,
                 lengthscale: jnp.ndarray = jnp.array([1.]),
                 variance: jnp.ndarray = jnp.array([1.]),
                 name: str = "Stationary"):
        self.lengthscale = TrainVar(lengthscale)
        self.variance = TrainVar(variance)

    def dist(x: jnp.array, y: jnp.array) -> float:
        return x - y

With all the groundwork in place, we can now define our RBF kernel. With everything that has already been defined, the only thing left to do is define our RBF kernel’s __call__ method. Just for now, I’ll define (1) in its own method, though there’s probably no harm in placing this code inside the __call__ method itself.

class RBF(Stationary):
    def __init__(self,
                 lengthscale: jnp.ndarray = jnp.array([1.]),
                 variance: jnp.ndarray = jnp.array([1.]),
                 name: str = "RBF"):

    def k_func(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        tau = jnp.sum(self.dist(x, y)**2)
        return self.variance.value * jnp.exp(-jnp.square(tau) / self.lengthscale.value)

    def __call__(self, X: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
        return self.gram(self.k_func, X, Y).squeeze()

Mean Function

Typically in machine learning we standardise our data and then posit a zero-mean function into our GP. For now, we’ll just implement a zero-mean function, but this object is simple to extend should someone wish to create their own mean function. To facilitate easier extension, we’ll create a base class for all mean functions, much alike how we did for kernel functions.

class MeanFunction(Module):
    def __init__(self, name: str = "Mean Function"): = name

    def __call__(self, X: jnp.ndarray):
        raise NotImplementedError

We’ll now inherit this object and define our zero-mean function

class ZeroMean(MeanFunction):
    def __init__(self, name: str = "Zero Mean"):

    def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
        return jnp.zeros(X.shape, dtype=X.dtype)

Gaussian process prior

With objects for mean and kernel functions defined, we now have all the constituent objects defined that are required to define a GP prior. Due to numerical instabilities in the Gram matrix, we often have to add a negligible constant to the diagonal. This ensures that all of the eigenvalues are greater than zero. Sometimes on high performance computers where the floating-point precision is higher, a user may wish to specify a smaller amount of diagonal jitter, so we’ll also add this a non-trainable model attribute.

class Prior(Module):
    def __init__(self,
                 kernel: Kernel,
                 mean_function: MeanFunction = ZeroMean(),
                 jitter: float = 1e-6):
        self.meanf = mean_function
        self.kernel = kernel
        self.jitter = jitter

With the base object defined, we can now add some methods to the class.

It can often be helpful to visualise the prior samples from the GP. That is, given a GP prior $p(f) = \mathcal{N}(m, k)$ where $m$, $k$ is the GP’s mean and kernel function, we sometimes wish to draw samples from $p(f)$. To facilitate this, we define a sample method which will draw a set of $n$ samples at a set of inputs $X$

    def sample(self, X: jnp.ndarray, key, n_samples: int = 1):
        Inn = jnp.eye(X.shape[0])
        mu = self.meanf(X)
        cov = self.kernel(X, X) + self.jitter * Inn
        return jr.multivariate_normal(key,
                                      shape=(n_samples, ))

The key argument here is a Jax idiosyncrasy that handles pseudo random number generation. For the interested reader, it’s well worth visiting the Jax documentation for more details on this.

The final method for now that we’ll attach to the GP prior is a __mul__ magic method. This purpose of this method will become more apparent when we implement a posterior.

    def __mul__(self, other: Likelihood):
        return Posterior(self, other)


With a prior established, the only component that is now required in order to define our posterior is a likelihood function. Conjugacy greatly simplifies matters, as we are spared from learning the GP’s latent values $f$. For this reason, whilst we are trying to get a minimal working example up and running, we’ll restrict ourselves to considering the case that $p(y|f)=\mathcal{N}(f, \sigma_n^2 I)$. That is, we’ll consider only Gaussian likelihood functions where $\sigma_n$ is an observation level noise parameter that we’ll seek to optimise later. We can follow a similar structure to our existing objects and go ahead an implement this.

class Likelihood(Module):
    def __init__(self, name: str = "Likelihood"): = name

class Gaussian(Likelihood):
    def __init__(self, noise: jnp.array = jnp.array([1.0]), name="Gaussian"):
        self.noise = TrainVar(noise)


We now have a prior and likelihood object define. Using the prior’s __mul__ method, we can now multiply these two objects to generate a posterior distribution.

class Posterior(Module):
    def __init__(self, prior: Prior, likelihood: Gaussian):
        self.kernel = prior.kernel
        self.meanf = prior.meanf
        self.likelihood = likelihood
        self.jitter = prior.jitter

As mentioned, we are assuming a Gaussian likelihood here, so we need only learn the GP’s hyperparameters $\theta$. In this instance, this equates to the kernel’s lengthscale and variance, along with the observational noise. We can learn these by optimising the GP’s marginal log-likelihood with respect to $\theta$. By using a framework such as Jax, this task is trivial once a marginal log-likelihood function has been implemented as we can compute derivatives with respect to $\theta$ using the built-in AutoDiff framework that is native to Jax.

We’ll attach this method to the Posterior class. Typically, it is more efficient and stable to define the covariance matrix in terms of its Cholesky decomposition. Currently, the only implementation of a multivariate Gaussian in Jax is found in jax.scipy.stats, but this only takes input of the dense covariance matrix.

    def marginal_ll(self, X: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        Inn = jnp.eye(X.shape[0])
        mu = self.meanf(X)
        cov = self.kernel(X, X) + self.jitter * Inn
        cov += self.likelihood.noise.value * Inn
        # L = jnp.linalg.cholesky(cov)
        # TODO: Return the logpdf w.r.t. the Cholesky, not the full cov.
        lpdf = multivariate_normal.logpdf(y.squeeze(), mu.squeeze(), cov)
        return lpdf

Optimisers minimise a function, but we wish to maximise the marginal log-likelihood. For tidiness, we’ll therefore return the negative marginal log-likelihood in separate method.

    def neg_mll(self, X: jnp.ndarray, y: jnp.ndarray):
        return -self.marginal_ll(X, y)

The only thing left to do now is define the predictive posterior. That is, we wish to compute the posterior distribution over a set of unseen points $X^{\star}$, conditional on our training points $X, y$. We’ll call in a utility function here get_factorisations that computes the cholesky decomposition and respective weights of the GP prior. We’ll then return the predictive mean and predictive covariance of the GP.

    def predict(self, Xstar, X, y):
        sigma = self.likelihood.noise.value
        L, alpha = get_factorisations(X, y, sigma, self.kernel, self.meanf)
        Kfx = self.kernel(Xstar, X)
        mu =, alpha)
        v = cho_solve(L, Kfx.T)
        Kxx = self.kernel(Xstar, Xstar)
        cov = Kxx -, v)
        return mu, cov


We can now see how this all integrates together. To see this, we’ll first simulate 50 data points. Our inputs are drawn according to $x_1, \ldots , x_{50} \overset{iid}{\sim} \mathcal{U}(-3., 3)$ and corresponding outputs will be $y_i = \sin(4x_i) + \cos(2 x_i) + \epsilon_i$ such that $\epsilon_i \sim \mathcal{N}(0, 0.1)$.

N = 50
noise = 0.1

X = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(N,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(X)
y = signal + jr.normal(key, shape=signal.shape) * noise
Xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(Xtest)

We’ll now define our GP prior through a zero-mean function and an RBF kernel.

kernel = RBF(lengthscale=jnp.array([1.0]), variance=jnp.array([1.0]))
meanf = ZeroMean()

prior = Prior(kernel=RBF(), mean_function=meanf, jitter=1e-6)

We can see the effect that our kernel has on the smoothness of the samples from the GP prior by simulating a set of realisations from the prior.

samples = prior.sample(X, key, n_samples=20)

Prior samples

We’ll now go ahead and define our posterior as function of the prior multiplied by a Gaussian likelihood.

likelihood = Gaussian()
posterior = prior * likelihood

Using the optimisers that are shipped with Objax and our posteriors negative MLL function, we can now optimise the GP’s hyperparameters. To accelerate this, we’ll define a step of our optimisation procedure in its own function before Jit compiling the function.

opt = objax.optimizer.SGD(posterior.vars())
gv = objax.GradValues(posterior.neg_mll, posterior.vars())

def train_op(x, label):
    g, v = gv(x, label)
    opt(0.01, g)
    return v

train_op = objax.Jit(train_op, gv.vars() + opt.vars())
nits = 100
loss = [train_op(X, y.squeeze())[0].item() for _ in range(nits)]

The optimisation was run for 100 iterations, but we can see that our parameter estimates had converged around 45 steps through the procedure. We can use our posterior distribution to now make predictions at our previously defined test locations $X^{\star}$.

mu, cov = posterior.predict(Xtest, X, y.squeeze())

The resultant predictions appear to be consistent with the underlying latent function’s shape. Further, the GP’s posterior is returning realistic uncertainty estimates, as seen through the 95% credible interval that surrounds the predictive mean. We also observe, reassuringly, that the uncertainty width begins to expand as we depart from the range of the data. Concurrently, the GP’s mean begins to revert to the global, zero mean; a behaviour that we would expect in a normal functioning GP.

Extending the package

Sparse schemes

One problem with the current implementation is the need to invert the kernel’s Gram matrix every time we query the GP’s MLL. This operation scales cubically in the number of data points, and as such becomes intractable when the dataset size surpasses several thousand points.

PyPi submission

Once the package was in a stable and usable for, I submitted to PyPi. To be honest, the only reason for doing this was to allow people to install through pip install gpjax instead of having to clone the repo and run python install. As a forewarning, if you’re hoping to replicate my process for uploading to PyPi, then you’ll need an account with PyPi and test.PyPi - it’s free, just worth pointint out now before we get into the nitty gritty details of packaging.

Generally speaking, submitting to PyPi wasn’t too much trouble - you can find the package here if you’re interested. The first thing to do was to get the file in order. I’m not saying this is how you should structure the file, but this was the shape of mine upon initially submitting to PyPi

from setuptools import setup, find_packages

def parse_requirements_file(filename):
    with open(filename, encoding="utf-8") as fid:
        requires = [l.strip() for l in fid.readlines() if l]
    return requires

    author='Thomas Pinder',
    packages= find_packages(".", exclude=['tests']),
    'Didactic Gaussian processes in Jax and ObJax.',
    long_description="GPJax aims to provide a low-level interface to Gaussian process models. Code is written entirely in Jax and Objax to enhance readability, and structured so as to allow researchers to easily extend the code to suit their own needs. When defining GP prior in GPJax, the user need only specify a mean and kernel function. A GP posterior can then be realised by computing the product of our prior with a likelihood function. The idea behind this is that the code should be as close as possible to the maths that we would write on paper when working with GP models.",
    keywords = ['gaussian-processes jax machine-learning bayesian']

As I say, there are probably much better ways to structure files, but this worked for me.

Next, I stored package’s version number within the root file under the __version__ variable. This is not a requirement of submitting your package to PyPi, it just helps once the package is installed to ensure you’re using the right version number.

Finally, I constructed the compressed package file (tar.gz for me as I run Linux) and the corresponding wheel (.whl) file. These can be created using the previously created file by simply running the following command from within your terminal

python sdist bdist_wheel

These files get stored within the dist/ directory. You can check that the correct files are within your compressed tar.gz file by running tar xzf dist/GPJax-0.1.0.tar.gz - don’t worry, this only prints out the contents of the tar file, it doesn’t actually extract the files.

The twine package then allows us to check the validity of these files and then submit them to PyPi. To check the validity of your compressed package files, run

twine check dist/*

If your files pass, then It’s strongly recommended to check that the upload to PyPi will be successful by first uploading to test.PyPi. This can be done through

twine upload --repository-url dist/*

You can check if this was all successful by navigating to test.PyPi where you should be able to find your package by either searching or, if you’re quick enough, it’ll be on the homepage. Fortunately, for me at least, this process worked seamlessly first time, so I went ahead and uploaded to PyPi itself using

twine upload dist/*

And that was it. I could now install GPJax by just running pip install gpjax from within my terminal. This section is not intended to be a how-to guide on submitting a package to PyPi. If you want a guide, then I’d strongly advocate for following the guide at RealPython.

EDIT: After successfully uploading to PyPi, I did have some problems import the package. The specifc issue was that every time I tried to run import gpjax, I would receive the error ModuleNotFoundError: No module named 'gpjax.gps'. Now this was confusing as that specific import was called in the package’s unit test and wasn’t failing. Further, I could run a local installation python develop and again receive no issues. After much head banging, I was able to identify that my file needed the line packages= find_packages(".", exclude=['tests']), within the setup() function call. The reason for this appears to be that gpjax.gps is a submodule of gpjax. Without the find_packages() function call, the file appears to have no notion regarding the existence of this submodule, and it therefore can’t be imported. This is at least my understanding, I may very well have oversimplified the root cause though. If this is true, then please do let me know.

Package automation

There are a number of incredibly useful tools out there for package development in Python. In this section, I’ll briefly summarise the tools and packages that I’ve found useful in automating the nitty gritty parts of package developments.


I try and stick to semantic versioning with major projects that I work on. In short, semantic versioning requires version numbers to have three numerical components: a major, minor and patch specific number. I’ll initialise the version of GPJax to 0.1.0. I’ve opted to start with my major version number as 0 to, hopefully, highlight to users that package is still in its infancy. Typically, once packages surpass version 1 they tend to be more stable than pre-version 1.x.x.

As an when bugs inevitably occur, a piece of code will be pushed to fix the bug. Doing so will increment the patch version number. When package functionality is added in a backward-compatible manner, the minor version number will be incremented and when the newly introduced functionality makes existing code incompatible, the major version number will be incremented.

To keep track of all of this, I’ll use the bumpversion package. This is a really helpful package that increments the package’s version number in both the file and the packages root file. As an example of how the patch version number can be incremented, one would just run the following command

bumpversion --current-version 0.1.0 patch gpjax/

This, although seemingly trivial, is really helpful as it just means there’s one less thing to keep track of when managing the package.

Unit tests

Unit tests are a really useful way to prevent accidentally breaking your code by changing some code that results in unanticipated changes downstream. I don’t find writing tests very interesting, so any package that accelerates the process is welcome to me. Fortunately, PyTest satisfies this criteria.

As I understand it, PyTest will automatically treat any function that begins with test...() as a unit test and the test’s outcome is determined by an assertion. An additional component of PyTest that I find useful is the mark.parameterize decorator that any can precede any test. This allows for a range of different arguments and configurations to be passed to the test. To see an example of this, consider the following unit test that checks whether the Gram matrix generated by the kernel is square. Through the mark.parameterize decorator, this test is run for 1, 2, and 5-dimensional inputs.

@pytest.mark.parametrize("dim", [1, 2, 5])
def test_shape(dim):
    x = jnp.linspace(-1., 1., num=10).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack([x]*dim)
    kern = RBF()
    gram = kern(x, x)
    assert gram.shape[0] == x.shape[0]
    assert gram.shape[0] == gram.shape[1]

Personally, I like this style of writing unit tests as it’s relatively straightforward and allows me to get on with writing the GP code.

Code coverage

One metric for quantifying how much of your code base is considered by the unit tests that you have written is the code coverage. This value is the percentage of lines in your code base that feature in at least one of your tests. As such, 100% is optimal.

The web app is a nice way to automate the computation of the metric. By giving the app read access to your codebase’s repository, the following Github action yaml file can be used to automatically run your tests and then upload the results to codecov.

name: Python Master Workflow
      - 'master'
    name: Codecov Workflow
    runs-on: ubuntu-18.04

      - uses: actions/checkout@v1
      - name: Set up Python
        uses: actions/setup-python@master
          python-version: 3.8
      - name: Generate coverage report
        run: |
          pip install -r requirements.txt
          pip install -e .
          pip install pytest
          pip install pytest-cov
          pytest --cov=./ --cov-report=xml          
      - name: Upload coverage to Codecov
        uses: codecov/codecov-action@v1
          token: ${{ secrets.CODECOV_TOKEN }}
          file: ./coverage.xml
          flags: unittests

In codecov’s site you can then get a detailed breakdown of which lines are covered and which are not. You can also get a badge which can be placed on your repository: codecov I’ll admit that setting up the github actions above took some wrangling, but this blog post was really useful.


Installing Sphinx

For documentation I’ve been using Sphinx in conjunction with ReadTheDocs. This was one of the more fiddly components of the package to set up, although, once setup, it is effortless to maintain. From looking at several major Python project, it seems that there are many ways to manage documentation, so this is just my procees, and you may well find a different approach that is more intuitive to you.

First up, you can install Sphinx using the following pip commain

pip install -u Sphinx

Once installed, navigate to your project’s root directory and run

mkdir docs
cd docs/

This will present you with an interactive terminal routine. The options that I selected were

  • Separate source and build directories (y/n) [n]: y
  • Project name: GPJax
  • Author name(s): Thomas Pinder
  • Project release []: 0.2.0
  • Project language [en]: en

From your docs/ directory, this will create two folders, a source/ and a build/ directory. If you haven’t already, then I would recommend adding the build/ directory to your .gitignore file as this directory will contain all of your compiled .html files and it’ll unecessarilly bloat the repo. In addition to these two directories, you’ll also have a Makefile and a make.bat file. I won’t go into the make.bat file as it’s a Windows specific file and I have no knowledge on how this works. The Makefile is highly convenient though as it allows all of your documentation to be build by simply running the make html command. You can actually go ahead an run this now and if you launch a web server (python -m http.server) from the docs/build/html/ directory, then you should see an empty documentation site being served.

Configuring Sphinx

It’s helpful to load a few Sphinx extensions before writing any documentation as some of them can save a lot of time when writing documentation. They live within the file and you can simply add them to the already existing list that is about halfway down the file. I have the following extensions loaded

extensions = [

You’ll also want to point Sphinx up an extra directory by adding sys.path.insert(0, os.path.abspath('../..')) to near the top of your file.

Creating documentation

So at this point Sphinx is working so we just need to tell Sphinx what information should be loaded. Sphinx works by converting a series of restructured text (.rst) files into html. I don’t think it’s worthwhile going into the structure and syntax of .rst here. If you’ve not come across it before, then I’d recommend one of the Sphinx/restructured text cheatsheets.

The real power of Sphinx comes from the fact that it can automatically extract the docstrings from your Python objects and respective methods, provided you’ve documented all your code… In short, a docstring should look similar to the following which I currently have for the base kernel object

class Kernel(Module):
    Base class for all kernel functions. By inheriting the `Module` class from Objax, seamless interaction with model parameters is provided.
    def __init__(self, name: Optional[str] = None):
            name: Optional naming of the kernel.
        """ = name
        self.spectral = False

    def gram(func: Callable, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        Compute the kernel's gram matrix given two, possibly identical, Jax arrays.

            func: The kernel function to be called for any two values in x and y.
            x: An NxD vector of inputs.
            y: An MXE vector of inputs.

            An NxM gram matrix.
        mapx1 = vmap(lambda x, y: func(x=x, y=y),
                     in_axes=(0, None),
        mapx2 = vmap(lambda x, y: mapx1(x, y), in_axes=(None, 0), out_axes=1)
        return mapx2(x, y)

You’ll see the docstrings under the object and method definitions all follow a similar structure, and it is this consistency that allows Sphinx to parse the text into nicely formatted html. To point Sphinx to the correct python file, you can create a new .rst file in your documentation’s source folder. Within that folder, all you need do is point Sphinx to the file as follows


Base kernel

.. currentmodule:: gpjax.kernels

.. autoclass:: Kernel

After the headers, first Sphinx command is point Sphinx to the correct module, whilst the second is calling in the kernel object and it’s respective methods. This will, in turn, generate a block of html that summarises the described the entire Kernel object.

Hosting documentation

You can now set about hosting your documentation. I opted to use readthedocs as it was free and straightforward to set up. After creating an account, you can link your Github repository to your readthedocs profile. This repository will then appear in your readthedocs profile. Click on the repository from readthedocs, and then simply select the build option to construct your documentation.

Over time you may wish to install packages specifically for documentation purposes. If you do, then you do not have to add them as a package dependency, and can instead create a requirements.txt file from within your docs/ directory. Then, from within readthedocs just navigate to the admin tab and select advanced settings. Then, under the default settings header you can point readthedocs to your docs/requirements.txt file in the Requirements file option.


For formatting, I like to use yapf. This is a terminal level code formatter that can be installed using pip pip install yapf. Once installed, you can point yapf to source code’s folder and format all files using yapf -i <source_folder>/*.py. This will overwrite all files, so it may be worth removing the -i argument if you’re just experimenting with the package.


I’ve been working with GPs now for several years, primarily using GaussianProcesses.jl and GPFlow. As such, much of the fundamentals of this package are inspired by these code bases. Further, the Gaussian Processes for Machine Learning has been an invaluable point of reference. Further, conversations with Christopher Nemeth and Ti John helped to mould the structure for this package.

As discussed in this article’s infancy, my motivation for using Jax in this work was partly due to my wish to learn the package. In doing so, the biweekly reading group at Lancaster with Jeremie Coullon and Matthew Ludkin has been a great way to dig into the nitty gritty internals in Jax. Further, the Uncertain Gaussian processes blog post/resource by J Emmanuel Johnson was a great way to see how Jax could be used for GP modelling.