Skip to content

Commit

Permalink
Add support for NeRFstudio
Browse files Browse the repository at this point in the history
  • Loading branch information
HengyiWang committed Oct 25, 2024
1 parent 9fdb9b9 commit e4949e1
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

## Update

[2024-10-25] Add support for [Nerfstudio](assets/spanner-gs.gif)

[2024-10-18] Add camera param estimation

[2024-09-30] [@hugoycj](https://github.com/hugoycj) adds a gradio demo
Expand Down Expand Up @@ -73,11 +75,26 @@
```

For visualization `--vis`, it will give you a window to adjust the rendering view. Once you find the view to render, please click `space key` and close the window. The code will then do the rendering of the incremental reconstruction.

3. Nerfstudio:

```
# Run demo use --save_ori to save scaled intrinsics for original images
python demo.py --demo_path ./examples/s00567 --kf_every 10 --vis --vis_cam --save_ori
# Run splatfacto
ns-train splatfacto --data ./output/demo/s00567 --pipeline.model.camera-optimizer.mode SO3xR3
# Render your results
ns-render interpolate --load-config [path-to-your-config]/config.yml
```

Note that here you can use `--save_ori` to save the scaled intrinsics into `transform.json` to train NeRF/3D Gaussians with original images.'


## Gradio interface 🤗
## Gradio interface

We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
We also provide a Gradio interface for a better experience, just run by:

```bash
# For Linux and Windows users (and macOS with Intel??)
Expand Down
Binary file added assets/spanner-gs.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import cv2
import json
import time
import torch
import argparse
Expand All @@ -13,6 +14,7 @@
from dust3r.utils.geometry import inv
from dust3r.inference import inference
from dust3r.image_pairs import make_pairs
from dust3r.utils.image import imread_cv2
from dust3r.post_process import estimate_focal_knowing_depth

from spann3r.datasets import *
Expand All @@ -33,9 +35,41 @@ def get_args_parser():
parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames')
parser.add_argument('--vis', action='store_true', help='visualize')
parser.add_argument('--vis_cam', action='store_true', help='visualize camera pose')
parser.add_argument('--save_ori', action='store_true', help='save original parameters for NeRF')

return parser

def get_transform_json(H, W, focal, poses_all, ply_file_path, ori_path=None):
transform_dict = {
'w': W,
'h': H,
'fl_x': focal.item(),
'fl_y': focal.item(),
'cx': W/2,
'cy': H/2,
'k1': 0,
'k2': 0,
'p1': 0,
'p2': 0,
'camera_model': 'OPENCV',
}
frames = []

for i, pose in enumerate(poses_all):
# CV2 GL format
pose[:3, 1] *= -1
pose[:3, 2] *= -1
frame = {
'file_path': f"imgs/img_{i:04d}.png" if ori_path is None else ori_path[i],
'transform_matrix': pose.tolist()
}
frames.append(frame)

transform_dict['frames'] = frames
transform_dict['ply_file_path'] = ply_file_path

return transform_dict

@torch.no_grad()
def main(args):

Expand Down Expand Up @@ -188,6 +222,35 @@ def main(args):

render_frames(pts_all, images_all, camera_parameters, save_demo_path, mask=conf_sig_all>args.conf_thresh)
vis_pred_and_imgs(pts_all, save_demo_path, images_all=images_all, conf_all=conf_sig_all)

# Save transform.json
if args.save_ori:
scale_factor = ordered_batch[0]['camera_intrinsics'][0, 0, 0]
assert scale_factor < 1.0, "Scale factor should be less than 1.0"
focal_ori = focal / scale_factor

image = imread_cv2(ordered_batch[0]['label'][0])

H_ori, W_ori = image.shape[:2]

paths_all = [osp.normpath(osp.join(osp.abspath(os.getcwd()), view['label'][0]))
for view in ordered_batch]

transform_dict = get_transform_json(H_ori, W_ori, focal_ori, poses_all,
f"{demo_name}_conf{args.conf_thresh}.ply",
ori_path=paths_all)




else:
transform_dict = get_transform_json(H, W, focal, poses_all, f"{demo_name}_conf{args.conf_thresh}.ply")


# Save to json
with open(os.path.join(save_demo_path, 'transforms.json'), 'w') as f:
json.dump(transform_dict, f, indent=4)




Expand Down
2 changes: 1 addition & 1 deletion spann3r/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_views(self, idx, resolution, rng):
img_idxs = self.sample_frame_idx(img_idxs, rng, full_video=self.full_video)

# pseudo intrinsics
fx, fy = 525, 525
fx, fy = 1.0, 1.0

views = []
imgs_idxs = deque(img_idxs)
Expand Down

0 comments on commit e4949e1

Please sign in to comment.