From 7b5067d6d695548c931302665a1f0c7107c7f042 Mon Sep 17 00:00:00 2001 From: "Documenter.jl" Date: Thu, 9 Jan 2025 10:27:56 +0000 Subject: [PATCH] build based on 1152c08 --- previews/PR435/.documenter-siteinfo.json | 2 +- .../developer_tools/index.html | 6 +- .../forwards_mode_design/index.html | 2 +- .../internal_docstrings/index.html | 60 +++++++++---------- .../running_tests_locally/index.html | 2 +- previews/PR435/index.html | 2 +- previews/PR435/known_limitations/index.html | 2 +- .../algorithmic_differentiation/index.html | 2 +- .../introduction/index.html | 2 +- .../rule_system/index.html | 6 +- .../PR435/utilities/debug_mode/index.html | 4 +- .../utilities/debugging_and_mwes/index.html | 4 +- .../PR435/utilities/defining_rules/index.html | 10 ++-- 13 files changed, 52 insertions(+), 52 deletions(-) diff --git a/previews/PR435/.documenter-siteinfo.json b/previews/PR435/.documenter-siteinfo.json index dd3dc5929..6cd097403 100644 --- a/previews/PR435/.documenter-siteinfo.json +++ b/previews/PR435/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.11.2","generation_timestamp":"2025-01-09T08:53:48","documenter_version":"1.8.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.11.2","generation_timestamp":"2025-01-09T10:27:48","documenter_version":"1.8.0"}} \ No newline at end of file diff --git a/previews/PR435/developer_documentation/developer_tools/index.html b/previews/PR435/developer_documentation/developer_tools/index.html index 2dc765cd4..ce5ef554e 100644 --- a/previews/PR435/developer_documentation/developer_tools/index.html +++ b/previews/PR435/developer_documentation/developer_tools/index.html @@ -2,16 +2,16 @@ Developer Tools · Mooncake.jl

Developer Tools

Mooncake.jl offers developers to a few convenience functions which give access to the IR that it generates in order to perform AD. These are lightweight wrappers around internals which save you from having to dig in to the objects created by build_rrule.

Since these provide access to internals, they do not follow the usual rules of semver, and may change without notice!

Mooncake.primal_irFunction
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true
source
Mooncake.fwd_irFunction
fwd_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_irFunction
rvs_ir(
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_irFunction
rvs_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
+true

Arguments

Keyword Arguments

source diff --git a/previews/PR435/developer_documentation/forwards_mode_design/index.html b/previews/PR435/developer_documentation/forwards_mode_design/index.html index b063a2667..3c7f4087a 100644 --- a/previews/PR435/developer_documentation/forwards_mode_design/index.html +++ b/previews/PR435/developer_documentation/forwards_mode_design/index.html @@ -36,4 +36,4 @@ 3 │ %2 = invoke rule_for_h($(Dual(Main.h, NoTangent())), _3::Dual{Float64, Float64}, %1::Dual{Float64, Float64})::Dual{Float64, Float64} 4 └── return %2 => Dual{Float64, Float64}

Observe that:

  1. All Arguments have been incremented by 1. i.e. _2 has been replaced with _3. This corresponds to the fact that the arguments to the rule have all been shuffled along by one, and the rule itself is now the first argument.
  2. Everything has been turned into a Dual.
  3. Constants such as Dual(Main.g, NoTangent()) appear directly in the code (here as QuoteNodes).

(In practice it might be that we actually construct the Dualed constants on the lines immediately preceding a call and rely on the compiler to optimise them back into the call directly).

Here, as before, we have not specified exactly what rule_for_f, rule_for_g, and rule_for_h are. This is intentional – they are just callables satisfying the Forwards-Rule Interface. In the following we show how to derive rule_for_f, and show how rule_for_g and rule_for_h might be methods of Mooncake.frule!!, or themselves derived rules.

Rule Derivation Outline

Equipped with some intuition about what a derived rule ought to look like, we examine how we go about producing it algorithmically.

Rule derivation is implemented via the function Mooncake.build_frule. This function accepts as arguments a context and a signature / Base.MethodInstance / MistyClosure and, roughly speaking, does the following:

  1. Look up the optimised Compiler.IRCode.
  2. Apply a series of standardising transformations to the IRCode.
  3. Transform each statement according to a set of rules to produce a new IRCode.
  4. Apply standard Julia optimisations to this new IRCode.
  5. Put this code inside a MistyClosure in order to produce an executable object.
  6. Wrap this MistyClosure in a DerivedFRule to handle various bits of book-keeping around varargs.

In order:

Looking up the Compiler.IRCode.

This is done using Mooncake.lookup_ir. This function has methods with will return the IRCode associated to:

  1. signatures (e.g. Tuple{typeof(f), Float64})
  2. Base.MethodInstances (relevant for :invoke expressions – see Statement Transformation below)
  3. MistyClosures.MistyClosure objects, which is essential when computing higher order derivatives and Hessians by applying Mooncake.jl to itself.

Standardisation

We apply the following transformations to the Julia IR. They can all be found in ir_normalisation.jl:

  1. Mooncake.foreigncall_to_call: convert Expr(:foreigncall, ...) expressions into Expr(:call, Mooncake._foreigncall_, ...) expressions.
  2. Mooncake.new_to_call: convert Expr(:new, ...) expressions to Expr(:call, Mooncake._new_, ...) expressions.
  3. Mooncake.splatnew_to_call: convert Expr(:splatnew, ...) expressions to Expr(:call, Mooncake._splat_new_...) expressions.
  4. Mooncake.intrinsic_to_function: convert Expr(:call, ::IntrinsicFunction, ...) to calls to the corresponding function in Mooncake.IntrinsicsWrappers.

The purpose of converting Expr(:foreigncall...), Expr(:new, ...) and Expr(:splatnew, ...) into Expr(:call, ...)s is to enable us to differentiate such expressions by adding methods to frule!!(::Dual{typeof(Mooncake._foreigncall_)}), frule!!(::Dual{typeof(Mooncake._new_)}), and frule!!(::Dual{typeof(Mooncake._splat_new_)}), in exactly the same way that we would for any other regular Julia function.

The purpose of translating Expr(:call, ::IntrinsicFunction, ...) is to do with type stability – see the docstring for the Mooncake.IntrinsicsWrappers module for more info.

Statement Transformation

Each statment which can appear in the Julia IR is transformed by a method of Mooncake.make_fwds_ad_stmts. Consequently, this transformation phase simply corresponds to iterating through all of the expressions in the IRCode, applying Mooncake.make_fwd_ad_stmts to each to produce new IRCode. To understand how to modify IRCode and insert new instructions, see Oxinabox's Gist.

We provide here a high-level summary of the transformations for the most important Julia IR statements, and refer readers to the methods of Mooncake.make_fwds_ad_stmts for the definitive explanation of what transformation is applied, and the rationale for applying it. In particular there are quite a number more statements which can appear in Julia IR than those listed here and, for those we do list here, there are typically a few edge cases left out.

Expr(:invoke, method_instance, f, x...) and Expr(:call, f, x...)

:call expressions correspond to dynamic dispatch, while :invoke expressions correspond to static dispatch. That is, if you see an :invoke expression, you know for sure that the compiler knows enough information about the types of f and x to prove exactly which specialisation of which method to call. This specialisation is method_instance. This typically happens when the compiler is able to prove the types of f and x. Conversely, a :call expression typically occurs when the compiler has not been able to deduce the exact types of f and x, and therefore not been able to figure out what to call. It therefore has to wait until runtime to figure out what to call, resulting in dynamic dispatch.

As we saw earlier, the idea is to translate these kinds of expressions into something vaguely along the lines of

Expr(:call, rule_for_f, f, x...)

There are three cases to consider, in order of preference:

Primitives:

If is_primitive returns true when applied to the signature constructed from the static types of f and x, then we simply replace the expression with Expr(:call, frule!!, f, x...), regardless whether we have an :invoke or :call expression. (Due to the Standardisation steps, it regularly happens that we see :call expressions in which we actually do know enough type information to do this, e.g. for Mooncake._new_ :call expressions).

Static Dispatch:

In the case of :invoke nodes we know for sure at rule compilation time what rule_for_f must be. We derive a rule for the call by passing method_instance to Mooncake.build_frule. (In practice, we might do this lazily, but while retaining enough information to maintain type stability. See the Mooncake.LazyDerivedRule for how this is handled in reverse-mode).

Dynamic Dispatch:

If we have a :call expression and are not able to prove that is_primitive will return true, we must defer dispatch until runtime. We do this by replacing the :call expression with a call to a DynamicFRule, which simply constructs (or retrieves from a cache) the rule at runtime. Reverse-mode utilises a similar strategy via Mooncake.DynamicDerivedRule.

The above was written in terms of f and x. In practice, of course, we encounter various kinds of constants (e.g. Base.sin), Arguments (e.g. _3), and Core.SSAValues (e.g. %5). The translation rules for these are:

  1. constants are turned into constant duals in which the tangent is zero,
  2. Arguments are incremented by 1.
  3. SSAValues are left as-is.

Core.GotoNodes

These remain entirely unchanged.

Core.GotoIfNot

These require minor modification. Suppose that a Core.GotoIfNot of the form Core.GotoIfNot(%5, 4) is encountered in the primal. Since %5 will be a Dual in the derived rule, we must pull out the primal field, and pass that to the conditional instead. Therefore, these statments get lowered to two lines in the derived rule. For example, Core.GotoIfNot(%5, 4) would be translated to:

%n = getfield(%5, :primal)
-Core.GotoIfNot(%n, 4)

Core.PhiNode

Core.PhiNode looks something like the following in the general case:

φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)

They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.

Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:

  1. Core.SSAValues are unchanged.
  2. Core.Arguments are incremented by 1 (as always).
  3. constants are translated into constant duals.
  4. undefined values remain undefined.

So the above example would be translated into something like

φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)

Optimisation

The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.

Put IRCode in MistyClosure

Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.

Put the MistyClosure in a DerivedFRule

See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the "rule" that users get.

Batch Mode

So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.

In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of "batched tangents". Then, instead of propagating a "primal-tangent" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).

Forwards vs Reverse Implementation

The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:

  1. forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.
  2. forwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.

Comparison with ForwardDiff.jl

With reference to the limitations of ForwardDiff.jl, there are a few noteworthy differences between ForwardDiff.jl and this implementation:

  1. :foreigncalls pose much less of a problem for Mooncake's forward-mode than for ForwardDiff.jl, because we can write a rule for any method of any function. In essence, you can only (reliably) write rules for ForwardDiff.jl via dispatch on ForwardDiff.Dual.
  2. the target function can be of any arity in Mooncake.jl, but must be unary in ForwardDiff.jl.
  3. there are no limitations on the argument type constraints that Mooncake.jl can handle, while ForwardDiff.jl requires that argument type constraints be <:Real or arrays of <:Real.
  4. No special storage types are required with Mooncake.jl, while ForwardDiff.jl requires that any container you write to is able to contain ForwardDiff.Duals.
+Core.GotoIfNot(%n, 4)

Core.PhiNode

Core.PhiNode looks something like the following in the general case:

φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)

They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.

Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:

  1. Core.SSAValues are unchanged.
  2. Core.Arguments are incremented by 1 (as always).
  3. constants are translated into constant duals.
  4. undefined values remain undefined.

So the above example would be translated into something like

φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)

Optimisation

The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.

Put IRCode in MistyClosure

Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.

Put the MistyClosure in a DerivedFRule

See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the "rule" that users get.

Batch Mode

So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.

In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of "batched tangents". Then, instead of propagating a "primal-tangent" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).

Forwards vs Reverse Implementation

The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:

  1. forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.
  2. forwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.

Comparison with ForwardDiff.jl

With reference to the limitations of ForwardDiff.jl, there are a few noteworthy differences between ForwardDiff.jl and this implementation:

  1. :foreigncalls pose much less of a problem for Mooncake's forward-mode than for ForwardDiff.jl, because we can write a rule for any method of any function. In essence, you can only (reliably) write rules for ForwardDiff.jl via dispatch on ForwardDiff.Dual.
  2. the target function can be of any arity in Mooncake.jl, but must be unary in ForwardDiff.jl.
  3. there are no limitations on the argument type constraints that Mooncake.jl can handle, while ForwardDiff.jl requires that argument type constraints be <:Real or arrays of <:Real.
  4. No special storage types are required with Mooncake.jl, while ForwardDiff.jl requires that any container you write to is able to contain ForwardDiff.Duals.
diff --git a/previews/PR435/developer_documentation/internal_docstrings/index.html b/previews/PR435/developer_documentation/internal_docstrings/index.html index 3a5c5ad9f..723e43c6f 100644 --- a/previews/PR435/developer_documentation/internal_docstrings/index.html +++ b/previews/PR435/developer_documentation/internal_docstrings/index.html @@ -1,17 +1,17 @@ -Internal Docstrings · Mooncake.jl

Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.TerminatorType
Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}

A Union of the possible types of a terminator node.

source
Core.Compiler.IRCodeMethod
IRCode(bb_code::BBCode)

Produce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.

All IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.

In the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.

source
Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BBCodeMethod
BBCode(ir::IRCode)

Convert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.

All PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.

See IRCode for conversion back to IRCode.

Note that IRCode(BBCode(ir)) should be equal to the identity function.

source
Mooncake.BBCodeMethod
BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})

Make a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.

source
Mooncake.BBCodeType
BBCode(
+Internal Docstrings · Mooncake.jl

Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.TerminatorType
Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}

A Union of the possible types of a terminator node.

source
Core.Compiler.IRCodeMethod
IRCode(bb_code::BBCode)

Produce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.

All IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.

In the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.

source
Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BBCodeMethod
BBCode(ir::IRCode)

Convert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.

All PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.

See IRCode for conversion back to IRCode.

Note that IRCode(BBCode(ir)) should be equal to the identity function.

source
Mooncake.BBCodeMethod
BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})

Make a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.

source
Mooncake.BBCodeType
BBCode(
     blocks::Vector{BBlock}
     argtypes::Vector{Any}
     sptypes::Vector{CC.VarState}
     linetable::Vector{Core.LineInfoNode}
     meta::Vector{Expr}
-)

A BBCode is a data structure which is similar to IRCode, but adds additional structure.

In particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.

The consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.

Additionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.

Finally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.

source
Mooncake.BBlockMethod
BBlock(id::ID, inst_pairs::Vector{IDInstPair})

Convenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.

source
Mooncake.BBlockType
BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)

A basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.

Each BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.

The nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.

Note that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.

source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.ConfigType
Config(; debug_mode=false, silence_debug_messages=false)

Configuration struct for use with ADTypes.AutoMooncake.

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.IDType
ID()

An ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.

This counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.

source
Mooncake.IDPhiNodeType
IDPhiNode(edges::Vector{ID}, values::Vector{Any})

Like a PhiNode, but edges are IDs rather than Int32s.

source
Mooncake.InstVectorType
const InstVector = Vector{NewInstruction}

Note: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.SwitchType
Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)

A switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:

goto dests[1] if not conds[1]
+)

A BBCode is a data structure which is similar to IRCode, but adds additional structure.

In particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.

The consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.

Additionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.

Finally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.

source
Mooncake.BBlockMethod
BBlock(id::ID, inst_pairs::Vector{IDInstPair})

Convenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.

source
Mooncake.BBlockType
BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)

A basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.

Each BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.

The nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.

Note that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.

source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.ConfigType
Config(; debug_mode=false, silence_debug_messages=false)

Configuration struct for use with ADTypes.AutoMooncake.

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.IDType
ID()

An ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.

This counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.

source
Mooncake.IDPhiNodeType
IDPhiNode(edges::Vector{ID}, values::Vector{Any})

Like a PhiNode, but edges are IDs rather than Int32s.

source
Mooncake.InstVectorType
const InstVector = Vector{NewInstruction}

Note: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.SwitchType
Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)

A switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:

goto dests[1] if not conds[1]
 goto dests[2] if not conds[2]
 ...
 goto dests[N] if not conds[N]
-goto fallthrough_dest

where the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.

Switch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Base.insert!Method
Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing

Inserts stmt and id into bb immediately before the nth instruction.

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__line_numbers_to_block_numbers!Method
__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)

Converts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.

For context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.

This code is copied over directly from the body of Core.Compiler.inflate_ir!.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__run_rvs_pass!Method
__run_rvs_pass!(
+goto fallthrough_dest

where the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.

Switch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Base.insert!Method
Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing

Inserts stmt and id into bb immediately before the nth instruction.

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__line_numbers_to_block_numbers!Method
__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)

Converts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.

For context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.

This code is copied over directly from the body of Core.Compiler.inflate_ir!.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__run_rvs_pass!Method
__run_rvs_pass!(
     P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
-) where {sig}

Used in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
+) where {sig}

Used in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
 f(x, y) = sum(x .* y)
 x = [2.0, 2.0]
 y = [1.0, 1.0]
@@ -29,12 +29,12 @@
 )
 # output
 
-(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._block_nums_to_idsMethod
_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}

Assign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.

source
Mooncake._build_graph_of_cfgMethod
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.

source
Mooncake._compute_all_predecessorsMethod
_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._compute_all_successorsMethod
_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._control_flow_graphMethod
_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG

Internal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.

source
Mooncake._distance_to_entryMethod
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.

source
Mooncake._find_id_uses!Method
_find_id_uses!(d::Dict{ID, Bool}, x)

Helper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.

For example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.

source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._block_nums_to_idsMethod
_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}

Assign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.

source
Mooncake._build_graph_of_cfgMethod
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.

source
Mooncake._compute_all_predecessorsMethod
_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._compute_all_successorsMethod
_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._control_flow_graphMethod
_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG

Internal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.

source
Mooncake._distance_to_entryMethod
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.

source
Mooncake._find_id_uses!Method
_find_id_uses!(d::Dict{ID, Bool}, x)

Helper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.

For example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.

source
Mooncake._foreigncall_Method
function _foreigncall_(
     ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
-) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._ids_to_line_numbersMethod
_ids_to_line_numbers(bb_code::BBCode)::InstVector

For each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.

source
Mooncake._is_reachableMethod
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.

source
Mooncake._lines_to_blocksMethod
_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector

Pulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.

source
Mooncake._lower_switch_statementsMethod
_lower_switch_statements(bb_code::BBCode)

Converts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._remove_double_edgesMethod
_remove_double_edges(ir::BBCode)::BBCode

If the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.

source
Mooncake._sort_blocks!Method
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.

For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.

WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.

source
Mooncake._ssa_to_idsMethod
_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)

Produce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.

source
Mooncake._ssas_to_idsMethod
_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}

Assigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).

source
Mooncake._to_ssasMethod
_to_ssas(d::Dict, inst::NewInstruction)

Like _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.always_initialisedMethod
always_initialised(::Type{P}) where {P}

Returns a tuple with number of fields equal to the number of fields in P. The nth field is set to true if the nth field of P is initialised, and false otherwise.

source
Mooncake.arrayifyMethod
arrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat}})

Return the primal field of x, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.

source
Mooncake.build_primitive_rruleMethod
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that is_primitive(context_type, sig) is true. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source
Mooncake.characterise_unique_predecessor_blocksMethod
characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
-    Tuple{Dict{ID, Bool}, Dict{ID, Bool}}

We call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.

Returns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.

Context:

This information is important for optimising AD because knowing that b is a unique predecessor means that

  1. on the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and
  2. on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.

Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.

source
Mooncake.characterise_used_idsMethod
characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}

For each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.

source
Mooncake.collect_stmtsMethod
collect_stmts(ir::BBCode)::Vector{IDInstPair}

Produce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir.

source
Mooncake.collect_stmtsMethod
collect_stmts(bb::BBlock)::Vector{IDInstPair}

Returns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
+) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._ids_to_line_numbersMethod
_ids_to_line_numbers(bb_code::BBCode)::InstVector

For each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.

source
Mooncake._is_reachableMethod
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.

source
Mooncake._lines_to_blocksMethod
_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector

Pulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.

source
Mooncake._lower_switch_statementsMethod
_lower_switch_statements(bb_code::BBCode)

Converts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._remove_double_edgesMethod
_remove_double_edges(ir::BBCode)::BBCode

If the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.

source
Mooncake._sort_blocks!Method
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.

For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.

WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.

source
Mooncake._ssa_to_idsMethod
_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)

Produce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.

source
Mooncake._ssas_to_idsMethod
_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}

Assigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).

source
Mooncake._to_ssasMethod
_to_ssas(d::Dict, inst::NewInstruction)

Like _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.always_initialisedMethod
always_initialised(::Type{P}) where {P}

Returns a tuple with number of fields equal to the number of fields in P. The nth field is set to true if the nth field of P is initialised, and false otherwise.

source
Mooncake.arrayifyMethod
arrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat}})

Return the primal field of x, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.

source
Mooncake.build_primitive_rruleMethod
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that is_primitive(context_type, sig) is true. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source
Mooncake.characterise_unique_predecessor_blocksMethod
characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
+    Tuple{Dict{ID, Bool}, Dict{ID, Bool}}

We call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.

Returns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.

Context:

This information is important for optimising AD because knowing that b is a unique predecessor means that

  1. on the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and
  2. on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.

Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.

source
Mooncake.characterise_used_idsMethod
characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}

For each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.

source
Mooncake.collect_stmtsMethod
collect_stmts(ir::BBCode)::Vector{IDInstPair}

Produce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir.

source
Mooncake.collect_stmtsMethod
collect_stmts(bb::BBlock)::Vector{IDInstPair}

Returns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
     blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
-)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.const_codual_stmtMethod
const_codual_stmt(stmt, info::ADInfo)

Returns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.fix_up_invoke_inference!Method
fix_up_invoke_inference!(ir::IRCode)

The Problem

Consider the following:

@noinline function bar!(x)
+)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.const_codual_stmtMethod
const_codual_stmt(stmt, info::ADInfo)

Returns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.fix_up_invoke_inference!Method
fix_up_invoke_inference!(ir::IRCode)

The Problem

Consider the following:

@noinline function bar!(x)
     x .*= 2
 end
 
@@ -45,20 +45,20 @@
 1-element Vector{Any}:
 2 1 ─     invoke Main.bar!(_2::Vector{Float64})::Any
 3 └──     return Main.nothing
-   => Nothing

Observe that the type inferred for the first line is Any. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}.

This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any rather than Vector{Float64} causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.

The Solution

:invoke expressions contain the Core.MethodInstance associated to them, which contains a Core.CodeCache, which contains the return type of the :invoke. This function looks for :invoke statements whose return type is inferred to be Any in ir, and modifies it to be the return type given by the code cache.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
+   => Nothing

Observe that the type inferred for the first line is Any. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}.

This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any rather than Vector{Float64} causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.

The Solution

:invoke expressions contain the Core.MethodInstance associated to them, which contains a Core.CodeCache, which contains the return type of the :invoke. This function looks for :invoke statements whose return type is inferred to be Any in ir, and modifies it to be the return type given by the code cache.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
     interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
-)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.id_to_line_mapMethod
id_to_line_map(ir::BBCode)

Produces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.insert_before_terminator!Method
insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing

If the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
+)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.id_to_line_mapMethod
id_to_line_map(ir::BBCode)

Produces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.insert_before_terminator!Method
insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing

If the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
     inst::Vector{Any},
     argtypes::Vector{Any},
     sptypes::Vector{CC.VarState}=CC.VarState[],
-) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_reachable_return_nodeMethod
is_reachable_return_node(x::ReturnNode)

Determine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
-getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
+) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_reachable_return_nodeMethod
is_reachable_return_node(x::ReturnNode)

Determine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
+getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
 y = expr
 Expr(:gc_preserve_end, token)

These expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.

In the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.

To achieve this, we replace the primal code with

# store `x` in `pb_gc_preserve` to prevent it from being freed.
 _, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)
@@ -74,11 +74,11 @@
 _, dargs... = foo_pb(dy)
 
 # No-op pullback associated to `gc_preserve`.
-pb_gc_preserve(NoRData())
source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lookup_irMethod
lookup_ir(
     interp::AbstractInterpreter,
     sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
-)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
-setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
+)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
+setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
     pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
 )

preds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of

prev_block = pop!(block_stack)
 not_pred_was_1 = !(prev_block == ID(1))
@@ -87,41 +87,41 @@
     not_pred_was_1 => ID(1),
     not_pred_was_2 => ID(2),
     ID(3)
-)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.misty_closureMethod
misty_closure(
+)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.new_instFunction
new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction

Create a NewInstruction with fields:

  • stmt = stmt
  • type = type
  • info = CC.NoCallInfo()
  • line = Int32(1)
  • flag = flag
source
Mooncake.new_inst_vecMethod
new_inst_vec(x::CC.InstructionStream)

Convert an Compiler.InstructionStream into a list of Compiler.NewInstructions.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.new_instFunction
new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction

Create a NewInstruction with fields:

  • stmt = stmt
  • type = type
  • info = CC.NoCallInfo()
  • line = Int32(1)
  • flag = flag
