This repository provides the official PyTorch implementation for SOT (The Self-Optimal-Transport), as described in the paper The Self-Optimal-Transport Feature Transform.
The Self-Optimal-Transport (SOT) feature transform is designed to upgrade the set of features of a data instance to facilitate downstream matching or grouping related tasks.
The transformed set encodes a rich representation of high order relations between the instance features. Distances between transformed features capture their direct original similarity and their third party 'agreement' regarding similarity to other features in the set.
A particular min-cost-max-flow fractional matching problem, whose entropy regularized version can be approximated by an optimal transport (OT) optimization, results in our transductive transform which is efficient, differentiable, equivariant, parameterless and probabilistically interpretable.
Dataset | Method | 5-Way 1-Shot | 5-Way 5-Shot |
---|---|---|---|
MiniImagenet | PTMAP-SOTp | 83.19 | 89.56 |
PTMAP-SOTt | 84.18 | 90.51 | |
PTMAP-SF-SOT | 85.59 | 91.34 | |
CIFAR-FS | PTMAP-SOTp | 87.37 | 91.12 |
PTMAP-SF-SOT | 89.94 | 92.83 | |
CUB | PTMAP-SOTp | 91.90 | 94.63 |
PTMAP-SF-SOT | 95.80 | 97.12 |
We provide the code for training and evaluating PT-MAP and ProtoNet with and without SOT. Note that the results from the paper are not reproducible here. To fully reproduce the results, use the SOT as shown here, in the original repositories.
Find instructions on how to download the datasets under the datasets dir.
You can choose between ProtoNet/PT-MAP and their SOT variations.
To train ProtoNet with SOT on miniimagenet, run:
python train.py --data_path <./datasets/miniimagenet/> --backbone WRN --method proto_sot --ot_reg 0.1 --max_epochs 200 --train_way 5 --scheduler step --step_size 40 --lr 0.0002 --augment false
We also support logging results into the cloud using the Wandb logger (highly suggested).
First, install it via:
pip install wandb
Then, set the following arguments:
--wandb true --project <project_name> --entity <wandb_entity>
For WRN-12, use checkpoints given by Manifold Mixup repository. For Resnet-12, use the checkpoints as in FEAT.
Dowload the weights according to the backbone you want and set:
--backbone <model name> --pretrained_path <./path>
Run the same you used for training with:
--eval true --pretrained_path <./path> --backbone <backbone_name>
You can choose the number of episodes by modify
--test_episodes
@article{shalam2022self,
title={The Self-Optimal-Transport Feature Transform},
author={Shalam, Daniel and Korman, Simon},
journal={arXiv preprint arXiv:2204.03065},
year={2022}
}
Leveraging the Feature Distribution in Transfer-based Few-Shot Learning
S2M2 Charting the Right Manifold: Manifold Mixup for Few-shot Learning
Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions