diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 57177dd6..d285c2d0 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -11,6 +11,8 @@ Changelog](https://keepachangelog.com). ([5a1cc29](https://github.com/JuliaInterop/Clang.jl/commit/5a1cc29c154ed925f01e59dfd705cbf8042158e4)). - Added bindings for Clang 17, which should allow compatibility with Julia 1.12 ([#494]). +- Experimental support for automatically dereferencing struct fields in + `Base.getproperty()` with the `auto_field_dereference` option ([#502]). ### Fixed diff --git a/gen/generator.toml b/gen/generator.toml index 08cbfe04..b3029e6c 100644 --- a/gen/generator.toml +++ b/gen/generator.toml @@ -181,6 +181,19 @@ wrap_variadic_function = false # generate getproperty/setproperty! methods for the types in the following list field_access_method_list = [] +# EXPERIMENTAL: +# By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped +# types will return pointers (Ptr{T}) to the struct fields. That behaviour is +# useful for accessing nested struct fields but it does require explicitly +# calling unsafe_load() every time. When enabled this option will automatically +# call unsafe_load() for you *except on nested struct fields and arrays*, which +# should make explicitly calling unsafe_load() unnecessary in most cases. A @ptr +# macro will be defined for cases where you really do want a pointer to a field +# (e.g. for writing), which supports syntax like `@ptr(foo.bar)`. +# +# This should be used with `field_access_method_list`. +auto_field_dereference = false + # the generator will prefix the function argument names in the following list with a "_" to # prevent the generated symbols from conflicting with the symbols defined and exported in Base. function_argument_conflict_symbols = [] diff --git a/src/generator/codegen.jl b/src/generator/codegen.jl index 5de6ecae..6e1a090b 100644 --- a/src/generator/codegen.jl +++ b/src/generator/codegen.jl @@ -296,19 +296,20 @@ end ############################### Struct ############################### -function _emit_getproperty_ptr!(body, root_cursor, cursor, options) +function _emit_pointer_access!(body, root_cursor, cursor, options) field_cursors = fields(getCursorType(cursor)) field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors for field_cursor in field_cursors n = name(field_cursor) if isempty(n) - _emit_getproperty_ptr!(body, root_cursor, field_cursor, options) + _emit_pointer_access!(body, root_cursor, field_cursor, options) continue end fsym = make_symbol_safe(n) fty = getCursorType(field_cursor) ty = translate(tojulia(fty), options) offset = getOffsetOf(getCursorType(root_cursor), n) + if isBitField(field_cursor) w = getFieldDeclBitWidth(field_cursor) @assert w <= 32 # Bit fields should not be larger than int(32 bits) @@ -322,12 +323,63 @@ function _emit_getproperty_ptr!(body, root_cursor, cursor, options) end end -# Base.getproperty(x::Ptr, f::Symbol) -> Ptr +# getptr(x::Ptr, f::Symbol) -> Ptr +function emit_getptr!(dag, node, options) + sym = make_symbol_safe(node.id) + signature = Expr(:call, :getptr, :(x::Ptr{$sym}), :(f::Symbol)) + body = Expr(:block) + _emit_pointer_access!(body, node.cursor, node.cursor, options) + + push!(body.args, :(error($("Unrecognized field of type `$sym`") * ": $f"))) + push!(node.exprs, Expr(:function, signature, body)) + return dag +end + +function emit_deref_getproperty!(body, root_cursor, cursor, options) + field_cursors = fields(getCursorType(cursor)) + field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors + for field_cursor in field_cursors + n = name(field_cursor) + if isempty(n) + emit_deref_getproperty!(body, root_cursor, field_cursor, options) + continue + end + fsym = make_symbol_safe(n) + fty = getCursorType(field_cursor) + canonical_type = getCanonicalType(fty) + + return_expr = :(getptr(x, f)) + + # Automatically dereference all field types except for nested structs + # and arrays. + if !(canonical_type isa Union{CLRecord, CLConstantArray}) && !isBitField(field_cursor) + return_expr = :(unsafe_load($return_expr)) + elseif isBitField(field_cursor) + return_expr = :(getbitfieldproperty(x, $return_expr)) + end + + ex = :(f === $(QuoteNode(fsym)) && return $return_expr) + push!(body.args, ex) + end +end + +# Base.getproperty(x::Ptr, f::Symbol) function emit_getproperty_ptr!(dag, node, options) + auto_deref = get(options, "auto_field_dereference", false) sym = make_symbol_safe(node.id) + + # If automatically dereferencing, we first need to emit getptr!() + if auto_deref + emit_getptr!(dag, node, options) + end + signature = Expr(:call, :(Base.getproperty), :(x::Ptr{$sym}), :(f::Symbol)) body = Expr(:block) - _emit_getproperty_ptr!(body, node.cursor, node.cursor, options) + if auto_deref + emit_deref_getproperty!(body, node.cursor, node.cursor, options) + else + _emit_pointer_access!(body, node.cursor, node.cursor, options) + end push!(body.args, :(return getfield(x, f))) getproperty_expr = Expr(:function, signature, body) push!(node.exprs, getproperty_expr) @@ -370,10 +422,14 @@ end function emit_setproperty!(dag, node, options) sym = make_symbol_safe(node.id) signature = Expr(:call, :(Base.setproperty!), :(x::Ptr{$sym}), :(f::Symbol), :v) - store_expr = :(unsafe_store!(getproperty(x, f), v)) + + auto_deref = get(options, "auto_field_dereference", false) + pointer_getter = auto_deref ? :getptr : :getproperty + store_expr = :(unsafe_store!($pointer_getter(x, f), v)) + if is_bitfield_type(node.type) body = quote - fptr = getproperty(x, f) + fptr = $pointer_getter(x, f) if fptr isa Ptr $store_expr else @@ -398,7 +454,7 @@ function get_names_types(root_cursor, cursor, options) for field_cursor in field_cursors n = name(field_cursor) if isempty(n) - _emit_getproperty_ptr!(root_cursor, field_cursor, options) + _emit_pointer_access!(root_cursor, field_cursor, options) continue end fsym = make_symbol_safe(n) diff --git a/src/generator/passes.jl b/src/generator/passes.jl index 9da947a8..2abb4e9b 100644 --- a/src/generator/passes.jl +++ b/src/generator/passes.jl @@ -1094,6 +1094,7 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict) use_native_enum = get(general_options, "use_julia_native_enum_type", false) print_CEnum = get(general_options, "print_using_CEnum", true) wrap_variadic_function = get(codegen_options, "wrap_variadic_function", false) + auto_deref = get(codegen_options, "auto_field_dereference", false) show_info && @info "[ProloguePrinter]: print to $(x.file)" open(x.file, "w") do io @@ -1186,6 +1187,21 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict) println(io, string(set_expr), "\n") end + if auto_deref + println(io, raw""" + macro ptr(expr) + if !Meta.isexpr(expr, :.) + error("Expression is not a property access, cannot use @ptr on it.") + end + + quote + local penultimate_obj = $(esc(expr.args[1])) + getptr(penultimate_obj, $(esc(expr.args[2]))) + end + end + """) + end + # print prelogue patches if !isempty(prologue_file_path) println(io, read(prologue_file_path, String)) diff --git a/test/generators.jl b/test/generators.jl index 9f72c7e0..4f250d4f 100644 --- a/test/generators.jl +++ b/test/generators.jl @@ -249,3 +249,104 @@ end @test docstring_has("callback") end end + +@testset "Struct getproperty()/setproperty!()" begin + options = Dict("general" => Dict{String, Any}("auto_mutability" => true, + "auto_mutability_with_new" => false, + "auto_mutability_includelist" => ["WithFields"]), + "codegen" => Dict{String, Any}("field_access_method_list" => ["WithFields", "Other"])) + + # Test the default getproperty()/setproperty!() behaviour + mktemp() do path, io + options["general"]["output_file_path"] = path + ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options) + build!(ctx) + + println(read(path, String)) + + m = Module() + Base.include(m, path) + + # We now have to run in the latest world to use the new definitions + Base.invokelatest() do + obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1)) + + GC.@preserve obj begin + obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj)) + + # The default getproperty() should basically always return a + # pointer to the field (except for bitfields, which are tested + # elsewhere). + @test obj_ptr.int_value isa Ptr{Cint} + @test obj_ptr.int_ptr isa Ptr{Ptr{Cint}} + @test obj_ptr.struct_value isa Ptr{m.Other} + @test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct} + @test obj_ptr.array isa Ptr{NTuple{2, Cint}} + + # Sanity test + int_value = unsafe_load(obj_ptr.int_value) + @test int_value == obj.int_value + + # Test setproperty!() + obj_ptr.int_value = int_value + 1 + @test unsafe_load(obj_ptr.int_value) == int_value + 1 + end + end + end + + # Test the auto_field_dereference option + mktemp() do path, io + options["general"]["output_file_path"] = path + options["codegen"]["auto_field_dereference"] = true + ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options) + build!(ctx) + + println(read(path, String)) + + m = Module() + Base.include(m, path) + + # We now have to run in the latest world to use the new definitions + Base.invokelatest() do + obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1)) + + GC.@preserve obj begin + obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj)) + + # Test getproperty() + @test obj_ptr.int_value isa Cint + @test obj_ptr.int_value == obj.int_value + @test obj_ptr.int_ptr isa Ptr{Cint} + + @test obj_ptr.struct_value isa Ptr{m.Other} + @test obj_ptr.struct_value.i == obj.struct_value.i + @test obj_ptr.struct_ptr isa Ptr{m.Other} + @test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct} + + @test obj_ptr.array isa Ptr{NTuple{2, Cint}} + + @test_throws ErrorException obj_ptr.foo + + # Test @ptr + val_ptr = @eval m @ptr $obj_ptr.int_value + @test val_ptr isa Ptr{Cint} + int_ptr = @eval m @ptr $obj_ptr.int_ptr + @test int_ptr isa Ptr{Ptr{Cint}} + + @test_throws LoadError (@eval m @ptr $obj_ptr) + @test_throws ErrorException (@eval m @ptr $obj_ptr.foo) + + # Test setproperty!() + new_value = obj.int_value * 2 + obj_ptr.int_value = new_value + @test obj.int_value == new_value + + new_value = obj.struct_value.i * 2 + obj_ptr.struct_value.i = new_value + @test obj.struct_value.i == new_value + + @test_throws ErrorException obj_ptr.foo = 1 + end + end + end +end diff --git a/test/include/struct-properties.h b/test/include/struct-properties.h new file mode 100644 index 00000000..d836eba1 --- /dev/null +++ b/test/include/struct-properties.h @@ -0,0 +1,18 @@ +typedef struct { + int i; +} TypedefStruct; + +struct Other { + int i; +}; + +struct WithFields { + int int_value; + int* int_ptr; + + struct Other struct_value; + struct Other* struct_ptr; + TypedefStruct typedef_struct_value; + + int array[2]; +}; diff --git a/test/test_bitfield.jl b/test/test_bitfield.jl index ad32b441..fec17ec8 100644 --- a/test/test_bitfield.jl +++ b/test/test_bitfield.jl @@ -61,21 +61,8 @@ function build_libbitfield() error("Could not build libbitfield binary") end - # Generate wrappers - @info "Building libbitfield wrapper" - args = get_default_args() - headers = joinpath(@__DIR__, "build", "include", "bitfield.h") - options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml")) - lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield") - options["general"]["library_name"] = "\"$(escape_string(lib_path))\"" - options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl") - ctx = create_context(headers, args, options) - build!(ctx) - - # Call a function to ensure build is successful - include("LibBitField.jl") - m = Base.@invokelatest LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3) - Base.@invokelatest LibBitField.toBitfield(Ref(m)) + # Test the binary + generate_wrappers(false) catch e @warn "Building libbitfield failed: $e" success = false @@ -83,22 +70,51 @@ function build_libbitfield() return success end +function generate_wrappers(auto_deref::Bool) + @info "Building libbitfield wrapper" + args = get_default_args() + headers = joinpath(@__DIR__, "build", "include", "bitfield.h") + options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml")) + options["codegen"]["auto_field_dereference"] = auto_deref + options["codegen"]["field_access_method_list"] = ["BitField"] + + lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield") + options["general"]["library_name"] = "\"$(escape_string(lib_path))\"" + options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl") + ctx = create_context(headers, args, options) + build!(ctx) + # Call a function to ensure build is successful + anonmod = Module() + Base.include(anonmod, "LibBitField.jl") + m = Base.@invokelatest anonmod.LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3) + Base.@invokelatest anonmod.LibBitField.toBitfield(Ref(m)) + + return anonmod +end @testset "Bitfield" begin if build_libbitfield() - bf = Ref(LibBitField.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3))) - m = Ref(LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)) - GC.@preserve bf m begin - pbf = Ptr{LibBitField.BitField}(pointer_from_objref(bf)) - pm = Ptr{LibBitField.Mirror}(pointer_from_objref(m)) - @test LibBitField.toMirror(bf) == m[] - @test LibBitField.toBitfield(m).a == bf[].a - @test LibBitField.toBitfield(m).b == bf[].b - @test LibBitField.toBitfield(m).c == bf[].c - @test LibBitField.toBitfield(m).d == bf[].d - @test LibBitField.toBitfield(m).e == bf[].e - @test LibBitField.toBitfield(m).f == bf[].f + # Test the wrappers with and without auto-dereferencing. In the case of + # bitfields they should have identical behaviour. + for auto_deref in [false, true] + anonmod = generate_wrappers(auto_deref) + lib = anonmod.LibBitField + + bf = Ref(lib.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3))) + m = Ref(lib.Mirror(10, 1.5, 1e6, -4, 7, 3)) + + GC.@preserve bf m begin + pbf = Ptr{lib.BitField}(pointer_from_objref(bf)) + pm = Ptr{lib.Mirror}(pointer_from_objref(m)) + @test lib.toMirror(bf) == m[] + @test lib.toBitfield(m).a == bf[].a + @test lib.toBitfield(m).b == bf[].b + @test lib.toBitfield(m).c == bf[].c + @test lib.toBitfield(m).d == bf[].d + @test lib.toBitfield(m).e == bf[].e + @test lib.toBitfield(m).f == bf[].f + end end end end