This repository provides the official PyTorch implementation for BPA (former SOT) (The Balanced-Pairwise-Affinities), as described in the paper The Balanced-Pairwise-Affinities Feature Transform (Accepted by ICML 2024).
The Balanced-Pairwise-Affinities (BPA) feature transform is designed to upgrade the features of a set of input items to facilitate downstream matching or grouping related tasks.
The transformed set encodes a rich representation of high order relations between the instance features. Distances between transformed features capture their direct original similarity, and their third party 'agreement' regarding similarity to other features in the set.
Dataset | Method | 5-Way 1-Shot | 5-Way 5-Shot |
---|---|---|---|
MiniImagenet | PTMAP-BPAp | 83.19 | 89.56 |
PTMAP-BPAt | 84.18 | 90.51 | |
PTMAP-SF-BPA | 85.59 | 91.34 | |
CIFAR-FS | PTMAP-BPAp | 87.37 | 91.12 |
PTMAP-SF-BPA | 89.94 | 92.83 | |
CUB | PTMAP-BPAp | 91.90 | 94.63 |
PTMAP-SF-BPA | 95.80 | 97.12 |
BPA flexibility allow to improve your set representation with only 2 additional lines of code!
import torch
from bpa import BPA
x = torch.randn(100, 128) # x is of shape [n_samples, dim]
x = BPA()(x)
# after BPA, x shape is [n_samples, n_samples]
We provide the code for training and evaluating PT-MAP and ProtoNet with and without BPA. Note that the results from the paper are not reproducible here. To fully reproduce the results, use the BPA as shown here, in the original repositories.
Find instructions on how to download the datasets under the datasets dir.
For now, you can choose between ProtoNet/PT-MAP including their BPA variations.
For example, to train our ProtoNet+BPA version on the MiniImagenet dataset using WideResnet as a backbone, run:
python train.py --data_path <yourdatasetpath/miniimagenet/> --backbone WRN --method proto_bpa --ot_reg 0.1 --max_epochs 200 --train_way 5 --scheduler step --step_size 40 --lr 0.0002 --augment false
We also support logging results into the cloud using the Wandb logger (highly suggested).
First, install it via:
pip install wandb
Then, set the following arguments:
--wandb true --project <project_name> --entity <wandb_entity>
For WRN-28, we use the pretrained checkpoint from Manifold Mixup repository.
For Resnet-12, we use the pretrained checkpoint from FEAT.
Dowload the weights according to the backbone you want and set:
--backbone <model name> --pretrained_path <./path>
Run the same you used for training with:
--eval true --pretrained_path <./path> --backbone <backbone_name>
You can choose the number of episodes by modify
--test_episodes
@article{shalam2024balanced,
title={The Balanced-Pairwise-Affinities Feature Transform},
author={Shalam, Daniel and Korman, Simon},
journal={arXiv preprint arXiv:2407.01467},
year={2024}
}
Leveraging the Feature Distribution in Transfer-based Few-Shot Learning
S2M2 Charting the Right Manifold: Manifold Mixup for Few-shot Learning
Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions