Design

This page details the key features of the design of Models.jl, which exists to solve the issue highlighted by the following quote:

ML researchers tend to develop general purpose solutions as self-contained packages. A wide variety of these are available as open-source packages ... Using generic packages often results in a glue-code system design pattern, in which a massive amount of supporting code is written to get data into and out of general-purpose packages. Glue-code is costly in the long term because it tends to freeze a system to the peculiarities of a specific package; testing alternatives may become prohibitively expensive.... An important strategy for combating glue-code is to wrap black-box packages into common API’s. This allows supporting infrastructure to be more reusable and reduces the cost of changing packages.

Sculley et al 2015

Models.jl provides a common API for mostly preexisting models to allow them to all be used in the same way. As such, the most important thing is that it itself has a common API. Here are some facts about that API:

Models and Templates

A Model is an object that can be used to make predictions via calling predict. A Template is an object that can create a Model by being fit to some data.

All information about how to perform fit, such as hyper-parameters, is stored inside the Template. This is different from some other APIs which might, for example, pass hyper-parameters as keyword arguments to fit. The Template based API is superior to these as it means fit is always the same. One does not have to carry both a Model type, and a varying collection of keyword arguments, which would get complicated when composing wrapper models.

Calling fit and predict

model = StatsBase.fit(
    template::Template,
    outputs::AbstractMatrix,  # always Features x Observations
    inputs::AbstractMatrix,   # always Variates x Observations
    weights=uweights(Float32, size(outputs, 2))
)::Model
# estimate_type(model) == PointEsimate
outputs = StatsBase.predict(
    model::Model,
    inputs::AbstractMatrix  # always Features x Observations
)::AbstractMatrix  # always Variates x Observations

# estimate_type(model) == DistributionEstimate
outputs = StatsBase.predict(
    model::Model,
    inputs::AbstractMatrix  # always Features x Observations
)::AbstractVector{<:Distribution}  # length Observations

fit takes in a Template and some data and returns a Model that has been fit to the data. predict takes a Model (that has been fit from a Template) and produces a predicted output.

Important facts about fit and predict:

  • outputs and inputs always have observations as the second dimension – even if it is SingleOutput (that just means that it will be a 1 x num_obs output. (See Docs on Julia being column-major)
  • The functions must accept any AbstractMatrix for the inputs and outputs (fit only). If the underlying implementation needs a plain dense Matrix then fit/predict should perform the conversion.
  • fit always accepts a weights argument. If the underlying Model does not support weighted fitting, then fit should throw and error if the weights that passed in and are not all equal.
  • fit/predict take no keyword arguments, or any other arguments except the ones shown.

Traits

This package largely avoids using complicated abstract types, or relying on a Model having a particular abstract type. Instead we use traits to determine Model behavior.

Here are the current Model traits in use and their possible values:

The traits always agree between the Model and the Template. Every Model and Template should define all the listed traits. If left undefined, the PredictInputTrait will have the default value of PointPredictInput.

This package uses traits implemented such that the trait function returns an abstract type (rather than an instance). That means to check a trait one uses:

if estimate_type(model) isa DistributionEstimate

and to dispatch on a trait one uses:

foo(::Type{<:DistributionEstimate}, ...)