Nabla.jl has two interfaces, both of which we expose to the end user. We first provide a minimal working example with the high-level interface, and subsequently show how the low-level interface can be used to achieve similar results. More involved examples can be found here.
A Toy Problem
Consider the gradient of a vector-quadratic function. The following code snippet constructs such a function, and inputs
using Nabla # Generate some data. rng, N = MersenneTwister(123456), 2 x, y = randn.(rng, [N, N]) A = randn(rng, N, N) # Construct a vector-quadratic function in `x` and `y`. f(x, y) = y' * (A * x) f(x, y)
Only a small amount of matrix calculus is required to the find the gradient of
f(x, y) w.r.t.
y, which we denote by
∇y respectively, to be
(∇x, ∇y) = (A'y, A * x)
The high-level interface provides a simple way to "just get the gradients" w.r.t. each argument of
∇x, ∇y = ∇(f)(x, y)
This interface is implemented in
core.jl, and is a thin wrapper of the low-level interface constructed above. Here, we first use
∇ to get a function which, when evaluated, returns the gradient of
f w.r.t. each of it's inputs at the values of the inputs provided.
We may provide an optional argument to also return the value
(z, (∇x, ∇y)) = ∇(f; get_output=true)(x, y)
If the gradient w.r.t. a single argument is all that is required, or a subset of the arguments for an N-ary function, we recommend closing over the arguments which respect to which you do not wish to take gradients. For example, to take the gradient w.r.t. just
x, one could do the following:
Note that this returns a 1-tuple containing the result, not the result itself!
Furthermore, indexable containers such as
Dicts behave sensibly. For example, the following lambda with a
∇(d->f(d[:x], d[:y]))(Dict(:x=>x, :y=>y))
∇(v->f(v, v))([x, y])
The methods considered so far have been completely generically typed. If one wishes to use methods whose argument types are restricted then one must surround the definition of the method in the
@unionise macro. For example, if only a single definition is required:
@unionise g(x::Real) = ...
Alternatively, if multiple methods / functions are to be defined, the following format is recommended:
@unionise begin g(x::Real) = ... g(x::T, y::T) where T<:Real = ... foo(x) = ... # This definition is unaffected by `@unionise`. end
@unionise simply changes the method signature to allow each argument to accept the union of the types specified and
Node type. This will have no impact on the performance of your code when arguments of the types specified in the definition are provided, so you can safely
@unionise code without worrying about potential performance implications.
We now use
Nabla.jl's low-level interface to take the gradient of
y at the values of
y generated above. We first place
y into a
Leaf container. This enables these variables to be traced by
Nabla.jl. This can be achieved by first creating a
Tape object, onto which all computations involving
y are recorded, as follows:
tape = Tape() x_ = Leaf(tape, x) y_ = Leaf(tape, y)
which can be achieved more concisely using Julia's broadcasting capabilities:
x_, y_ = Leaf.(Tape(), (x, y))
Note that it is critical that
y_ are constructed using the same
Tape instance. Currently,
Nabla.jl will fail silently if this is not the case. We then simply pass
f instead of
z_ = f(x_, y_)
We can compute the gradients of
∇, and access them by indexing the output with
∇z = ∇(z_) (∇x, ∇y) = (∇z[x_], ∇z[y_])
Gotchas and Best Practice
Nabla.jldoes not currently have complete coverage of the entire standard library due to finite resources and competing priorities. Particularly notable omissions are the subtypes of
Factorizationobjects and all in-place functions. These are both issues which will be resolved in the future.
- The usual RMAD gotcha applies: due to the need to record each of the operations performed in the execution of a function for use in efficient gradient computation, the memory requirement of a programme scales approximately linearly in the length of the programme. Although, due to our use of a dynamically constructed computation graph, we support all forms of control flow, long
whileloops should be performed with care, so as to avoid running out of memory.
- In a similar vein, develop a (strong) preference for higher-order functions and linear algebra over for-loops;
Nabla.jlhas optimisations targetting Julia's higher-order functions (
mapreduceand friends), and consequently loop-fusion / "dot-syntax", and linear algebra operations which should be made use of where possible.