Skip to content

Commit

Permalink
WIP: ASWA
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 20, 2023
1 parent 8666a1c commit 4e8cd74
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
62 changes: 34 additions & 28 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def on_train_epoch_end(self, trainer, model) -> None:
self.store_ensemble(ensemble_state_dict)


class ASWE(AbstractPPECallback):
class ASWA(AbstractPPECallback):
""" Adaptive stochastic weight averaging
ASWE keeps track of the validation performance and update s the ensemble model accordingly.
"""
Expand All @@ -201,18 +201,10 @@ def __init__(self, num_epochs, path):
self.reports_of_running_model = []
self.reports_of_ensemble_model = []
self.last_mrr_ensemble = None
self.fit = "LR"
self.initial_eval_setting = None
# Although it can be negative
self.max_epoch_r2score = float("-inf")
self.r2scores = []
self.entered_good_regions = None
self.start_avg = False
if self.fit is None:
# Weights for models participating the ensemble
self.alphas = np.ones(self.num_ensemble_coefficient) / self.num_ensemble_coefficient
else:
self.alphas = None
self.alphas = None
self.val_aswa=-1

def on_fit_end(self, trainer, model):
"""
Expand Down Expand Up @@ -317,7 +309,7 @@ def store_ensemble_model(self, ensemble_state_dict, mrr_updated_ensemble_model)
torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt")
else:
if mrr_updated_ensemble_model > self.last_mrr_ensemble:
print(f"Ensemble model is updated: Current MRR: {mrr_updated_ensemble_model}")
print(f"ASWA ensemble updated: Current MRR: {mrr_updated_ensemble_model}")
self.last_mrr_ensemble = mrr_updated_ensemble_model
self.sample_counter += 1
torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt")
Expand All @@ -332,22 +324,36 @@ def on_train_epoch_end(self, trainer, model):
self.initial_eval_setting = trainer.evaluator.args.eval_model
trainer.evaluator.args.eval_model = "val"

# (3) Does the validation performance of running model still increase?
if self.is_mrr_increasing(trainer, model):
return True

ensemble_state_dict = self.initialize_or_load_ensemble(model)
# Update
self.inplace_update_parameter_ensemble(ensemble_state_dict, model)
# Evaluate
ensemble = type(model)(model.args)
ensemble.load_state_dict(ensemble_state_dict)
mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]
# Store
self.store_ensemble_model(ensemble_state_dict, mrr_updated_ensemble_model=mrr_updated_ensemble_model)

val_running_model = self.__compute_mrr(trainer, model)
if val_running_model > self.val_aswa:
# hard update
self.val_aswa= val_running_model
torch.save(model.state_dict(), f=f"{self.path}/trainer_checkpoint_main.pt")
self.sample_counter = 1
print(f"Hard Update: MRR: {self.val_aswa}")
else:
# Load ensemble
ensemble_state_dict = torch.load(f"{self.path}/trainer_checkpoint_main.pt", torch.device(model.device))
# Perform provision parameter update.
with torch.no_grad():
for k, parameters in model.state_dict().items():
if parameters.dtype == torch.float:
ensemble_state_dict[k] = (ensemble_state_dict[k] * self.sample_counter + parameters) / (1 + self.sample_counter)
# Evaluate
ensemble = type(model)(model.args)
ensemble.load_state_dict(ensemble_state_dict)
mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]

if mrr_updated_ensemble_model > self.val_aswa:

self.val_aswa = mrr_updated_ensemble_model
torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt")
self.sample_counter += 1
print(f" Soft Update: MRR: {self.val_aswa} | |ASWA|:{self.sample_counter}")
else:
print(" No update")

class FPPE(AbstractPPECallback):
"""
Expand Down
20 changes: 11 additions & 9 deletions dicee/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_default_arguments(description=None):
parser = pl.Trainer.add_argparse_args(argparse.ArgumentParser(add_help=False))
# Default Trainer param https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#methods
# Knowledge graph related arguments
parser.add_argument("--dataset_dir", type=str, default=None,
parser.add_argument("--dataset_dir", type=str, default="KGs/UMLS",
help="The path of a folder containing train.txt, and/or valid.txt and/or test.txt"
",e.g., KGs/UMLS")
parser.add_argument("--sparql_endpoint", type=str, default=None,
Expand All @@ -28,9 +28,9 @@ def get_default_arguments(description=None):
help='Backend for loading, preprocessing, indexing input knowledge graph.')
# Model related arguments
parser.add_argument("--model", type=str,
default="DistMult",
choices=["ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult",
"OMult", "Shallom", "DistMult", "TransE", "ComplEx", "Keci",
default="Keci",
choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult",
"OMult", "Shallom", "DistMult", "TransE",
"Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP",
"Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE",
"Pykeen_TransE", "Pykeen_TransF", "Pykeen_TransH",
Expand All @@ -41,17 +41,17 @@ def get_default_arguments(description=None):
parser.add_argument('--optim', type=str, default='Adam',
help='An optimizer',
choices=['Adam', 'SGD'])
parser.add_argument('--embedding_dim', type=int, default=2,
parser.add_argument('--embedding_dim', type=int, default=32,
help='Number of dimensions for an embedding vector. ')
parser.add_argument("--num_epochs", type=int, default=100, help='Number of epochs for training. ')
parser.add_argument('--batch_size', type=int, default=1096,
parser.add_argument("--num_epochs", type=int, default=200, help='Number of epochs for training. ')
parser.add_argument('--batch_size', type=int, default=1024,
help='Mini batch size. If None, automatic batch finder is applied')
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument('--callbacks', type=json.loads,
default={},
help='{"PPE":{ "last_percent_to_consider": 10}}'
'"Perturb": {"level": "out", "ratio": 0.2, "method": "RN", "scaler": 0.3}')
parser.add_argument("--trainer", type=str, default='torchCPUTrainer',
parser.add_argument("--trainer", type=str, default='PL',
choices=['torchCPUTrainer', 'PL', 'torchDDP'],
help='PL (pytorch lightning trainer), torchDDP (custom ddp), torchCPUTrainer (custom cpu only)')
parser.add_argument('--scoring_technique', default="KvsAll",
Expand Down Expand Up @@ -102,7 +102,9 @@ def get_default_arguments(description=None):
parser.add_argument("--byte_pair_encoding",
action="store_true",
help="Currently only avail. for KGE implemented within dice-embeddings.")
parser.add_argument("--adaptive_swa", action="store_true", help="Adaptive stochastic weight averaging")
parser.add_argument("--adaptive_swa",
action="store_true",
help="Adaptive stochastic weight averaging")
if description is None:
return parser.parse_args()
return parser.parse_args(description)
Expand Down
4 changes: 2 additions & 2 deletions dicee/trainer/dice_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union
from dicee.models.base_model import BaseKGE
from dicee.static_funcs import select_model
from dicee.callbacks import (ASWE, PPE, FPPE, Eval, KronE, PrintCallback, KGESaveCallback, AccumulateEpochLossCallback,
from dicee.callbacks import (ASWA, PPE, FPPE, Eval, KronE, PrintCallback, KGESaveCallback, AccumulateEpochLossCallback,
Perturb)
from dicee.dataset_classes import construct_dataset, reload_dataset
from .torch_trainer import TorchTrainer
Expand Down Expand Up @@ -48,7 +48,7 @@ def get_callbacks(args):
AccumulateEpochLossCallback(path=args.full_storage_path)
]
if args.adaptive_swa:
callbacks.append(ASWE(num_epochs=args.num_epochs, path=args.full_storage_path))
callbacks.append(ASWA(num_epochs=args.num_epochs, path=args.full_storage_path))

if isinstance(args.callbacks, list):
return callbacks
Expand Down
6 changes: 3 additions & 3 deletions tests/test_regression_polyak.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def test_ppe_keci_k_vs_all(self):
args.batch_size = 1024
args.adaptive_swa = True
adaptive_swa_report = Execute(args).start()
assert adaptive_swa_report["Train"]["MRR"]>=0.987
assert adaptive_swa_report["Val"]["MRR"] >=0.872
assert adaptive_swa_report["Test"]["MRR"] >= 0.872
assert adaptive_swa_report["Train"]["MRR"]>=0.983
assert adaptive_swa_report["Val"]["MRR"] >=0.861
assert adaptive_swa_report["Test"]["MRR"] >= 0.875
assert adaptive_swa_report["Test"]["MRR"]>ppe_reports["Test"]["MRR"]

@pytest.mark.filterwarnings('ignore::UserWarning')
Expand Down

0 comments on commit 4e8cd74

Please sign in to comment.