-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
37 changed files
with
6,603 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,95 @@ | ||
# Focus-DETR | ||
Focus-DETR is a model that focuses attention on more informative tokens for a better trade-off between computation efficiency and model accuracy. Compared with the state-of-the-art sparse transformed-based detector under the same setting, our Focus-DETR gets comparable complexity while achieving 50.4AP (+2.2) on COCO. | ||
# Contents | ||
|
||
- [Contents](#contents) | ||
- [Focus-DETR Description](#focus-detr-description) | ||
- [Model architecture](#model-architecture) | ||
- [Dataset](#dataset) | ||
- [Environment Requirements](#environment-requirements) | ||
- [Eval process](#eval-process) | ||
- [Usage](#usage) | ||
- [Launch](#launch) | ||
- [Result](#result) | ||
- [ModelZoo Homepage](#modelzoo-homepage) | ||
|
||
## [Focus-DETR Description](#contents) | ||
|
||
Focus-DETR is a model that focuses attention on more informative tokens for a better trade-off between computation efficiency and model accuracy. Compared with the state-of-the-art sparse transformed-based detector under the same setting, | ||
our Focus-DETR gets comparable complexity while achieving 50.4AP (+2.2) on COCO. | ||
|
||
> [Paper](https://openreview.net/pdf?id=iuW96ssPQX): Less is More: Focus Attention for Efficient DETR. | ||
> Dehua Zheng*, Wenhui Dong*, Hailin Hu, Xinghao Chen, Yunhe Wang. | ||
## [Model architecture](#contents) | ||
|
||
Our Focus-DETR comprises a backbone network, a Transformer encoder, and a Transformer decoder. We design a foreground token selector (FTS) based on top-down score modulations across multi-scale features. And the selected tokens by a multi-category score predictor and foreground tokens go through the Pyramid Encoder to remedy the limitation of deformable attention in distant information mixing. | ||
|
||
![Focus-DETR](./figs/model_arch.PNG) | ||
|
||
## [Dataset](#contents) | ||
|
||
Dataset used: [COCO2017](https://cocodataset.org/#download) | ||
|
||
- Dataset size:~19G | ||
- [Train](http://images.cocodataset.org/zips/train2017.zip) - 18G,118000 images | ||
- [Val](http://images.cocodataset.org/zips/val2017.zip) - 1G,5000 images | ||
- [Annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) - | ||
241M,instances,captions,person_keypoints etc | ||
- Data format:image and json files | ||
- The directory structure is as follows: | ||
|
||
```text | ||
. | ||
├── annotations # annotation jsons | ||
├── test2017 # test data | ||
├── train2017 # train dataset | ||
└── val2017 # val dataset | ||
``` | ||
|
||
## [Environment Requirements](#contents) | ||
|
||
- Hardware(GPU) | ||
- Prepare hardware environment with GPU. | ||
- Framework | ||
- [MindSpore](https://www.mindspore.cn/install/en) | ||
- For more information, please check the resources below£º | ||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) | ||
- [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html) | ||
|
||
## [Eval process](#contents) | ||
|
||
### Usage | ||
|
||
After installing MindSpore via the official website, you can start evaluation as follows: | ||
|
||
### Launch | ||
|
||
```bash | ||
# infer example python | ||
bash scripts/DINO_eval_ms_coco.sh /path/to/your/COCODIR /path/to/your/checkpoint | ||
# bash scripts/DINO_eval_ms_coco.sh coco2017 ./logs/best_ckpt.ckpt | ||
``` | ||
|
||
> checkpoint can be downloaded at xxxx | ||
### Result | ||
|
||
```bash | ||
Results of Focus-DETR with Resnet50 backbone: | ||
IoU metric: bbox | ||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.479 | ||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.659 | ||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.521 | ||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.323 | ||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.505 | ||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.619 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.372 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.640 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.720 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.568 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.757 | ||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.878 | ||
``` | ||
|
||
## [ModelZoo Homepage](#contents) | ||
|
||
Please check the official [homepage](https://gitee.com/mindspore/models). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
_base_ = ["coco_transformer.py"] | ||
|
||
num_classes = 91 | ||
|
||
lr = 0.0001 | ||
param_dict_type = "default" | ||
lr_backbone = 1e-05 | ||
lr_backbone_names = ["backbone.0"] | ||
lr_linear_proj_names = ["reference_points", "sampling_offsets"] | ||
lr_linear_proj_mult = 0.1 | ||
ddetr_lr_param = False | ||
batch_size = 1 | ||
weight_decay = 0.0001 | ||
epochs = 12 | ||
lr_drop = 11 | ||
save_checkpoint_interval = 1 | ||
clip_max_norm = 0.1 | ||
onecyclelr = False | ||
multi_step_lr = False | ||
lr_drop_list = [33, 45] | ||
|
||
|
||
modelname = "dino" | ||
frozen_weights = None | ||
backbone = "resnet50" | ||
use_checkpoint = False | ||
|
||
dilation = False | ||
position_embedding = "sine" | ||
pe_temperatureH = 20 | ||
pe_temperatureW = 20 | ||
return_interm_indices = [0, 1, 2, 3] | ||
backbone_freeze_keywords = None | ||
enc_layers = 6 | ||
dec_layers = 6 | ||
unic_layers = 0 | ||
pre_norm = False | ||
dim_feedforward = 2048 | ||
hidden_dim = 256 | ||
dropout = 0.0 | ||
nheads = 8 | ||
num_queries = 900 | ||
query_dim = 4 | ||
num_patterns = 0 | ||
pdetr3_bbox_embed_diff_each_layer = False | ||
pdetr3_refHW = -1 | ||
random_refpoints_xy = False | ||
fix_refpoints_hw = -1 | ||
dabdetr_yolo_like_anchor_update = False | ||
dabdetr_deformable_encoder = False | ||
dabdetr_deformable_decoder = False | ||
use_deformable_box_attn = False | ||
box_attn_type = "roi_align" | ||
dec_layer_number = None | ||
num_feature_levels = 5 | ||
enc_n_points = 4 | ||
dec_n_points = 4 | ||
decoder_layer_noise = False | ||
dln_xy_noise = 0.2 | ||
dln_hw_noise = 0.2 | ||
add_channel_attention = False | ||
add_pos_value = False | ||
two_stage_type = "standard" | ||
two_stage_pat_embed = 0 | ||
two_stage_add_query_num = 0 | ||
two_stage_bbox_embed_share = False | ||
two_stage_class_embed_share = False | ||
two_stage_learn_wh = False | ||
two_stage_default_hw = 0.05 | ||
two_stage_keep_all_tokens = False | ||
num_select = 300 | ||
transformer_activation = "relu" | ||
batch_norm_type = "FrozenBatchNorm2d" | ||
masks = False | ||
aux_loss = True | ||
set_cost_class = 2.0 | ||
set_cost_bbox = 5.0 | ||
set_cost_giou = 2.0 | ||
cls_loss_coef = 1.0 | ||
mask_loss_coef = 1.0 | ||
dice_loss_coef = 1.0 | ||
bbox_loss_coef = 5.0 | ||
giou_loss_coef = 2.0 | ||
enc_loss_coef = 1.0 | ||
interm_loss_coef = 1.0 | ||
no_interm_box_loss = False | ||
focal_alpha = 0.25 | ||
|
||
decoder_sa_type = "sa" # ['sa', 'ca_label', 'ca_content'] | ||
matcher_type = "HungarianMatcher" # or SimpleMinsumMatcher | ||
decoder_module_seq = ["sa", "ca", "ffn"] | ||
nms_iou_threshold = -1 | ||
|
||
dec_pred_bbox_embed_share = True | ||
dec_pred_class_embed_share = True | ||
|
||
# for dn | ||
use_dn = True | ||
dn_number = 100 | ||
dn_box_noise_scale = 0.4 | ||
dn_label_noise_ratio = 0.5 | ||
embed_init_tgt = True | ||
dn_labelbook_size = 91 | ||
|
||
match_unstable_error = True | ||
|
||
# for ema | ||
use_ema = False | ||
ema_decay = 0.9997 | ||
ema_epoch = 0 | ||
|
||
use_detached_boxes_dec_out = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | ||
data_aug_max_size = 1333 | ||
data_aug_scales2_resize = [400, 500, 600] | ||
data_aug_scales2_crop = [384, 600] | ||
|
||
|
||
data_aug_scale_overlap = None |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.