Skip to content

Commit

Permalink
support Bool
Browse files Browse the repository at this point in the history
  • Loading branch information
Moelf committed Oct 17, 2024
1 parent 0719ae5 commit 7e20fae
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/RNTuple/Writing/constants.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const RNTUPLE_WRITE_TYPE_IDX_DICT = Dict(
Index64 => (0x01, sizeof(Index64) * 8),
Index32 => (0x02, sizeof(Index32) * 8),
Bool => (0x06, 1),
Float64 => (0x10, sizeof(UInt64) * 8),
Float32 => (0x11, sizeof(UInt32) * 8),
Float16 => (0x12, sizeof(UInt16) * 8),
Expand All @@ -14,6 +15,7 @@ const RNTUPLE_WRITE_TYPE_IDX_DICT = Dict(
)

const RNTUPLE_WRITE_TYPE_CPPNAME_DICT = Dict(
Bool => "bool",
Float16 => "std::float16_t",
Float32 => "float",
Float64 => "double",
Expand Down
5 changes: 5 additions & 0 deletions src/RNTuple/Writing/page_writing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ Turns an AbstractVector into a page of an RNTuple. The element type must be prim
"""
function rnt_ary_to_page(ary::AbstractVector, cr::ColumnRecord) end

function rnt_ary_to_page(ary::AbstractVector{Bool}, cr::ColumnRecord)
chunks = BitVector(ary).chunks
Page_write(reinterpret(UInt8, chunks))
end

function rnt_ary_to_page(ary::AbstractVector{Float64}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Expand Down
3 changes: 2 additions & 1 deletion src/RNTuple/fieldcolumn_reading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ function read_field(io, field::LeafField{Bool}, page_list)

# pad to nearest 8*k bytes because each chunk needs to be UInt64
bytes = read_pagedesc(io, pages, cr)
append!(bytes, zeros(eltype(bytes), (64 - rem(total_num_elements, 64))÷8))
N_pad = 8 - mod1(length(bytes), 8)
append!(bytes, zeros(eltype(bytes), N_pad))
chunks = reinterpret(UInt64, bytes)

res = BitVector(undef, total_num_elements)
Expand Down
8 changes: 5 additions & 3 deletions test/RNTupleWriting/lowlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ end
@test field_ids == [0,1,2,3,4,5]
end

const RNT_primitive_Ts = [Bool, Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16]

@testset "RNTuple Writing - Single colunm round trips" begin
for _ = 1:10, T in [Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16]
for _ = 1:10, T in RNT_primitive_Ts
newtable = Dict(randstring(rand(2:10)) => rand(T, rand(1:100)))
newio = IOBuffer()
UnROOT.write_rntuple(newio, newtable)
Expand All @@ -260,7 +262,7 @@ end
end

@testset "RNTuple Writing - Multiple colunm round trips" begin
Ts = rand([Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16], 15)
Ts = rand(RNT_primitive_Ts, 15)
Nitems = rand(10:1000)
newtable = Dict(randstring(rand(2:10)) => rand(T, Nitems) for T in Ts)
newio = IOBuffer()
Expand All @@ -284,7 +286,7 @@ end
end

@testset "RNTuple Writing - Vector colunms" begin
Ts = rand([Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16], 15)
Ts = rand(RNT_primitive_Ts, 15)
inner_Nitems = [3,4,0,0,1,2]
newtable = Dict(randstring(rand(2:10)) => [rand(T, Nitems) for Nitems in inner_Nitems] for T in Ts)
newio = IOBuffer()
Expand Down

0 comments on commit 7e20fae

Please sign in to comment.