Implementing a scalable multi-output GP model with exact inference

This is the final post in our series about multi-output Gaussian process (GP) models. In the first post, we described how to generalise single-output GPs to multi-output GPs (MOGPs). We also introduced the Mixing Model Hierarchy (MMH), as a way to classify and organise a large number of MOGP models from the literature. In the second post, we discussed the Instantaneous Linear Mixing Model (ILMM), the base model of the MMH, showing how its low-rank assumption can be exploited to speed up inference via simple linear algebra tricks. We used those results to motivate the Orthogonal Instantaneous Linear Mixing Model (OILMM), a version of the ILMM which scales even more favourably, allowing us to model up to tens of millions of points on a regular laptop.

In this post, we give concrete implementation details of an inference algorithm for the OILMM, showing how simple it is to have an efficient implementation. We present some simple code and also links to public implementations, in both Python and in Julia.

We start with a quick recap of the OILMM.

The Orthogonal Instantaneous Linear Mixing Model (OILMM)

The Orthogonal Instantaneous Linear Mixing Model (OILMM) is a multi-output GP (MOGP) model designed with scalability in mind. It describes the data as a linear combination of latent (unobserved) processes, which are themselves described as independent GPs. Mathematically:

\[\begin{align} y(t) \sim \mathcal{GP}(Hx(t), \delta_{tt'} \Sigma), \\ x(t) \sim \mathcal{GP}(0, K(t, t')), \end{align}\]

