Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
fabio-sim committed Jan 22, 2024
1 parent 22ccfac commit d6cc8e1
Show file tree
Hide file tree
Showing 97 changed files with 9,940 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.vscode
*.onnx
megadepth_test_1500

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright 2024 Fabio Milentiansen Sim

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
82 changes: 82 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
[![ONNX](https://img.shields.io/badge/ONNX-grey)](https://onnx.ai/)
[![GitHub Repo stars](https://img.shields.io/github/stars/fabio-sim/Depth-Anything-ONNX)](https://github.com/fabio-sim/Depth-Anything-ONNX/stargazers)
[![GitHub all releases](https://img.shields.io/github/downloads/fabio-sim/Depth-Anything-ONNX/total)](https://github.com/fabio-sim/Depth-Anything-ONNX/releases)

# Depth Anything ONNX

Open Neural Network Exchange (ONNX) compatible implementation of [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://github.com/LiheYoung/Depth-Anything).

<p align="center"><img src="assets/sample.png" width=90%>

<details>
<summary>Changelog</summary>

- **22 January 2024**: Release.
</details>

## 🔥 ONNX Export

Prior to exporting the ONNX models, please install the [requirements](/requirements.txt).

To convert the Depth Anything models to ONNX, run [`export.py`](/export.py). The pretrained weights will be downloaded automatically.

<details>
<summary>Export Example</summary>
<pre>
python export.py --model s
</pre>
</details>

If you would like to try out inference right away, you can download ONNX models that have already been exported [here](https://github.com/fabio-sim/Depth-Anything-ONNX/releases).

## ⚡ ONNX Inference

With ONNX models in hand, one can perform inference on Python using ONNX Runtime. See [`infer.py`](/infer.py).

<details>
<summary>Inference Example</summary>
<pre>
python infer.py --img assets/DSC_0410.JPG --model weights/depth_anything_vits14.onnx --viz
</pre>
</details>

## 🚀 TensorRT Support

(To be investigated)

## ⏱️ Inference Time Comparison

<p align="center"><img src="assets/latency.png" alt="Latency Comparison" width=90%>

We report the inference time, or latency, of only the model; that is, the time taken for preprocessing, postprocessing, or copying data between the host & device is not measured. The average inference time is defined as the median over all samples in the [MegaDepth](https://arxiv.org/abs/1804.00607) test dataset. We use the data provided by [LoFTR](https://arxiv.org/abs/2104.00680) [here](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md) - a total of 806 images.

Each image is resized such that its size is 518x518 before being fed into the model. The inference time is then measured for all model variants (S, B, L). See [eval.py](/eval.py) for the measurement code.

All experiments are conducted on an i9-12900HX CPU and RTX4080 12GB GPU with `CUDA==11.8.1`, `torch==2.1.2`, and `onnxruntime==1.16.3`.

### Notes

- Currently, the inference speed is bottlenecked by Conv operations.
- ONNXRuntime performs slightly (20-25%) faster for the ViT-L model variant.

## Credits
If you use any ideas from the papers or code in this repo, please consider citing the authors of [Depth Anything](https://arxiv.org/abs/2401.10891) and [DINOv2](https://arxiv.org/abs/2304.07193). Lastly, if the ONNX versions helped you in any way, please also consider starring this repository.

```bibtex
@article{depthanything,
title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
journal={arXiv:2401.10891},
year={2024}
}
```

```bibtex
@misc{oquab2023dinov2,
title={DINOv2: Learning Robust Visual Features without Supervision},
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
journal={arXiv:2304.07193},
year={2023}
}
```

Binary file added assets/DSC_0410.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sacre_coeur1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sacre_coeur2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/sample.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
204 changes: 204 additions & 0 deletions depth_anything/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import torch.nn as nn


def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()

out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape

if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8

scratch.layer1_rn = nn.Conv2d(
in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)

return scratch


class ResidualConvUnit(nn.Module):
"""Residual convolution module."""

def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()

self.bn = bn

self.groups = 1

self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=self.groups,
)

self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=self.groups,
)

if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)

self.activation = activation

self.skip_add = nn.quantized.FloatFunctional()

def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""

out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)

out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)

if self.groups > 1:
out = self.conv_merge(out)

return self.skip_add.add(out, x)


class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""

def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()

self.deconv = deconv
self.align_corners = align_corners

self.groups = 1

self.expand = expand
out_features = features
if self.expand:
out_features = features // 2

self.out_conv = nn.Conv2d(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1,
)

self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)

self.skip_add = nn.quantized.FloatFunctional()

self.size = size

def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)

output = self.resConfUnit2(output)

if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}

output = nn.functional.interpolate(
output, **modifier, mode="bilinear", align_corners=self.align_corners
)

output = self.out_conv(output)

return output
Loading

0 comments on commit d6cc8e1

Please sign in to comment.