Skip to content

Commit

Permalink
Completed integration. Tests are left to make sure everything runs pr…
Browse files Browse the repository at this point in the history
…operly
  • Loading branch information
Jean-KOUAGOU committed Dec 19, 2024
1 parent a10fae8 commit 4a2c847
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 241 deletions.
11 changes: 6 additions & 5 deletions examples/train_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
(2) pip install ontolearn
"""


from ontolearn.concept_learner import NCES, NCES2, ROCES
import argparse
import json, os

from ontolearn.concept_learner import NCES, NCES2, ROCES
from transformers import set_seed

def str2bool(v):
if isinstance(v, bool):
Expand Down Expand Up @@ -54,24 +53,26 @@ def start(args):
print("Could not find training data. Will generate some data and train.")
training_data = NCES2.generate_training_data(knowledge_base_path, beyond_alc=True)
if args.synthesizer == "NCES":
synthesizer = NCES(knowledge_base_path=knowledge_base_path, learner_names=args.models, path_of_embeddings=path_of_embeddings,
synthesizer = NCES(knowledge_base_path=knowledge_base_path, learner_names=args.models, path_of_embeddings=path_of_embeddings, path_of_trained_models=args.path_of_trained_models,
max_length=48, proj_dim=128, rnn_n_layers=2, drop_prob=0.1, num_heads=4, num_seeds=1, m=32, load_pretrained=args.load_pretrained, verbose=True)
elif args.synthesizer == "NCES2":
synthesizer = NCES2(knowledge_base_path=knowledge_base_path, path_of_trained_models=args.path_of_trained_models, nces2_or_roces=True, max_length=48, proj_dim=128,
drop_prob=0.1, num_heads=4, num_seeds=1, m=32, verbose=True, load_pretrained=args.load_pretrained)
else:
synthesizer = ROCES(knowledge_base_path=knowledge_base_path, path_of_trained_models=args.path_of_trained_models, nces2_or_roces=True, k=5, max_length=48, proj_dim=128,
drop_prob=0.1, num_heads=4, num_seeds=1, m=32, load_pretrained=args.load_pretrained, verbose=True)
synthesizer.train(training_data, epochs=args.epochs, learning_rate=args.learning_rate, num_workers=2, save_model=True)
synthesizer.train(training_data, epochs=args.epochs, learning_rate=args.learning_rate, num_workers=2, save_model=True, storage_path=args.storage_dir)


if __name__ == '__main__':
set_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--kbs', type=str, nargs='+', default=None, help='Paths of knowledge bases (OWL files)')
parser.add_argument('--embeddings', type=str, nargs='+', default=None, help='Paths of embeddings for each KB.')
parser.add_argument('--synthesizer', type=str, default="NCES", help='Neural synthesizer to train')
parser.add_argument('--path_train_data', type=str, help='Path to training data')
parser.add_argument('--path_of_trained_models', type=str, default=None, help='Path to training data')
parser.add_argument('--storage_dir', type=str, default=None, help='Path to training data')
parser.add_argument('--models', type=str, nargs='+', default=['SetTransformer', 'LSTM', 'GRU'],
help='Neural models')
parser.add_argument('--load_pretrained', type=str2bool, default=False, help='Whether to load the pretrained model')
Expand Down
15 changes: 10 additions & 5 deletions ontolearn/base_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, knowledge_base_path, nces2_or_roces, quality_func, num_predic
concrete_role_names = [rel.iri.get_remainder() for rel in kb.ontology.data_properties_in_signature()]
vocab.extend(concrete_role_names)
vocab.extend(['⁻', '≤', '≥', 'True', 'False', 'true', 'false', '{', '}', ':', '[', ']', 'double', 'integer', 'date', 'xsd'])
vocab = sorted(vocab) + ['PAD']
vocab = sorted(set(vocab)) + ['PAD']
self.knowledge_base_path = knowledge_base_path
self.kb = kb
self.all_individuals = set([ind.str.split("/")[-1] for ind in kb.individuals()])
Expand Down Expand Up @@ -85,17 +85,22 @@ def add_data_values(self, data):
print("\nUpdating vocabulary based on training data...\n")
quantified_restriction_values = [str(i) for i in range(1,12)]
vocab = list(self.vocab.keys())
vocab.extend(quantified_restriction_values)
vocab_set = set(vocab)
len_before_update = len(vocab_set)
vocab_set.update(set(quantified_restriction_values))
values = set()
for ce, examples in data:
if '[' in ce:
for val in re.findall("\[(.*?)\]", ce):
values.add(val.split(' ')[-1])
vocab.extend(list(values))
vocab = sorted(vocab)
vocab_set.update(values)
vocab = sorted(vocab_set)
self.inv_vocab = np.array(vocab, dtype='object')
self.vocab = {vocab[i]: i for i in range(len(vocab))}
print("Done.\n")
if len_before_update < len(vocab):
print("Done.\n")
else:
print("No update necessary!\n")


def collate_batch_inference(self, batch): # pragma: no cover
Expand Down
Loading

0 comments on commit 4a2c847

Please sign in to comment.