Implementing a scalable multi-output GP model with exact inference
30 Jul 2021This is the final post in our series about multi-output Gaussian process (GP) models. In the first post, we described how to generalise single-output GPs to multi-output GPs (MOGPs). We also introduced the Mixing Model Hierarchy (MMH), as a way to classify and organise a large number of MOGP models from the literature. In the second post, we discussed the Instantaneous Linear Mixing Model (ILMM), the base model of the MMH, showing how its low-rank assumption can be exploited to speed up inference via simple linear algebra tricks. We used those results to motivate the Orthogonal Instantaneous Linear Mixing Model (OILMM), a version of the ILMM which scales even more favourably, allowing us to model up to tens of millions of points on a regular laptop.
In this post, we give concrete implementation details of an inference algorithm for the OILMM, showing how simple it is to have an efficient implementation. We present some simple code and also links to public implementations, in both Python and in Julia.
We start with a quick recap of the OILMM.
The Orthogonal Instantaneous Linear Mixing Model (OILMM)
The Orthogonal Instantaneous Linear Mixing Model (OILMM) is a multi-output GP (MOGP) model designed with scalability in mind. It describes the data as a linear combination of latent (unobserved) processes, which are themselves described as independent GPs. Mathematically:
where
The key aspect of an OILMM is that it turns a MOGP problem into a set of independent single-output GP problems, which brings a very significant gain in scalability. This independence also allows the OILMM to be trivially combined with other single-output GP techniques, such as sparse GPs or state-space approximations. If you are curious about how this is possible, we recommend checking out our previous post for a high-level explanation, or our paper, Scalable Exact Inference in Multi-Output Gaussian Processes, for a rigorous discussion. We will now focus on the practical computational aspects of the OILMM.
Implementing the OILMM
Let’s start by showing the algorithm for implementing the OILMM in a general regression/prediction setting. All that is needed is a regular GP package. To illustrate, we’ll show code using AbstractGPs in Julia, and Stheno1 in Python2. We choose these packages because they are minimal and the code is almost like pseudo-code making it straightforward to follow even for people who have never used Julia or Python before.
We discuss the procedure for performing inference, for sampling from the posterior, and for computing the log-likelihood of the data. We will assume that the OILMM has
Notation
Symbol | Type | Description |
---|---|---|
Truncated orthogonal |
Orthogonal part of the mixing matrix, |
|
Positive, diagonal |
Diagonal part of the mixing matrix, |
|
Positive real number | Part of the observation noise | |
Positive, diagonal |
Part of the observation noise deriving from the latent processes | |
Matrix of observations |
Performing inference
There are five steps to performing inference in the OILMM framework:
- Build the projection.
- Project the observations to the latent space.
- Project the noise to the latent space.
- Independently condition each latent process on the projected observations, using the projected noise.
- Transform the posterior means and covariances back to the space as observations, using the mixing matrix.
Step 0: preliminaries
Let’s start with some basic definitions for the example code.
Julia:
using AbstractGPs
using LinearAlgebra
using Statistics
n = 100 # Number of timestamps
# Model specification:
p = 10 # Number of outputs
m = 3 # Number of latent processes
σ² = 0.1 # Observation noise
D = Diagonal(rand(m)) # Latent noises
Python:
import lab as B
from matrix import Dense, Diagonal
from stheno import GP, Matern52
n = 100 # Number of timestamps
# Model specification:
p = 10 # Number of outputs
m = 3 # Number of latent processes
noise = 0.1 # Observation noise
d = B.rand(m) # Latent noises
Step 1: build the projection
We know that the original space, where the observations are made, and the latent space are connected via the mixing matrix
Julia:
# Sample a random mixing matrix.
U, s, _ = svd(randn(p, m))
H = U * Diagonal(broadcast(sqrt, s))
# Build the projection.
T = Diagonal(sqrt.(s)) \ transpose(U)
Python:
# Sample a random mixing matrix.
U, s, _ = B.svd(B.randn(p, m))
U, S = Dense(U), Diagonal(s)
H = U @ S
# Build the projection.
T = B.inv(B.sqrt(S)) @ U.T
Step 2: project the observations
Taking the observations to the latent space is done by left-multiplying by
Julia:
# Sample some noisy data.
x = range(0.0, 10.0; length=n)
Y = transpose(rand(GP(Matern52Kernel())(x), 10)) # Generate sample data from some GP.
# Project the observations.
Y_proj = T * Y
Python:
# Sample some noisy data.
x = B.linspace(0, 10, 100)
f = GP(Matern52())
Y = f(x, noise).sample(p).T # Generate sample data from some GP.
# Project the observations.
Y_proj = T @ Y
Step 3: project the noise
We start by noting that
Julia:
ΣT = repeat(σ² ./ s + diag(D), 1, n) # Repeat the same noise matrix for every timestamp.
Python:
noise_proj = noise / B.diag(S) + d
Step 4: condition latent processes
Since
Julia:
lats = [GP(Matern52Kernel()) for _ in 1:m] # Noiseless latent processes
# Condition the latent processes.
lats_post = [posterior(lats[j](x, ΣT[j, :]), Y_proj[j, :]) for j in 1:m]
Python:
lats = [GP(Matern52()) for _ in range(m)] # Noiseless latent processes
# Condition the latent processes.
lats_post = [f.condition(f(x, ni), yi) for ni, yi in zip(noise_proj, Y)]
Step 5: transform posterior latent processes to observation space
For the predictive mean of the full MOGP, simply compute the predictive means of each of the posterior latent processes, stack them in an
Julia:
# Compute the posterior marginal means `M`.
M_latent = vcat([transpose(mean(lats_post[j](x))) for j in 1:m]...)
M = H * M_latent
Python:
# Compute the posterior marginal means `M`.
M_latent = B.concat(*[f.mean(x).T for f in lats_post], axis=0)
M = H @ M_latent
For the predictive marginal variances, repeat the same process as with the predictive means, but stacking the variances
Julia:
# Compute the posterior marginal means `V`.
V_latent = vcat([transpose(var(lats_post[j](x))) for j in 1:m]...)
V = abs2.(H) * (V_latent .+ D.diag) .+ σ²
Python:
# Compute the posterior marginal means `V`.
V_latent = B.concat(*[f.kernel.elwise(x).T for f in lats_post], axis=0)
V = (H ** 2) @ (V_latent + d[:, None]) + noise
It is also possible to compute full predictive covariance matrices, by observing that for any two given points in time, say
with
and
Sampling from the posterior
Drawing samples from the posterior is rather similar to computing the predictive mean in step 5 above. Because the posterior latent processes remain independent from each other, we can sample each of them independently, which is a functionality that any GP package provides. This way, we can stack a single sample from each latent process into a matrix
Julia:
# Sample from the noiseless posterior.
F_latent = vcat([transpose(rand(lats_post[j](x))) for j in 1:m]...)
F = H * F_latent
F_noisy = F .+ sqrt(σ²) .* randn(size(F))
Python:
# Sample from the noiseless posterior.
F_latent = B.concat(*[f(x).sample().T for f in lats_post], axis=0)
F = H @ F_latent
F_noisy = F + B.sqrt(noise) * B.randn(*B.shape(F))
Computing the log-likelihood of the data
Computing the log-likelihood of data is a three-step process. First, we compute the log-likelihood for each latent process independently. Then, we compute a term that is independent of the latent kernels, which we identify as a regularisation term. Finally, we combine the two terms.
Step 1: compute the likelihood under the latent processes
First, we must project the data to the latent space, as described in the inference section above, giving us
Julia:
lml_latents = [logpdf(lats[j](x, ΣT[j, :]), Y_proj[j, :]) for j in 1:m]
Python:
lml_latents = [f(x, ni).logpdf(yi) for f, ni, yi in zip(lats, noise_proj, Y_proj)]
Step 2: add regularisation term
The log-likelihood of the data under an OILMM does not depend solely on the latent processes, as it must account for the effects of the projection step. As we show in our work, the log-likelihood can be written as the sum of two terms. The first one is the log-likelihood of the projected data under the latent processes, which we computed in the step above. The second one is a term that accounts for the loss of information during the projection. Since it prevents the data from being projected to zero, it can be seen as a regularisation term. This term is given by:
with
Julia:
regulariser = -(n * (sum(abs2, s) + (p - m) * log(2π * σ²)) + (sum(abs2, Y) - sum(abs2, Diagonal(sqrt.(s)) * Y_proj)) / σ²) / 2
Python:
regulariser = -0.5 * (
n * (p - m) * B.log(2 * B.pi * noise)
+ n * B.logdet(S)
+ (B.sum(Y ** 2) - B.sum((B.sqrt(S) @ Y_proj) ** 2)) / noise
)
Step 3: combine both terms
The log-likelihood of the data is given by the sum of the log-likelihoods under each of the latent processes,
Julia:
loglik = regulariser + sum(lml_latents)
Python:
loglik = regulariser + sum(lml_latents)
Summary of section
We highlight that all of the implementation steps above are quite simple to perform using any GP package that is available, even if it has no MOGP implementation. We consider the simplicity of the implementation one of the strengths of this method. However, if you are interested in dedicated implementations of the method, which work off-the-shelf, we show below available packages for both Python and Julia.
Time and memory complexities
Scalability is one of the key strengths of the OILMM, so it is relevant to discuss the time and memory complexities involved in utilising the method. We’ll consider the case of
In realistic cases we typically have
Table 1: Time and memory scaling for storing and inverting the covariance matrix under the general MOGPs, the ILMM, and the OILMM.
Model | Time | Memory |
---|---|---|
General MOGP | ||
ILMM | ||
OILMM |
For cases in which
Table 2: Time and memory scaling for performing inference under the OILMM combined with single-output GP scaling techniques.
Model | Time | Memory |
---|---|---|
OILMM + Titsias | ||
OILMM + Hartikainen and Särkkä |
If we are rigorous, there are other costs involved in applying the OILMM, such as the cost of storing the data in memory, building the matrix T to project the data into the latent space, performing this projection, and building the predictive marginal means and variances. Usually, these costs are largely dominated by the costs shown in Table 1 above, and they become relevant only when the number of timestamps
Table 3: Time and memory scaling for performing secondary tasks under the OILMM.
Task | Time | Memory |
---|---|---|
Storing data | — | |
Building matrix T | ||
Projecting data | ||
Building marginal statistics |
Conclusion
With this post we conclude a three-part series on multi-output Gaussian process models, with emphasis on the OILMM.
In the first part of the series we presented a very brief introduction to MOGPs, arguing that they can be viewed simply as single-output GPs acting over an extended input space. In that post we also introduced the Mixing Model Hierarchy, which attempts to organise a large number of MOGP models from the literature using a simple and widespread ILMM as reference.
In the second post we delved a bit deeper into the ILMM, discussing the mathematical tricks that make it more scalable. We used one of these tricks to motivate and present the OILMM, which improves on the scalability of the ILMM.
In this post we learned how to efficiently implement the OILMM in practice, and shared some of our implementations in both Julia and Python.
We hope these posts have served to highlight some of the interesting properties of MOGPs, and might serve as a general (albeit not too technical) introduction to the OILMM. Below we offer some open implementations of the OILMM we have made in Python and in Julia.
Open implementations
We hope that this post shows how simple it is to implement the OILMM, and can serve as a reference for implementing the model in any language. We also offer open implementations in both Julia and Python, which should make the OILMM readily accessible to everyone. Both implementations are based on the GP package Stheno (about which we already talked in our post about linear Gaussian process models using Jax), and present very similar APIs, adapted to the particularities of each language. Below we briefly show a simple application with each of the implementations.
Julia
This implementation can be found in OILMMs.jl
Without learning:
using AbstractGPs
using LinearAlgebra
using OILMMs
using Random
using TemporalGPs
# Specify and construct an OILMM.
p = 10
m = 3
U, s, _ = svd(randn(p, m))
σ² = 0.1
f = OILMM(
[to_sde(GP(Matern52Kernel()), SArrayStorage(Float64)) for _ in 1:m],
U,
Diagonal(s),
Diagonal(rand(m) .+ 0.1),
);
# Sample from the model. LARGE DATA SET!
x = MOInput(RegularSpacing(0.0, 1.0, 1_000_000), p);
fx = f(x, σ²);
rng = MersenneTwister(123456);
y = rand(rng, fx);
# Compute the logpdf of the data under the model.
logpdf(fx, y)
# Perform posterior inference. This produces another OILMM.
f_post = posterior(fx, y)
# Compute the posterior marginals.
# We can also use `rand` and `logpdf` as before.
post_marginals = marginals(f_post(x));
With learning:
using AbstractGPs
using OILMMs
using TemporalGPs
# Load standard packages from the Julia ecosystem
using LinearAlgebra
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Random
using Zygote # Algorithmic Differentiation
# Specify OILMM parameters as a NamedTuple.
# Utilise orthogonal and positive from ParameterHandling.jl to constrain appropriately.
p = 2
m = 1
θ_init = (
U = orthogonal(randn(p, m)),
s = positive.(rand(m) .+ 0.1),
σ² = positive(0.1),
)
# Define a function which builds an OILMM, given a NamedTuple of parameters.
function build_oilmm(θ::NamedTuple)
return OILMM(
# Here we adopt a state-space approximation for better
# scalability. We could have instead chosen to use regular
# GPs, for instance, `GP(SEKernel())`, without the call to
# `to_sde`.
[to_sde(GP(Matern52Kernel()), SArrayStorage(Float64)) for _ in 1:m],
θ.U,
Diagonal(θ.s),
Diagonal(zeros(m)),
)
end
# Generate some synthetic data to train on.
f = build_oilmm(ParameterHandling.value(θ_init));
const x = MOInput(RegularSpacing(0.0, 0.01, 1_000_000), p);
fx = f(x, 0.1);
rng = MersenneTwister(123456);
const y = rand(rng, fx);
# Define a function which computes the negative log marginal likelihood given the parameters.
function objective(θ::NamedTuple)
f = build_oilmm(θ)
return -logpdf(f(x, θ.σ²), y)
end
# Build a version of the objective function which can be used with Optim.jl.
θ_init_flat, unflatten = flatten(θ_init);
unpack(θ::Vector{<:Real}) = ParameterHandling.value(unflatten(θ))
objective(θ::Vector{<:Real}) = objective(unpack(θ))
# Utilise Optim.jl + Zygote.jl to optimise the model parameters.
training_results = Optim.optimize(
objective,
θ -> only(Zygote.gradient(objective, θ)),
θ_init_flat + randn(length(θ_init_flat)), # Add some noise to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
linesearch = Optim.LineSearches.BackTracking(),
),
Optim.Options(show_trace = true);
inplace=false,
)
# Compute posterior marginals at optimal solution.
θ_opt = unpack(training_results.minimizer)
f = build_oilmm(θ_opt)
f_post = posterior(f(x, θ_opt.σ²), y)
fx = marginals(f_post(x))
Python
The dependencies for this implementation can be installed via a call to pip install oilmm jax jaxlib
in the command line.
import numpy as np
import jax.numpy as jnp
from stheno import EQ, GP
from oilmm.jax import OILMM
def build_latent_processes(params):
# Return models for latent processes, which are noise-contaminated GPs.
return [
(
# Create GPs with learnable variances initialised to one and
# learnable length scales, also initialised to one.
p.variance.positive(1) * GP(EQ().stretch(p.length_scale.positive(1))),
# Use learnable noise variances, initialised to `1e-2`.
p.noise.positive(1e-2),
)
for p, _ in zip(params, range(3))
]
# Construct model.
prior = OILMM(jnp.float32, build_latent_processes, num_outputs=6)
# Create some sample data.
x = np.linspace(0, 10, 100)
y = prior.sample(x)
# Fit OILMM.
prior.fit(x, y, trace=True, jit=True)
prior.vs.print() # Print all learned parameters.
# Make predictions.
posterior = prior.condition(x, y)
mean, var = posterior.predict(x)
lower = mean - 1.96 * np.sqrt(var)
upper = mean + 1.96 * np.sqrt(var)