Skip to content

Commit

Permalink
Add support for automatically calling unsafe_load() in getproperty()
Browse files Browse the repository at this point in the history
Copying the description from the code:
> By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped
> types will return pointers (Ptr{T}) to the struct fields. That behaviour is
> useful for accessing nested struct fields but it does require explicitly
> calling unsafe_load() every time. When enabled this option will automatically
> call unsafe_load() for you *except on nested struct fields and arrays*, which
> should make explicitly calling unsafe_load() unnecessary in most cases.
  • Loading branch information
JamesWrigley committed Aug 20, 2024
1 parent a226276 commit d901cb7
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Changelog](https://keepachangelog.com).
([5a1cc29](https://github.com/JuliaInterop/Clang.jl/commit/5a1cc29c154ed925f01e59dfd705cbf8042158e4)).
- Added bindings for Clang 17, which should allow compatibility with Julia 1.12
([#494]).
- Experimental support for automatically dereferencing struct fields in
`Base.getproperty()` with the `auto_field_dereference` option ([#502]).

### Fixed

Expand Down
13 changes: 13 additions & 0 deletions gen/generator.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ wrap_variadic_function = false
# generate getproperty/setproperty! methods for the types in the following list
field_access_method_list = []

# EXPERIMENTAL:
# By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped
# types will return pointers (Ptr{T}) to the struct fields. That behaviour is
# useful for accessing nested struct fields but it does require explicitly
# calling unsafe_load() every time. When enabled this option will automatically
# call unsafe_load() for you *except on nested struct fields and arrays*, which
# should make explicitly calling unsafe_load() unnecessary in most cases. A @ptr
# macro will be defined for cases where you really do want a pointer to a field
# (e.g. for writing), which supports syntax like `@ptr(foo.bar)`.
#
# This should be used with `field_access_method_list`.
auto_field_dereference = false

# the generator will prefix the function argument names in the following list with a "_" to
# prevent the generated symbols from conflicting with the symbols defined and exported in Base.
function_argument_conflict_symbols = []
Expand Down
70 changes: 63 additions & 7 deletions src/generator/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,20 @@ end

############################### Struct ###############################

function _emit_getproperty_ptr!(body, root_cursor, cursor, options)
function _emit_pointer_access!(body, root_cursor, cursor, options)
field_cursors = fields(getCursorType(cursor))
field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors
for field_cursor in field_cursors
n = name(field_cursor)
if isempty(n)
_emit_getproperty_ptr!(body, root_cursor, field_cursor, options)
_emit_pointer_access!(body, root_cursor, field_cursor, options)
continue
end
fsym = make_symbol_safe(n)
fty = getCursorType(field_cursor)
ty = translate(tojulia(fty), options)
offset = getOffsetOf(getCursorType(root_cursor), n)

if isBitField(field_cursor)
w = getFieldDeclBitWidth(field_cursor)
@assert w <= 32 # Bit fields should not be larger than int(32 bits)
Expand All @@ -322,12 +323,63 @@ function _emit_getproperty_ptr!(body, root_cursor, cursor, options)
end
end

# Base.getproperty(x::Ptr, f::Symbol) -> Ptr
# getptr(x::Ptr, f::Symbol) -> Ptr
function emit_getptr!(dag, node, options)
sym = make_symbol_safe(node.id)
signature = Expr(:call, :getptr, :(x::Ptr{$sym}), :(f::Symbol))
body = Expr(:block)
_emit_pointer_access!(body, node.cursor, node.cursor, options)

push!(body.args, :(error($("Unrecognized field of type `$sym`") * ": $f")))
push!(node.exprs, Expr(:function, signature, body))
return dag
end

function emit_deref_getproperty!(body, root_cursor, cursor, options)
field_cursors = fields(getCursorType(cursor))
field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors
for field_cursor in field_cursors
n = name(field_cursor)
if isempty(n)
emit_deref_getproperty!(body, root_cursor, field_cursor, options)
continue
end
fsym = make_symbol_safe(n)
fty = getCursorType(field_cursor)
canonical_type = getCanonicalType(fty)

return_expr = :(getptr(x, f))

# Automatically dereference all field types except for nested structs
# and arrays.
if !(canonical_type isa Union{CLRecord, CLConstantArray}) && !isBitField(field_cursor)
return_expr = :(unsafe_load($return_expr))
elseif isBitField(field_cursor)
return_expr = :(getbitfieldproperty(x, $return_expr))
end

ex = :(f === $(QuoteNode(fsym)) && return $return_expr)
push!(body.args, ex)
end
end

# Base.getproperty(x::Ptr, f::Symbol)
function emit_getproperty_ptr!(dag, node, options)
auto_deref = get(options, "auto_field_dereference", false)
sym = make_symbol_safe(node.id)

# If automatically dereferencing, we first need to emit getptr!()
if auto_deref
emit_getptr!(dag, node, options)
end

signature = Expr(:call, :(Base.getproperty), :(x::Ptr{$sym}), :(f::Symbol))
body = Expr(:block)
_emit_getproperty_ptr!(body, node.cursor, node.cursor, options)
if auto_deref
emit_deref_getproperty!(body, node.cursor, node.cursor, options)
else
_emit_pointer_access!(body, node.cursor, node.cursor, options)
end
push!(body.args, :(return getfield(x, f)))
getproperty_expr = Expr(:function, signature, body)
push!(node.exprs, getproperty_expr)
Expand Down Expand Up @@ -370,10 +422,14 @@ end
function emit_setproperty!(dag, node, options)
sym = make_symbol_safe(node.id)
signature = Expr(:call, :(Base.setproperty!), :(x::Ptr{$sym}), :(f::Symbol), :v)
store_expr = :(unsafe_store!(getproperty(x, f), v))

auto_deref = get(options, "auto_field_dereference", false)
pointer_getter = auto_deref ? :getptr : :getproperty
store_expr = :(unsafe_store!($pointer_getter(x, f), v))

if is_bitfield_type(node.type)
body = quote
fptr = getproperty(x, f)
fptr = $pointer_getter(x, f)
if fptr isa Ptr
$store_expr
else
Expand All @@ -398,7 +454,7 @@ function get_names_types(root_cursor, cursor, options)
for field_cursor in field_cursors
n = name(field_cursor)
if isempty(n)
_emit_getproperty_ptr!(root_cursor, field_cursor, options)
_emit_pointer_access!(root_cursor, field_cursor, options)
continue
end
fsym = make_symbol_safe(n)
Expand Down
16 changes: 16 additions & 0 deletions src/generator/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,7 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict)
use_native_enum = get(general_options, "use_julia_native_enum_type", false)
print_CEnum = get(general_options, "print_using_CEnum", true)
wrap_variadic_function = get(codegen_options, "wrap_variadic_function", false)
auto_deref = get(codegen_options, "auto_field_dereference", false)

show_info && @info "[ProloguePrinter]: print to $(x.file)"
open(x.file, "w") do io
Expand Down Expand Up @@ -1186,6 +1187,21 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict)
println(io, string(set_expr), "\n")
end

if auto_deref
println(io, raw"""
macro ptr(expr)
if !Meta.isexpr(expr, :.)
error("Expression is not a property access, cannot use @ptr on it.")
end
quote
local penultimate_obj = $(esc(expr.args[1]))
getptr(penultimate_obj, $(esc(expr.args[2])))
end
end
""")
end

# print prelogue patches
if !isempty(prologue_file_path)
println(io, read(prologue_file_path, String))
Expand Down
101 changes: 101 additions & 0 deletions test/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,104 @@ end
@test docstring_has("callback")
end
end

@testset "Struct getproperty()/setproperty!()" begin
options = Dict("general" => Dict{String, Any}("auto_mutability" => true,
"auto_mutability_with_new" => false,
"auto_mutability_includelist" => ["WithFields"]),
"codegen" => Dict{String, Any}("field_access_method_list" => ["WithFields", "Other"]))

# Test the default getproperty()/setproperty!() behaviour
mktemp() do path, io
options["general"]["output_file_path"] = path
ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options)
build!(ctx)

println(read(path, String))

m = Module()
Base.include(m, path)

# We now have to run in the latest world to use the new definitions
Base.invokelatest() do
obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1))

GC.@preserve obj begin
obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj))

# The default getproperty() should basically always return a
# pointer to the field (except for bitfields, which are tested
# elsewhere).
@test obj_ptr.int_value isa Ptr{Cint}
@test obj_ptr.int_ptr isa Ptr{Ptr{Cint}}
@test obj_ptr.struct_value isa Ptr{m.Other}
@test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct}
@test obj_ptr.array isa Ptr{NTuple{2, Cint}}

# Sanity test
int_value = unsafe_load(obj_ptr.int_value)
@test int_value == obj.int_value

# Test setproperty!()
obj_ptr.int_value = int_value + 1
@test unsafe_load(obj_ptr.int_value) == int_value + 1
end
end
end

# Test the auto_field_dereference option
mktemp() do path, io
options["general"]["output_file_path"] = path
options["codegen"]["auto_field_dereference"] = true
ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options)
build!(ctx)

println(read(path, String))

m = Module()
Base.include(m, path)

# We now have to run in the latest world to use the new definitions
Base.invokelatest() do
obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1))

GC.@preserve obj begin
obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj))

# Test getproperty()
@test obj_ptr.int_value isa Cint
@test obj_ptr.int_value == obj.int_value
@test obj_ptr.int_ptr isa Ptr{Cint}

@test obj_ptr.struct_value isa Ptr{m.Other}
@test obj_ptr.struct_value.i == obj.struct_value.i
@test obj_ptr.struct_ptr isa Ptr{m.Other}
@test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct}

@test obj_ptr.array isa Ptr{NTuple{2, Cint}}

@test_throws ErrorException obj_ptr.foo

# Test @ptr
val_ptr = @eval m @ptr $obj_ptr.int_value
@test val_ptr isa Ptr{Cint}
int_ptr = @eval m @ptr $obj_ptr.int_ptr
@test int_ptr isa Ptr{Ptr{Cint}}

@test_throws LoadError (@eval m @ptr $obj_ptr)
@test_throws ErrorException (@eval m @ptr $obj_ptr.foo)

# Test setproperty!()
new_value = obj.int_value * 2
obj_ptr.int_value = new_value
@test obj.int_value == new_value

new_value = obj.struct_value.i * 2
obj_ptr.struct_value.i = new_value
@test obj.struct_value.i == new_value

@test_throws ErrorException obj_ptr.foo = 1
end
end
end
end
18 changes: 18 additions & 0 deletions test/include/struct-properties.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
typedef struct {
int i;
} TypedefStruct;

struct Other {
int i;
};

struct WithFields {
int int_value;
int* int_ptr;

struct Other struct_value;
struct Other* struct_ptr;
TypedefStruct typedef_struct_value;

int array[2];
};
70 changes: 43 additions & 27 deletions test/test_bitfield.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,44 +61,60 @@ function build_libbitfield()
error("Could not build libbitfield binary")
end

# Generate wrappers
@info "Building libbitfield wrapper"
args = get_default_args()
headers = joinpath(@__DIR__, "build", "include", "bitfield.h")
options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml"))
lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield")
options["general"]["library_name"] = "\"$(escape_string(lib_path))\""
options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl")
ctx = create_context(headers, args, options)
build!(ctx)

# Call a function to ensure build is successful
include("LibBitField.jl")
m = Base.@invokelatest LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)
Base.@invokelatest LibBitField.toBitfield(Ref(m))
# Test the binary
generate_wrappers(false)
catch e
@warn "Building libbitfield failed: $e"
success = false
end
return success
end

function generate_wrappers(auto_deref::Bool)
@info "Building libbitfield wrapper"
args = get_default_args()
headers = joinpath(@__DIR__, "build", "include", "bitfield.h")
options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml"))
options["codegen"]["auto_field_dereference"] = auto_deref
options["codegen"]["field_access_method_list"] = ["BitField"]

lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield")
options["general"]["library_name"] = "\"$(escape_string(lib_path))\""
options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl")
ctx = create_context(headers, args, options)
build!(ctx)

# Call a function to ensure build is successful
anonmod = Module()
Base.include(anonmod, "LibBitField.jl")
m = Base.@invokelatest anonmod.LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)
Base.@invokelatest anonmod.LibBitField.toBitfield(Ref(m))

return anonmod
end

@testset "Bitfield" begin
if build_libbitfield()
bf = Ref(LibBitField.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3)))
m = Ref(LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3))
GC.@preserve bf m begin
pbf = Ptr{LibBitField.BitField}(pointer_from_objref(bf))
pm = Ptr{LibBitField.Mirror}(pointer_from_objref(m))
@test LibBitField.toMirror(bf) == m[]
@test LibBitField.toBitfield(m).a == bf[].a
@test LibBitField.toBitfield(m).b == bf[].b
@test LibBitField.toBitfield(m).c == bf[].c
@test LibBitField.toBitfield(m).d == bf[].d
@test LibBitField.toBitfield(m).e == bf[].e
@test LibBitField.toBitfield(m).f == bf[].f
# Test the wrappers with and without auto-dereferencing. In the case of
# bitfields they should have identical behaviour.
for auto_deref in [false, true]
anonmod = generate_wrappers(auto_deref)
lib = anonmod.LibBitField

bf = Ref(lib.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3)))
m = Ref(lib.Mirror(10, 1.5, 1e6, -4, 7, 3))

GC.@preserve bf m begin
pbf = Ptr{lib.BitField}(pointer_from_objref(bf))
pm = Ptr{lib.Mirror}(pointer_from_objref(m))
@test lib.toMirror(bf) == m[]
@test lib.toBitfield(m).a == bf[].a
@test lib.toBitfield(m).b == bf[].b
@test lib.toBitfield(m).c == bf[].c
@test lib.toBitfield(m).d == bf[].d
@test lib.toBitfield(m).e == bf[].e
@test lib.toBitfield(m).f == bf[].f
end
end
end
end

0 comments on commit d901cb7

Please sign in to comment.