Skip to content

Commit

Permalink
Update to new alloc cache
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jan 8, 2025
1 parent 79a0076 commit 6f7c7a5
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 13 deletions.
2 changes: 2 additions & 0 deletions ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using GaussianSplatting
# using Statistics
# using Zygote

GaussianSplatting.base_array_type(::ROCBackend) = ROCArray

GaussianSplatting.use_ak(::ROCBackend) = true

function GaussianSplatting.allocate_pinned(kab, ::Type{T}, shape) where T
Expand Down
2 changes: 2 additions & 0 deletions ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using GaussianSplatting
# using Statistics
# using Zygote

GaussianSplatting.base_array_type(::CUDABackend) = CuArray

function GaussianSplatting.allocate_pinned(::CUDABackend, ::Type{T}, shape) where T
x = Array{T}(undef, shape)
buf = CUDA.register(CUDA.HostMemory, pointer(x), sizeof(x),
Expand Down
2 changes: 2 additions & 0 deletions src/GaussianSplatting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ include("gui/gui.jl")
# Hacky way to get KA.Backend.
gpu_backend() = get_backend(Flux.gpu(Array{Int}(undef, 0)))

base_array_type(backend) = error("Not implemented for backend: `$backend`.")

allocate_pinned(kab, T, shape) = error("Pinned memory not supported for `$kab`.")

unpin_memory(x) = error("Unpinning memory is not supported for `$(typeof(x))`.")
Expand Down
2 changes: 1 addition & 1 deletion src/rasterization/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function project(

if length(rast.gstate) < n
KA.unsafe_free!(rast.gstate)
rast.gstate = GPUArrays.AllocCache.@disable GeometryState(kab, n; extended=rast.mode == :rgbd)
rast.gstate = GPUArrays.@disable GeometryState(kab, n; extended=rast.mode == :rgbd)
end

project!(kab)(
Expand Down
4 changes: 2 additions & 2 deletions src/rasterization/rasterizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ function rasterize(
n = size(means_3d, 2)
if length(rast.gstate) < n
KA.unsafe_free!(rast.gstate)
rast.gstate = GPUArrays.AllocCache.@disable GeometryState(kab, n; extended=render_depth)
rast.gstate = GPUArrays.@disable GeometryState(kab, n; extended=render_depth)
end

(; width, height) = resolution(camera)
Expand Down Expand Up @@ -291,7 +291,7 @@ function rasterize(

if length(rast.bstate) < n_rendered
KA.unsafe_free!(rast.bstate)
rast.bstate = GPUArrays.AllocCache.@disable BinningState(kab, n_rendered)
rast.bstate = GPUArrays.@disable BinningState(kab, n_rendered)
end

# For each instance to be rendered, produce [tile | depth] key
Expand Down
2 changes: 1 addition & 1 deletion src/rasterization/render.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function render(

if length(rast.bstate) < n_rendered
KA.unsafe_free!(rast.bstate)
rast.bstate = GPUArrays.AllocCache.@disable BinningState(kab, n_rendered)
rast.bstate = GPUArrays.@disable BinningState(kab, n_rendered)
end

# For each instance to be rendered, produce [tile | depth] key
Expand Down
17 changes: 8 additions & 9 deletions src/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mutable struct Trainer{
G <: GaussianModel,
D <: ColmapDataset,
S <: SSIM,
C <: GPUArrays.AllocCache,
F,
O,
}
Expand All @@ -13,6 +14,8 @@ mutable struct Trainer{
optimizers::O
ssim::S

cache::C

points_lr_scheduler::F
opt_params::OptimizationParams

Expand All @@ -25,14 +28,12 @@ function Trainer(
rast::GaussianRasterizer, gs::GaussianModel,
dataset::ColmapDataset, opt_params::OptimizationParams;
)
# If we are going to use trainer, invalidate its alloc cache.
AT = typeof(gs.points)
GPUArrays.AllocCache.invalidate!(AT, :train_step)

ϵ = 1f-15
kab = get_backend(gs)
camera_extent = dataset.camera_extent

cache = GPUArrays.AllocCache(base_array_type(kab))

optimizers = (;
points=NU.Adam(kab, gs.points; lr=opt_params.lr_points_start * camera_extent, ϵ),
features_dc=NU.Adam(kab, gs.features_dc; lr=opt_params.lr_feature, ϵ),
Expand All @@ -51,7 +52,7 @@ function Trainer(
densify = true
step = 0
Trainer(
rast, gs, dataset, optimizers, ssim,
rast, gs, dataset, optimizers, ssim, cache,
points_lr_scheduler, opt_params, densify, step, ids)
end

Expand Down Expand Up @@ -181,9 +182,7 @@ function step!(trainer::Trainer)
gs.opacities, gs.scales, gs.rotations)

kab = get_backend(rast)
AT = typeof(gs.points)

GPUArrays.AllocCache.@enable AT :train_step begin
GPUArrays.@enable trainer.cache begin
loss, ∇ = Zygote.withgradient(
θ...,
) do means_3d, features_dc, features_rest, opacities, scales, rotations
Expand Down Expand Up @@ -221,7 +220,7 @@ function step!(trainer::Trainer)
trainer.step params.densify_from_iter &&
trainer.step % params.densification_interval == 0
if do_densify
GPUArrays.AllocCache.invalidate!(AT, :train_step)
GPUArrays.unsafe_free!(trainer.cache)

max_screen_size::Int32 =
trainer.step > params.opacity_reset_interval ? 20 : 0
Expand Down

0 comments on commit 6f7c7a5

Please sign in to comment.