Skip to content

Official implementation of "Domain Generalizable Multiple Domain Clustering", TMLR 2024

License

Notifications You must be signed in to change notification settings

AmitRozner/domain-generalizable-multiple-domain-clustering

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Domain-Generalizable Multiple-Domain Clustering

Link to our paper: https://arxiv.org/abs/2301.13530

Installation

Please refer to requirement.txt for all required packages. Tested using Python 3.8

Then, clone this repo

git clone 
cd DomainGMDC

Download pre-trained model from: AdaIN-Style-Transfer-PyTorch

Put the weights in ./adain_weights folder

Data

Prepare datasets of interest as described in dataset.md.

Training

Example training command for OfficeHome dataset with RealWorld, Clipart, and Product domains:

 python tools/run_end_to_end.py --domain_names Product_Clipart_RealWorld --seed 0 --embedding_batch_size 512 --domain_loss_weight 0.02 --dist-url tcp://localhost:10026 --keep_strong_heads --multi_q --balance_moco_domains --data_type officehome --num_cluster 65 --center_based_truncate --wandb_run_name <enter_run_name>  --epochs 500 --data ./datasets/ --root_save_folder ./results/ --use_wandb --arch resnet18 --soft_balance --domain_size_layers 2048 1024 512 256 128 --train_self_batch_size 256 --batch_size 8 --style_transfer --heads2keep 5 --p_bcd_augment 0.2 --self_smoothing 0.9  --pred_based_smoothing  --moco_p_bcd_augment 0.8 

The above should run both pre-training and training in one run.

Evaluation

An example on Officehome dataset:

 python tools/evaluate.py --dir_and_regex '<Root_path>/spice/results/officehome/*Art_Clipart_Product*'

The above will find all runs in "officehome" folder which were trained using "Art" "Clipart" and "Product" domains. It will automatically infer the remaining domain and perform evaluation on it.

Implemented datasets and domains are the same as in our paper:

"pacs": ["cartoon", "photo", "artpainting", "sketch"]

"officehome": ["RealWorld", "Clipart", "Product", "Art"]

"office31": ["amazon", "dslr", "webcam"]

"DomainNet": ["clipart", "infograph", "quickdraw", "painting", "real", "sketch"]

Result should include accuracy in cases it is possible to compute, otherwise it will print a list of the not predicted clusters.

Acknowledgement for reference repos

About

Official implementation of "Domain Generalizable Multiple Domain Clustering", TMLR 2024

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages