Design
This page details the key features of the design of Models.jl, which exists to solve the issue highlighted by the following quote:
ML researchers tend to develop general purpose solutions as self-contained packages. A wide variety of these are available as open-source packages ... Using generic packages often results in a glue-code system design pattern, in which a massive amount of supporting code is written to get data into and out of general-purpose packages. Glue-code is costly in the long term because it tends to freeze a system to the peculiarities of a specific package; testing alternatives may become prohibitively expensive.... An important strategy for combating glue-code is to wrap black-box packages into common API’s. This allows supporting infrastructure to be more reusable and reduces the cost of changing packages.
Models.jl provides a common API for mostly preexisting models to allow them to all be used in the same way. As such, the most important thing is that it itself has a common API. Here are some facts about that API:
Models and Templates
A Model
is an object that can be used to make predictions via calling predict
. A Template
is an object that can create a Model
by being fit
to some data.
All information about how to perform fit
, such as hyper-parameters, is stored inside the Template
. This is different from some other APIs which might, for example, pass hyper-parameters as keyword arguments to fit
. The Template
based API is superior to these as it means fit
is always the same. One does not have to carry both a Model
type, and a varying collection of keyword arguments, which would get complicated when composing wrapper models.
Calling fit
and predict
model = StatsBase.fit(
template::Template,
outputs::AbstractMatrix, # always Features x Observations
inputs::AbstractMatrix, # always Variates x Observations
weights=uweights(Float32, size(outputs, 2))
)::Model
# estimate_type(model) == PointEsimate
outputs = StatsBase.predict(
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractMatrix # always Variates x Observations
# estimate_type(model) == DistributionEstimate
outputs = StatsBase.predict(
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractVector{<:Distribution} # length Observations
fit
takes in a Template
and some data and returns a Model
that has been fit to the data. predict
takes a Model
(that has been fit
from a Template
) and produces a predicted output.
Important facts about fit
and predict
:
outputs
andinputs
always have observations as the second dimension – even if it isSingleOutput
(that just means that it will be a1 x num_obs
output. (See Docs on Julia being column-major)- The functions must accept any
AbstractMatrix
for theinputs
andoutputs
(fit
only). If the underlying implementation needs a plain denseMatrix
thenfit
/predict
should perform the conversion. fit
always accepts aweights
argument. If the underlyingModel
does not support weighted fitting, thenfit
should throw and error if the weights that passed in and are not all equal.fit
/predict
take no keyword arguments, or any other arguments except the ones shown.
Traits
This package largely avoids using complicated abstract types, or relying on a Model
having a particular abstract type. Instead we use traits to determine Model
behavior.
Here are the current Model
traits in use and their possible values:
estimate_type
- determines what kinds of estimates theModel
outputs.PointEstimate
: Predicts point-estimates of the most likely values.DistributionEstimate
: Estimates distributions over possible values.
output_type
- determines how many output variates aModel
can learnSingleOutput
: Fits and predicts on a single output only.MultiOutput
: Fits and predicts on multiple outputs at a time.
predict_input_type
- determines which datatypes aModel
can accept at predict time.PointPredictInput
: Real valued input variables accepted at predict time.PointOrDistributionPredictInput
: Either real valued or distributions of input variables accepted at predict time.
The traits always agree between the Model
and the Template
. Every Model
and Template
should define all the listed traits. If left undefined, the PredictInputTrait
will have the default value of PointPredictInput
.
This package uses traits implemented such that the trait function returns an abstract type
(rather than an instance). That means to check a trait one uses:
if estimate_type(model) isa DistributionEstimate
and to dispatch on a trait one uses:
foo(::Type{<:DistributionEstimate}, ...)