Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Grid Sampling for 3D images. #627

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

gRox167
Copy link

@gRox167 gRox167 commented Jan 20, 2025

[Current grid_sampling function only accept 4D input, which is suitable for 2D images. Adding function for 5D input to support 3D images.]

PR Checklist

  • Tests are added
  • Documentation, if applicable

@gRox167 gRox167 mentioned this pull request Jan 20, 2025
@maxfreu
Copy link
Contributor

maxfreu commented Jan 21, 2025

Thanks for putting this together! The failing integration tests seem unrelated, RNNCell and InstanceNorm timed out.

src/sampling.jl Outdated
# abc are the index of the vertex of the cube (001,010...)

# Initialize gradient accumulators
gix, giy, giz = 0.0, 0.0, 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be zero(V), ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing me out.

src/sampling.jl Outdated

# ∀ channel: Calculate trilinear weighted voxel value.
@inbounds for c in 1:iC
r = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r should be zero(T)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right changing to zero(T) can make the function more type stable.

@maxfreu
Copy link
Contributor

maxfreu commented Jan 21, 2025

Oh and here is some barebones code you can play with to visualize the gradients

Click me
using NNlib: grid_sample, ∇grid_sample
using GLMakie

in_size = 32
out_size = 65

xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)

gauss(x) = exp(-x^2)

input = gauss.(hypot.(xr', yr))
input = input[:,:,:,:]
heatmap(input[:,:,1,1])

v = collect(range(-1.3, 1.3, length=out_size))
xgrid = repeat(v, 1, out_size)
ygrid = repeat(v',  out_size)
grid = reshape(stack((xgrid, ygrid), dims=1), 2, out_size, out_size, 1)

# make grid non-uniform so that we can actually see a derivative wrt input
scaler = hypot.(xgrid, ygrid)
grid = grid .* stack((scaler, scaler); dims=1)[:,:,:,:]

Δ = zeros(eltype(input), size(grid)[2:end-1]..., size(input)[end-1:end]...) .+ 1

resampled = grid_sample(input[:,:,:,:], grid)

heatmap(resampled[:,:,1,1])

begin
# dx, dgrid = ∇grid_sample(Δ, Float32.(input)[:,:,:,:], grid);
dx, dgrid = ∇grid_sample(Δ, input[:,:,:,:], grid);
# heatmap(dx[:,:,1,1])
heatmap(dgrid[1,:,:,1])
end

xr_grid = range(-2, 2, length=out_size)
yr_grid = range(-1, 1, length=out_size)

mult = 0.2
arrows(xr_grid, yr_grid, mult .* dgrid[1,:,:,1], mult .* dgrid[2,:,:,1]; axis=(limits=(extrema(xr_grid), extrema(yr_grid)),))


#%%
# 3D
in_size = 32
out_size = 65

xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)
zr = range(-1, 1, length=in_size)

input3d = zeros(in_size, in_size, in_size, 1, 1)

# Populate the tensor with corresponding values
for i in 1:in_size
  for j in 1:in_size
      for k in 1:in_size
          input3d[k, j, i, 1, 1] = gauss(hypot(xr[i], yr[j], zr[k]))
      end
  end
end

v = collect(range(-1.3, 1.3, length=out_size))

grid3d = zeros(3, out_size, out_size, out_size, 1)

for i in 1:out_size
  for j in 1:out_size
      for k in 1:out_size
          grid3d[:,k,j,i,1] .= (v[k], v[j], v[i]) .* hypot(v[k], v[j], v[i])
      end
  end
end


function plot_volume(x,y,z,vol)
  fig = Figure()
  ax = LScene(fig[1, 1], show_axis=true)

  sgrid = SliderGrid(
      fig[2, 1],
      (label = "yz plane - x axis", range = 1:length(x)),
      (label = "xz plane - y axis", range = 1:length(y)),
      (label = "xy plane - z axis", range = 1:length(z)),
  )

  lo = sgrid.layout
  nc = ncols(lo)

  plt = volumeslices!(ax, x, y, z, vol)

  # connect sliders to `volumeslices` update methods
  sl_yz, sl_xz, sl_xy = sgrid.sliders

  on(sl_yz.value) do v; plt[:update_yz][](v) end
  on(sl_xz.value) do v; plt[:update_xz][](v) end
  on(sl_xy.value) do v; plt[:update_xy][](v) end

  set_close_to!(sl_yz, .5length(x))
  set_close_to!(sl_xz, .5length(y))
  set_close_to!(sl_xy, .5length(z))

  # add toggles to show/hide heatmaps
  hmaps = [plt[Symbol(:heatmap_, s)][] for s  (:yz, :xz, :xy)]
  toggles = [Toggle(lo[i, nc + 1], active = true) for i  1:length(hmaps)]

  map(zip(hmaps, toggles)) do (h, t)
      connect!(h.visible, t.active)
  end

  # cam3d!(ax.scene, projectiontype=Makie.Orthographic)

  display(fig)
end

plot_volume(xr, yr, zr, input3d)

resampled3d = grid_sample(input3d, grid3d)

xr_grid = range(-2, 2, length=out_size) .* 1.3 .* hypot(1,1,2)
yr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)
zr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)

