Nabla.Arg
— TypeUsed to flag which argument is being specified in x̄.
Nabla.Branch
— TypeA Branch
is a Node with parents (args).
Fields: val::T - the value of this node produced in the forward pass. f - the function used to generate this Node. args - Values indicating which elements in the tape will require updating by this node. tape - The Tape to which this Branch is assigned. pos - the location of this Branch in the tape to which it is assigned. pullback::B - if there is a custom primative rule (a ChainRulesCore.rrule
) then this holds the pullback to propagate gradients back through the operation. If there is not a rule then this is set to nothing
. It may also be set to nothing
by legacy Nabla rules that have not moved to ChainRules.
Nabla.Leaf
— TypeAn element at the 'bottom' of the computational graph.
Fields: val - the value of the node. tape - The Tape to which this Leaf is assigned. pos - the location of this Leaf in the tape to which it is assigned.
Nabla.Node
— TypeBasic unit on the computational graph.
Nabla.Tape
— TypeA topologically ordered collection of Nodes.
Nabla.check_errs
— Methodcheck_errs(
f,
ȳ::∇ArrayOrScalar,
x::T,
v::T,
ε_abs::∇Scalar=1e-10,
ε_rel::∇Scalar=1e-7
)::Bool where T
Check that the difference between finite differencing directional derivative estimation and RMAD directional derivative computation for function f
at x
in direction v
, for both allocating and in-place modes, has absolute and relative errors of ε_abs
and ε_rel
respectively, when scaled by reverse-mode sensitivity ȳ
.
Nabla.domain1
— Methoddomain1{T}(in_domain::Function, measure::Function, points::Vector{T})
domain1(f::Function)
Attempt to find a domain for a unary, scalar function f
.
Arguments
in_domain::Function
: Function that takes a single argumentx
and returns whetherx
argument is inf
's domain.measure::Function
: Function that measures the size of a set of points forf
.points::Vector{T}
: Ordered set of test points to construct the domain from.
Nabla.domain2
— Methoddomain2(f::Function)
Attempt to find a rectangular domain for a binary, scalar function f
.
Nabla.in_domain
— Methodin_domain(f::Function, x::Float64...)
Check whether an input x
is in a scalar, real function f
's domain.
Nabla.preprocess
— Methodpreprocess(f, y, ȳ, xs...) = ()
Default implementation of preprocess returns an empty Tuple. Individual sensitivity implementations should add methods specific to their use case. The output is passed in to ∇
as the 3rd or 4th argument in the new-x̄ and update-x̄ cases respectively.
preprocess
is invoked with y
and xs
still boxed. The default implementation just calls unbox
on them then calls preprocess
on the unboxed values. If for preprocessing you need the boxed values you should overload preprocess(f, y::Node, ȳ, xs...)
. If you need them unboxed, then overloading preprocess(f, y, ȳ, xs...)
is fine.
Nabla.∇
— Method∇(f; get_output::Bool=false)
Returns a function which, when evaluated with arguments that are accepted by f
, will return the gradient w.r.t. each of the arguments. If get_output
is true
, the result of calling f
on the given arguments is also returned.
Nabla.∇
— Method∇(y::Node{<:∇Scalar})
∇(y::Node{T}, ȳ::T) where T
Return a Tape
object which can be indexed using Node
s, each element of which contains the result of multiplying ȳ
by the transpose of the Jacobian of the function specified by the Tape
object in y
. If y
is a scalar and ȳ = 1
then this is equivalent to computing the gradient of y
w.r.t. each of the elements in the Tape
.
∇(f::Function, ::Type{Arg{N}}, p, y, ȳ, x...)
To implement a new reverse-mode sensitivity for the N^{th}
argument of function f
. p is the output of preprocess
. x1
, x2
,... are the inputs to the function, y
is its output and ȳ
the reverse-mode sensitivity of y
.
∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...)
This is the optional in-place version of ∇
that should, if implemented, mutate x̄ to have the gradient added to it.
Nabla.@explicit_intercepts
— Macro@explicit_intercepts(f::Symbol, type_tuple::Expr, is_node::Expr[, kwargs::Expr])
@explicit_intercepts(f::Symbol, type_tuple::Expr)
Create a collection of methods which intecept the function calls to f
in which at least one argument is a Node
. Types of arguments are specified by the type tuple expression in type_tuple
. If there are arguments which are not differentiable, they can be specified by providing a boolean vector is_node
which indicates those arguments that are differentiable with true
values and those which are not as false
. Keyword arguments to add to the function signature can be specified in kwargs
, which must be a NamedTuple
.
Nabla.@union_intercepts
— Macro@union_intercepts f type_tuple invoke_type_tuple [kwargs]
Interception strategy based on adding a method to f
which accepts the union of each of the types specified by type_tuple
. If none of the arguments are Node
s then the method of f
specified by invoke_type_tuple
is invoked. If applicable, keyword arguments should be provided as a NamedTuple
and be added to the generated function's signature.
Nabla.@unionise
— Macro@unionise code
Transform code such that each function definition accepts Node
objects as arguments, without effecting dispatch in other ways.