Skip to content

Commit

Permalink
Support @check_allocs at callsites
Browse files Browse the repository at this point in the history
  • Loading branch information
tecosaur committed Nov 21, 2023
1 parent 48a8c34 commit 31f2f96
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
65 changes: 50 additions & 15 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,22 @@ end

"""
@check_allocs ignore_throw=true (function def)
@check_allocs ignore_throw=true func(...)
Wraps the provided function definition so that all calls to it will be automatically
checked for allocations.
If the check fails, an `AllocCheckFailure` exception is thrown containing the detailed
failures, including the backtrace for each defect.
Note: All calls to the wrapped function are effectively a dynamic dispatch, which
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
only guarantees the absence of allocations after the function has started running.
`@check_allocs` can also be applied to a function call, which operates by creating
an anonymous function that is passed to `@check_allocs` and then immediately calling
the wrapped result.
!!! note
All calls to the wrapped function are effectively a dynamic dispatch, which
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
only guarantees the absence of allocations after the function has started running.
# Example
```jldoctest
Expand All @@ -45,23 +51,27 @@ julia> multiply(1.5, 3.5) # no allocations for Float64
5.25
julia> multiply(rand(3,3), rand(3,3)) # matmul needs to allocate the result
ERROR: @check_alloc function contains 1 allocations.
ERROR: @check_alloc function contains 1 allocations (1 allocations / 0 dynamic dispatches).
Stacktrace:
[1] macro expansion
@ ~/repos/AllocCheck/src/macro.jl:134 [inlined]
@ ~/.julia/dev/AllocCheck/src/macro.jl:157 [inlined]
[2] multiply(x::Matrix{Float64}, y::Matrix{Float64})
@ Main ./REPL[2]:133
@ Main ./REPL[2]:156
[3] top-level scope
@ REPL[5]:1
@ REPL[4]:1
julia> @check_allocs 1.5 * 3.5 # check a call
5.25
```
"""
macro check_allocs(ex...)
kws, body = extract_keywords(ex)
if _is_func_def(body)
return _check_allocs_macro(body, __module__, __source__; kws...)
return _check_allocs_defun(body, __module__, __source__; kws...)
elseif Meta.isexpr(body, :call)
return _check_allocs_call(body, __module__, __source__; kws...)
else
error("@check_allocs used on something other than a function definition")
error("@check_allocs used on anything other than a function definition or call")
end
end

Expand Down Expand Up @@ -117,13 +127,20 @@ function forward_args!(func_def)
args, kwargs
end

function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
function _check_allocs_defun(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
(; original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(ex, mod, source; ignore_throw)
quote
local $f_sym = $(esc(original_fn))
$wrapper_fn
end
end

function _check_allocs_wrap_fn(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
# Transform original function to a renamed version with flattened args
def = splitdef(deepcopy(ex))
normalize_args!(def)
original_fn = combinedef(def)
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym()
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym("fn_alias")

# Next, create a wrapper function that will compile the original function on-the-fly.
def = splitdef(ex)
Expand All @@ -149,8 +166,26 @@ function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; igno
def[:body].args[1] = source

wrapper_fn = combinedef(def)
return quote
local $f_sym = $(esc(original_fn))
$(wrapper_fn)

(; original_fn, f_sym, wrapper_fn)
end

function _check_allocs_call(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
fn = first(ex.args)
args = ex.args[2:end]
args_template = if !isempty(args) && Meta.isexpr(first(args), :parameters)
kwargs = Expr(:parameters, map(a -> if Meta.isexpr(a, :kw) first(a.args) else a end::Symbol, first(args).args)...)
[kwargs, map(_ -> gensym("arg"), 2:length(args))...]
else
[map(_ -> gensym("arg"), 1:length(args))...]
end
passthrough_defun = Expr(:function, Expr(:tuple, args_template...), Expr(:call, fn, args_template...))
(original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(passthrough_defun, mod, source; ignore_throw)
af_sym = gensym("alloccheck_fn")
quote
let $f_sym = $(esc(original_fn))
$af_sym = $wrapper_fn
$(Expr(:call, af_sym, map(esc, args)...))
end
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ end
@test mysum2(x, y) == x + y
@check_allocs (x::Bar)(y::Bar) = x.val + y.val
@test Bar(x)(Bar(y)) == x + y

# Callsite forms
@test 1 + x == @check_allocs 1 + x
@test x^2 == @check_allocs (a -> a^2)(x)
@test_throws AllocCheck.AllocCheckFailure @check_allocs same_ccall()
end


Expand Down

0 comments on commit 31f2f96

Please sign in to comment.