source
Mooncake.new_inst_vecMethod
new_inst_vec(x::CC.InstructionStream)

Convert an Compiler.InstructionStream into a list of Compiler.NewInstructions.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.opaque_closureMethod
opaque_closure(
     ret_type::Type,
     ir::IRCode,
     @nospecialize env...;
     isva::Bool=false,
     do_compile::Bool=true,
-)::Core.OpaqueClosure{<:Tuple, ret_type}

Construct a Core.OpaqueClosure. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile), but instead of letting Core.compute_oc_rettype figure out the return type from ir, impose ret_type as the return type.

Warning

User beware: if the Core.OpaqueClosure produced by this function ever returns anything which is not an instance of a subtype of ret_type, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!

Extended Help

This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosures without actually having constructed them – see LazyDerivedRule. Without the capability to specify the return type, we have to guess what type compute_ir_rettype will return for a given IRCode before we have constructed the IRCode and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.

Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.

By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type correctly.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.phi_nodesMethod
phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}

Returns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.

source
Mooncake.prepare_gradient_cacheMethod
prepare_gradient_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.prepare_pullback_cacheMethod
prepare_pullback_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
+)::Core.OpaqueClosure{<:Tuple, ret_type}

Construct a Core.OpaqueClosure. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile), but instead of letting Core.compute_oc_rettype figure out the return type from ir, impose ret_type as the return type.

Warning

User beware: if the Core.OpaqueClosure produced by this function ever returns anything which is not an instance of a subtype of ret_type, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!

Extended Help

This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosures without actually having constructed them – see LazyDerivedRule. Without the capability to specify the return type, we have to guess what type compute_ir_rettype will return for a given IRCode before we have constructed the IRCode and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.

Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.

By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type correctly.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.phi_nodesMethod
phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}

Returns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.

source
Mooncake.prepare_gradient_cacheMethod
prepare_gradient_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.prepare_pullback_cacheMethod
prepare_pullback_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_unreachable_blocks!Method
remove_unreachable_blocks!(ir::BBCode)::BBCode

If a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:

  1. it has no predecessors and it is not the first block, or
  2. all of its predecessors are themselves unreachable.

For example, consider the following IR:

julia> ir = Mooncake.ircode(
+true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_unreachable_blocks!Method
remove_unreachable_blocks!(ir::BBCode)::BBCode

If a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:

  1. it has no predecessors and it is not the first block, or
  2. all of its predecessors are themselves unreachable.

For example, consider the following IR:

julia> ir = Mooncake.ircode(
            Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],
            Any[Any, Any, Any],
        );

There is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:

julia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))
-1 1 ─     return nothing

In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.reverse_data_ref_stmtsMethod
reverse_data_ref_stmts(info::ADInfo)

Create the :new statements which initialise the reverse-data Refs. Interpolates the initial rdata directly into the statement, which is safe because it is always a bits type.

source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
+1 1 ─     return nothing

In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.reverse_data_ref_stmtsMethod
reverse_data_ref_stmts(info::ADInfo)

Create the :new statements which initialise the reverse-data Refs. Interpolates the initial rdata directly into the statement, which is safe because it is always a bits type.

source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
 function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
     return rrule_wrapper(f, args...)
-end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
+end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
 %7 = φ (#2 => 5., #3 => _2)

Let the tangent refs associated to %6, %7, and _1be denotedt%6,t%7, andt1resp., and letpredidbe#2`, then this function will produce a basic block of the form

increment_ref!(t_1, t%6)
 nothing
-goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

source
Mooncake.seed_id!Method
seed_id!()

Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.

This is akin to setting the random seed associated to a random number generator globally.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
+goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

source
Mooncake.seed_id!Method
seed_id!()

Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.

This is akin to setting the random seed associated to a random number generator globally.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
     (ID(5), new_inst(:(getfield(_1, 1)))),
     (ID(3), new_inst(:(getfield(_1, 2)))),
-]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stable_allMethod
stable_all(x::NTuple{N, Bool}) where {N}

all(x::NTuple{N, Bool}) does not constant-fold nicely on 1.10 if the values of x are known statically. This implementation constant-folds nicely on both 1.10 and 1.11, so can be used in its place in situations where this is important.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.terminatorMethod
terminator(bb::BBlock)

Returns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
+]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stable_allMethod
stable_all(x::NTuple{N, Bool}) where {N}

all(x::NTuple{N, Bool}) does not constant-fold nicely on 1.10 if the values of x are known statically. This implementation constant-folds nicely on both 1.10 and 1.11, so can be used in its place in situations where this is important.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.terminatorMethod
terminator(bb::BBlock)

Returns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
 
 julia> foo(x::Vararg{Int}) = 5
 foo (generic function with 1 method)
@@ -131,7 +131,7 @@
 julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
 
 julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
-(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
+(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
 
 julia> using ChainRulesCore
 
@@ -176,7 +176,7 @@
        TestUtils.test_rule(
            Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
        )
-Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
+Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
 foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
 

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
 scale (generic function with 1 method)
@@ -196,7 +196,7 @@
 julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
 
 julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
-(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
+(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
 
 julia> foo(x) = 5
 foo (generic function with 1 method)
@@ -218,7 +218,7 @@
 true
 
 julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
-(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
+(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
 Core.IntrinsicFunction
 
 julia> typeof(Core.Intrinsics.sub_float)
@@ -228,4 +228,4 @@
     # return add_float and its pullback
 elseif
     ...
-end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source
+end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source
diff --git a/previews/PR435/developer_documentation/running_tests_locally/index.html b/previews/PR435/developer_documentation/running_tests_locally/index.html index 210b56d8a..98142d88a 100644 --- a/previews/PR435/developer_documentation/running_tests_locally/index.html +++ b/previews/PR435/developer_documentation/running_tests_locally/index.html @@ -1,2 +1,2 @@ -Running Tests Locally · Mooncake.jl

Running Tests Locally

Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.

There are two workflows for running tests, discussed below.

Main Testing Functionality

For all code in src, Mooncake's tests are organised as follows:

  1. Things that are required for most / all test suites are loaded up in test/front_matter.jl.
  2. The tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.

Thus, a workflow that I (Will) find works very well is the following:

  1. Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.
  2. start the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.
  3. using TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.
  4. Run the following: using Pkg; Pkg.activate("."); TestEnv.activate(); include("test/front_matter.jl"); to set up your environment.
  5. include whichever test file you want to run the tests from.
  6. Modify code, and re-include tests to check it has done was you need. Loop this until done.
  7. Make a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.

The purpose of this approach is to:

  1. Avoid restarting the REPL each time you make a change, and
  2. Run the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.

If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.

Extension and Integration Testing

Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.

These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.

+Running Tests Locally · Mooncake.jl

Running Tests Locally

Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.

There are two workflows for running tests, discussed below.

Main Testing Functionality

For all code in src, Mooncake's tests are organised as follows:

  1. Things that are required for most / all test suites are loaded up in test/front_matter.jl.
  2. The tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.

Thus, a workflow that I (Will) find works very well is the following:

  1. Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.
  2. start the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.
  3. using TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.
  4. Run the following: using Pkg; Pkg.activate("."); TestEnv.activate(); include("test/front_matter.jl"); to set up your environment.
  5. include whichever test file you want to run the tests from.
  6. Modify code, and re-include tests to check it has done was you need. Loop this until done.
  7. Make a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.

The purpose of this approach is to:

  1. Avoid restarting the REPL each time you make a change, and
  2. Run the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.

If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.

Extension and Integration Testing

Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.

These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.

diff --git a/previews/PR435/index.html b/previews/PR435/index.html index 70b6a923f..db672d245 100644 --- a/previews/PR435/index.html +++ b/previews/PR435/index.html @@ -1,2 +1,2 @@ -Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

+Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

diff --git a/previews/PR435/known_limitations/index.html b/previews/PR435/known_limitations/index.html index 222288462..32e4fb0fb 100644 --- a/previews/PR435/known_limitations/index.html +++ b/previews/PR435/known_limitations/index.html @@ -38,4 +38,4 @@ Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0]) # output -(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

+(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

diff --git a/previews/PR435/understanding_mooncake/algorithmic_differentiation/index.html b/previews/PR435/understanding_mooncake/algorithmic_differentiation/index.html index dd44d3ee1..1c5a75d5a 100644 --- a/previews/PR435/understanding_mooncake/algorithmic_differentiation/index.html +++ b/previews/PR435/understanding_mooncake/algorithmic_differentiation/index.html @@ -23,4 +23,4 @@ D f [x] (\dot{x}) &= [(D \mathcal{l} [g(x)]) \circ (D g [x])](\dot{x}) \nonumber \\ &= \langle \bar{y}, D g [x] (\dot{x}) \rangle \nonumber \\ &= \langle D g [x]^\ast (\bar{y}), \dot{x} \rangle, \nonumber -\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
+\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
diff --git a/previews/PR435/understanding_mooncake/introduction/index.html b/previews/PR435/understanding_mooncake/introduction/index.html index a14bf1c2c..ad85b2709 100644 --- a/previews/PR435/understanding_mooncake/introduction/index.html +++ b/previews/PR435/understanding_mooncake/introduction/index.html @@ -1,2 +1,2 @@ -Introduction · Mooncake.jl

Introduction

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

+Introduction · Mooncake.jl

Introduction

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

diff --git a/previews/PR435/understanding_mooncake/rule_system/index.html b/previews/PR435/understanding_mooncake/rule_system/index.html index 3660ead54..0f01f2cc8 100644 --- a/previews/PR435/understanding_mooncake/rule_system/index.html +++ b/previews/PR435/understanding_mooncake/rule_system/index.html @@ -51,7 +51,7 @@ x::Float64 end

you will find that it is

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64}}

Primitive Types

We've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.

One interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example

julia> tangent_type(Ptr{Float64})
-Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
+Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
 (NoFData, NoRData)

Float64

The tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so

julia> (fdata_type(Float64), rdata_type(Float64))
 (NoFData, Float64)

Vector{Float64}

The tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so

julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
 (Vector{Float64}, NoRData)

Tuple{Float64, Vector{Float64}, Int}

This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:

julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
@@ -70,7 +70,7 @@
            z::Int
        end

has tangent type

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}

and fdata / rdata types

julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
-(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
+(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
 foo (generic function with 1 method)

The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.

Under the framework introduced above, the model for this function is

\[f(x) = (x, x_1 + \sum_{n=1}^N (x_2)_n)\]

where the vector in the second element of x is of length $N$. Now, following our usual steps, the derivative is

\[D f [x](\dot{x}) = (\dot{x}, \dot{x}_1 + \sum_{n=1}^N (\dot{x}_2)_n)\]

A gradient for this is a tuple $(\bar{y}_x, \bar{y}_a)$ where $\bar{y}_a \in \RR$ and $\bar{y}_x \in \RR \times \RR^N$. A quick derivation will show that the adjoint is

\[D f [x]^\ast(\bar{y}) = ((\bar{y}_x)_1 + \bar{y}_a, (\bar{y}_x)_2 + \bar{y}_a \mathbf{1})\]

where $\mathbf{1}$ is the vector of length $N$ in which each element is equal to $1$. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).

Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.

julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})
            dx_fdata = x.dx
            function dfoo_adjoint(dy::Float64)
@@ -105,4 +105,4 @@
 typeof(g)
 
 # output
-typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

+typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

diff --git a/previews/PR435/utilities/debug_mode/index.html b/previews/PR435/utilities/debug_mode/index.html index 1c214d032..d9e6cfc61 100644 --- a/previews/PR435/utilities/debug_mode/index.html +++ b/previews/PR435/utilities/debug_mode/index.html @@ -2,5 +2,5 @@ Debug Mode · Mooncake.jl

Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
     plus_reverse_pass(dz::Real) = NoRData(), dz, dz
     return zero_fcodual(primal(x) + primal(y))
-end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
-AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

+end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
+AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

diff --git a/previews/PR435/utilities/debugging_and_mwes/index.html b/previews/PR435/utilities/debugging_and_mwes/index.html index 9fa63cd2d..d18391de2 100644 --- a/previews/PR435/utilities/debugging_and_mwes/index.html +++ b/previews/PR435/utilities/debugging_and_mwes/index.html @@ -7,5 +7,5 @@ interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, unsafe_perturb::Bool=false, -)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

Arguments

Keyword Arguments

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
-Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

+)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

Arguments

Keyword Arguments

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
+Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

diff --git a/previews/PR435/utilities/defining_rules/index.html b/previews/PR435/utilities/defining_rules/index.html index 87ba02742..64e8b30de 100644 --- a/previews/PR435/utilities/defining_rules/index.html +++ b/previews/PR435/utilities/defining_rules/index.html @@ -19,7 +19,7 @@ julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) -(20.0, (NoTangent(), 4.0))source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
+(20.0, (NoTangent(), 4.0))
source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
 
 julia> foo(x) = 5
 foo (generic function with 1 method)
@@ -41,7 +41,7 @@
 true
 
 julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
-(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
+(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
 
 julia> foo(x::Vararg{Int}) = 5
 foo (generic function with 1 method)
@@ -51,7 +51,7 @@
 julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
 
 julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
-(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
+(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
 
 julia> using ChainRulesCore
 
@@ -96,10 +96,10 @@
        TestUtils.test_rule(
            Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
        )
-Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source

Adding Methods To rrule!! And build_primitive_rrule

If the above strategies do not work for you, you should first implement a method of Mooncake.is_primitive for the signature of interest:

Mooncake.is_primitiveFunction
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source

Then implement a method of one of the following:

Mooncake.rrule!!Function
rrule!!(f::CoDual, x::CoDual...)

Performs the forwards-pass of AD. The tangent field of f and each x should contain the forwards tangent data (fdata) associated to each corresponding primal field.

Returns a 2-tuple. The first element, y, is a CoDual whose primal field is the value associated to running f.primal(map(x -> x.primal, x)...), and whose tangent field is its associated fdata. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y to the rdata associated to f and each x.

using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
+Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source

Adding Methods To rrule!! And build_primitive_rrule

If the above strategies do not work for you, you should first implement a method of Mooncake.is_primitive for the signature of interest:

Mooncake.is_primitiveFunction
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source

Then implement a method of one of the following:

Mooncake.rrule!!Function
rrule!!(f::CoDual, x::CoDual...)

Performs the forwards-pass of AD. The tangent field of f and each x should contain the forwards tangent data (fdata) associated to each corresponding primal field.

Returns a 2-tuple. The first element, y, is a CoDual whose primal field is the value associated to running f.primal(map(x -> x.primal, x)...), and whose tangent field is its associated fdata. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y to the rdata associated to f and each x.

using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
 y, pb!! = rrule!!(zero_fcodual(sin), CoDual(5.0, NoFData()))
 pb!!(1.0)
 
 # output
 
-(NoRData(), 0.28366218546322625)
source
Mooncake.build_primitive_rruleFunction
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that is_primitive(context_type, sig) is true. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source
+(NoRData(), 0.28366218546322625)source
Mooncake.build_primitive_rruleFunction
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that is_primitive(context_type, sig) is true. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source