From 4e8cd746545b0a0b2a3109763efc93befddb27df Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Mon, 20 Nov 2023 16:27:07 +0100 Subject: [PATCH] WIP: ASWA --- dicee/callbacks.py | 62 ++++++++++++++++++--------------- dicee/run.py | 20 ++++++----- dicee/trainer/dice_trainer.py | 4 +-- tests/test_regression_polyak.py | 6 ++-- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/dicee/callbacks.py b/dicee/callbacks.py index cdb66ad7..aa28979a 100644 --- a/dicee/callbacks.py +++ b/dicee/callbacks.py @@ -191,7 +191,7 @@ def on_train_epoch_end(self, trainer, model) -> None: self.store_ensemble(ensemble_state_dict) -class ASWE(AbstractPPECallback): +class ASWA(AbstractPPECallback): """ Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. """ @@ -201,18 +201,10 @@ def __init__(self, num_epochs, path): self.reports_of_running_model = [] self.reports_of_ensemble_model = [] self.last_mrr_ensemble = None - self.fit = "LR" self.initial_eval_setting = None - # Although it can be negative - self.max_epoch_r2score = float("-inf") - self.r2scores = [] self.entered_good_regions = None - self.start_avg = False - if self.fit is None: - # Weights for models participating the ensemble - self.alphas = np.ones(self.num_ensemble_coefficient) / self.num_ensemble_coefficient - else: - self.alphas = None + self.alphas = None + self.val_aswa=-1 def on_fit_end(self, trainer, model): """ @@ -317,7 +309,7 @@ def store_ensemble_model(self, ensemble_state_dict, mrr_updated_ensemble_model) torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt") else: if mrr_updated_ensemble_model > self.last_mrr_ensemble: - print(f"Ensemble model is updated: Current MRR: {mrr_updated_ensemble_model}") + print(f"ASWA ensemble updated: Current MRR: {mrr_updated_ensemble_model}") self.last_mrr_ensemble = mrr_updated_ensemble_model self.sample_counter += 1 torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt") @@ -332,22 +324,36 @@ def on_train_epoch_end(self, trainer, model): self.initial_eval_setting = trainer.evaluator.args.eval_model trainer.evaluator.args.eval_model = "val" - # (3) Does the validation performance of running model still increase? - if self.is_mrr_increasing(trainer, model): - return True - - ensemble_state_dict = self.initialize_or_load_ensemble(model) - # Update - self.inplace_update_parameter_ensemble(ensemble_state_dict, model) - # Evaluate - ensemble = type(model)(model.args) - ensemble.load_state_dict(ensemble_state_dict) - mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble, - form_of_labelling=trainer.form_of_labelling, - during_training=True)["Val"]["MRR"] - # Store - self.store_ensemble_model(ensemble_state_dict, mrr_updated_ensemble_model=mrr_updated_ensemble_model) - + val_running_model = self.__compute_mrr(trainer, model) + if val_running_model > self.val_aswa: + # hard update + self.val_aswa= val_running_model + torch.save(model.state_dict(), f=f"{self.path}/trainer_checkpoint_main.pt") + self.sample_counter = 1 + print(f"Hard Update: MRR: {self.val_aswa}") + else: + # Load ensemble + ensemble_state_dict = torch.load(f"{self.path}/trainer_checkpoint_main.pt", torch.device(model.device)) + # Perform provision parameter update. + with torch.no_grad(): + for k, parameters in model.state_dict().items(): + if parameters.dtype == torch.float: + ensemble_state_dict[k] = (ensemble_state_dict[k] * self.sample_counter + parameters) / (1 + self.sample_counter) + # Evaluate + ensemble = type(model)(model.args) + ensemble.load_state_dict(ensemble_state_dict) + mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble, + form_of_labelling=trainer.form_of_labelling, + during_training=True)["Val"]["MRR"] + + if mrr_updated_ensemble_model > self.val_aswa: + + self.val_aswa = mrr_updated_ensemble_model + torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt") + self.sample_counter += 1 + print(f" Soft Update: MRR: {self.val_aswa} | |ASWA|:{self.sample_counter}") + else: + print(" No update") class FPPE(AbstractPPECallback): """ diff --git a/dicee/run.py b/dicee/run.py index c3f5462d..92afc639 100755 --- a/dicee/run.py +++ b/dicee/run.py @@ -8,7 +8,7 @@ def get_default_arguments(description=None): parser = pl.Trainer.add_argparse_args(argparse.ArgumentParser(add_help=False)) # Default Trainer param https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#methods # Knowledge graph related arguments - parser.add_argument("--dataset_dir", type=str, default=None, + parser.add_argument("--dataset_dir", type=str, default="KGs/UMLS", help="The path of a folder containing train.txt, and/or valid.txt and/or test.txt" ",e.g., KGs/UMLS") parser.add_argument("--sparql_endpoint", type=str, default=None, @@ -28,9 +28,9 @@ def get_default_arguments(description=None): help='Backend for loading, preprocessing, indexing input knowledge graph.') # Model related arguments parser.add_argument("--model", type=str, - default="DistMult", - choices=["ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult", - "OMult", "Shallom", "DistMult", "TransE", "ComplEx", "Keci", + default="Keci", + choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult", + "OMult", "Shallom", "DistMult", "TransE", "Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP", "Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE", "Pykeen_TransE", "Pykeen_TransF", "Pykeen_TransH", @@ -41,17 +41,17 @@ def get_default_arguments(description=None): parser.add_argument('--optim', type=str, default='Adam', help='An optimizer', choices=['Adam', 'SGD']) - parser.add_argument('--embedding_dim', type=int, default=2, + parser.add_argument('--embedding_dim', type=int, default=32, help='Number of dimensions for an embedding vector. ') - parser.add_argument("--num_epochs", type=int, default=100, help='Number of epochs for training. ') - parser.add_argument('--batch_size', type=int, default=1096, + parser.add_argument("--num_epochs", type=int, default=200, help='Number of epochs for training. ') + parser.add_argument('--batch_size', type=int, default=1024, help='Mini batch size. If None, automatic batch finder is applied') parser.add_argument("--lr", type=float, default=0.1) parser.add_argument('--callbacks', type=json.loads, default={}, help='{"PPE":{ "last_percent_to_consider": 10}}' '"Perturb": {"level": "out", "ratio": 0.2, "method": "RN", "scaler": 0.3}') - parser.add_argument("--trainer", type=str, default='torchCPUTrainer', + parser.add_argument("--trainer", type=str, default='PL', choices=['torchCPUTrainer', 'PL', 'torchDDP'], help='PL (pytorch lightning trainer), torchDDP (custom ddp), torchCPUTrainer (custom cpu only)') parser.add_argument('--scoring_technique', default="KvsAll", @@ -102,7 +102,9 @@ def get_default_arguments(description=None): parser.add_argument("--byte_pair_encoding", action="store_true", help="Currently only avail. for KGE implemented within dice-embeddings.") - parser.add_argument("--adaptive_swa", action="store_true", help="Adaptive stochastic weight averaging") + parser.add_argument("--adaptive_swa", + action="store_true", + help="Adaptive stochastic weight averaging") if description is None: return parser.parse_args() return parser.parse_args(description) diff --git a/dicee/trainer/dice_trainer.py b/dicee/trainer/dice_trainer.py index 6c32fa10..7855e7b1 100644 --- a/dicee/trainer/dice_trainer.py +++ b/dicee/trainer/dice_trainer.py @@ -3,7 +3,7 @@ from typing import Union from dicee.models.base_model import BaseKGE from dicee.static_funcs import select_model -from dicee.callbacks import (ASWE, PPE, FPPE, Eval, KronE, PrintCallback, KGESaveCallback, AccumulateEpochLossCallback, +from dicee.callbacks import (ASWA, PPE, FPPE, Eval, KronE, PrintCallback, KGESaveCallback, AccumulateEpochLossCallback, Perturb) from dicee.dataset_classes import construct_dataset, reload_dataset from .torch_trainer import TorchTrainer @@ -48,7 +48,7 @@ def get_callbacks(args): AccumulateEpochLossCallback(path=args.full_storage_path) ] if args.adaptive_swa: - callbacks.append(ASWE(num_epochs=args.num_epochs, path=args.full_storage_path)) + callbacks.append(ASWA(num_epochs=args.num_epochs, path=args.full_storage_path)) if isinstance(args.callbacks, list): return callbacks diff --git a/tests/test_regression_polyak.py b/tests/test_regression_polyak.py index 841b7115..f5739896 100644 --- a/tests/test_regression_polyak.py +++ b/tests/test_regression_polyak.py @@ -66,9 +66,9 @@ def test_ppe_keci_k_vs_all(self): args.batch_size = 1024 args.adaptive_swa = True adaptive_swa_report = Execute(args).start() - assert adaptive_swa_report["Train"]["MRR"]>=0.987 - assert adaptive_swa_report["Val"]["MRR"] >=0.872 - assert adaptive_swa_report["Test"]["MRR"] >= 0.872 + assert adaptive_swa_report["Train"]["MRR"]>=0.983 + assert adaptive_swa_report["Val"]["MRR"] >=0.861 + assert adaptive_swa_report["Test"]["MRR"] >= 0.875 assert adaptive_swa_report["Test"]["MRR"]>ppe_reports["Test"]["MRR"] @pytest.mark.filterwarnings('ignore::UserWarning')