where \(x(t)\) are the latent processes, \(H\) is the mixing matrix, \(\delta_{tt'}\) is a Kronecker delta, and \(\Sigma\) is a matrix that specifies the noise model. The expressions above define a general Instantaneous Linear Mixing Model (ILMM). To obtain an orthogonal ILMM (OILMM), the mixing matrix \(H\) must have orthogonal columns and the noise matrix \(\Sigma\) to be of the form \(\Sigma = \sigma^2I + HDH^\top\), where \(I\) is an identity matrix and \(D > 0\) is diagonal.

The key aspect of an OILMM is that it turns a MOGP problem into a set of independent single-output GP problems, which brings a very significant gain in scalability. This independence also allows the OILMM to be trivially combined with other single-output GP techniques, such as sparse GPs or state-space approximations. If you are curious about how this is possible, we recommend checking out our previous post for a high-level explanation, or our paper, Scalable Exact Inference in Multi-Output Gaussian Processes, for a rigorous discussion. We will now focus on the practical computational aspects of the OILMM.

Implementing the OILMM

Let’s start by showing the algorithm for implementing the OILMM in a general regression/prediction setting. All that is needed is a regular GP package. To illustrate, we’ll show code using AbstractGPs in Julia, and Stheno1 in Python2. We choose these packages because they are minimal and the code is almost like pseudo-code making it straightforward to follow even for people who have never used Julia or Python before.

We discuss the procedure for performing inference, for sampling from the posterior, and for computing the log-likelihood of the data. We will assume that the OILMM has \(p\) outputs, \(n\) timestamps, and \(m\) latent processes.

Notation

Symbol Type Description
\(U\) Truncated orthogonal \(p \times m\) matrix Orthogonal part of the mixing matrix, \(H = US^{1/2}\)
\(S\) Positive, diagonal \(m \times m\) matrix Diagonal part of the mixing matrix, \(H = US^{1/2}\)
\(\sigma^2\) Positive real number Part of the observation noise
\(D\) Positive, diagonal \(m \times m\) matrix Part of the observation noise deriving from the latent processes
\(Y\) \(p \times n\) matrix Matrix of observations

Performing inference

There are five steps to performing inference in the OILMM framework:

  1. Build the projection.
  2. Project the observations to the latent space.
  3. Project the noise to the latent space.
  4. Independently condition each latent process on the projected observations, using the projected noise.
  5. Transform the posterior means and covariances back to the space as observations, using the mixing matrix.

Step 0: preliminaries

Let’s start with some basic definitions for the example code.

Julia:

using AbstractGPs
using LinearAlgebra
using Statistics

n = 100 # Number of timestamps

# Model specification:
p = 10                # Number of outputs
m = 3                 # Number of latent processes
σ² = 0.1              # Observation noise
D = Diagonal(rand(m)) # Latent noises

Python:

import lab as B
from matrix import Dense, Diagonal
from stheno import GP, Matern52

n = 100  # Number of timestamps

# Model specification:
p = 10         # Number of outputs
m = 3          # Number of latent processes
noise = 0.1    # Observation noise
d = B.rand(m)  # Latent noises

Step 1: build the projection

We know that the original space, where the observations are made, and the latent space are connected via the mixing matrix \(H\). Thus, to project observations to the latent space we need to determine the left inverse of \(H\), which we call \(T\). It is easy to see that \(T = S^{-1/2}U^\top\). Notice that the matrix \(T\) is the same that defines the sufficient statistic we used to motivate the OILMM.

Julia:

# Sample a random mixing matrix.
U, s, _ = svd(randn(p, m))
H = U * Diagonal(broadcast(sqrt, s))

# Build the projection.
T = Diagonal(sqrt.(s)) \ transpose(U)

Python:

# Sample a random mixing matrix.
U, s, _ = B.svd(B.randn(p, m))
U, S = Dense(U), Diagonal(s)
H = U @ S

# Build the projection.
T = B.inv(B.sqrt(S)) @ U.T

Step 2: project the observations

Taking the observations to the latent space is done by left-multiplying by \(T\). Thus, the projected observations can be written as \(Y_{\text{proj}} = TY\). Some intuition for why this takes the observations to the latent space is as follows: if the observations \(Y\) were generated as \(Y = HX\), then \(TY = THX = X\) recovers \(X\). We call \(y^{(i)}_ \text{proj}\) the \(i\)-th row of \(Y_\text{proj}\), corresponding to the observations of the \(i\)-th latent process.

Julia:

# Sample some noisy data.
x = range(0.0, 10.0; length=n)
Y = transpose(rand(GP(Matern52Kernel())(x), 10)) # Generate sample data from some GP.

# Project the observations.
Y_proj = T * Y

Python:

# Sample some noisy data.
x = B.linspace(0, 10, 100)
f = GP(Matern52())
Y = f(x, noise).sample(p).T  # Generate sample data from some GP.

# Project the observations.
Y_proj = T @ Y

Step 3: project the noise

We start by noting that \(Y_{\text{proj}} = TY\) means that \(Y_{\text{proj}}\) encodes the noise present in the observations \(Y\). We know that \(Ty(t) | x(t), H \sim \mathrm{GP}(THx(t), \delta_{tt'}T \Sigma_T T^\top)\), so we must compute \(\Sigma_T = T \Sigma_T T^\top\), which we call the projected noise. \(\Sigma_T\) is a diagonal matrix by construction.

Julia:

ΣT = repeat(σ² ./ s + diag(D), 1, n) # Repeat the same noise matrix for every timestamp.

Python:

noise_proj = noise / B.diag(S) + d

Step 4: condition latent processes

Since \(\Sigma_T\) is diagonal and the latent processes are independent under the prior, the latent processes can be conditioned independently. Thus, the GP corresponding to the \(i\)-th latent process is conditioned on \(y^{(i)}_ \text{proj}\) using the corresponding noise \((\Sigma_T)_{ii}\). Since this only involves dealing with single-output GPs, any available GP package can be used for this. This makes it particularly easy to construct the OILMM with your favourite GP package. Moreover, any scaling technique for single-output GPs can be used here, such as the sparse inducing points technique by Titsias.

Julia:

lats = [GP(Matern52Kernel()) for _ in 1:m] # Noiseless latent processes

# Condition the latent processes.
lats_post = [posterior(lats[j](x, ΣT[j, :]), Y_proj[j, :]) for j in 1:m]

Python:

lats = [GP(Matern52()) for _ in range(m)]  # Noiseless latent processes

# Condition the latent processes.
lats_post = [f.condition(f(x, ni), yi) for ni, yi in zip(noise_proj, Y)]

Step 5: transform posterior latent processes to observation space

For the predictive mean of the full MOGP, simply compute the predictive means of each of the posterior latent processes, stack them in an \(m \times n'\) matrix \(\mu\) (with each row corresponding to a different latent process), and then multiply by the mixing matrix, obtaining \(\text{predictive mean} = H \mu \in \mathrm{R}^{p \times n'}\).

Julia:

# Compute the posterior marginal means `M`.
M_latent = vcat([transpose(mean(lats_post[j](x))) for j in 1:m]...)
M = H * M_latent

Python:

# Compute the posterior marginal means `M`.
M_latent = B.concat(*[f.mean(x).T for f in lats_post], axis=0)
M = H @ M_latent

For the predictive marginal variances, repeat the same process as with the predictive means, but stacking the variances \(\nu^{(i)}\) of each latent process instead, obtaining \begin{equation} \text{predictive marginal variances} = (H \circ H) \nu, \end{equation} with \(\circ\) denoting the element-wise (Hadamard) product.

Julia:

# Compute the posterior marginal means `V`.
V_latent = vcat([transpose(var(lats_post[j](x))) for j in 1:m]...)
V = abs2.(H) * (V_latent .+ D.diag) .+ σ²

Python:

# Compute the posterior marginal means `V`.
V_latent = B.concat(*[f.kernel.elwise(x).T for f in lats_post], axis=0)
V = (H ** 2) @ (V_latent + d[:, None]) + noise

It is also possible to compute full predictive covariance matrices, by observing that for any two given points in time, say \(t_1\) and \(t_2\),

\begin{equation} (\text{predictive covariance})_{t_1t_2} = H K_{\text{posterior}}(t_1, t_2) H^\top \in \mathrm{R}^{p \times p} \end{equation}

with

\begin{equation} K_{\text{posterior}}(t_1, t_2) = \mathrm{Diagonal}([k^{(1)}(t_1, t_2), \cdots, k^{(m)}(t_1, t_2)]) \in \mathrm{R}^{m \times m} \end{equation}

and \(k^{(i)}\) the posterior kernel of the \(i\)-th latent process. Putting this together takes a bit of careful index-keeping and depends on how you are organising the dimensions (time first or outputs first), but the computation itself is still straightforward.

Sampling from the posterior

Drawing samples from the posterior is rather similar to computing the predictive mean in step 5 above. Because the posterior latent processes remain independent from each other, we can sample each of them independently, which is a functionality that any GP package provides. This way, we can stack a single sample from each latent process into a matrix \(\hat X \in \mathrm{R}^{m \times n'}\), and then transform it via the mixing matrix, obtaining \(\text{posterior sample} = H \hat X\).

Julia:

# Sample from the noiseless posterior.
F_latent = vcat([transpose(rand(lats_post[j](x))) for j in 1:m]...)
F = H * F_latent
F_noisy = F .+ sqrt(σ²) .* randn(size(F))

Python:

# Sample from the noiseless posterior.
F_latent = B.concat(*[f(x).sample().T for f in lats_post], axis=0)
F = H @ F_latent
F_noisy = F + B.sqrt(noise) * B.randn(*B.shape(F))

Computing the log-likelihood of the data

Computing the log-likelihood of data is a three-step process. First, we compute the log-likelihood for each latent process independently. Then, we compute a term that is independent of the latent kernels, which we identify as a regularisation term. Finally, we combine the two terms.

Step 1: compute the likelihood under the latent processes

First, we must project the data to the latent space, as described in the inference section above, giving us \(Y_{\text{proj}} = TY\). Each row of this matrix will correspond to the time series of a different latent process. Since they are fully independent, all we have to do is compute the log-likelihood of the \(i\)-th row under the \(i\)-th latent process. We call this quantity \(\operatorname{LL}_i\). This can be done using any GP package.

Julia:

lml_latents = [logpdf(lats[j](x, ΣT[j, :]), Y_proj[j, :]) for j in 1:m]

Python:

lml_latents = [f(x, ni).logpdf(yi) for f, ni, yi in zip(lats, noise_proj, Y_proj)]

Step 2: add regularisation term

The log-likelihood of the data under an OILMM does not depend solely on the latent processes, as it must account for the effects of the projection step. As we show in our work, the log-likelihood can be written as the sum of two terms. The first one is the log-likelihood of the projected data under the latent processes, which we computed in the step above. The second one is a term that accounts for the loss of information during the projection. Since it prevents the data from being projected to zero, it can be seen as a regularisation term. This term is given by:

\[\begin{equation} \text{regulariser} = - \frac{n}{2} \log |S| - \frac{n (p-m)}{2} \log 2 \pi \sigma^2 - \frac{1}{2\sigma^2} (\left \| Y \right \|_F^2 - \left \| U^T Y \right \|_F^2), \end{equation}\]

with \(\left \| Y \right \|_F\) the Frobenius norm.

Julia:

regulariser = -(n * (sum(abs2, s) + (p - m) * log(2π * σ²)) + (sum(abs2, Y) - sum(abs2, Diagonal(sqrt.(s)) * Y_proj)) / σ²) / 2

Python:

regulariser = -0.5 * (
    n * (p - m) * B.log(2 * B.pi * noise)
    + n * B.logdet(S)
    + (B.sum(Y ** 2) - B.sum((B.sqrt(S) @ Y_proj) ** 2)) / noise
)

Step 3: combine both terms

The log-likelihood of the data is given by the sum of the log-likelihoods under each of the latent processes, \(\operatorname{LL}_i\), as computed in step 1 above, and the regularisation term from the previous step. That is:

\begin{equation} \log p(Y) = \text{regulariser} + \sum_{i=1}^m \text{LL}_i. \end{equation}

Julia:

loglik = regulariser + sum(lml_latents)

Python:

loglik = regulariser + sum(lml_latents)

Summary of section

We highlight that all of the implementation steps above are quite simple to perform using any GP package that is available, even if it has no MOGP implementation. We consider the simplicity of the implementation one of the strengths of this method. However, if you are interested in dedicated implementations of the method, which work off-the-shelf, we show below available packages for both Python and Julia.

Time and memory complexities

Scalability is one of the key strengths of the OILMM, so it is relevant to discuss the time and memory complexities involved in utilising the method. We’ll consider the case of \(p\) outputs, observed at \(n\) timestamps each, and an OILMM with \(m\) latent processes.

In realistic cases we typically have \(n > p > m\), meaning that costs in both time and memory will be dominated by the storage and inversion of the covariance matrix. We have discussed in detail how we arrive at these costs in our two previous posts in the series. In the table below we summarise the results using the general case for MOGPs and the ILMM as reference. The linear scaling on \(m\) presented by the OILMM is a direct consequence of inference being done independently for each latent process, as described in step 4 of the “Performing Inference” section above.

Table 1: Time and memory scaling for storing and inverting the covariance matrix under the general MOGPs, the ILMM, and the OILMM.

Model Time Memory
General MOGP \(\mathcal{O}(n^3p^3)\) \(\mathcal{O}(n^2p^2)\)
ILMM \(\mathcal{O}(n^3m^3)\) \(\mathcal{O}(n^2m^2)\)
OILMM \(\mathcal{O}(n^3m)\) \(\mathcal{O}(n^2m)\)

For cases in which \(n\) is too large and even the favourable scaling presented by the OILMM is still prohibitive, it is straightforward to combine the OILMM with scaling techniques designed for single-output GPs, such as the one developed by Titsias or the one developed by Hartikainen and Särkkä. Focusing only on the implementation aspects, the addition of these scaling techniques only affects step 4 of the “Performing Inference” section and step 1 of the section “Computing the log-likelihood of the data’’ above, in which the independent single-output GPs should be treated according to the chosen approximation. In the table below we present the scaling costs for the combination of the OILMM with these two methods, assuming \(r\) inducing points in Titsias’ method, and dimensionality \(d\) in Hartikainen and Särkkä’s method. Typically \(r \ll n\) and \(d \ll n, m\). In Bruinsma et al., we used Hartikainen and Särkkä’s method and a separable spatio-temporal kernel to efficiently train the OILMM over several million data points.

Table 2: Time and memory scaling for performing inference under the OILMM combined with single-output GP scaling techniques.

Model Time Memory
OILMM + Titsias \(\mathcal{O}(nmr^2)\) \(\mathcal{O}(nmr)\)
OILMM + Hartikainen and Särkkä \(\mathcal{O}(nmd^3)\) \(\mathcal{O}(nmd^2)\)

If we are rigorous, there are other costs involved in applying the OILMM, such as the cost of storing the data in memory, building the matrix T to project the data into the latent space, performing this projection, and building the predictive marginal means and variances. Usually, these costs are largely dominated by the costs shown in Table 1 above, and they become relevant only when the number of timestamps \(n\) becomes comparable with the number of latent processes \(m\), which is not a common setting (that would likely be too little data to efficiently train the model). We summarise these costs in the table below.

Table 3: Time and memory scaling for performing secondary tasks under the OILMM.

Task Time Memory
Storing data \(\mathcal{O}(np)\)
Building matrix T \(\mathcal{O}(m^2p)\) \(\mathcal{O}(mp)\)
Projecting data \(\mathcal{O}(nmp)\) \(\mathcal{O}(np)\)
Building marginal statistics \(\mathcal{O}(nmp)\) \(\mathcal{O}(np)\)

Conclusion

With this post we conclude a three-part series on multi-output Gaussian process models, with emphasis on the OILMM.

In the first part of the series we presented a very brief introduction to MOGPs, arguing that they can be viewed simply as single-output GPs acting over an extended input space. In that post we also introduced the Mixing Model Hierarchy, which attempts to organise a large number of MOGP models from the literature using a simple and widespread ILMM as reference.

In the second post we delved a bit deeper into the ILMM, discussing the mathematical tricks that make it more scalable. We used one of these tricks to motivate and present the OILMM, which improves on the scalability of the ILMM.

In this post we learned how to efficiently implement the OILMM in practice, and shared some of our implementations in both Julia and Python.

We hope these posts have served to highlight some of the interesting properties of MOGPs, and might serve as a general (albeit not too technical) introduction to the OILMM. Below we offer some open implementations of the OILMM we have made in Python and in Julia.

Open implementations

We hope that this post shows how simple it is to implement the OILMM, and can serve as a reference for implementing the model in any language. We also offer open implementations in both Julia and Python, which should make the OILMM readily accessible to everyone. Both implementations are based on the GP package Stheno (about which we already talked in our post about linear Gaussian process models using Jax), and present very similar APIs, adapted to the particularities of each language. Below we briefly show a simple application with each of the implementations.

Julia

This implementation can be found in OILMMs.jl

Without learning:

using AbstractGPs
using LinearAlgebra
using OILMMs
using Random
using TemporalGPs

# Specify and construct an OILMM.
p = 10
m = 3
U, s, _ = svd(randn(p, m))
σ² = 0.1

f = OILMM(
    [to_sde(GP(Matern52Kernel()), SArrayStorage(Float64)) for _ in 1:m],
    U,
    Diagonal(s),
    Diagonal(rand(m) .+ 0.1),
);

# Sample from the model. LARGE DATA SET!
x = MOInput(RegularSpacing(0.0, 1.0, 1_000_000), p);
fx = f(x, σ²);
rng = MersenneTwister(123456);
y = rand(rng, fx);

# Compute the logpdf of the data under the model.
logpdf(fx, y)

# Perform posterior inference. This produces another OILMM.
f_post = posterior(fx, y)

# Compute the posterior marginals.
# We can also use `rand` and `logpdf` as before.
post_marginals = marginals(f_post(x));

With learning:

using AbstractGPs
using OILMMs
using TemporalGPs

# Load standard packages from the Julia ecosystem
using LinearAlgebra
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Random
using Zygote # Algorithmic Differentiation

# Specify OILMM parameters as a NamedTuple.
# Utilise orthogonal and positive from ParameterHandling.jl to constrain appropriately.
p = 2
m = 1
θ_init = (
    U = orthogonal(randn(p, m)),
    s = positive.(rand(m) .+ 0.1),
    σ² = positive(0.1),
)

# Define a function which builds an OILMM, given a NamedTuple of parameters.
function build_oilmm(θ::NamedTuple)
    return OILMM(
    # Here we adopt a state-space approximation for better
    # scalability. We could have instead chosen to use regular
    # GPs, for instance, `GP(SEKernel())`, without the call to
    # `to_sde`.
    	[to_sde(GP(Matern52Kernel()), SArrayStorage(Float64)) for _ in 1:m],
    	θ.U,
    	Diagonal(θ.s),
    	Diagonal(zeros(m)),
    )
end

# Generate some synthetic data to train on.
f = build_oilmm(ParameterHandling.value(θ_init));
const x = MOInput(RegularSpacing(0.0, 0.01, 1_000_000), p);
fx = f(x, 0.1);
rng = MersenneTwister(123456);
const y = rand(rng, fx);

# Define a function which computes the negative log marginal likelihood given the parameters.
function objective(θ::NamedTuple)
    f = build_oilmm(θ)
    return -logpdf(f(x, θ.σ²), y)
end

# Build a version of the objective function which can be used with Optim.jl.
θ_init_flat, unflatten = flatten(θ_init);
unpack(θ::Vector{<:Real}) = ParameterHandling.value(unflatten(θ))
objective(θ::Vector{<:Real}) = objective(unpack(θ))

# Utilise Optim.jl + Zygote.jl to optimise the model parameters.
training_results = Optim.optimize(
    objective,
    θ -> only(Zygote.gradient(objective, θ)),
    θ_init_flat + randn(length(θ_init_flat)), # Add some noise to make learning non-trivial
    BFGS(
        alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
        linesearch = Optim.LineSearches.BackTracking(),
    ),
    Optim.Options(show_trace = true);
    inplace=false,
)

# Compute posterior marginals at optimal solution.
θ_opt = unpack(training_results.minimizer)
f = build_oilmm(θ_opt)
f_post = posterior(f(x, θ_opt.σ²), y)
fx = marginals(f_post(x))

Python

The dependencies for this implementation can be installed via a call to pip install oilmm jax jaxlib in the command line.

import numpy as np
import jax.numpy as jnp
from stheno import EQ, GP
from oilmm.jax import OILMM

def build_latent_processes(params):
    # Return models for latent processes, which are noise-contaminated GPs.
    return [
        (
            # Create GPs with learnable variances initialised to one and
            # learnable length scales, also initialised to one.
            p.variance.positive(1) * GP(EQ().stretch(p.length_scale.positive(1))),
            # Use learnable noise variances, initialised to `1e-2`.
            p.noise.positive(1e-2),
        )
        for p, _ in zip(params, range(3))
    ]


# Construct model.
prior = OILMM(jnp.float32, build_latent_processes, num_outputs=6)

# Create some sample data.
x = np.linspace(0, 10, 100)
y = prior.sample(x)

# Fit OILMM.
prior.fit(x, y, trace=True, jit=True)
prior.vs.print()  # Print all learned parameters.

# Make predictions.
posterior = prior.condition(x, y)
mean, var = posterior.predict(x)
lower = mean - 1.96 * np.sqrt(var)
upper = mean + 1.96 * np.sqrt(var)

Notes

  1. To install Stheno, simply run ​​pip install stheno from the command line. 

  2. We also use the python package LAB to make our code agnostic to the linear algebra backend.