Design
This page details the key features of the design of Models.jl, which exists to solve the issue highlighted by the following quote:
ML researchers tend to develop general purpose solutions as self-contained packages. A wide variety of these are available as open-source packages ... Using generic packages often results in a glue-code system design pattern, in which a massive amount of supporting code is written to get data into and out of general-purpose packages. Glue-code is costly in the long term because it tends to freeze a system to the peculiarities of a specific package; testing alternatives may become prohibitively expensive.... An important strategy for combating glue-code is to wrap black-box packages into common API’s. This allows supporting infrastructure to be more reusable and reduces the cost of changing packages.
Models.jl provides a common API for mostly preexisting models to allow them to all be used in the same way. As such, the most important thing is that it itself has a common API. Here are some facts about that API:
Models and Templates
A Model is an object that can be used to make predictions via calling predict. A Template is an object that can create a Model by being fit to some data.
All information about how to perform fit, such as hyper-parameters, is stored inside the Template. This is different from some other APIs which might, for example, pass hyper-parameters as keyword arguments to fit. The Template based API is superior to these as it means fit is always the same. One does not have to carry both a Model type, and a varying collection of keyword arguments, which would get complicated when composing wrapper models.
Calling fit and predict
model = StatsBase.fit(
template::Template,
outputs::AbstractMatrix, # always Features x Observations
inputs::AbstractMatrix, # always Variates x Observations
weights=uweights(Float32, size(outputs, 2))
)::Model# estimate_type(model) == PointEsimate
outputs = StatsBase.predict(
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractMatrix # always Variates x Observations
# estimate_type(model) == DistributionEstimate
outputs = StatsBase.predict(
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractVector{<:Distribution} # length Observationsfit takes in a Template and some data and returns a Model that has been fit to the data. predict takes a Model (that has been fit from a Template) and produces a predicted output.
Important facts about fit and predict:
outputsandinputsalways have observations as the second dimension – even if it isSingleOutput(that just means that it will be a1 x num_obsoutput. (See Docs on Julia being column-major)- The functions must accept any
AbstractMatrixfor theinputsandoutputs(fitonly). If the underlying implementation needs a plain denseMatrixthenfit/predictshould perform the conversion. fitalways accepts aweightsargument. If the underlyingModeldoes not support weighted fitting, thenfitshould throw and error if the weights that passed in and are not all equal.fit/predicttake no keyword arguments, or any other arguments except the ones shown.
Traits
This package largely avoids using complicated abstract types, or relying on a Model having a particular abstract type. Instead we use traits to determine Model behavior.
Here are the current Model traits in use and their possible values:
estimate_type- determines what kinds of estimates theModeloutputs.PointEstimate: Predicts point-estimates of the most likely values.DistributionEstimate: Estimates distributions over possible values.
output_type- determines how many output variates aModelcan learnSingleOutput: Fits and predicts on a single output only.MultiOutput: Fits and predicts on multiple outputs at a time.
predict_input_type- determines which datatypes aModelcan accept at predict time.PointPredictInput: Real valued input variables accepted at predict time.PointOrDistributionPredictInput: Either real valued or distributions of input variables accepted at predict time.
The traits always agree between the Model and the Template. Every Model and Template should define all the listed traits. If left undefined, the PredictInputTrait will have the default value of PointPredictInput.
This package uses traits implemented such that the trait function returns an abstract type (rather than an instance). That means to check a trait one uses:
if estimate_type(model) isa DistributionEstimateand to dispatch on a trait one uses:
foo(::Type{<:DistributionEstimate}, ...)