Skip to content

muhammadsmalik/SOT

 
 

Repository files navigation

SOT: The Self-Optimal-Transport Feature Transform

This repository provides the official PyTorch implementation for SOT (The Self-Optimal-Transport), as described in the paper The Self-Optimal-Transport Feature Transform.

SOT

The Self-Optimal-Transport (SOT) feature transform is designed to upgrade the set of features of a data instance to facilitate downstream matching or grouping related tasks.

The transformed set encodes a rich representation of high order relations between the instance features. Distances between transformed features capture their direct original similarity and their third party 'agreement' regarding similarity to other features in the set.

A particular min-cost-max-flow fractional matching problem, whose entropy regularized version can be approximated by an optimal transport (OT) optimization, results in our transductive transform which is efficient, differentiable, equivariant, parameterless and probabilistically interpretable.

PWC

PWC

PWC

PWC

PWC

PWC

Few-Shot Classification Results

Dataset Method 5-Way 1-Shot 5-Way 5-Shot
MiniImagenet PTMAP-SOTp 83.19 89.56
PTMAP-SOTt 84.18 90.51
PTMAP-SF-SOT 85.59 91.34
CIFAR-FS PTMAP-SOTp 87.37 91.12
PTMAP-SF-SOT 89.94 92.83
CUB PTMAP-SOTp 91.90 94.63
PTMAP-SF-SOT 95.80 97.12

Running instructions

We provide the code for training and evaluating PT-MAP and ProtoNet with and without SOT. Note that the results from the paper are not reproducible here. To fully reproduce the results, use the SOT as shown here, in the original repositories.

Find instructions on how to download the datasets under the datasets dir.

Training

You can choose between ProtoNet/PT-MAP and their SOT variations.

To train ProtoNet with SOT on miniimagenet, run:

python train.py --data_path <./datasets/miniimagenet/> --backbone WRN --method proto_sot --ot_reg 0.1 --max_epochs 200 --train_way 5 --scheduler step --step_size 40 --lr 0.0002  --augment false

We also support logging results into the cloud using the Wandb logger (highly suggested).

First, install it via:

pip install wandb

Then, set the following arguments:

--wandb true --project <project_name> --entity <wandb_entity>

Fine-tuning

For WRN-12, use checkpoints given by Manifold Mixup repository. For Resnet-12, use the checkpoints as in FEAT.

Dowload the weights according to the backbone you want and set:

--backbone <model name> --pretrained_path <./path>

Evaluation

Run the same you used for training with:

--eval true --pretrained_path <./path> --backbone <backbone_name>

You can choose the number of episodes by modify

--test_episodes

Citation

If you find this repository useful in your research, please cite:

@article{shalam2022self,
  title={The Self-Optimal-Transport Feature Transform},
  author={Shalam, Daniel and Korman, Simon},
  journal={arXiv preprint arXiv:2204.03065},
  year={2022}
}

Acknowledgment

Leveraging the Feature Distribution in Transfer-based Few-Shot Learning

S2M2 Charting the Right Manifold: Manifold Mixup for Few-shot Learning

Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 85.4%
  • PowerShell 7.6%
  • Roff 5.8%
  • Batchfile 1.2%