Skip to content

Commit

Permalink
[RNTuple] support more types in RNTuple writing (#348)
Browse files Browse the repository at this point in the history
* support more types in RNTuple writing

* restore test

* add missing file

* restructure files
  • Loading branch information
Moelf authored Sep 27, 2024
1 parent 8a28452 commit 0679f73
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 35 deletions.
36 changes: 15 additions & 21 deletions src/RNTuple/Writing/TFileWriter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,6 @@ function rnt_write_observe(io::IO, x::T) where T
WriteObservable(io, pos, len, x)
end

function split4_encode(src::AbstractVector{UInt8})
@views [src[1:4:end-3]; src[2:4:end-2]; src[3:4:end-1]; src[4:4:end]]
end

function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rntuple_name="myntuple")
if !istable(table)
error("RNTuple writing accepts object compatible with Tables.jl interface, got type $(typeof(table))")
Expand All @@ -488,16 +484,12 @@ function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rn
error("Currently, RNTuple writing only supports a single, UInt32 column, got $input_Ncols columns")
end
input_T = only(input_schema.types)
if input_T != UInt32
error("Currently, RNTuple writing only supports a single, UInt32 column, got type $input_T")
end
input_col = only(columntable(table))
input_length = length(input_col)
if input_length > 65535
error("Input too long: RNTuple writing currently only supports a single page (65535 elements)")
end


rntAnchor_update = Dict{Symbol, Any}()

file_preamble_obs = rnt_write_observe(file, Stubs.file_preamble)
Expand All @@ -516,55 +508,55 @@ function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rn
RBlob1_obs = rnt_write_observe(file, Stubs.RBlob1)
rntAnchor_update[:fSeekHeader] = UInt32(position(file))
rnt_header = UnROOT.RNTupleHeader(zero(UInt64), rntuple_name, "", "ROOT v6.33.01", [
UnROOT.FieldRecord(zero(UInt32), zero(UInt32), zero(UInt32), zero(UInt16), zero(UInt16), 0, -1, -1, string(only(input_schema.names)), "std::uint32_t", "", ""),
], [UnROOT.ColumnRecord(0x14, 0x20, zero(UInt32), 0x00, 0x00, 0),], UnROOT.AliasRecord[], UnROOT.ExtraTypeInfo[])
UnROOT.FieldRecord(zero(UInt32), zero(UInt32), zero(UInt32), zero(UInt16), zero(UInt16), 0, -1, -1, string(only(input_schema.names)), RNTUPLE_WRITE_TYPE_CPPNAME_DICT[input_T], "", ""),
], [UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[input_T]..., zero(UInt32), 0x00, 0x00, 0),], UnROOT.AliasRecord[], UnROOT.ExtraTypeInfo[])

rnt_header_obs = rnt_write_observe(file, rnt_header)
rntAnchor_update[:fNBytesHeader] = rnt_header_obs.len
rntAnchor_update[:fLenHeader] = rnt_header_obs.len

RBlob2_obs = rnt_write_observe(file, Stubs.RBlob2)
page1 = reinterpret(UInt8, input_col)
page1_bytes = Page_write(split4_encode(page1))
page1_position = position(file)
page1_obs = rnt_write_observe(file, page1_bytes)
page1 = rnt_ary_to_page(input_col)
page1_obs = rnt_write_observe(file, page1)

RBlob3_obs = rnt_write_observe(file, Stubs.RBlob3)
cluster_summary = Write_RNTupleListFrame([ClusterSummary(0, input_length)])
nested_page_locations =
UnROOT.RNTuplePageTopList([
UnROOT.RNTuplePageOuterList([
UnROOT.RNTuplePageInnerList([
PageDescription(input_length, UnROOT.Locator(sizeof(input_T) * input_length, page1_position, )),
PageDescription(input_length, UnROOT.Locator(sizeof(input_T) * input_length, page1_obs.position, )),
]),
]),
])

# stub checksum 0x3dec59c009c67e28
pagelink = UnROOT.PageLink(_checksum(rnt_header_obs.object), cluster_summary.payload, nested_page_locations)
pagelink_position = position(file)
pagelink_obs = rnt_write_observe(file, pagelink)

RBlob4_obs = rnt_write_observe(file, Stubs.RBlob4)
rntAnchor_update[:fSeekFooter] = UInt32(position(file))
rnt_footer = UnROOT.RNTupleFooter(0, _checksum(rnt_header_obs.object), UnROOT.RNTupleSchemaExtension([], [], [], []), [], [
UnROOT.ClusterGroupRecord(0, input_length, 1, UnROOT.EnvLink(0x000000000000007c, UnROOT.Locator(124, pagelink_position, ))),
UnROOT.ClusterGroupRecord(0, input_length, 1, UnROOT.EnvLink(pagelink_obs.len, UnROOT.Locator(pagelink_obs.len, pagelink_obs.position, ))),
])
rnt_footer_obs = rnt_write_observe(file, rnt_footer)
rntAnchor_update[:fNBytesFooter] = 0xA0
rntAnchor_update[:fLenFooter] = 0xA0
rntAnchor_update[:fNBytesFooter] = rnt_footer_obs.len
rntAnchor_update[:fLenFooter] = rnt_footer_obs.len

tkey32_anchor_position = position(file)
tkey32_anchor = UnROOT.TKey32(0x0000008E, 4, 0x0000004E, Stubs.WRITE_TIME, 64, 1, tkey32_anchor_position, 100, "ROOT::Experimental::RNTuple", rntuple_name, "")
tkey32_anchor = UnROOT.TKey32(0x0000008E, 4, typemin(Int32), Stubs.WRITE_TIME, 64, 1, tkey32_anchor_position, 100, "ROOT::Experimental::RNTuple", rntuple_name, "")
tkey32_anchor_obs1 = rnt_write_observe(file, tkey32_anchor)
tkey32_anchor_update = Dict{Symbol, Any}()
magic_6bytes_obs = rnt_write_observe(file, Stubs.magic_6bytes)
rnt_anchor_obs = rnt_write_observe(file, Stubs.rnt_anchor)
Base.setindex!(rnt_anchor_obs, rntAnchor_update)
tkey32_anchor_update[:fObjlen] = rnt_anchor_obs.len + magic_6bytes_obs.len
Base.setindex!(tkey32_anchor_obs1, tkey32_anchor_update)

tdirectory32_obs[:fSeekKeys] = UInt32(position(file))
tkey32_TDirectory_obs = rnt_write_observe(file, Stubs.tkey32_TDirectory)
n_keys_obs = rnt_write_observe(file, Stubs.n_keys)
tkey32_anchor_obs2 = rnt_write_observe(file, tkey32_anchor)
Base.setindex!(tkey32_anchor_obs2, tkey32_anchor_update)

fileheader_obs[:fSeekInfo] = UInt32(position(file))
tkey32_TStreamerInfo_obs = rnt_write_observe(file, Stubs.tkey32_TStreamerInfo)
Expand All @@ -573,6 +565,8 @@ function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rn
tfile_end_obs = rnt_write_observe(file, Stubs.tfile_end)
fileheader_obs[:fEND] = UInt32(position(file))

flush!(tkey32_anchor_obs1)
flush!(tkey32_anchor_obs2)
flush!(tkey32_tfile_obs)
flush!(tdirectory32_obs)
flush!(fileheader_obs)
Expand Down
25 changes: 25 additions & 0 deletions src/RNTuple/Writing/constants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
const RNTUPLE_WRITE_TYPE_IDX_DICT = Dict(
Float64 => (0x10, sizeof(UInt64) * 8),
Float32 => (0x11, sizeof(UInt32) * 8),
Float16 => (0x12, sizeof(UInt16) * 8),
UInt64 => (0x13, sizeof(UInt64) * 8),
UInt32 => (0x14, sizeof(UInt32) * 8),
UInt16 => (0x15, sizeof(UInt16) * 8),
Int64 => (0x16, sizeof(Int64) * 8),
Int32 => (0x17, sizeof(Int32) * 8),
Int16 => (0x18, sizeof(Int16) * 8),
Int8 => (0x19, sizeof(Int8) * 8),
)

const RNTUPLE_WRITE_TYPE_CPPNAME_DICT = Dict(
Float16 => "std::float16_t",
Float32 => "std::float32_t",
Float64 => "std::float64_t",
Int8 => "std::int8_t",
Int16 => "std::int16_t",
Int32 => "std::int32_t",
Int64 => "std::int64_t",
UInt16 => "std::uint16_t",
UInt32 => "std::uint32_t",
UInt64 => "std::uint64_t",
)
66 changes: 66 additions & 0 deletions src/RNTuple/Writing/page_writing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
rnt_ary_to_page(ary::AbstractVector) end
Turns an AbstractVector into a page of an RNTuple. The element type must be primitive for this to work.
"""
function rnt_ary_to_page(ary::AbstractVector) end

function rnt_ary_to_page(ary::AbstractVector{Float64})
Page_write(split8_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{Float32})
Page_write(split4_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{Float16})
Page_write(split2_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{UInt64})
Page_write(split8_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{UInt32})
Page_write(split4_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{UInt16})
Page_write(split2_encode(reinterpret(UInt8, ary)))
end

function rnt_ary_to_page(ary::AbstractVector{Int64})
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int32})
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int16})
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int8})
Page_write(reinterpret(UInt8, ary))
end

function split8_encode(src::AbstractVector{UInt8})
@views [src[1:8:end-7]; src[2:8:end-6]; src[3:8:end-5]; src[4:8:end-4]; src[5:8:end-3]; src[6:8:end-2]; src[7:8:end-1]; src[8:8:end]]
end
function split4_encode(src::AbstractVector{UInt8})
@views [src[1:4:end-3]; src[2:4:end-2]; src[3:4:end-1]; src[4:4:end]]
end
function split2_encode(src::AbstractVector{UInt8})
@views [src[1:2:end-1]; src[2:2:end]]
end

_to_zigzag(n) = (n << 1) (n >> (sizeof(n)*8-1))
function _to_zigzag(res::AbstractVector)
out = similar(res)
@simd for i in eachindex(out, res)
out[i] = _to_zigzag(res[i])
end
return out
end
16 changes: 4 additions & 12 deletions src/RNTuple/fieldcolumn_reading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,14 @@ function read_field(io, field::RNTupleCardinality{T}, page_list) where T
return res::_field_output_type(field)
end

_from_zigzag(n) = (n >> 1) -(n & 1)
_to_zigzag(n) = (n << 1) (n >> 63)
_from_zigzag(n) = (n >> 1) (-(n & 1))
function _from_zigzag!(res::AbstractVector)
@simd for i in eachindex(res)
res[i] = _from_zigzag(res[i])
end
return res
end

function _to_zigzag!(res::AbstractVector)
@simd for i in eachindex(res)
res[i] = _to_zigzag(res[i])
end
return res
end

_field_output_type(::Type{LeafField{T}}) where {T} = Vector{T}
function read_field(io, field::LeafField{T}, page_list) where T
cr = field.columnrecord
Expand Down Expand Up @@ -193,7 +185,7 @@ function _detect_encoding(typenum)
split = 14 <= typenum <= 21 || 26 <= typenum <= 28
zigzag = 26 <= typenum <= 28
delta = 14 <= typenum <= 15
return split, zigzag, delta
return (;split, zigzag, delta)
end

"""
Expand All @@ -208,7 +200,7 @@ column since `pagedesc` only contains `num_elements` information.
"""
function read_pagedesc(io, pagedescs::AbstractVector{PageDescription}, cr::ColumnRecord)
nbits = cr.nbits
split, zigzag, delta = _detect_encoding(cr.type)
(;split, zigzag, delta) = _detect_encoding(cr.type)
list_num_elements = [-p.num_elements for p in pagedescs]
total_num_elements = sum(list_num_elements)
if any(<(0), total_num_elements)
Expand All @@ -221,7 +213,7 @@ function read_pagedesc(io, pagedescs::AbstractVector{PageDescription}, cr::Colum
tmp = Vector{UInt8}(undef, 65536)

tip = 1
for i in eachindex(pagedescs)
for i in eachindex(list_num_elements, pagedescs)
pagedesc = pagedescs[i]
# when nbits == 1 for bits, need RoundUp
uncomp_size = div(list_num_elements[i] * nbits, 8, RoundUp)
Expand Down
2 changes: 2 additions & 0 deletions src/UnROOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ include("RNTuple/highlevel.jl")
include("RNTuple/fieldcolumn_reading.jl")
include("RNTuple/displays.jl")

include("RNTuple/Writing/constants.jl")
include("RNTuple/Writing/page_writing.jl")
include("RNTuple/Writing/TFileWriter.jl")
include("RNTuple/Writing/Stubs.jl")

Expand Down
6 changes: 4 additions & 2 deletions test/RNTupleWriting/lowlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,11 @@ UnROOT.write_rntuple(myio, mytable; rntuple_name="myntuple")
mio = take!(myio)
write("/tmp/mine.root", mio)
@test MINE == mio
end

for _ = 1:100
newtable = Dict(randstring(rand(2:10)) => rand(UInt32, rand(1:1000)))
@testset "RNTuple Writing - Single colunm round trips" begin
for _ = 1:50, T in [Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16]
newtable = Dict(randstring(rand(2:10)) => rand(T, rand(1:1000)))
newio = IOBuffer()
UnROOT.write_rntuple(newio, newtable)
nio = take!(newio)
Expand Down

0 comments on commit 0679f73

Please sign in to comment.