plot_volume(xr_grid, yr_grid, zr_grid, resampled3d)

Δ3d = ones(out_size, out_size, out_size, 1, 1)

dx3d, dgrid3d = ∇grid_sample(Δ3d, input3d, grid3d; padding_mode=:zeros)

plot_volume(xr, yr, zr, dx3d)

mult = 0.1

points = [Point3(x,y,z) for x in xr_grid for y in yr_grid for z in zr_grid]
normals = [Point3(mult .* dgrid3d[:,x,y,z,1]) for x in 1:size(dgrid3d,2) for y in 1:size(dgrid3d,3) for z in 1:size(dgrid3d,4)]

lengths = [hypot(n...) for n in normals]
gz = lengths .> 0

arrows(points[gz][1:47:end], normals[gz][1:47:end];
linewidth=0.03,
color=lengths[gz][1:47:end],
arrowsize=Vec3f(0.03, 0.03, 0.04))

asd = reshape(lengths, out_size, out_size, out_size)

plot_volume(xr_grid, yr_grid, zr_grid, asd)


#%%

resampled
resampled3d
extrema(resampled3d[:,33,:,1,1] - resampled[:,:,1,1])
isapprox(resampled3d[:,33,:,1,1], resampled[:,:,1,1]; rtol=1e-2)

@gRox167
Copy link
Author

gRox167 commented Jan 24, 2025

Oh and here is some barebones code you can play with to visualize the gradients
Click me

using NNlib: grid_sample, ∇grid_sample
using GLMakie

in_size = 32
out_size = 65

xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)

gauss(x) = exp(-x^2)

input = gauss.(hypot.(xr', yr))
input = input[:,:,:,:]
heatmap(input[:,:,1,1])

v = collect(range(-1.3, 1.3, length=out_size))
xgrid = repeat(v, 1, out_size)
ygrid = repeat(v',  out_size)
grid = reshape(stack((xgrid, ygrid), dims=1), 2, out_size, out_size, 1)

# make grid non-uniform so that we can actually see a derivative wrt input
scaler = hypot.(xgrid, ygrid)
grid = grid .* stack((scaler, scaler); dims=1)[:,:,:,:]

Δ = zeros(eltype(input), size(grid)[2:end-1]..., size(input)[end-1:end]...) .+ 1

resampled = grid_sample(input[:,:,:,:], grid)

heatmap(resampled[:,:,1,1])

begin
# dx, dgrid = ∇grid_sample(Δ, Float32.(input)[:,:,:,:], grid);
dx, dgrid = ∇grid_sample(Δ, input[:,:,:,:], grid);
# heatmap(dx[:,:,1,1])
heatmap(dgrid[1,:,:,1])
end

xr_grid = range(-2, 2, length=out_size)
yr_grid = range(-1, 1, length=out_size)

mult = 0.2
arrows(xr_grid, yr_grid, mult .* dgrid[1,:,:,1], mult .* dgrid[2,:,:,1]; axis=(limits=(extrema(xr_grid), extrema(yr_grid)),))


#%%
# 3D
in_size = 32
out_size = 65

xr = range(-2, 2, length=in_size)
yr = range(-1, 1, length=in_size)
zr = range(-1, 1, length=in_size)

input3d = zeros(in_size, in_size, in_size, 1, 1)

# Populate the tensor with corresponding values
for i in 1:in_size
  for j in 1:in_size
      for k in 1:in_size
          input3d[k, j, i, 1, 1] = gauss(hypot(xr[i], yr[j], zr[k]))
      end
  end
end

v = collect(range(-1.3, 1.3, length=out_size))

grid3d = zeros(3, out_size, out_size, out_size, 1)

for i in 1:out_size
  for j in 1:out_size
      for k in 1:out_size
          grid3d[:,k,j,i,1] .= (v[k], v[j], v[i]) .* hypot(v[k], v[j], v[i])
      end
  end
end


function plot_volume(x,y,z,vol)
  fig = Figure()
  ax = LScene(fig[1, 1], show_axis=true)

  sgrid = SliderGrid(
      fig[2, 1],
      (label = "yz plane - x axis", range = 1:length(x)),
      (label = "xz plane - y axis", range = 1:length(y)),
      (label = "xy plane - z axis", range = 1:length(z)),
  )

  lo = sgrid.layout
  nc = ncols(lo)

  plt = volumeslices!(ax, x, y, z, vol)

  # connect sliders to `volumeslices` update methods
  sl_yz, sl_xz, sl_xy = sgrid.sliders

  on(sl_yz.value) do v; plt[:update_yz][](v) end
  on(sl_xz.value) do v; plt[:update_xz][](v) end
  on(sl_xy.value) do v; plt[:update_xy][](v) end

  set_close_to!(sl_yz, .5length(x))
  set_close_to!(sl_xz, .5length(y))
  set_close_to!(sl_xy, .5length(z))

  # add toggles to show/hide heatmaps
  hmaps = [plt[Symbol(:heatmap_, s)][] for s  (:yz, :xz, :xy)]
  toggles = [Toggle(lo[i, nc + 1], active = true) for i  1:length(hmaps)]

  map(zip(hmaps, toggles)) do (h, t)
      connect!(h.visible, t.active)
  end

  # cam3d!(ax.scene, projectiontype=Makie.Orthographic)

  display(fig)
end

plot_volume(xr, yr, zr, input3d)

resampled3d = grid_sample(input3d, grid3d)

xr_grid = range(-2, 2, length=out_size) .* 1.3 .* hypot(1,1,2)
yr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)
zr_grid = range(-1, 1, length=out_size) .* 1.3 .* hypot(1,1,2)

plot_volume(xr_grid, yr_grid, zr_grid, resampled3d)

Δ3d = ones(out_size, out_size, out_size, 1, 1)

dx3d, dgrid3d = ∇grid_sample(Δ3d, input3d, grid3d; padding_mode=:zeros)

plot_volume(xr, yr, zr, dx3d)

mult = 0.1

points = [Point3(x,y,z) for x in xr_grid for y in yr_grid for z in zr_grid]
normals = [Point3(mult .* dgrid3d[:,x,y,z,1]) for x in 1:size(dgrid3d,2) for y in 1:size(dgrid3d,3) for z in 1:size(dgrid3d,4)]

lengths = [hypot(n...) for n in normals]
gz = lengths .> 0

arrows(points[gz][1:47:end], normals[gz][1:47:end];
linewidth=0.03,
color=lengths[gz][1:47:end],
arrowsize=Vec3f(0.03, 0.03, 0.04))

asd = reshape(lengths, out_size, out_size, out_size)

plot_volume(xr_grid, yr_grid, zr_grid, asd)


#%%

resampled
resampled3d
extrema(resampled3d[:,33,:,1,1] - resampled[:,:,1,1])
isapprox(resampled3d[:,33,:,1,1], resampled[:,:,1,1]; rtol=1e-2)

Thanks. This visualization helps a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants