diff --git a/ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl b/ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl index f7c3158..90d2fb2 100644 --- a/ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl +++ b/ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl @@ -8,6 +8,8 @@ using GaussianSplatting # using Statistics # using Zygote +GaussianSplatting.base_array_type(::ROCBackend) = ROCArray + GaussianSplatting.use_ak(::ROCBackend) = true function GaussianSplatting.allocate_pinned(kab, ::Type{T}, shape) where T diff --git a/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl b/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl index 29d8635..ee0761e 100644 --- a/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl +++ b/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl @@ -9,6 +9,8 @@ using GaussianSplatting # using Statistics # using Zygote +GaussianSplatting.base_array_type(::CUDABackend) = CuArray + function GaussianSplatting.allocate_pinned(::CUDABackend, ::Type{T}, shape) where T x = Array{T}(undef, shape) buf = CUDA.register(CUDA.HostMemory, pointer(x), sizeof(x), diff --git a/src/GaussianSplatting.jl b/src/GaussianSplatting.jl index 29e5437..060fc94 100644 --- a/src/GaussianSplatting.jl +++ b/src/GaussianSplatting.jl @@ -70,6 +70,8 @@ include("gui/gui.jl") # Hacky way to get KA.Backend. gpu_backend() = get_backend(Flux.gpu(Array{Int}(undef, 0))) +base_array_type(backend) = error("Not implemented for backend: `$backend`.") + allocate_pinned(kab, T, shape) = error("Pinned memory not supported for `$kab`.") unpin_memory(x) = error("Unpinning memory is not supported for `$(typeof(x))`.") diff --git a/src/rasterization/projection.jl b/src/rasterization/projection.jl index b31f729..0b5d672 100644 --- a/src/rasterization/projection.jl +++ b/src/rasterization/projection.jl @@ -23,7 +23,7 @@ function project( if length(rast.gstate) < n KA.unsafe_free!(rast.gstate) - rast.gstate = GPUArrays.AllocCache.@disable GeometryState(kab, n; extended=rast.mode == :rgbd) + rast.gstate = GPUArrays.@disable GeometryState(kab, n; extended=rast.mode == :rgbd) end project!(kab)( diff --git a/src/rasterization/rasterizer.jl b/src/rasterization/rasterizer.jl index 533e267..b9bf3b8 100644 --- a/src/rasterization/rasterizer.jl +++ b/src/rasterization/rasterizer.jl @@ -226,7 +226,7 @@ function rasterize( n = size(means_3d, 2) if length(rast.gstate) < n KA.unsafe_free!(rast.gstate) - rast.gstate = GPUArrays.AllocCache.@disable GeometryState(kab, n; extended=render_depth) + rast.gstate = GPUArrays.@disable GeometryState(kab, n; extended=render_depth) end (; width, height) = resolution(camera) @@ -291,7 +291,7 @@ function rasterize( if length(rast.bstate) < n_rendered KA.unsafe_free!(rast.bstate) - rast.bstate = GPUArrays.AllocCache.@disable BinningState(kab, n_rendered) + rast.bstate = GPUArrays.@disable BinningState(kab, n_rendered) end # For each instance to be rendered, produce [tile | depth] key diff --git a/src/rasterization/render.jl b/src/rasterization/render.jl index eaa0178..6abfe03 100644 --- a/src/rasterization/render.jl +++ b/src/rasterization/render.jl @@ -35,7 +35,7 @@ function render( if length(rast.bstate) < n_rendered KA.unsafe_free!(rast.bstate) - rast.bstate = GPUArrays.AllocCache.@disable BinningState(kab, n_rendered) + rast.bstate = GPUArrays.@disable BinningState(kab, n_rendered) end # For each instance to be rendered, produce [tile | depth] key diff --git a/src/training.jl b/src/training.jl index da3c95d..cf4bac0 100644 --- a/src/training.jl +++ b/src/training.jl @@ -4,6 +4,7 @@ mutable struct Trainer{ G <: GaussianModel, D <: ColmapDataset, S <: SSIM, + C <: GPUArrays.AllocCache, F, O, } @@ -13,6 +14,8 @@ mutable struct Trainer{ optimizers::O ssim::S + cache::C + points_lr_scheduler::F opt_params::OptimizationParams @@ -25,14 +28,12 @@ function Trainer( rast::GaussianRasterizer, gs::GaussianModel, dataset::ColmapDataset, opt_params::OptimizationParams; ) - # If we are going to use trainer, invalidate its alloc cache. - AT = typeof(gs.points) - GPUArrays.AllocCache.invalidate!(AT, :train_step) - ϵ = 1f-15 kab = get_backend(gs) camera_extent = dataset.camera_extent + cache = GPUArrays.AllocCache(base_array_type(kab)) + optimizers = (; points=NU.Adam(kab, gs.points; lr=opt_params.lr_points_start * camera_extent, ϵ), features_dc=NU.Adam(kab, gs.features_dc; lr=opt_params.lr_feature, ϵ), @@ -51,7 +52,7 @@ function Trainer( densify = true step = 0 Trainer( - rast, gs, dataset, optimizers, ssim, + rast, gs, dataset, optimizers, ssim, cache, points_lr_scheduler, opt_params, densify, step, ids) end @@ -181,9 +182,7 @@ function step!(trainer::Trainer) gs.opacities, gs.scales, gs.rotations) kab = get_backend(rast) - AT = typeof(gs.points) - - GPUArrays.AllocCache.@enable AT :train_step begin + GPUArrays.@enable trainer.cache begin loss, ∇ = Zygote.withgradient( θ..., ) do means_3d, features_dc, features_rest, opacities, scales, rotations @@ -221,7 +220,7 @@ function step!(trainer::Trainer) trainer.step ≥ params.densify_from_iter && trainer.step % params.densification_interval == 0 if do_densify - GPUArrays.AllocCache.invalidate!(AT, :train_step) + GPUArrays.unsafe_free!(trainer.cache) max_screen_size::Int32 = trainer.step > params.opacity_reset_interval ? 20 : 0