diff --git a/README.md b/README.md index 38e3e5c..79f5614 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,4 @@ `torchcvnn/examples` is a repository showcasing examples of using [torchcvnn](https://www.github.com/torchcvnn/torchcvnn). - [Classification of MNIST in the Fourier space with complex valued CNNs](./mnist_conv/README.md) +- [Complex valued Neural Implicit Representation for cardiac reconstruction](./nir_cinejense/README.md) diff --git a/nir_cinejense/README.md b/nir_cinejense/README.md new file mode 100644 index 0000000..05a3f16 --- /dev/null +++ b/nir_cinejense/README.md @@ -0,0 +1,82 @@ +# Neural Implicit Representation for cardiac reconstruction + +This example reproduces the code from [https://github.com/MDL-UzL/CineJENSE](https://github.com/MDL-UzL/CineJENSE) along the paper "CineJENSE: Simultaneous Cine MRI Image Reconstruction and Sensitivity Map Estimation using Neural Representations" by Ziad Al-Haj Hemidi, Nora Vogt, Lucile Quillien, Christian Weihsbach, Mattias P. Heinrich, and Julien Oster. + +It is showcasing NIR for cardiac reconstruction. In a few words : MRI is +sampling in the Fourier space as bands. The longer the exam, the more you +collect bands from the Fourier representation. The task is to estimate the +non-observed bands from the observed bands, hopefully to get an exam as short as +possible but still observing the heart as if it was observed for a longer +period. + +For running this example, you need to download the data from the [CMRxRecon MICCAI 2023 +challenge](https://cmrxrecon.github.io/Home.html). See also the github of the +challenge to access the data [https://github.com/CmrxRecon/CMRxRecon2024](https://github.com/CmrxRecon/CMRxRecon2024) + +The data directory is expected to follow the structure used by [torchcvnn](https://torchcvnn.github.io/torchcvnn/modules/datasets.html#torchcvnn.datasets.MICCAI2023). + + rootdir/ChallengeData/MultiCoil/cine/TrainingSet/P{id}/ + - cine_sax.mat + - cin_lax.mat + rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor04/P{id}/ + - cine_sax.mat + - cine_sax_mask.mat + - cin_lax.mat + - cine_lax_mask.mat + + +The script supports the three acceleration factors and both the Short Axis (SAX) and +Long Axis (LAX). + +If multiple patient data are provided, the script will sample one of them randomly. + +The installation is done in two steps + +```bash +python -m pip install -r requirements.txt +python -m pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +``` + +The tinycudann installation may complain if the version of torch, installed during the first step is not using the cuda version for which you have installed the librairies. You may need to overwrite the installed torch version by installing the one supporting the right cuda version and listed on [https://pytorch.org/get-started/previous-versions/](https://pytorch.org/get-started/previous-versions). + +Once the dependencies are installed and you have the data available, you should be able to run the code : + +``` +python nir_miccai2023.py --rootdir /path/to/the/data --acc_factor ACC10 --view SAX +``` + +## Examples + +The examples below have been produced on a GTX GeForce 3090, taking 2 minutes per slice. + +They were executed with numpy==1.26.4 tinycudann==1.7 torchcvnn==0.8.0 torch==2.0.1 + +### Acceleration Factor 4 + +- ACC4, SAX, patient P002, mean PSNR= $42.19$ (mean over the $12$ frames, for slice number $5$) + +![ACC4, SAX, P002](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc4_sax_p002.gif?raw=true) + +- ACC4, LAX, patient P012, mean PSNR= $41.97$ (mean over the $12$ frames, for slice number $1$) + +![ACC4, LAX, P012](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc4_lax_p012.gif?raw=true) + +### Acceleration Factor 8 + +- ACC8, SAX, patient P063, mean PSNR= $34.59$ (mean over the $12$ frames, for slice number $5$) + +![ACC8, SAX, P063](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc8_sax_p063.gif?raw=true) + +- ACC8, LAX, patient P001, mean PSNR= $36.48$ (mean over the $12$ frames, for slice number $1$) + +![ACC8, LAX, P001](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc8_lax_p001.gif?raw=true) + +### Acceleration Factor 10 + +- ACC10, SAX, patient P107, mean PSNR= $34.83$ (mean over the $12$ frames, for slice number $5$) + +![ACC10, SAX, P107](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc10_sax_p107.gif?raw=true) + +- ACC10, LAX, patient P014, mean PSNR= $31.44$ (mean over the $12$ frames, for slice number $1$) + +![ACC10, LAX, P014](https://github.com/torchcvnn/examples/blob/nir_cinejense/nir_cinejense/gifs/acc10_lax_p014.gif?raw=true) diff --git a/nir_cinejense/gifs/acc10_lax_p014.gif b/nir_cinejense/gifs/acc10_lax_p014.gif new file mode 100644 index 0000000..98df839 Binary files /dev/null and b/nir_cinejense/gifs/acc10_lax_p014.gif differ diff --git a/nir_cinejense/gifs/acc10_sax_p107.gif b/nir_cinejense/gifs/acc10_sax_p107.gif new file mode 100644 index 0000000..c4a5a51 Binary files /dev/null and b/nir_cinejense/gifs/acc10_sax_p107.gif differ diff --git a/nir_cinejense/gifs/acc4_lax_p012.gif b/nir_cinejense/gifs/acc4_lax_p012.gif new file mode 100644 index 0000000..ef9db51 Binary files /dev/null and b/nir_cinejense/gifs/acc4_lax_p012.gif differ diff --git a/nir_cinejense/gifs/acc4_sax_p002.gif b/nir_cinejense/gifs/acc4_sax_p002.gif new file mode 100644 index 0000000..4fd462a Binary files /dev/null and b/nir_cinejense/gifs/acc4_sax_p002.gif differ diff --git a/nir_cinejense/gifs/acc8_lax_p001.gif b/nir_cinejense/gifs/acc8_lax_p001.gif new file mode 100644 index 0000000..8bdfd36 Binary files /dev/null and b/nir_cinejense/gifs/acc8_lax_p001.gif differ diff --git a/nir_cinejense/gifs/acc8_sax_p063.gif b/nir_cinejense/gifs/acc8_sax_p063.gif new file mode 100644 index 0000000..2ce9c5c Binary files /dev/null and b/nir_cinejense/gifs/acc8_sax_p063.gif differ diff --git a/nir_cinejense/nir_miccai2023.py b/nir_cinejense/nir_miccai2023.py new file mode 100644 index 0000000..fe06f2d --- /dev/null +++ b/nir_cinejense/nir_miccai2023.py @@ -0,0 +1,471 @@ +# coding: utf-8 + +# MIT License + +# Copyright (c) 2024 Clément Cornet, Jérémy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Standard imports +import argparse +import logging +import random +import pathlib +import shutil + +# External imports +import tqdm +import torch +import torch.nn as nn +import numpy as np +import matplotlib + +matplotlib.use("Agg") # Non interactive backend + +import matplotlib.pyplot as plt + +# Local imports +from torchcvnn.datasets.miccai2023 import ( + MICCAI2023, + CINEView, + kspace_to_image, + AccFactor, + combine_coils_from_kspace, +) +import utils + + +@torch.jit.script +def FFT(x): + return torch.fft.fftshift( + torch.fft.fft2(torch.fft.ifftshift(x, dim=(0, 1)), dim=(0, 1)), dim=(0, 1) + ) + + +@torch.jit.script +def IFFT(x): + return torch.fft.ifftshift( + torch.fft.ifft2(torch.fft.fftshift(x, dim=(0, 1)), dim=(0, 1)), dim=(0, 1) + ) + + +def combine_coils(kspace): + """ + Combine the coils from the given k-space + + Arguments: + kspace: Tensor of shape (nrows, ncols, ncoils) + or (nrows, ncols, ncoils, nframes) + complex valued + + Returns: + image: Tensor of shape (nrows, ncols) + or (nrows, ncols, nframes) + magnitude only + """ + if isinstance(kspace, np.ndarray): + kspace = torch.tensor(kspace, dtype=torch.complex64) + + images = IFFT(kspace) + + # Combine the coils in the image space with the RSS + coils_combined = (images.abs() ** 2).sum(axis=2).sqrt() + + return coils_combined + + +class TVLoss(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, X): + """ + Computes the total variation of the tensor X + + X is of shape (nrows, ncols, nframes, ncoils) + + The TV loss is computed from the 2D signals nrows x ncols + averaged over the other dimensions + """ + + diff_rows = X[1:, :, ...] - X[:-1, :, ...] + diff_cols = X[:, 1:, ...] - X[:, :-1, ...] + tv = torch.mean(torch.abs(diff_rows) ** 2) + torch.mean( + torch.abs(diff_cols) ** 2 + ) + return tv + + +def infer_on_slice( + subsampled_slice, subsampled_mask, slice_idx, results_dir, training_cfg +): + """ + Perform inference on a single slice for all the frames and all the coils + + Arguments: + subsampled_slice (torch.Tensor): Subsampled k-space data for a single slice, (ky, kx, sc, t) + subsampled_mask (torch.Tensor): Subsampled mask for a single slice (ky, kx) + slice_idx (int): the index of the slice, used to determine the results filenames + results_dir (pathlib.Path): the path where to save all the results + training_cfg (dict): the training parameters (reg_weight, learning rate, number of iterations) + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + reg_weight = training_cfg["reg_weight"] + max_iter = training_cfg["iter"] + lr = training_cfg["lr"] + + # Put the slices on the right device + subsampled_slice = torch.tensor(subsampled_slice, dtype=torch.complex64).to(device) + subsampled_mask = torch.tensor(subsampled_mask, dtype=torch.float32).to(device) + + # Build the models + nrows, ncols, ncoils, nframes = subsampled_slice.shape + + # Pre-compute the coordinates for sampling the 3D volume X, Y, T + # and view it as (n_samples, 3) + coor = utils.build_coordinate_2Dt(nrows, ncols, nframes, device).view(-1, 3) + + # Build the image model + encoding_image = { + "otype": "Grid", + "type": "Hash", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hasmap_size": 19, + "base_resolution": 16, + "per_level_scale": 2, + "interpolation": "Linear", + } + mlp_image = {"n_hidden_units": 32, "n_hidden_layers": 4} + image_model = utils.ComplexNGP( + n_inputs=3, n_outputs=1, encoding_cfg=encoding_image, mlp_cfg=mlp_image + ).to(device) + + # Build the Coil Sensitivity Map network + encoding_csm = { + "otype": "Grid", + "type": "Hash", + "n_levels": 4, + "n_features_per_level": 8, + "log2_hasmap_size": 19, + "base_resolution": 2, + "per_level_scale": 1.1, + "interpolation": "Linear", + } + mlp_csm = {"n_hidden_units": 32, "n_hidden_layers": 4} + csm_model = utils.ComplexNGP( + n_inputs=3, n_outputs=ncoils, encoding_cfg=encoding_csm, mlp_cfg=mlp_csm + ).to(device) + + # Built the optimizers and losses + optim_image = torch.optim.Adam(image_model.parameters(), lr=lr) + optim_csm = torch.optim.Adam(csm_model.parameters(), lr=lr) + + # The loss has two components. + # - A Total Variation Loss in the image space for the pre-intensity + # - a Huber loss between the sampled components of the k-space + reg_loss = TVLoss() + kspace_loss = torch.nn.HuberLoss() + + # Loop for max_iter + with tqdm.tqdm(range(max_iter)) as pbar: + for _ in pbar: + + # Switch the models in training mode + image_model.train() + csm_model.train() + + # Compute the forward pass + pre_intensity = image_model(coor).view( + nrows, ncols, nframes + ) # Nrows, Ncols, Nframes + csm = csm_model(coor).view( + nrows, ncols, nframes, ncoils + ) # Nrows, Ncols, Nframes, Ncoils + + # Compute the RSS over the coils + csm_norm = torch.sqrt((csm.conj() * csm).sum(axis=-1)) + # Unsqueeze over the coil dimension to apply the same scaling for every coil + csm = csm / (csm_norm.unsqueeze(-1) + 1e-12) + + # Apply the same pre-instensity through every coil specific sensitivity + fft_pre_intensity = FFT(pre_intensity.unsqueeze(axis=-1) * csm).transpose( + 3, 2 + ) # (Nrows, Ncols, Ncoils, Nframes) + + # Compute the loss with the reconstruction loss + # and the regularization loss + masked_pred_kspace = torch.view_as_real( + fft_pre_intensity[subsampled_mask == 1] + ) + masked_kspace = torch.view_as_real(subsampled_slice[subsampled_mask == 1]) + + kspace_loss_value = kspace_loss(masked_pred_kspace, masked_kspace) + reg_loss_value = reg_loss(pre_intensity) + loss = kspace_loss_value + reg_weight * reg_loss_value + pbar.set_postfix( + { + "TV": reg_loss_value.item(), + "Data consistency Loss": kspace_loss_value.item(), + } + ) + + # Zero grad, backward and update + optim_image.zero_grad() + optim_csm.zero_grad() + + loss.backward() + + optim_image.step() + optim_csm.step() + + # Inference + logging.info("Performing inference") + image_model.eval() + csm_model.eval() + with torch.no_grad(): + pre_intensity = image_model(coor).view( + nrows, ncols, nframes + ) # Nrows, Ncols, Nframes + csm = csm_model(coor).view( + nrows, ncols, nframes, ncoils + ) # Nrows, Ncols, Nframes, Ncoils + + # Compute the RSS over the coils + csm_norm = torch.sqrt((csm.conj() * csm).sum(axis=-1)) + # Unsqueeze over the coil dimension to apply the same scaling for every coil + csm = csm / (csm_norm.unsqueeze(-1) + 1e-12) + + fft_pre_intensity = FFT( + pre_intensity.unsqueeze(axis=-1) * csm + ) # (Nrows, Ncols, Nframes, Ncoils) + + recon_kspace = torch.clone(fft_pre_intensity) + # Keep the input k-space untouched + recon_kspace[subsampled_mask == 1] = subsampled_slice[ + subsampled_mask == 1 + ].transpose(1, 2) + # Compute the image for the reconstructed k-space + recon_img = IFFT(recon_kspace) + + # Merge all the coils, each contribution being modulated by the CSM + fused_recon_img = (recon_img * torch.conj(csm)).sum(axis=-1) + + # Plot the reconstruction with the contributions of all the coils + scale_factor = fused_recon_img.abs().max() + img = fused_recon_img.abs() / scale_factor + img = img.cpu() # (Nrows, Ncols, Nframes) + + # frame_idx = 0 + # h = plt.figure() + # plt.imshow(img[:, :, frame_idx], cmap="gray") + # plt.savefig( + # str(results_dir / f"slice_{slice_idx}_frame_{frame_idx}.png"), + # bbox_inches="tight", + # ) + # plt.close(h) + + # Plot the combined coils for every frame + # This allows to make a video + for frame_idx in range(img.shape[2]): + h = plt.figure() + plt.imshow(img[:, :, frame_idx], cmap="gray") + plt.savefig( + str(results_dir / f"slice_{slice_idx}_{frame_idx:04d}.png"), + bbox_inches="tight", + ) + plt.close(h) + + # Plot, for evey coil, the predicted image + # for coil_idx in range(recon_img.shape[3]): + # img = (recon_img[:, :, frame_idx, coil_idx].abs() / scale_factor).cpu() + # h = plt.figure() + # plt.imshow(img, cmap="gray", clim=[0, 1]) + # plt.savefig( + # str( + # results_dir + # / f"slice_{slice_idx}_frame_{frame_idx}_coil_{coil_idx}.png" + # ), + # bbox_inches="tight", + # ) + # plt.close(h) + return pre_intensity, csm, img + + +def train(rootdir, acc_factor, view, training_cfg): + dataset = MICCAI2023( + rootdir=rootdir, + view=view, + acc_factor=acc_factor, + ) + + # Take a random sample + sample_idx = random.randint(0, len(dataset) - 1) + # sample_idx = 0 + + # Prepare the directory in which to store the results + patient_path = dataset.patients[sample_idx] + results_dir = pathlib.Path("./results") / patient_path.name + + # Prepare the results directory, rm if already exists + if results_dir.exists(): + logging.info(f"Removing {results_dir} as it already exists") + shutil.rmtree(results_dir) + + logging.info(f"Results will be saved into {results_dir}") + results_dir.mkdir(parents=True) + + # Get the sample + subsampled_data, subsampled_mask, fullsampled_data = dataset[sample_idx] + + # Subsampled_data and fullsampled_data are (kx, ky, sc, sz, t) + n_coils = subsampled_data.shape[-3] + n_slices = subsampled_data.shape[-2] + n_frames = subsampled_data.shape[-1] + + # Iterate over the slices + all_psnrs = [] + + logging.info(f"Processing {n_slices} slices") + for slice_idx in tqdm.tqdm(range(n_slices)): + + # Get the slices from the subsampled and fullsampled data + # These slices are (kx, ky, sc, t) + subsampled_slice = subsampled_data[:, :, :, slice_idx, :] + fullsampled_slice = fullsampled_data[:, :, :, slice_idx, :] + + # Compute the normalization factor by computing the max RSS + # of the images + # This step is super important for the training to work properly + coils_combined = combine_coils(subsampled_slice) + norm_factor = coils_combined.max() + logging.debug( + f"For slice {slice_idx}, using the normalization factor {norm_factor}" + ) + + subsampled_slice = subsampled_slice / norm_factor.item() + + # Perform inference on this slice + pre_intensity, csm, pred_image = infer_on_slice( + subsampled_slice, subsampled_mask, slice_idx, results_dir, training_cfg + ) + + # Compute the PSNR for every slice and every coil + # and also plot, for comparison, + # - the full sampled image + # - the zero filled image (the unobserved k-space is filled with zeros) + # - the predicted image + + full_image = combine_coils(fullsampled_slice) # nrows, ncols, nframes + full_image /= full_image.max() + + zero_filled = combine_coils(subsampled_slice) + zero_filled /= zero_filled.max() + + for frame_idx in range(full_image.shape[2]): + img = full_image[:, :, frame_idx] + zero_img = zero_filled[:, :, frame_idx] + pred_img = pred_image[:, :, frame_idx] + + # Compute the PSNR + mse = ((img - pred_img) ** 2).mean() + data_range = 1.0 + psnr = 10.0 * np.log10(data_range**2 / mse) + all_psnrs.append(psnr) + logging.debug(f"PSNR for slice {slice_idx}, frame {frame_idx}: {psnr}") + + fig, axes = plt.subplots(nrows=1, ncols=3) + axes[0].imshow(img, cmap="gray", clim=[0, 1]) + axes[0].set_title("Ground truth") + axes[0].axis("off") + + axes[1].imshow(pred_img, cmap="gray", clim=[0, 1]) + axes[1].set_title("Predicted") + axes[1].axis("off") + + axes[2].imshow(zero_img, cmap="gray", clim=[0, 1]) + axes[2].set_title("Zero filled") + axes[2].axis("off") + + plt.tight_layout() + plt.savefig( + results_dir / f"slice_{slice_idx}_frame_{frame_idx}.png", + bbox_inches="tight", + dpi=150, + ) + plt.close(fig) + + logging.info( + f"Mean PSNR evaluated over all the slices and all the frames : {np.mean(all_psnrs)}" + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser( + description="Implementation of the neural implict neural representation for cine MRI reconstruction" + ) + parser.add_argument( + "--rootdir", + type=str, + default="data", + help="Path to the data directory", + required=True, + ) + parser.add_argument( + "--acc_factor", + default=AccFactor.ACC4, + type=AccFactor.__getitem__, + help="Acceleration factor (ACC4, ACC8, ACC10)", + ) + parser.add_argument( + "--view", + default=CINEView.SAX, + type=CINEView.__getitem__, + help="View of the cine MRI data (SAX, LAX)", + ) + + # Some training parameters + parser.add_argument( + "--lr", + default=0.01, + type=float, + help="The base learning rate for the optimizer", + ) + parser.add_argument( + "--iter", + default=512, + type=int, + help="The number of training iterations per slice", + ) + parser.add_argument( + "--reg_weight", + default=4.0, + type=float, + help="The weights of the TV loss in the total loss", + ) + + args = parser.parse_args() + + training_cfg = {"lr": args.lr, "iter": args.iter, "reg_weight": args.reg_weight} + + train(args.rootdir, args.acc_factor, args.view, training_cfg) diff --git a/nir_cinejense/requirements.txt b/nir_cinejense/requirements.txt new file mode 100644 index 0000000..285992c --- /dev/null +++ b/nir_cinejense/requirements.txt @@ -0,0 +1,4 @@ +torch +torchcvnn +tqdm +matplotlib diff --git a/nir_cinejense/run_trainings.sh b/nir_cinejense/run_trainings.sh new file mode 100755 index 0000000..77c6efe --- /dev/null +++ b/nir_cinejense/run_trainings.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Run the training script for all the acceleration factors ACC4, ACC8 and ACC10 +# and both views SAX and LAX +# The results are saved in "./results" for every run , we copy the results in a dedicated directory after each run + +# Example call : +# python nir_miccai2023.py --rootdir /mounts/Datasets4/MICCAIChallenge2023/ChallegeData/ --acc_factor ACC10 --view SAX + +for acc in ACC4 ACC8 ACC10 +do + for view in SAX LAX + do + echo "==================================================" + echo "Running for $acc $view" + echo "==================================================" + python nir_miccai2023.py --rootdir /mounts/Datasets4/MICCAIChallenge2023/ChallegeData/ --acc_factor $acc --view $view + mkdir -p ./all_results/$acc/$view + cp -r ./results/* ./all_results/$acc/$view + done +done diff --git a/nir_cinejense/utils.py b/nir_cinejense/utils.py new file mode 100644 index 0000000..4802403 --- /dev/null +++ b/nir_cinejense/utils.py @@ -0,0 +1,123 @@ +# coding: utf-8 + +# MIT License + +# Copyright (c) 2024 Clément Cornet, Jérémy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# External imports +import torch +import torch.nn as nn +import tinycudann as tcnn +import torchcvnn.nn as c_nn + + +def build_coordinate_2Dt(Nx, Ny, Nt, device=torch.device("cpu")): + x = torch.linspace(-1, 1, Nx, device=device) + y = torch.linspace(-1, 1, Ny, device=device) + t = torch.linspace(-1, 1, Nt, device=device) + + x, y, t = torch.meshgrid(x, y, t, indexing="ij") + xyt = torch.stack([x, y, t], -1).view(-1, 3) + xyt = xyt.view(Nx, Ny, Nt, 3) + return xyt + + +class ComplexNGP(nn.Module): + """ + Complex Neural Graphic Primitive + + The first layer involves a Hash Encoding then followed + by dense layers to get the value of the function at the provided + input coordinates. + + Arguments: + n_inputs (int): Number of input coordinates (e.g. 3 for x, y, t) + n_outputs (int): Number of output values (e.g. 1 for a scalar field) + encoding_cfg (dict): Configuration for the encoding + mlp_config (dict): Configuration for the MLP after the encoding + """ + + def __init__(self, n_inputs, n_outputs, encoding_cfg, mlp_cfg): + super().__init__() + + # The input layer uses a hash encoding + self.hash_encoder = tcnn.Encoding(n_inputs, encoding_cfg) + output_hash_encoding = ( + encoding_cfg["n_levels"] * encoding_cfg["n_features_per_level"] + ) + + # And then comes the FFNN with dense layers based + # on the above coordinate encoding + n_hidden_units = mlp_cfg["n_hidden_units"] + n_hidden_layers = mlp_cfg["n_hidden_layers"] + hidden_activation = c_nn.modReLU + self.cdtype = torch.complex64 + + layers = [] + input_dim = output_hash_encoding + for _ in range(n_hidden_layers - 1): + layers.append(nn.Linear(input_dim, n_hidden_units, dtype=self.cdtype)) + layers.append(hidden_activation()) + input_dim = n_hidden_units + # The last dense layer projects onto the output space + layers.append(nn.Linear(input_dim, n_outputs, dtype=self.cdtype)) + + self.ffnn = nn.Sequential(*layers) + + def forward(self, x): + x = self.hash_encoder(x).to(self.cdtype) + x = self.ffnn(x) + return x + + +def test_ngp(): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + n_inputs = 3 + n_outputs = 4 + model = ComplexNGP( + n_inputs=n_inputs, + n_outputs=n_outputs, + encoding_cfg={ + "otype": "Grid", + "type": "Hash", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hasmap_size": 19, + "base_resolution": 16, + "per_level_scale": 2, + "interpolation": "Linear", + }, + mlp_cfg={"n_hidden_units": 32, "n_hidden_layers": 2}, + ) + model = model.to(device) + + # Build up the volume over which to evaluate the NIR + Nx, Ny, Nt = 16, 16, 16 + coords = build_coordinate_2Dt(Nx, Ny, Nt, device=device).view(-1, 3) + + # Sample the volume + outputs = model(coords).reshape(Nx, Ny, Nt, n_outputs) + + +if __name__ == "__main__": + test_ngp()