Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Map Visualization #42

Open
Made-Gpt opened this issue Dec 7, 2024 · 2 comments
Open

Feature Map Visualization #42

Made-Gpt opened this issue Dec 7, 2024 · 2 comments

Comments

@Made-Gpt
Copy link

Made-Gpt commented Dec 7, 2024

Great work on this project !

I am especially interest on the attention map visualization in Figure 8. Could you please share how this was implemented? I'm curious about the specific approach or code you used to achieve this visualization.

@HengyiWang
Copy link
Owner

Hi @Made-Gpt, in eq9, we have an attention map with the shape of (Bs, P, TP). For each patch in the query frame, you can extract and visualize the corresponding attention map with shape (TP). This corresponds to a row in Fig. 8.

@Made-Gpt
Copy link
Author

Made-Gpt commented Dec 26, 2024

Hi, I see! Thank you for your response!

By the way, I was wondering if it’s possible to train the model on 4 NVIDIA 4090 GPUs?
I attempted to train the model on the KITTI dataset, and here’s the command I used:

torchrun --nproc_per_node 4 train.py --batch_size 2 --num_workers 4 \
         --train_dataset "Kitti(split='train', resolution=[[1216, 224]], skip_frames=2, down_sample=2,
                                        ROOT='/data2/shenghai/kitti/data_depth_annotated',
                                        target_seq_file='./spann3r/data_preprocess/kitti/kitti_valid_list.txt')" \
         --test_dataset "Kitti(split='val', resolution=[[1216, 224]], skip_frames=2, down_sample=2,
                                        ROOT='/data2/shenghai/kitti/data_depth_annotated',
                                        target_seq_file='./spann3r/data_preprocess/kitti/kitti_valid_list.txt')" \
         --model "Spann3R(dus3r_name='/data2/shenghai/ckpts/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth',
                          use_feat=False, mem_pos_enc=False)" \
         --pretrained /data2/shenghai/ckpts/spann3r.pth \
         --output_dir ./outputs/kitti/train

I used spann3r.pth as the pretrained weights and wrote a custom Kitti.py file modeled after the other files in the spann3r/datasets folder.
In Kitti.py, I added a progress bar using tqdm to visualize the data loading process, as shown below:

...
            with tqdm(total=len(gts_idxs), desc=f'Loading {mode} images') as pbar:
                while len(gts_idxs) > 0:
                    gtpath = gts_idxs.popleft()

                    if not gtpath.exists():
                        raise FileNotFoundError(f"Image not found: {gtpath}")

                    # Update tqdm progress bar with the current file being processed
                    pbar.set_postfix(file=str(gtpath.name))
                    pbar.update(1)
...

However, in the command line, the progress bar appears multiple times, which seems to indicate that the entire KITTI training dataset is being loaded repeatedly, and this problem happens in 'train_one_epoch()'.
. Initially, the total number of images to be loaded exceeds 20,000. However, as the progress bar shows it has loaded around 4,000 images, some distributed training errors begin to appear. By checking nvidia-smi, I noticed that the memory on all 4 GPUs is gradually released one by one, which suggests that some issue is causing them to stop working.

To address this, I modified the command to include skip_frames=10 and down_sample=2 to reduce the dataset size and decrease memory usage. After this adjustment, the total dataset size was reduced to fewer than 1,000 images. However, during the experiment, these reduced datasets still kept being repeatedly loaded, and the same distributed training errors persisted.

I can’t figure out why this is happening. Could it be related to the multi-GPU training setup, or is there an issue with how I implemented the data loader in Kitti.py? I’d greatly appreciate your insight on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants