-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Grid Sampling for 3D images. #627
base: master
Are you sure you want to change the base?
Conversation
Thanks for putting this together! The failing integration tests seem unrelated, RNNCell and InstanceNorm timed out. |
src/sampling.jl
Outdated
# abc are the index of the vertex of the cube (001,010...) | ||
|
||
# Initialize gradient accumulators | ||
gix, giy, giz = 0.0, 0.0, 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be zero(V), ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing me out.
src/sampling.jl
Outdated
|
||
# ∀ channel: Calculate trilinear weighted voxel value. | ||
@inbounds for c in 1:iC | ||
r = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r should be zero(T)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right changing to zero(T) can make the function more type stable.
Oh and here is some barebones code you can play with to visualize the gradients Click meusing NNlib: grid_sample, ∇grid_sample
using GLMakie
in_size = 32
out_size = 65
xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)
gauss(x) = exp(-x^2)
input = gauss.(hypot.(xr', yr))
input = input[:,:,:,:]
heatmap(input[:,:,1,1])
v = collect(range(-1.3, 1.3, length=out_size))
xgrid = repeat(v, 1, out_size)
ygrid = repeat(v', out_size)
grid = reshape(stack((xgrid, ygrid), dims=1), 2, out_size, out_size, 1)
# make grid non-uniform so that we can actually see a derivative wrt input
scaler = hypot.(xgrid, ygrid)
grid = grid .* stack((scaler, scaler); dims=1)[:,:,:,:]
Δ = zeros(eltype(input), size(grid)[2:end-1]..., size(input)[end-1:end]...) .+ 1
resampled = grid_sample(input[:,:,:,:], grid)
heatmap(resampled[:,:,1,1])
begin
# dx, dgrid = ∇grid_sample(Δ, Float32.(input)[:,:,:,:], grid);
dx, dgrid = ∇grid_sample(Δ, input[:,:,:,:], grid);
# heatmap(dx[:,:,1,1])
heatmap(dgrid[1,:,:,1])
end
xr_grid = range(-2, 2, length=out_size)
yr_grid = range(-1, 1, length=out_size)
mult = 0.2
arrows(xr_grid, yr_grid, mult .* dgrid[1,:,:,1], mult .* dgrid[2,:,:,1]; axis=(limits=(extrema(xr_grid), extrema(yr_grid)),))
#%%
# 3D
in_size = 32
out_size = 65
xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)
zr = range(-1, 1, length=in_size)
input3d = zeros(in_size, in_size, in_size, 1, 1)
# Populate the tensor with corresponding values
for i in 1:in_size
for j in 1:in_size
for k in 1:in_size
input3d[k, j, i, 1, 1] = gauss(hypot(xr[i], yr[j], zr[k]))
end
end
end
v = collect(range(-1.3, 1.3, length=out_size))
grid3d = zeros(3, out_size, out_size, out_size, 1)
for i in 1:out_size
for j in 1:out_size
for k in 1:out_size
grid3d[:,k,j,i,1] .= (v[k], v[j], v[i]) .* hypot(v[k], v[j], v[i])
end
end
end
function plot_volume(x,y,z,vol)
fig = Figure()
ax = LScene(fig[1, 1], show_axis=true)
sgrid = SliderGrid(
fig[2, 1],
(label = "yz plane - x axis", range = 1:length(x)),
(label = "xz plane - y axis", range = 1:length(y)),
(label = "xy plane - z axis", range = 1:length(z)),
)
lo = sgrid.layout
nc = ncols(lo)
plt = volumeslices!(ax, x, y, z, vol)
# connect sliders to `volumeslices` update methods
sl_yz, sl_xz, sl_xy = sgrid.sliders
on(sl_yz.value) do v; plt[:update_yz][](v) end
on(sl_xz.value) do v; plt[:update_xz][](v) end
on(sl_xy.value) do v; plt[:update_xy][](v) end
set_close_to!(sl_yz, .5length(x))
set_close_to!(sl_xz, .5length(y))
set_close_to!(sl_xy, .5length(z))
# add toggles to show/hide heatmaps
hmaps = [plt[Symbol(:heatmap_, s)][] for s ∈ (:yz, :xz, :xy)]
toggles = [Toggle(lo[i, nc + 1], active = true) for i ∈ 1:length(hmaps)]
map(zip(hmaps, toggles)) do (h, t)
connect!(h.visible, t.active)
end
# cam3d!(ax.scene, projectiontype=Makie.Orthographic)
display(fig)
end
plot_volume(xr, yr, zr, input3d)
resampled3d = grid_sample(input3d, grid3d)
xr_grid = range(-2, 2, length=out_size) .* 1.3 .* hypot(1,1,2)
yr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)
zr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)
plot_volume(xr_grid, yr_grid, zr_grid, resampled3d)
Δ3d = ones(out_size, out_size, out_size, 1, 1)
dx3d, dgrid3d = ∇grid_sample(Δ3d, input3d, grid3d; padding_mode=:zeros)
plot_volume(xr, yr, zr, dx3d)
mult = 0.1
points = [Point3(x,y,z) for x in xr_grid for y in yr_grid for z in zr_grid]
normals = [Point3(mult .* dgrid3d[:,x,y,z,1]) for x in 1:size(dgrid3d,2) for y in 1:size(dgrid3d,3) for z in 1:size(dgrid3d,4)]
lengths = [hypot(n...) for n in normals]
gz = lengths .> 0
arrows(points[gz][1:47:end], normals[gz][1:47:end];
linewidth=0.03,
color=lengths[gz][1:47:end],
arrowsize=Vec3f(0.03, 0.03, 0.04))
asd = reshape(lengths, out_size, out_size, out_size)
plot_volume(xr_grid, yr_grid, zr_grid, asd)
#%%
resampled
resampled3d
extrema(resampled3d[:,33,:,1,1] - resampled[:,:,1,1])
isapprox(resampled3d[:,33,:,1,1], resampled[:,:,1,1]; rtol=1e-2) |
Thanks. This visualization helps a lot! |
[Current
grid_sampling
function only accept 4D input, which is suitable for 2D images. Adding function for 5D input to support 3D images.]PR Checklist