Custom Sensitivities
Part of the power of Nabla is its extensibility, specifically in the form of defining custom sensitivities for functions. This is accomplished by defining methods for ∇
that specialize on the function for which you'd like to define sensitivities.
Given a function of the form $f(x_1, \ldots, x_n)$, we want to be able to compute $\frac{\partial f}{\partial x_i}$ for all $i$ of interest as efficiently as possible. Defining our own sensitivities $\bar{x}_i$ means that $f$ will be taken as a "unit," and its intermediate operations are not written separately to the tape. For more details on that, refer to the Details section of the documentation.
Intercepting calls
Nabla's approach to RMAD is based on operator overloading. Specifically, for each $x_i$ we wish to differentiate, we need a method for f
that accepts a Node
in position $i$. There are two primary ways to go about this: @explicit_intercepts
and @unionise
.
@explicit_intercepts
When f
has already been defined, we can extend it to accept Node
s using this macro.
Nabla.@explicit_intercepts
— Macro.@explicit_intercepts(f::Symbol, type_tuple::Expr, is_node::Expr[, kwargs::Expr])
@explicit_intercepts(f::Symbol, type_tuple::Expr)
Create a collection of methods which intecept the function calls to f
in which at least one argument is a Node
. Types of arguments are specified by the type tuple expression in type_tuple
. If there are arguments which are not differentiable, they can be specified by providing a boolean vector is_node
which indicates those arguments that are differentiable with true
values and those which are not as false
. Keyword arguments to add to the function signature can be specified in kwargs
, which must be a NamedTuple
.
As a trivial example, take sin
for scalar values (not matrix sine). We extend it for Node
s as
import Base: sin # ensure sin can be extended without qualification
@explicit_intercepts sin Tuple{Real}
This generates the following code:
begin
function sin(##367::Node{<:Real})
#= REPL[7]:1 =#
Branch(sin, (##367,), getfield(##367, :tape))
end
end
And so calling sin
with a Node
argument will produce a Branch
that holds information about the call.
For a nontrivial example, take the sum
function, which accepts a function argument that gets mapped over the input prior to reduction by addition, as well as a dims
keyword argument that permits summing over a subset of the dimensions of the input. We want to differentiate with respect to the input array, but not with respect to the function argument nor the dimension. (Note that Nabla cannot currently differentiate with respect to keyword arguments.) We can extend this for Node
s as
import Base: sum
@explicit_intercepts(
sum,
Tuple{Function, AbstractArray{<:Real}},
[false, true],
(dims=:,),
)
The signature of the call to @explicit_intercepts
here may look a bit complex, so let's break it down. It's saying that we want to intercept calls to sum
for methods which accept a Function
and an AbstractArray{<:Real}
, and that we do not want to differentiate with respect to the function argument (false
) but do want to differentiate with respect to the array (true
). Furthermore, methods of this form will have the keyword argument dims
, which defaults to :
, and we'd like to make sure we're able to capture that when we intercept.
This macro generates the following code:
quote
function sum(##363::Function, ##364::Node{<:Array}; dims=:)
#= REPL[2]:1 =#
Branch(sum, (##363, ##364), getfield(##364, :tape); dims=dims)
end
end
As you can see, it defines a new method for sum
which has positional arguments of the given types, with the second extended for Node
s, as well as the given keyword arguments. Notice that we do not accept a Node
for the function argument; this is by virtue of using false
in that position in the call to @explicit_intercepts
.
@unionise
If f
has not yet been defined and you know off the bat that you want it to be able to work with Nabla, you can annotate its definition with @unionise
.
Nabla.@unionise
— Macro.@unionise code
Transform code such that each function definition accepts Node
objects as arguments, without effecting dispatch in other ways.
As a simple example,
@unionise f(x::Matrix, p::Real) = norm(x, p)
For each type constrained argument xi
in the method definition's signature, @unionise
changes the type constraint from T
to Union{T, Node{<:T}}
, allowing f
to work with Node
s without needing to define separate methods. In this example, the macro expands the definition to
f(x::Union{Matrix, Node{<:Matrix}}, p::Union{Real, Node{<:Real}}) = begin
#= REPL[9]:1 =#
norm(x, p)
end
Defining sensitivities
Now that our function f
works with Node
s, we want to define a method for ∇
for each argument xi
that we're interested in differentiating. Thus, for each argument position i
we care about, we'll define a method of ∇
that looks like:
function Nabla.∇(::typeof(f), ::Type{Arg{i}}, _, y, ȳ, x1, ..., xn)
# Compute x̄i
end
The method signature contains all of the information it needs to compute the derivative:
f
, the functionArg{i}
, which specifies which of thexi
we're computing the sensitivity of_
(placeholder, typically unused)y
, the result ofy = f(x1, ..., xn)
ȳ
, the "incoming" sensitivity propagated to this callx1, ..., xn
, the inputs tof
A fully worked example is provided in the Details section of the documentation.
Testing sensitivities
In order to ensure correctness for custom sensitivity definitions, we can compare the results against those computed by the method of finite differences. The finite differencing itself is implemented in the Julia package FDM, but Nabla defines and exports functionality that permits checking results against finite differencing.
The primary workhorse function for this is check_errs
.
Nabla.check_errs
— Function.check_errs(
f,
ȳ::∇ArrayOrScalar,
x::T,
v::T,
ε_abs::∇Scalar=1e-10,
ε_rel::∇Scalar=1e-7
)::Bool where T
Check that the difference between finite differencing directional derivative estimation and RMAD directional derivative computation for function f
at x
in direction v
, for both allocating and in-place modes, has absolute and relative errors of ε_abs
and ε_rel
respectively, when scaled by reverse-mode sensitivity ȳ
.